<a href="https://colab.research.google.com/github/martinloretzzz/vector-index-layer/blob/main/GPTVectorIndexOnOutEmbedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install hnswlib



In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Model
import hnswlib
import torch
import torch.nn.functional as F
import timeit
import functools

In [3]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model_headless = GPT2Model.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

input_text = "The quick brown fox"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    return_dict_in_generate=True,
    output_hidden_states=True,
    output_scores=True,
    output_logits=True
)

gen_tokens = output.sequences

gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


The quick brown fox-like creatures, known as vortigaunts called Vortigaunts, were introduced to the United States in the late 1940s.

They are a type of fox that can grow large in size and is capable of


In [4]:
hidden = output.hidden_states
logits = output.logits

last_hidden = hidden[-1][-1].squeeze(0)
last_logits = logits[-1]

print(len(hidden), len(logits), len(hidden[-1]))
print(last_hidden.shape)
print(last_logits.shape)

46 46 13
torch.Size([1, 768])
torch.Size([1, 50257])


In [5]:
class HNSWIndexEmbedding():
    def __init__(self, weight, k, M=32, ef=100, ef_construction=100):
        self.k, self.vocab_size, self.dim = k, weight.shape[0], weight.shape[1]
        self.index = hnswlib.Index(space='ip', dim=self.dim)
        self.index.init_index(max_elements=self.vocab_size, M=M, ef_construction=ef_construction, random_seed=42)
        self.index.add_items(weight.numpy())
        self.index.set_ef(ef)

    def forward(self, x):
        indices, distances = self.index.knn_query(x.detach().cpu().numpy(), k=self.k)
        return torch.from_numpy(1 - distances).to(torch.float32).to(x.device), torch.from_numpy(indices).to(torch.int64).to(x.device)

In [6]:
k = 50

out_emb_weight = model.transformer.wte.weight.detach().clone()
print(out_emb_weight.shape)

out_emb_vector = HNSWIndexEmbedding(out_emb_weight, k=k, ef_construction=150)

torch.Size([50257, 768])


In [7]:
out_logits, out_indices = out_emb_vector.forward(last_hidden)
out_logits, out_indices

(tensor([[-226.4115, -233.0608, -233.6529, -233.8399, -234.1039, -234.7538,
          -235.2826, -235.3414, -235.3539, -235.3834, -235.5113, -235.6125,
          -235.6129, -235.7076, -235.7217, -235.7992, -235.9429, -236.0075,
          -236.0421, -236.0859, -236.2032, -236.2416, -236.2963, -236.3992,
          -236.4150, -236.5225, -236.7359, -236.7570, -236.7627, -236.7924,
          -236.8477, -236.9031, -236.9309, -237.2243, -237.3063, -237.3725,
          -237.3989, -237.5249, -237.5824, -237.5862, -237.6090, -237.6250,
          -237.6516, -237.6921, -237.7494, -237.7602, -237.7992, -237.8032,
          -237.8066, -237.8293]]),
 tensor([[ 286,  284,   11,  329,  287,  290,  379,  351,  772,  355,  510,  326,
           319,  357, 1111,  422,  198,   13,  503,  257,  262,  517, 2048,  366,
           618,  416,  407,  691,  393,  611,  532,  625,  460,  475,  572, 2035,
           546, 1576,  481,  655, 2592,   12, 3016,  256,  477,  523,  739,   25,
           428,  706]]))

In [8]:
if False:
    data = torch.cat([hidden[i][-1].squeeze(0) for i in range(len(hidden))], dim=0).repeat(6, 1)
    time_repeat, time_num = 10, 10

    print("| B   | ef  | Speedup |")
    print("| --: | --: | ------: |")
    for ef in [100, 200]:
      for B in [1, 8, 64, 256]:
        out_emb_vector.index.set_ef(ef)
        batch = data[0:B, :]

        forward_time = min(timeit.repeat(lambda: out_emb_vector.forward(batch), number=time_num, repeat=time_repeat))
        forward_ref_time = min(timeit.repeat(lambda: batch @ out_emb_weight.T, number=time_num, repeat=time_repeat))

        print(f"|  {B}  | {ef} | {forward_ref_time / forward_time:.1f}x |")

In [10]:
# TODO measure on the GPU

out_emb_vector.index.set_ef(100)

forward_time = timeit.timeit(lambda: out_emb_vector.forward(last_hidden), number=10)
forward_ref_time = timeit.timeit(lambda: last_hidden @ out_emb_weight.T, number=10)

print(f"Average time taken (forward): {forward_time:.6f} seconds")
print(f"Average time taken (matrix multiplication): {forward_ref_time:.6f} seconds")
print(f"Speedup: {forward_ref_time / forward_time:.4f}")

Average time taken (forward): 0.012672 seconds
Average time taken (matrix multiplication): 0.199741 seconds
Speedup: 15.7622


In [11]:
positions = range(len(hidden))
# positions = [6]
k_options = [10, 50] # [1, 3, 5, 10, 50]

