In [2]:
!nvidia-smi

Fri Dec 20 18:46:05 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.100                Driver Version: 550.100        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  |   00000000:C3:00.0 Off |                  N/A |
| 53%   42C    P5            101W /  350W |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
conda install -c conda-forge -c nvidia -c rapidsai-nightly cuvs=25.02

Channels:
 - conda-forge
 - nvidia
 - rapidsai-nightly
Platform: linux-64
doneecting package metadata (repodata.json): - 
doneing environment: \ 

# All requested packages already installed.


Note: you may need to restart the kernel to use updated packages.


In [3]:
!pip install transformers
!pip install hnswlib
!pip install cupy-cuda12x

Collecting transformers
  Downloading transformers-4.47.1-py3-none-any.whl.metadata (44 kB)
Collecting huggingface-hub<1.0,>=0.24.0 (from transformers)
  Downloading huggingface_hub-0.27.0-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Downloading safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading transformers-4.47.1-py3-none-any.whl (10.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m180.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading huggingface_hub-0.27.0-py3-none-any.whl (450 kB)
Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x

In [18]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Model
import hnswlib
import torch
import torch.nn.functional as F
import timeit
import time
import functools
import os
from cuvs.neighbors import cagra
import cupy as cp
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
# model_headless = GPT2Model.from_pretrained("gpt2").to(device)

# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B").to(device)

input_text = "Hello, I'm an"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    return_dict_in_generate=True,
    output_hidden_states=True,
    output_scores=True,
    output_logits=True
)

gen_tokens = output.sequences

gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, I'm an engineer and one of the lead development engineers of Evernote, I am so happy that you are one of the people asking for my help. You can find me on Twitter, if not contact me at twitter.com/


In [20]:
hidden = output.hidden_states
logits = output.logits

last_hidden = hidden[-1][-1].squeeze(0)
last_logits = logits[-1]

print(len(hidden), len(logits), len(hidden[-1]))
print(last_hidden.shape)
print(last_logits.shape)

45 45 13
torch.Size([1, 768])
torch.Size([1, 50257])


### HNSW Index Embedding Layer

In [21]:
class HNSWIndexEmbedding():
    def __init__(self, weight, k, M=32, ef=100, ef_construction=100, index_file=None):
        self.k, self.vocab_size, self.dim = k, weight.shape[0], weight.shape[1]
        self.index = hnswlib.Index(space='ip', dim=self.dim)
        self.index.init_index(max_elements=self.vocab_size, M=M, ef_construction=ef_construction, random_seed=42)
        self.index.set_ef(ef)

        if index_file is None or not os.path.exists(index_file):
            self.index.add_items(weight.numpy())
            if index_file is not None:
                self.index.save_index(index_file)
        else:
            print(f"Loading index from file: {index_file}")
            self.index.load_index(index_file)

    def forward(self, x):
        indices, distances = self.index.knn_query(x.detach().cpu().numpy(), k=self.k)
        return torch.from_numpy(1 - distances).to(torch.float32).to(x.device), torch.from_numpy(indices).to(torch.int64).to(x.device)

