In [1]:
!pip install torch torchvision torchaudio
!pip install transformers
!pip install hnswlib

[0mCollecting transformers
  Downloading transformers-4.47.1-py3-none-any.whl.metadata (44 kB)
Collecting huggingface-hub<1.0,>=0.24.0 (from transformers)
  Downloading huggingface_hub-0.27.0-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Downloading safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading transformers-4.47.1-py3-none-any.whl (10.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m93.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading huggingface_hub-0.27.0-py3-none-any.whl (450 kB)
Downloading regex-2024.11.6-cp311-cp311-manylinux_2_1

In [1]:
import torch
from transformers import pipeline, set_seed
import torch.nn as nn
import hnswlib
import os

In [2]:
class HNSWIndexEmbedding():
    def __init__(self, weight, k, M=32, ef=100, ef_construction=100, index_file=None):
        self.k, self.vocab_size, self.dim = k, weight.shape[0], weight.shape[1]
        self.index = hnswlib.Index(space='ip', dim=self.dim)
        self.index.init_index(max_elements=self.vocab_size, M=M, ef_construction=ef_construction, random_seed=42)
        self.index.set_ef(ef)

        if index_file is None or not os.path.exists(index_file):
            self.index.add_items(weight.numpy())
            if index_file is not None:
                self.index.save_index(index_file)
        else:
            print(f"Loading index from file: {index_file}")
            self.index.load_index(index_file)

    def forward(self, x):
        indices, distances = self.index.knn_query(x.detach().cpu().numpy(), k=self.k)
        return torch.from_numpy(1 - distances).to(torch.float32).to(x.device), torch.from_numpy(indices).to(torch.int64).to(x.device)


class HNSWLogitsEmbedding(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer

    def forward(self, x):
        x_flat = x.view(-1, x.shape[-1])
        distances, indices = self.layer.forward(x_flat)
   
        logits = torch.full((x_flat.shape[0], self.layer.vocab_size), float("-inf"), dtype=torch.float32, device=x.device)
        logits.scatter_(-1, indices, distances)
        return logits.view((x.shape[0], x.shape[1], self.layer.vocab_size))

In [56]:
model_name = "meta-llama/Llama-3.2-3B" # "gpt2" # "meta-llama/Llama-3.2-1B" # "meta-llama/Llama-3.2-3B"
generator = pipeline('text-generation', model=model_name, device="cpu")
generator_ref = pipeline('text-generation', model=model_name, device="cpu")
set_seed(42)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cpu


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cpu


In [57]:
weight = generator.model.lm_head.weight.detach().clone()
generator.model.lm_head = HNSWLogitsEmbedding(HNSWIndexEmbedding(weight, k=50, ef_construction=150, index_file="./llama-3B-model-hnsw.index"))

In [63]:
%%timeit -n 1 -r 4
generator("Hello, I'm a language model,", max_new_tokens=64, num_return_sequences=1)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


25.2 s ± 120 ms per loop (mean ± std. dev. of 4 runs, 1 loop each)


In [64]:
%%timeit -n 1 -r 4
generator_ref("Hello, I'm a language model,", max_new_tokens=64, num_return_sequences=1)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


26.9 s ± 682 ms per loop (mean ± std. dev. of 4 runs, 1 loop each)