for pos in positions:
    last_layer_hidden = hidden[pos][-1].squeeze(0)[0,:]
    last_layer_logits = logits[pos].squeeze(0)
    position_topk_indices = torch.topk(last_layer_logits, k)[1]

    exp_logits = torch.exp(last_layer_logits.to(torch.float64))

    token_id = position_topk_indices[0]
    token = tokenizer.decode(token_id)

    out_logits, out_indices = out_emb_vector.forward(last_layer_hidden)
    for j in k_options:
        subset = position_topk_indices[0:j]
        common_indices = subset[torch.isin(subset, out_indices)]
        # print(common_indices.shape, exp_logits.shape, position_topk_indices.shape)
        # print(common_indices)

        exp_logits_all = exp_logits[position_topk_indices.squeeze(0)[0:j]]
        exp_logits_common = exp_logits[common_indices]
        logits_percentage = exp_logits_common.sum() / exp_logits_all.sum()
        color = "\033[33m" if logits_percentage < 0.9 else ""

        print(f"{color}{pos}: {len(common_indices)}/{j} ({len(common_indices)/j:0.2f}), logits: {logits_percentage:0.4f}, {token} {token_id}\033[0m")
    if len(k_options) > 1: print()

    if len(positions) == 1:
        print(exp_logits_common / exp_logits_all.sum())

[33m0: 6/10 (0.60), logits: 0.5174, es 274[0m
[33m0: 15/50 (0.30), logits: 0.4504, es 274[0m

1: 10/10 (1.00), logits: 1.0000, like 2339[0m
1: 48/50 (0.96), logits: 0.9810, like 2339[0m

2: 10/10 (1.00), logits: 1.0000,  creature 7185[0m
2: 44/50 (0.88), logits: 0.9229,  creature 7185[0m

3: 10/10 (1.00), logits: 1.0000,  are 389[0m
3: 50/50 (1.00), logits: 1.0000,  are 389[0m

4: 10/10 (1.00), logits: 1.0000,  which 543[0m
4: 48/50 (0.96), logits: 0.9864,  which 543[0m

5: 9/10 (0.90), logits: 0.9880,  as 355[0m
5: 44/50 (0.88), logits: 0.9782,  as 355[0m

6: 9/10 (0.90), logits: 0.9263,  the 262[0m
6: 44/50 (0.88), logits: 0.9226,  the 262[0m

7: 10/10 (1.00), logits: 1.0000, ult 586[0m
7: 48/50 (0.96), logits: 0.9779, ult 586[0m

8: 10/10 (1.00), logits: 1.0000, ices 1063[0m
8: 47/50 (0.94), logits: 0.9924, ices 1063[0m

[33m9: 8/10 (0.80), logits: 0.1246, aunts 43981[0m
[33m9: 45/50 (0.90), logits: 0.1655, aunts 43981[0m

10: 10/10 (1.00), logits: 1.0000, , 

In [None]:
def double_multinomial(p1, p2):
    assert len(p1.shape) == 1 and len(p2.shape) == 1

    p1 = p1 / p1.sum()
    p2 = p2 / p2.sum()

    p1_cumsum = torch.cumsum(p1, dim=0)
    p2_cumsum = torch.cumsum(p2, dim=0)

    random_number = torch.rand(1).item()

    i1 = torch.searchsorted(p1_cumsum, random_number).item()
    i2 = torch.searchsorted(p2_cumsum, random_number).item()

    return i1, i2, random_number

out_emb_vector.index.set_ef(100)

n_different_sample = 0

num_return_sequences = 1
max_length = 64
tokens = tokenizer.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long)
xgen = tokens.unsqueeze(0).repeat(num_return_sequences, 1)

while xgen.size(1) < max_length:
    with torch.no_grad():
        last_hidden_state = model_headless(xgen).last_hidden_state
        last_hidden_state = last_hidden_state[:, -1, :]

        logits, indices = out_emb_vector.forward(last_hidden_state)

        logits_ref = last_hidden_state @ out_emb_weight.T
        probs_ref = F.softmax(logits_ref, dim=-1)
        topk_probs_ref, topk_indices_ref = torch.topk(probs_ref, 50, dim=-1)
        # print(topk_probs_ref.shape, logits.shape)

        # print(topk_probs_ref[0,0:10], topk_indices_ref[0,0:10])
        # print(logits[0,0:10], indices[0,0:10])

        # topk_probs_ref = F.softmax(topk_probs_ref.to(torch.float64), dim=-1)

        exp_logits = F.softmax(logits.to(torch.float64), dim=-1)

        xcol = []
        for i in range(num_return_sequences):
            i1, i2, ran = double_multinomial(exp_logits[i, :], topk_probs_ref[i, :])
            i1 = torch.gather(indices[i,:], -1, torch.tensor(i1))
            i2 = torch.gather(topk_indices_ref[i,:], -1, torch.tensor(i2))
            xcol.append(i1)
            if i1 != i2:

                print(round(ran, 4), topk_probs_ref[i,0:10] / topk_probs_ref[i,0:10].sum())
                print(round(ran, 4), exp_logits[i,0:10] / exp_logits[i,0:10].sum())

                print(tokenizer.decode(xgen[i, -32:max_length].tolist()), f"'{tokenizer.decode(i1.tolist())}'/'{tokenizer.decode(i2.tolist())}'")
                n_different_sample += 1

        xcol = torch.tensor(xcol).unsqueeze(0)
        # ix = torch.multinomial(exp_logits, 1)
        # xcol = torch.gather(indices, -1, ix)
        xgen = torch.cat((xgen, xcol), dim=1)

print()
print("Results:")
for i in range(num_return_sequences):
    tokens = xgen[i, :max_length].tolist()
    print(tokenizer.decode(tokens))

print(f"{n_different_sample/(max_length * num_return_sequences):.4f}")