<a href="https://colab.research.google.com/github/martinloretzzz/vector-index-layer/blob/main/VectorIndexLayer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install hnswlib



In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Model
import hnswlib
import torch
import torch.nn.functional as F
import timeit
import time
import functools

In [3]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model_headless = GPT2Model.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

input_text = "The quick brown fox"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    return_dict_in_generate=True,
    output_hidden_states=True,
    output_scores=True,
    output_logits=True
)

gen_tokens = output.sequences

gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


The quick brown fox-like movements of his arm, which look just like those of a bull and a bear, will make me look like a tiger. But look what he has for me!"

Bobby's voice had become muffled.


In [4]:
hidden = output.hidden_states
logits = output.logits

last_hidden = hidden[-1][-1].squeeze(0)
last_logits = logits[-1]

print(len(hidden), len(logits), len(hidden[-1]))
print(last_hidden.shape)
print(last_logits.shape)

46 46 13
torch.Size([1, 768])
torch.Size([1, 50257])


### HNSW Index Embedding Layer

In [5]:
class HNSWIndexEmbedding():
    def __init__(self, weight, k, M=32, ef=100, ef_construction=100):
        self.k, self.vocab_size, self.dim = k, weight.shape[0], weight.shape[1]
        self.index = hnswlib.Index(space='ip', dim=self.dim)
        self.index.init_index(max_elements=self.vocab_size, M=M, ef_construction=ef_construction, random_seed=42)
        self.index.add_items(weight.numpy())
        self.index.set_ef(ef)

    def forward(self, x):
        indices, distances = self.index.knn_query(x.detach().cpu().numpy(), k=self.k)
        return torch.from_numpy(1 - distances).to(torch.float32).to(x.device), torch.from_numpy(indices).to(torch.int64).to(x.device)

In [6]:
k = 50

out_emb_weight = model.transformer.wte.weight.detach().clone()
print(out_emb_weight.shape)

out_emb_vector = HNSWIndexEmbedding(out_emb_weight, k=k, ef_construction=150)

torch.Size([50257, 768])


In [7]:
out_logits, out_indices = out_emb_vector.forward(last_hidden)
out_logits, out_indices