In [57]:
class CagraIndexEmbedding(nn.Module):
    def __init__(self, weight, k, intermediate_graph_degree=128, graph_degree=32, index_file=None):
        super().__init__()
        self.weight = nn.Parameter(weight)
        self.k, self.vocab_size, self.dim = k, weight.shape[0], weight.shape[1]
        self.params = cagra.IndexParams(metric="inner_product", intermediate_graph_degree=intermediate_graph_degree, graph_degree=graph_degree)
        self.search_params = cagra.SearchParams(itopk_size=100, max_iterations=24) # (itopk_size=64, max_iterations=40)

        index_path = f"{index_file}-{intermediate_graph_degree}-{graph_degree}.index"
        if index_file is None or not os.path.exists(index_path):
            self.index = cagra.build(self.params, weight)
            if index_file is not None:
                cagra.save(index_path, self.index)
        else:
            print(f"Loading index from file: {index_path}")
            self.index = cagra.load(index_path)

    
    def forward_index(self, x):
        n = x.shape[0]
        distances = cp.empty((n, self.k), dtype="float32")
        indices = cp.empty((n, self.k), dtype="uint32")
        cagra.search(self.search_params, self.index, x, self.k, indices, distances)
        distances = torch.as_tensor(distances, device=x.device)
        indices = torch.as_tensor(indices, device=x.device)
        return distances, indices

    
    def forward(self, x, targets=None):
        distances, indices = self.forward_index(x)
        indices = indices.to(torch.long)
        loss = None
        if targets is not None:
            print(indices.shape, x.shape, targets.shape)
            has_target = torch.any(indices == targets.unsqueeze(-1), dim=-1)
            print(has_target.shape, has_target.dtype, indices[:,-1].shape, torch.where(has_target, targets, indices[:,-1]).shape)
            indices[:,-1] = torch.where(has_target, targets, indices[:,-1])
            print(indices.shape, x.shape, x.unsqueeze(1).shape, self.weight[indices].shape)
            logits = (x.unsqueeze(1) * self.weight[indices]).sum(-1)
            print(logits.shape)
            loss = F.cross_entropy(logits, targets.view(-1))
        return logits, indices, loss


    def update_index(self):
        self.index = cagra.build(self.params, self.weight.weight)


k = 100

out_emb_weight = model.transformer.wte.weight.detach().clone()
# out_emb_weight = model.get_input_embeddings().weight.detach().clone()
print(out_emb_weight.shape)

out_emb_vector = CagraIndexEmbedding(out_emb_weight, k=k, index_file="index-gpt2")

print(last_hidden.shape)

targets = torch.tensor([42, 34], dtype=torch.long, device=device)

out_logits, out_indices, loss = out_emb_vector.forward(last_hidden.repeat(2,1), targets)
out_logits, out_indices, loss

torch.Size([50257, 768])
Loading index from file: index-gpt2-128-32.index
torch.Size([1, 768])
torch.Size([2, 100]) torch.Size([2, 768]) torch.Size([2])
torch.Size([2]) torch.bool torch.Size([2]) torch.Size([2])
torch.Size([2, 100]) torch.Size([2, 768]) torch.Size([2, 1, 768]) torch.Size([2, 100, 768])
torch.Size([2, 100])


