In [1]:
import torch
from tqdm import tqdm
import numpy as np
import gc

from transformers import AutoTokenizer, AutoModelForCausalLM
from emb_vectors_functions import find_self_embeds, get_shadow_ratios
from model_loading import get_weight_by_name
from find_bed_embed_min import train_vectors, batch_list

In [4]:
# "meta-llama/Llama-3.1-70B", "meta-llama/Llama-3.1-8B"
# "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-7B-Instruct"
# "Qwen/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen2.5-72B-Instruct"
model_name = "meta-llama/Llama-3.1-70B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
embeddings = get_weight_by_name(model_name, "head")
embeddings = embeddings.cuda()
# embeddings = torch.randn_like(embeddings)
embeddings.requires_grad = False

Fetching 50 files:   0%|          | 0/50 [00:00<?, ?it/s]

In [5]:
fail_indices, failed_res_emb, failed_pairs = find_self_embeds(embeddings, tokenizer)
print(f"Number of bad embeddings = {len(fail_indices)}")

100%|██████████| 13/13 [00:02<00:00,  5.98it/s]


Number of bad embeddings = 1060


In [6]:
# shadow_ratios = get_shadow_ratios(fail_indices, embeddings)
# shadow_ratios_sorted = sorted(shadow_ratios, key=lambda x: x[1], reverse=True);

In [7]:
# loss, x_optim, self_emb, mask, bad_indices = train_vectors(n_lst, embeddings, n_steps=100, verbose=True)

In [8]:
ind_batched_list = batch_list(fail_indices, 1000)
# ind_batched_list = [[183]]

In [9]:
print(
    f"Number of bad embeddings = {len(fail_indices)}, dimension = {embeddings.shape[1]}, dict size = {embeddings.shape[0]}"
)

for ind_list_ in ind_batched_list:
    x_optim = None
    ind_list = ind_list_
    pbar = tqdm(range(300))
    for i in pbar:
        loss, x_optim, self_emb, mask, is_good_embed = train_vectors(
            ind_list,
            embeddings,
            x_optim_start=x_optim,
            n_steps=200,
            verbose=False,
            use_tqdm=False,
        )
        ind_list = np.array(ind_list)[~is_good_embed]
        x_optim = x_optim[~is_good_embed]
        pbar.set_postfix_str(f"Bad embeds: {len(ind_list)}/{len(ind_list_)}")

        if len(ind_list) == 0:
            break

Number of bad embeddings = 1060, dimension = 8192, dict size = 128256


 10%|▉         | 29/300 [00:44<06:56,  1.54s/it, Bad embeds: 0/1000] 
  5%|▌         | 16/300 [00:19<05:42,  1.21s/it, Bad embeds: 0/60] 


In [95]:
m = 30000
d = embeddings.shape[1]
num_iterations = 1000
argmax_results = []

for iteration in tqdm(range(num_iterations)):
    X = torch.randn(m, d, device="cuda", dtype=torch.bfloat16, requires_grad=False)
    X = X / torch.norm(X, dim=1, keepdim=True)
    logits = X @ embeddings.T  # (m, n)
    argmax_indices = logits.argmax(dim=1)  # (m,)
    argmax_results.append(argmax_indices.cpu())
    del X, logits
    torch.cuda.empty_cache()
    # gc.collect()

argmax_results = torch.cat(argmax_results)

torch.cuda.empty_cache()
gc.collect()

100%|██████████| 1000/1000 [05:25<00:00,  3.08it/s]


751

In [96]:
unique_indices, counts = torch.unique(argmax_results, return_counts=True)

In [None]:
sorted_counts, sorted_idx = counts.sort(descending=True)
sorted_indices = unique_indices[sorted_idx]