In [27]:
import torch
from tqdm import tqdm
import numpy as np
import gc

from transformers import AutoTokenizer, AutoModelForCausalLM
from emb_vectors_functions import find_self_embeds, get_shadow_ratios
from model_loading import get_weight_by_name

In [103]:
# "meta-llama/Llama-3.1-70B"
# "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-7B-Instruct"
# "Qwen/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen2.5-72B-Instruct"
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
embeddings = get_weight_by_name(model_name, "head")
embeddings = embeddings.cuda()

Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

head was not found. Trying to load the full model.


In [104]:
# embeddings = torch.randn_like(embeddings)
embeddings.requires_grad = False

In [105]:
fail_indices, failed_res_emb, failed_pairs = find_self_embeds(embeddings, tokenizer)
print(f"Number of bad embeddings = {len(fail_indices)}")

100%|██████████| 16/16 [00:01<00:00, 12.63it/s]


Number of bad embeddings = 8913


In [106]:
# shadow_ratios = get_shadow_ratios(fail_indices, embeddings)
# shadow_ratios_sorted = sorted(shadow_ratios, key=lambda x: x[1], reverse=True);

In [107]:
def batch_list(lst, batch_size):
    return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]


def calc_loss(x, self_emb, X, mask, epsilon=1e-4):

    xself = torch.einsum("ij,ij->i", x, self_emb)
    xX = x @ X.T

    xA = xself[:, None] - xX
    xA = xA * mask
    loss = torch.sum(torch.relu(-xA + epsilon))

    return loss


def calc_bad_embeds(x_optim, self_emb, embeddings, mask):

    xself = torch.einsum("ij,ij->i", x_optim, self_emb)
    xX = x_optim @ embeddings.T
    xA = xself[:, None] - xX
    xA = xA + 1e10 * (1 - mask)
    is_good_embed = torch.all(xA > 0, dim=1)
    bad_embeds = len(is_good_embed) - sum(is_good_embed).item()
    bad_embeds_ratio = bad_embeds / len(is_good_embed)

    return bad_embeds, bad_embeds_ratio, np.array(is_good_embed.cpu())


def train_vectors(n_lst, embeddings, x_optim_start=None, n_steps=100, verbose=False, use_tqdm=True):
    min_bad = len(n_lst)
    X = embeddings
    self_emb = X[n_lst]
    mask = torch.ones((len(n_lst), len(X)), requires_grad=False, device=X.device)
    indices = torch.arange(len(n_lst))
    mask[indices, n_lst] = 0

    if x_optim_start is None:
        x_optim = self_emb.detach().clone()
    else:
        x_optim = x_optim_start.detach().clone()
    x_optim.requires_grad = True
    optimizer = torch.optim.AdamW([x_optim], lr=0.01)

    with torch.no_grad():
        loss = calc_loss(x_optim, self_emb, X, mask)
        bad_embeds, bad_embeds_ratio, is_good_embed = calc_bad_embeds(
            x_optim, self_emb, embeddings, mask
        )
        if verbose:
            print(f"Initial\nloss = {loss.item()}")
            print(f"Bad embeds = {bad_embeds}/{len(n_lst)}, ratio = {bad_embeds_ratio}")

    pbar = tqdm(range(n_steps), disable=not use_tqdm)
    for step in pbar:
        optimizer.zero_grad()

        loss = calc_loss(x_optim, self_emb, X, mask)
        loss.backward()
        optimizer.step()
        if (step + 1) % 10 == 0:
            with torch.no_grad():
                bad_embeds, bad_embeds_ratio, is_good_embed = calc_bad_embeds(
                    x_optim, self_emb, embeddings, mask
                )
            pbar.set_postfix_str(f"Bad embeds: {bad_embeds}/{len(n_lst)}")
            if bad_embeds < min_bad:
                min_bad = bad_embeds
            if bad_embeds_ratio == 0.0:
                break
            # if verbose:
            #     print(f"Step {step + 1}, Loss: {loss.item()}, bad embeds = {bad_embeds}/{len(n_lst)}, ratio = {bad_embeds_ratio}")

    with torch.no_grad():
        loss = calc_loss(x_optim, self_emb, X, mask)
        bad_embeds, bad_embeds_ratio, is_good_embed = calc_bad_embeds(
            x_optim, self_emb, embeddings, mask
        )
    if verbose:
        print("Final")
        print(f"steps = {step+1}, loss = {loss.item()}")
        print(
            f"Bad embeds = {bad_embeds}/{len(n_lst)}, ratio = {bad_embeds_ratio:.4f}, Minimal bad = {min_bad}"
        )

    torch.cuda.empty_cache()
    gc.collect()

    return loss, x_optim, self_emb, mask, is_good_embed

In [108]:
n_lst = fail_indices

In [109]:
# loss, x_optim, self_emb, mask, bad_indices = train_vectors(n_lst, embeddings, n_steps=100, verbose=True)

In [110]:
ind_batched_list = batch_list(n_lst, 1000)

In [111]:
print(f"Number of bad embeddings = {len(n_lst)}, dimension = {embeddings.shape[1]}, dict size = {embeddings.shape[0]}")

for ind_list_ in ind_batched_list:
    x_optim = None
    ind_list = ind_list_
    pbar = tqdm(range(100))
    for i in pbar:
        loss, x_optim, self_emb, mask, is_good_embed = train_vectors(
            ind_list,
            embeddings,
            x_optim_start=x_optim,
            n_steps=100,
            verbose=False,
            use_tqdm=False,
        )
        ind_list = np.array(ind_list)[~is_good_embed]
        x_optim = x_optim[~is_good_embed]
        pbar.set_postfix_str(f"Bad embeds: {len(ind_list)}/{len(ind_list_)}")

        if len(ind_list) == 0:
            break

Number of bad embeddings = 8913, dimension = 896, dict size = 151936


  6%|▌         | 6/100 [00:05<01:26,  1.08it/s, Bad embeds: 0/1000]  
 33%|███▎      | 33/100 [00:20<00:40,  1.64it/s, Bad embeds: 0/1000] 
 21%|██        | 21/100 [00:14<00:54,  1.45it/s, Bad embeds: 0/1000] 
 22%|██▏       | 22/100 [00:14<00:50,  1.53it/s, Bad embeds: 0/1000] 
 71%|███████   | 71/100 [00:30<00:12,  2.36it/s, Bad embeds: 0/1000] 
 11%|█         | 11/100 [00:07<00:58,  1.51it/s, Bad embeds: 0/1000]
 27%|██▋       | 27/100 [00:13<00:35,  2.03it/s, Bad embeds: 0/1000]
  4%|▍         | 4/100 [00:05<02:00,  1.25s/it, Bad embeds: 0/1000]  
 27%|██▋       | 27/100 [00:15<00:41,  1.75it/s, Bad embeds: 0/913] 


In [102]:
1+1

2