(tensor([[-29.2514, -33.3873, -34.9027, -35.8461, -35.9987, -36.8335, -37.0154,
          -37.3537, -37.3606, -37.4043, -37.7783, -37.9624, -38.1702, -38.2616,
          -38.4445, -38.6041, -38.7200, -38.7357, -38.7809, -38.8268, -38.9144,
          -38.9363, -38.9747, -39.2869, -39.3602, -39.4034, -39.5476, -39.6078,
          -39.6367, -39.6784, -39.7424, -39.7868, -39.8291, -39.8703, -39.8999,
          -39.9018, -39.9133, -39.9230, -40.0153, -40.0275, -40.0284, -40.0705,
          -40.1092, -40.1202, -40.1393, -40.1484, -40.1511, -40.1647, -40.1749,
          -40.1963, -40.1994, -40.2076, -40.2910, -40.3538, -40.3635, -40.3850,
          -40.4385, -40.4613, -40.4713, -40.4947, -40.5126, -40.6431, -40.6517,
          -40.6589, -40.6795, -40.7312, -40.7623, -40.7746, -40.7873, -40.7892,
          -40.8075, -40.8138, -40.8382, -40.8776, -40.9042, -40.9176, -40.9375,
          -41.0038, -41.0218, -41.0642, -41.0654, -41.0757, -41.0835, -41.0863,
          -41.0901, -41.0911, -41.0960, 

In [26]:
k = 100

out_emb_weight = model.transformer.wte.weight.detach().clone()
# out_emb_weight = model.get_input_embeddings().weight.detach().clone()
print(out_emb_weight.shape)

torch.Size([50257, 768])


In [27]:
out_emb_vector = DifferentialCagraLayer(CagraIndexEmbedding(out_emb_weight, k=k, index_file="index-gpt2"))

Loading index from file: index-gpt2-128-32.index


In [28]:
out_logits, out_indices = out_emb_vector.forward(last_hidden)
out_logits, out_indices

AttributeError: 'bool' object has no attribute 'sum'

In [9]:
ref_out = last_hidden @ out_emb_weight.T
ref_logits, ref_indices = torch.topk(ref_out, k=k)
print(torch.isin(out_indices.to(torch.long)[0, 0:100], ref_indices[0, 0:100]).sum())
ref_logits, ref_indices

tensor(81, device='cuda:0')


(tensor([[-126.6614, -127.1938, -128.6376, -129.1087, -129.2057, -129.2185,
          -129.3537, -129.3808, -129.5098, -129.7573, -129.8551, -129.9439,
          -130.0248, -130.0328, -130.3503, -130.4384, -130.5805, -130.5928,
          -130.6030, -130.6158, -130.6413, -130.7489, -130.8294, -130.9059,
          -130.9340, -130.9808, -131.0147, -131.0730, -131.1241, -131.1581,
          -131.2195, -131.2692, -131.3570, -131.3711, -131.4168, -131.5092,
          -131.5485, -131.6041, -131.7452, -131.8310, -131.9424, -131.9483,
          -131.9522, -131.9538, -131.9659, -132.0049, -132.0702, -132.0872,
          -132.0903, -132.1017, -132.1630, -132.2622, -132.2875, -132.3112,
          -132.3194, -132.3295, -132.3509, -132.3565, -132.3696, -132.3816,
          -132.5125, -132.5275, -132.6001, -132.6261, -132.6556, -132.6936,
          -132.7185, -132.7505, -132.7521, -132.7674, -132.7725, -132.8896,
          -132.9072, -132.9092, -132.9217, -132.9271, -132.9387, -132.9417,
          -1

In [14]:
# out_emb_hnsw = HNSWIndexEmbedding(out_emb_weight.cpu(), k=k, ef_construction=150, index_file="./hnsw.index")

In [15]:
# out_logits, out_indices = out_emb_hnsw.forward(last_hidden)
# out_logits, out_indices

In [16]:
# out_emb_vector.index.set_ef(100)

last_hidden_repeat = last_hidden.repeat(16 * 1024, 1)
print(last_hidden_repeat.shape)

def time_forward():
    ol, oi = out_emb_vector.forward(last_hidden_repeat)
    ol = F.softmax(ol, dim=-1)
    torch.cuda.synchronize()
    del ol, oi

def time_ref():
    out = last_hidden_repeat @ out_emb_weight.T
    out = F.softmax(out, dim=-1)
    # out_topk = torch.topk(out, k=50)
    torch.cuda.synchronize() 
    del out

forward_time = timeit.timeit(time_forward, number=10)
forward_ref_time = timeit.timeit(time_ref, number=10)

print(f"Average time taken (forward): {forward_time:.6f} seconds")
print(f"Average time taken (matrix multiplication): {forward_ref_time:.6f} seconds")
print(f"Speedup: {forward_ref_time / forward_time:.4f}")

torch.Size([16384, 768])
Average time taken (forward): 0.167872 seconds
Average time taken (matrix multiplication): 0.764478 seconds
Speedup: 4.5539


### Generate & Prediction Similarity Measurement

In [17]:
positions = range(len(hidden))
# positions = [6]
k_options = [100] # [1, 3, 5, 10, 50]

for pos in positions:
    last_layer_hidden = hidden[pos][-1].squeeze(0)[0,:]
    last_layer_logits = logits[pos].squeeze(0)
    position_topk_indices = torch.topk(last_layer_logits, k)[1]

    exp_logits = torch.exp(last_layer_logits.to(torch.float64))

    token_id = position_topk_indices[0]
    token = tokenizer.decode(token_id)
    out_logits, out_indices = out_emb_vector.forward(last_layer_hidden.unsqueeze(0))
    for j in k_options:
        subset = position_topk_indices[0:j]
        common_indices = subset[torch.isin(subset, out_indices.to(torch.long))]

        # common_indices = out_indices.to(torch.long)[0]

        exp_logits_all = exp_logits[position_topk_indices.squeeze(0)[0:j]].sum()
        exp_logits_common = exp_logits[common_indices].sum()
        # Add "label" if missing, only doable in training
        if token_id.item() not in common_indices.tolist():
            exp_logits_common += exp_logits[token_id]

        logits_percentage = exp_logits_common / exp_logits_all
        color = "\033[33m" if logits_percentage < 0.95 else ""

        print(f"{color}{pos}: {len(common_indices)}/{j} ({len(common_indices)/j:0.2f}), logits: {logits_percentage:0.4f}, {token} {token_id}\033[0m")
    if len(k_options) > 1: print()

[33m0: 2/100 (0.02), logits: 0.0817,  old 1468[0m
1: 87/100 (0.87), logits: 0.9950,  on 319[0m
[33m2: 83/100 (0.83), logits: 0.9295,  the 262[0m
3: 91/100 (0.91), logits: 0.9519,  subject 2426[0m
4: 84/100 (0.84), logits: 0.9964,  of 286[0m
[33m5: 87/100 (0.87), logits: 0.9067,  gaming 7776[0m
6: 88/100 (0.88), logits: 0.9880, . 13[0m
7: 97/100 (0.97), logits: 0.9935,  I 314[0m
8: 79/100 (0.79), logits: 0.9999, 
 198[0m
[33m9: 74/100 (0.74), logits: 0.9358, I 40[0m
10: 74/100 (0.74), logits: 0.9598, 'm 1101[0m
11: 95/100 (0.95), logits: 0.9926,  a 257[0m
[33m12: 85/100 (0.85), logits: 0.8909,  big 1263[0m
[33m13: 81/100 (0.81), logits: 0.9462,  expert 5887[0m
14: 92/100 (0.92), logits: 0.9922, , 11[0m
15: 79/100 (0.79), logits: 0.9645,  the 262[0m
[33m16: 75/100 (0.75), logits: 0.9020,  of 286[0m
17: 96/100 (0.96), logits: 0.9994,  the 262[0m
18: 77/100 (0.77), logits: 0.9690,  most 749[0m
[33m19: 77/100 (0.77), logits: 0.8729,  cryptocurrencies 29760[0m
20:

In [18]:
# Samples simultaneously from the hnsw and the reference(full matrix multiplication) distribution
# and logs all the positions where a different token are smapled.

def double_multinomial(p1, p2):
    assert len(p1.shape) == 1 and len(p2.shape) == 1

    p1 = p1 / p1.sum()
    p2 = p2 / p2.sum()

    p1_cumsum = torch.cumsum(p1, dim=0)
    p2_cumsum = torch.cumsum(p2, dim=0)

    random_number = torch.rand(1).item()

    i1 = torch.searchsorted(p1_cumsum, random_number).item()
    i2 = torch.searchsorted(p2_cumsum, random_number).item()

    return i1, i2, random_number

# out_emb_vector.index.set_ef(100)

n_different_sample = 0
max_length = 64
tokens = tokenizer.encode("Hello, I'm a language model,")
xgen = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)

while xgen.size(1) < max_length:
    with torch.no_grad():
        last_hidden_state = model_headless(xgen).last_hidden_state
        last_hidden_state = last_hidden_state[:, -1, :]

        logits_vec, indices_vec = out_emb_vector.forward(last_hidden_state)
        indices_vec = indices_vec.to(torch.long)
        
        logits_ref = last_hidden_state @ out_emb_weight.T
        probs_ref = F.softmax(logits_ref, dim=-1)
        topk_probs_ref, topk_indices_ref = torch.topk(probs_ref, 50, dim=-1)

        exp_logits = F.softmax(logits_vec.to(torch.float64), dim=-1)

        i1, i2, ran = double_multinomial(exp_logits[0, :], topk_probs_ref[0, :])
        i1 = torch.gather(indices_vec[0,:], -1, torch.tensor(i1, device=device))
        i2 = torch.gather(topk_indices_ref[0,:], -1, torch.tensor(i2, device=device))
        xcol = i1.view(1, 1)

        if i1 != i2:
            print(tokenizer.decode(xgen[0, -32:max_length].tolist()), f"'{tokenizer.decode(i1.tolist())}'/'{tokenizer.decode(i2.tolist())}'")
            n_different_sample += 1

        xgen = torch.cat((xgen, xcol), dim=1)

print("\n\nGenerated Text:")
tokens = xgen[0, :max_length].tolist()
print(tokenizer.decode(tokens))

print(f"{n_different_sample/max_length:.4f}")

NameError: name 'model_headless' is not defined

### Generate & Performance Measurement

In [79]:
# Generate text using top-k sampling from a GPT-2 model without the LM head,
# utilizing a vector index to get the top-k elements (or without the index if method=ref)
def generate(method="vec-index", num_return_sequences=4, max_length=64):
    tokens = tokenizer.encode("Hello, I'm a language model,")
    tokens = torch.tensor(tokens, dtype=torch.long)
    xgen = tokens.unsqueeze(0).repeat(num_return_sequences, 1).to(device)
    while xgen.size(1) < max_length:
        with torch.no_grad():
            last_hidden_state = model_headless(xgen).last_hidden_state
            last_hidden_state = last_hidden_state[:, -1, :]

            if method == "vec-index":
                logits, indices = out_emb_vector.forward(last_hidden_state)
                exp_logits = F.softmax(logits, dim=-1)
            else:
                logits_ref = last_hidden_state @ out_emb_weight.T
                probs_ref = F.softmax(logits_ref, dim=-1)
                exp_logits, indices = torch.topk(probs_ref, 50, dim=-1)

            ix = torch.multinomial(exp_logits, 1)
            xcol = torch.gather(indices.to(torch.long), -1, ix)
            xgen = torch.cat((xgen, xcol), dim=1)
    return xgen

# out_emb_vector.index.set_ef(100)

start = time.time()
xgen = generate("vec-index", max_length=128, num_return_sequences=16)
vec_time = time.time() - start
print(f"Vec took {vec_time:.2f} seconds")

start = time.time()
xgen = generate("ref", max_length=128, num_return_sequences=16)
ref_time = time.time() - start
print(f"Ref took {ref_time:.2f} seconds")

print(f"Speedup: {ref_time / vec_time:.2f}")

# for i in range(num_return_sequences):
#    print(tokenizer.decode(xgen[i, :max_length].tolist()))

NameError: name 'model_headless' is not defined

In [22]:
if False:
    data = torch.cat([hidden[i][-1].squeeze(0) for i in range(len(hidden))], dim=0).repeat(6, 1)
    time_repeat, time_num = 10, 10

    print("| B   | ef  | Speedup |")
    print("| --: | --: | ------: |")
    for ef in [100, 200]:
      for B in [1, 8, 54, 256]:
        out_emb_vector.index.set_ef(ef)
        batch = data[0:B, :]

        forward_time = min(timeit.repeat(lambda: out_emb_vector.forward(batch), number=time_num, repeat=time_repeat))
        forward_ref_time = min(timeit.repeat(lambda: batch @ out_emb_weight.T, number=time_num, repeat=time_repeat))

        print(f"|  {B}  | {ef} | {forward_ref_time / forward_time:.1f}x |")