In [8]:
import torch
from tqdm import tqdm
import gc

from transformers import AutoTokenizer
from emb_vectors_functions import find_self_embeds, get_shadow_ratios
from model_loading import get_weight_by_name


In [5]:
model_name = "meta-llama/Llama-3.1-70B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
embeddings = get_weight_by_name(model_name, "head")
embeddings = embeddings.cuda()

Fetching 50 files:   0%|          | 0/50 [00:00<?, ?it/s]

In [6]:
# embeddings = torch.randn_like(embeddings)
embeddings.requires_grad = False

In [7]:
fail_indices, failed_res_emb, failed_pairs = find_self_embeds(embeddings, tokenizer)

100%|██████████| 13/13 [00:01<00:00,  6.94it/s]


In [10]:
torch.cuda.empty_cache()
gc.collect()

0

In [11]:
shadow_ratios = get_shadow_ratios(fail_indices, embeddings)
shadow_ratios_sorted = sorted(shadow_ratios, key=lambda x: x[1], reverse=True);

In [24]:
def calc_loss(x, self_emb, X, mask, epsilon = 1e-4):

    xself = torch.einsum('ij,ij->i', x, self_emb)
    xX = x@X.T

    xA = xself[:, None]-xX
    xA = xA*mask
    loss = torch.sum(torch.relu(-xA + epsilon))

    return loss

def calc_bad_embeds(x_optim, self_emb, embeddings, mask):

    xself = torch.einsum('ij,ij->i', x_optim, self_emb)
    xX = x_optim@embeddings.T
    xA = xself[:, None]-xX
    xA = xA + 1e10*(1-mask)
    is_good_embed = torch.all(xA>0, dim=1)
    bad_embeds = len(is_good_embed)-sum(is_good_embed).item()
    bad_embeds_ratio = bad_embeds/len(is_good_embed)

    return bad_embeds, bad_embeds_ratio

def train_vectors(n_lst, embeddings, n_steps=100, verbose=False):
    min_bad = len(n_lst)
    X = embeddings
    self_emb = X[n_lst]
    mask = torch.ones((len(n_lst), len(X)), requires_grad=False, device=X.device)
    indices = torch.arange(len(n_lst))
    mask[indices, n_lst] = 0

    x_optim = self_emb.detach().clone()
    x_optim.requires_grad = True
    optimizer = torch.optim.AdamW([x_optim], lr=0.01)

    with torch.no_grad():
        loss = calc_loss(x_optim, self_emb, X, mask)
        bad_embeds, bad_embeds_ratio = calc_bad_embeds(x_optim, self_emb, embeddings, mask)
        print(f"Initial\nloss = {loss.item()}")
        print(f"Bad embeds = {bad_embeds}/{len(n_lst)}, ratio = {bad_embeds_ratio}")

    pbar = tqdm(range(n_steps))
    for step in pbar:
        optimizer.zero_grad()

        loss = calc_loss(x_optim, self_emb, X, mask)
        loss.backward()
        optimizer.step()
        if (step+1) % 10 == 0:
            with torch.no_grad():
                bad_embeds, bad_embeds_ratio = calc_bad_embeds(x_optim, self_emb, embeddings, mask)
            pbar.set_postfix_str(f"Bad embeds: {bad_embeds}/{len(n_lst)}")
            if bad_embeds < min_bad:
                min_bad = bad_embeds
            if bad_embeds_ratio == 0.0:
                break
            if verbose:
                print(f"Step {step + 1}, Loss: {loss.item()}, bad embeds = {bad_embeds}/{len(n_lst)}, ratio = {bad_embeds_ratio}")

    with torch.no_grad():
        loss = calc_loss(x_optim, self_emb, X, mask)
        bad_embeds, bad_embeds_ratio = calc_bad_embeds(x_optim, self_emb, embeddings, mask)
    print("Final")
    print(f"steps = {step+1}, loss = {loss.item()}")
    print(f"Bad embeds = {bad_embeds}/{len(n_lst)}, ratio = {bad_embeds_ratio:.4f}, Minimal bad = {min_bad}")

    return loss, x_optim, self_emb, mask

In [25]:
n_lst = fail_indices

In [None]:
loss, x_optim, self_emb, mask = train_vectors(n_lst, embeddings, n_steps=100000)

Initial
loss = 19570.24609375
Bad embeds = 1051/1060, ratio = 0.9915094339622641


  0%|          | 230/100000 [00:09<1:12:00, 23.09it/s, Bad embeds: 321/1060]