In [1]:
import torch

from emb_vectors_functions import (
    find_self_embeds,
    get_model_and_embed,
    get_shadow_ratios,
)

In [2]:
model_name = "meta-llama/Llama-3.2-1B"  # "meta-llama/Meta-Llama-3.1-8B"  # "meta-llama/Meta-Llama-3.1-8B" , "gpt2", "meta-llama/Llama-2-7b-hf"
model, embedding, head, model_norm, mean_norm, tokenizer = get_model_and_embed(
    model_name
)

del model, embedding
embeddings = head
embeddings.requires_grad = False

In [3]:
fail_indices, failed_res_emb, failed_pairs = find_self_embeds(embeddings, tokenizer)

100%|██████████| 13/13 [00:02<00:00,  6.31it/s]


In [4]:
shadow_ratios = get_shadow_ratios(fail_indices, embeddings)
shadow_ratios_sorted = sorted(shadow_ratios, key=lambda x: x[1], reverse=True);

In [75]:
n = 124
embeddings_cut = embeddings  # embeddings[:10000]
embeddings_others = torch.cat((embeddings_cut[:n], embeddings_cut[n + 1 :]), dim=0)
self_emb = embeddings_cut[n]

A = self_emb - embeddings_others
A.requires_grad = False
torch.all(A @ self_emb > 0)

embeddings_others = embeddings_others.detach()
self_emb = self_emb.detach()
A = A.detach()
x = self_emb.detach().clone()
x.requires_grad = True
optimizer = torch.optim.SGD([x], lr=0.01)
optimizer = torch.optim.AdamW([x], lr=0.01)

epsilon = 1e-4

with torch.no_grad():
    loss = torch.sum(torch.relu(-A @ x + epsilon))
    print(f"Initial loss = {loss}")

for step in range(1000):
    optimizer.zero_grad()
    loss = torch.sum(torch.relu(-A @ x + epsilon))
    loss.backward()
    optimizer.step()

    if (step + 1) % 100 == 0:
        print(f"Step {step+1}, Loss: {loss.item()}")
    if loss.item() <= 0:
        break

print("Optimization finished. Final x:", loss.item())
print(f"Target embedding - {torch.all(A@x>0).item()}, min = {torch.min(A@x)}")

Initial loss = 6.937034606933594
Step 100, Loss: 0.00045663019409403205
Step 200, Loss: 0.0001864949445007369
Step 300, Loss: 0.00011208559590158984
Step 400, Loss: 7.662771531613544e-05
Step 500, Loss: 6.027078052284196e-05
Step 600, Loss: 3.7434889236465096e-05
Step 700, Loss: 1.3979071809444577e-05
Step 800, Loss: 4.358837031759322e-06
Step 900, Loss: 0.0
Optimization finished. Final x: 0.0
Target embedding - True, min = 0.00010008066601585597


In [67]:
def calc_loss(x, self_emb, X, mask, epsilon = 1e-4):

    xself = torch.einsum('ij,ij->i', x, self_emb)
    xX = x@X.T

    xA = xself[:, None]-xX
    xA = xA*mask
    loss = torch.sum(torch.relu(-xA + epsilon))

    return loss

def train_vectors(n_lst, embeddings, n_steps=100, verbose=False):
    X = embeddings
    self_emb = X[n_lst]
    mask = torch.ones((len(n_lst), len(X)), requires_grad=False, device=X.device)
    indices = torch.arange(len(n_lst))
    mask[indices, n_lst] = 0

    x_optim = self_emb.detach().clone()
    x_optim.requires_grad = True
    optimizer = torch.optim.AdamW([x_optim], lr=0.01)

    with torch.no_grad():
        loss = calc_loss(x_optim, self_emb, X, mask)
        print(f"Initial loss = {loss.item()}")

    for step in range(n_steps):
        optimizer.zero_grad()

        loss = calc_loss(x_optim, self_emb, X, mask)
        loss.backward()
        optimizer.step()
        if verbose and (step+1) % 100 == 0:
            print(f"Step {step + 1}, Loss: {loss.item()}")

    with torch.no_grad():
        loss = calc_loss(x_optim, self_emb, X, mask)
        print(f"Final loss = {loss.item()}")

    return loss, x_optim, self_emb, mask

In [68]:
n_lst = fail_indices[:4]

In [69]:
loss, x_optim, self_emb, mask = train_vectors(n_lst, embeddings, n_steps=1000);

Initial loss = 27.730876922607422
Final loss = 0.0004165483987890184


In [70]:
xself = torch.einsum('ij,ij->i', x_optim, self_emb)
xX = x_optim@embeddings.T
xA = xself[:, None]-xX
xA1 = xA + 1e10*(1-mask)

In [79]:
xqq = torch.min(xA1, dim=1)[0]
xqq

tensor([9.3132e-05, 9.7319e-05, 1.0000e-04, 9.3073e-05], device='cuda:0',
       grad_fn=<MinBackward0>)

In [93]:
for n, x in zip(n_lst, x_optim):
    embeddings_others = torch.cat((embeddings[:n], embeddings[n + 1 :]), dim=0)
    x0 = embeddings[n]
    A = x0 - embeddings_others
    print(f"Target embedding - {torch.all(A@x>0).item()}, min = {torch.min(A@x)}")

Target embedding - True, min = 9.319287346443161e-05
Target embedding - True, min = 9.72848865785636e-05
Target embedding - True, min = 0.00010005751391872764
Target embedding - True, min = 9.309838060289621e-05


In [85]:

    # if loss.item() <= 0:
    #     break

# print("Optimization finished. Final loss:", loss.item())
#
# # Check final condition for each embedding in X
# with torch.no_grad():
#     target_embeddings_check = torch.all(torch.matmul(A, x_init.unsqueeze(2)).squeeze(2) > 0, dim=1)
#     min_values = torch.min(torch.matmul(A, x_init.unsqueeze(2)).squeeze(2), dim=1).values
#
# print(f"Target embeddings satisfied for all n? - {torch.all(target_embeddings_check).item()}")
# print(f"Minimum values per embedding: {min_values}")

In [36]:
n, d = 5, 6  # Example sizes
self_emb = torch.randn(n, d)
n_lst = [1, 3, 0, 4, 2]

mask = torch.ones_like(self_emb, requires_grad=False)

indices = torch.arange(len(n_lst))  # [0, 1, 2, 3, 4]
mask[indices, n_lst] = 0

print("Generated mask:")
print(mask)

Generated mask:
tensor([[1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1.],
        [0., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0., 1.],
        [1., 1., 0., 1., 1., 1.]])