(tensor([[-106.2140, -106.8267, -106.8412, -107.0850, -107.6768, -108.0547,
          -108.6532, -108.7733, -108.8456, -109.0358, -109.3222, -109.5707,
          -109.6564, -109.9158, -109.9889, -110.0014, -110.0058, -110.0616,
          -110.0690, -110.1507, -110.1909, -110.2026, -110.2058, -110.2102,
          -110.3968, -110.4023, -110.4185, -110.4547, -110.6066, -110.7092,
          -110.7136, -110.7649, -110.9780, -111.0564, -111.1580, -111.1598,
          -111.2218, -111.2827, -111.2983, -111.3177, -111.3209, -111.4116,
          -111.4208, -111.4266, -111.5473, -111.6009, -111.6026, -111.6071,
          -111.6134, -111.6419]]),
 tensor([[  416,    13,    11,   355,   290,   287,   351,   422,   329,   618,
            772,   379,   783,   832,   284,   739,    26,   757,   257,   319,
            625,   780,  2157,   475,   878,  1752,   523,   588,   706,    25,
            477,  2029,  3690,  2048,   656,   981,   517,   655,  1165,   262,
          11061,  1626,   503,  1201,

### Generate & Prediction Similarity Measurement

In [8]:
positions = range(len(hidden))
# positions = [6]
k_options = [50] # [1, 3, 5, 10, 50]

for pos in positions:
    last_layer_hidden = hidden[pos][-1].squeeze(0)[0,:]
    last_layer_logits = logits[pos].squeeze(0)
    position_topk_indices = torch.topk(last_layer_logits, k)[1]

    exp_logits = torch.exp(last_layer_logits.to(torch.float64))

    token_id = position_topk_indices[0]
    token = tokenizer.decode(token_id)

    out_logits, out_indices = out_emb_vector.forward(last_layer_hidden)
    for j in k_options:
        subset = position_topk_indices[0:j]
        common_indices = subset[torch.isin(subset, out_indices)]

        exp_logits_all = exp_logits[position_topk_indices.squeeze(0)[0:j]]
        exp_logits_common = exp_logits[common_indices]
        logits_percentage = exp_logits_common.sum() / exp_logits_all.sum()
        color = "\033[33m" if logits_percentage < 0.9 else ""

        print(f"{color}{pos}: {len(common_indices)}/{j} ({len(common_indices)/j:0.2f}), logits: {logits_percentage:0.4f}, {token} {token_id}\033[0m")
    if len(k_options) > 1: print()

[33m0: 15/50 (0.30), logits: 0.4504, es 274[0m
1: 48/50 (0.96), logits: 0.9810, like 2339[0m
2: 44/50 (0.88), logits: 0.9229,  creature 7185[0m
3: 49/50 (0.98), logits: 0.9909,  of 286[0m
4: 50/50 (1.00), logits: 1.0000,  the 262[0m
[33m5: 40/50 (0.80), logits: 0.8770,  tail 7894[0m
6: 48/50 (0.96), logits: 0.9853,  made 925[0m
7: 50/50 (1.00), logits: 1.0000,  which 543[0m
8: 48/50 (0.96), logits: 0.9912,  he 339[0m
9: 47/50 (0.94), logits: 0.9799,  like 588[0m
10: 44/50 (0.88), logits: 0.9970,  like 588[0m
11: 47/50 (0.94), logits: 0.9951,  a 257[0m
12: 46/50 (0.92), logits: 0.9951,  of 286[0m
13: 49/50 (0.98), logits: 0.9990,  a 257[0m
14: 44/50 (0.88), logits: 0.9522,  fox 21831[0m
15: 40/50 (0.80), logits: 0.9763, dog 9703[0m
[33m16: 48/50 (0.96), logits: 0.8570,  a 257[0m
[33m17: 43/50 (0.86), logits: 0.8090,  dog 3290[0m
18: 46/50 (0.92), logits: 0.9917, , 11[0m
19: 48/50 (0.96), logits: 0.9889,  are 389[0m
20: 46/50 (0.92), logits: 0.9715,  make 787[0m


In [9]:
# Samples simultaneously from the hnsw and the reference(full matrix multiplication) distribution
# and logs all the positions where a different token are smapled.

def double_multinomial(p1, p2):
    assert len(p1.shape) == 1 and len(p2.shape) == 1

    p1 = p1 / p1.sum()
    p2 = p2 / p2.sum()

    p1_cumsum = torch.cumsum(p1, dim=0)
    p2_cumsum = torch.cumsum(p2, dim=0)

    random_number = torch.rand(1).item()

    i1 = torch.searchsorted(p1_cumsum, random_number).item()
    i2 = torch.searchsorted(p2_cumsum, random_number).item()

    return i1, i2, random_number

out_emb_vector.index.set_ef(100)

n_different_sample = 0
max_length = 64
tokens = tokenizer.encode("Hello, I'm a language model,")
xgen = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)

while xgen.size(1) < max_length:
    with torch.no_grad():
        last_hidden_state = model_headless(xgen).last_hidden_state
        last_hidden_state = last_hidden_state[:, -1, :]

        logits_vec, indices_vec = out_emb_vector.forward(last_hidden_state)

        logits_ref = last_hidden_state @ out_emb_weight.T
        probs_ref = F.softmax(logits_ref, dim=-1)
        topk_probs_ref, topk_indices_ref = torch.topk(probs_ref, 50, dim=-1)

        exp_logits = F.softmax(logits_vec.to(torch.float64), dim=-1)

        i1, i2, ran = double_multinomial(exp_logits[0, :], topk_probs_ref[0, :])
        i1 = torch.gather(indices_vec[0,:], -1, torch.tensor(i1))
        i2 = torch.gather(topk_indices_ref[0,:], -1, torch.tensor(i2))
        xcol = i1.view(1, 1)

        if i1 != i2:
            print(tokenizer.decode(xgen[0, -32:max_length].tolist()), f"'{tokenizer.decode(i1.tolist())}'/'{tokenizer.decode(i2.tolist())}'")
            n_different_sample += 1

        xgen = torch.cat((xgen, xcol), dim=1)

print("\n\nGenerated Text:")
tokens = xgen[0, :max_length].tolist()
print(tokenizer.decode(tokens))

print(f"{n_different_sample/max_length:.4f}")

Hello, I'm a language model, and this is the most important one for any ' programming'/' of'
Hello, I'm a language model, and this is the most important one for any programming language '.'/','
Hello, I'm a language model, and this is the most important one for any programming language.

 'No'/'Hello'
Hello, I'm a language model, and this is the most important one for any programming language.

No I am not an ' editor'/' engineer'
, I'm a language model, and this is the most important one for any programming language.

No I am not an editor and I do not write ' this'/' languages'
 the most important one for any programming language.

No I am not an editor and I do not write this code. I should only write this code so ' it'/' others'
 for any programming language.

No I am not an editor and I do not write this code. I should only write this code so it will have a ' meaning'/' higher'
 I am not an editor and I do not write this code. I should only write this code so it will have a meanin

### Generate & Performance Measurement

In [10]:
out_emb_vector.index.set_ef(100)

forward_time = timeit.timeit(lambda: out_emb_vector.forward(last_hidden), number=10)
forward_ref_time = timeit.timeit(lambda: last_hidden @ out_emb_weight.T, number=10)

print(f"Average time taken (forward): {forward_time:.6f} seconds")
print(f"Average time taken (matrix multiplication): {forward_ref_time:.6f} seconds")
print(f"Speedup: {forward_ref_time / forward_time:.4f}")

Average time taken (forward): 0.006952 seconds
Average time taken (matrix multiplication): 0.160113 seconds
Speedup: 23.0304


In [11]:
# Generate text using top-k sampling from a GPT-2 model without the LM head,
# utilizing a vector index to get the top-k elements (or without the index if method=ref)
def generate(method="vec-index", num_return_sequences=4, max_length=64):
    tokens = tokenizer.encode("Hello, I'm a language model,")
    tokens = torch.tensor(tokens, dtype=torch.long)
    xgen = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
    while xgen.size(1) < max_length:
        with torch.no_grad():
            last_hidden_state = model_headless(xgen).last_hidden_state
            last_hidden_state = last_hidden_state[:, -1, :]

            if method == "vec-index":
                logits, indices = out_emb_vector.forward(last_hidden_state)
                exp_logits = F.softmax(logits.to(torch.float64), dim=-1)
            else:
                logits_ref = last_hidden_state @ out_emb_weight.T
                probs_ref = F.softmax(logits_ref, dim=-1)
                exp_logits, indices = torch.topk(probs_ref, 50, dim=-1)

            ix = torch.multinomial(exp_logits, 1)
            xcol = torch.gather(indices, -1, ix)
            xgen = torch.cat((xgen, xcol), dim=1)
    return xgen

out_emb_vector.index.set_ef(100)

start = time.time()
xgen = generate("vec-index", max_length=32, num_return_sequences=4)
vec_time = time.time() - start
print(f"Vec took {vec_time:.2f} seconds")

start = time.time()
xgen = generate("ref", max_length=32, num_return_sequences=4)
ref_time = time.time() - start
print(f"Ref took {ref_time:.2f} seconds")

print(f"Speedup: {ref_time / vec_time:.2f}")

# for i in range(num_return_sequences):
#    print(tokenizer.decode(xgen[i, :max_length].tolist()))

Vec took 7.65 seconds
Ref took 9.55 seconds
Speedup: 1.25


In [12]:
if False:
    data = torch.cat([hidden[i][-1].squeeze(0) for i in range(len(hidden))], dim=0).repeat(6, 1)
    time_repeat, time_num = 10, 10

    print("| B   | ef  | Speedup |")
    print("| --: | --: | ------: |")
    for ef in [100, 200]:
      for B in [1, 8, 54, 256]:
        out_emb_vector.index.set_ef(ef)
        batch = data[0:B, :]

        forward_time = min(timeit.repeat(lambda: out_emb_vector.forward(batch), number=time_num, repeat=time_repeat))
        forward_ref_time = min(timeit.repeat(lambda: batch @ out_emb_weight.T, number=time_num, repeat=time_repeat))

        print(f"|  {B}  | {ef} | {forward_ref_time / forward_time:.1f}x |")