In [46]:
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch.nn as nn
import hnswlib
import os
from datasets import load_dataset

In [47]:
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
model_ref = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

In [48]:
class HNSWIndexEmbedding():
    def __init__(self, weight, k, M=32, ef=100, ef_construction=100, index_file=None):
        self.k, self.vocab_size, self.dim = k, weight.shape[0], weight.shape[1]
        self.index = hnswlib.Index(space='ip', dim=self.dim)
        self.index.init_index(max_elements=self.vocab_size, M=M, ef_construction=ef_construction, random_seed=42)

        index_path = f"{index_file}-{M}-{ef_construction}.index"
        if index_file is None or not os.path.exists(index_path):
            self.index.add_items(weight.numpy())
            if index_file is not None:
                self.index.save_index(index_path)
        else:
            print(f"Loading index from file: {index_path}")
            self.index.load_index(index_path)
        self.index.set_ef(ef)

    def forward(self, x):
        indices, distances = self.index.knn_query(x.detach().cpu().numpy(), k=self.k)
        return torch.from_numpy(1 - distances).to(torch.float32).to(x.device), torch.from_numpy(indices).to(torch.int64).to(x.device)


class HNSWLogitsEmbedding(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer

    def forward(self, x):
        x_flat = x.view(-1, x.shape[-1])
        distances, indices = self.layer.forward(x_flat)
   
        logits = torch.full((x_flat.shape[0], self.layer.vocab_size), float("-inf"), dtype=torch.float32, device=x.device)
        logits.scatter_(-1, indices, distances)
        return logits.view((x.shape[0], x.shape[1], self.layer.vocab_size))

In [49]:
weight = model.lm_head.weight.detach().clone()
model.lm_head = HNSWLogitsEmbedding(HNSWIndexEmbedding(weight, k=50, ef=200, M=32, ef_construction=1000, index_file=model_name))

In [50]:
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
test = test.select(range(200))
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")

Token indices sequence length is longer than the specified maximum sequence length for this model (12122 > 1024). Running this sequence through the model will result in indexing errors


In [51]:
def batch_list(tensor, batch_size):
    return [tensor[:, i:i+batch_size] for i in range(0, tensor.size(1), batch_size)]

batch_size = 128
input_ids_list = batch_list(encodings.input_ids, batch_size)

In [52]:
ratio_accum, error_accum = 0, 0

for input_ids in input_ids_list:
    # print(input_ids.shape)

    with torch.no_grad():
        # Get logits from model and model_ref
        outputs = model(input_ids)
        outputs_ref = model_ref(input_ids)

        logits = outputs.logits.squeeze(0).to(torch.float64)
        logits_ref = outputs_ref.logits.squeeze(0).to(torch.float64)

        topk_indices = torch.topk(logits, k=50, dim=-1)[1]
        topk_indices_ref = torch.topk(logits_ref, k=50, dim=-1)[1]

        # topk_indices = topk_indices_ref

        # gather logits from topk_indices
        # print(torch.gather(logits, -1, topk_indices))
        exp = torch.exp(torch.gather(logits, -1, topk_indices))
        exp_ref = torch.exp(torch.gather(logits_ref, -1, topk_indices_ref))

        # replace nan with 0
        #small_value = 1e-32
        #exp[torch.isnan(exp)] = small_value
        #exp_ref[torch.isnan(exp_ref)] = small_value

        
        # print(exp.sum(-1).shape)
        ratios = exp.sum(-1) / exp_ref.sum(-1)
        lower_80 = (ratios < 0.75).sum() 
        # print(ratios)
        ratio = ratios.mean()
        # print(exp.sum(-1) / exp_ref.sum(-1))
        ratio_accum += ratio
        error_accum += lower_80 / ratios.numel()

        print(f"Ratio: {ratio.item():.4f}, Lower 75%: {lower_80.item()}/{ratios.numel()}") 

print(f"Average ratio: {ratio_accum / len(input_ids_list):.4f}, Average lower 75: {error_accum / len(input_ids_list):.4f}")

Ratio: 0.9909, Lower 75%: 1/128
Ratio: 0.9876, Lower 75%: 0/128
Ratio: 0.9918, Lower 75%: 1/128
Ratio: 0.9916, Lower 75%: 0/128
Ratio: 0.9806, Lower 75%: 3/128
Ratio: 0.9859, Lower 75%: 2/128
Ratio: 0.9932, Lower 75%: 0/128
Ratio: 0.9924, Lower 75%: 1/128
Ratio: 0.9884, Lower 75%: 2/128
Ratio: 0.9966, Lower 75%: 0/128
Ratio: 0.9963, Lower 75%: 0/128
Ratio: 0.9829, Lower 75%: 4/128
Ratio: 0.9941, Lower 75%: 0/128
Ratio: 0.9964, Lower 75%: 0/128
Ratio: 0.9978, Lower 75%: 0/128
Ratio: 0.9973, Lower 75%: 0/128
Ratio: 0.9985, Lower 75%: 0/128
Ratio: 0.9944, Lower 75%: 0/128
Ratio: 0.9988, Lower 75%: 0/128
Ratio: 0.9895, Lower 75%: 2/128
Ratio: 0.9967, Lower 75%: 0/128
Ratio: 0.9980, Lower 75%: 0/128
Ratio: 0.9974, Lower 75%: 0/128
Ratio: 0.9948, Lower 75%: 1/128
Ratio: 0.9981, Lower 75%: 0/128
Ratio: 0.9987, Lower 75%: 0/128
Ratio: 0.9960, Lower 75%: 0/128
Ratio: 0.9964, Lower 75%: 0/128
Ratio: 0.9969, Lower 75%: 0/128
Ratio: 0.9958, Lower 75%: 0/128
Ratio: 0.9980, Lower 75%: 0/128
Ratio: 0

In [53]:
model.lm_head.layer.index.ef

200

In [54]:
# M=40 ef=150 err=1.8%
# M=32 ef=150 err=1.64%
# M=48 ef=150 err=1.47%
# sample both topk and compare ratio
# M=32 ef=150 err=1.39%
# M=64 ef=200 err=0.64
# M=64 ef=150 err=0.97
# M=48 ef=200 err=0.83
# M=32 ef=200 err=0.92
# M=32 ef=200 err=0.61 ef_construction=300
# M=32 ef=200 err=0.44 ef_construction=500