<a href="https://colab.research.google.com/github/martinloretzzz/vector-index-layer/blob/main/GPTVectorIndexLmEval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install lm_eval
!pip install hnswlib



In [4]:
import lm_eval
import torch
import torch.nn.functional as F
from lm_eval.api.model import LM
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2Model

# lm_eval --model hf --model_args pretrained=openai-community/gpt2 --tasks hellaswag --device cuda:0 --batch_size 8

In [5]:
class LMWrapper(LM):
    def __init__(self, tokenizer, model, device='cuda'):
        super().__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.model = model

    def loglikelihood(self, requests, disable_tqdm: bool = False):
        return [self.calculate_loglikelihood(req.arguments[0], req.arguments[1]) for req in tqdm(requests, disable=disable_tqdm)]

    def calculate_loglikelihood(self, context, continuation):
        input_text = context + continuation
        inputs = self.tokenizer.encode(input_text)
        inputs = torch.tensor(inputs, dtype=torch.long).unsqueeze(0).to(self.device)
        model_inputs = inputs[:, :-1]

        continuation_ids = self.tokenizer.encode(continuation)
        continuation_ids = torch.tensor(continuation_ids, dtype=torch.long).unsqueeze(0).to(self.device)

        with torch.no_grad():
            logits = self.model(model_inputs)

        continuation_length = continuation_ids.size(-1)
        continuation_logits = logits[:, -continuation_length:]
        loss = F.cross_entropy(continuation_logits.view(-1, continuation_logits.size(-1)), continuation_ids.view(-1), reduction="sum")
        log_likelihood = -loss.item()

        greedy_ids = torch.argmax(continuation_logits, dim=-1)
        is_greedy = torch.equal(greedy_ids, continuation_ids)

        return (log_likelihood, is_greedy)

    def generate_until(self, requests):
        raise NotImplementedError()

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError()

In [6]:
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
model_no_head = GPT2Model.from_pretrained("gpt2").to(device)

In [7]:
import hnswlib

class HNSWIndexEmbedding():
    def __init__(self, weight, k, M=32, ef=100, ef_construction=100):
        self.k = k
        xd, d = weight.shape
        self.index = hnswlib.Index(space='ip', dim=d)
        self.index.init_index(max_elements=xd, ef_construction=ef_construction, M=M)
        self.index.set_ef(ef)
        self.index.add_items(weight.numpy())

    def forward(self, x):
        indices, distances = self.index.knn_query(x, k=self.k)
        return 1 - distances, indices

In [57]:
out_emb_weight = model.transformer.wte.weight.detach().clone().cpu()
out_emb_vector = HNSWIndexEmbedding(out_emb_weight, k=500, ef=100, ef_construction=200)

In [58]:
inputs = tokenizer.encode("Hello World how are you?")
inputs = torch.tensor(inputs, dtype=torch.long).unsqueeze(0).to(device)
model_inputs = inputs[:, :-1]

print(model_inputs.shape)

def model_ref(x):
    logits = model(x).logits
    # torch.cuda.synchronize()
    return logits

def model_vector_index(x):
    vocab_size = out_emb_weight.shape[0]
    last_hidden_state = model_no_head(x).last_hidden_state

    orig_shape = last_hidden_state.shape
    last_hidden_state = last_hidden_state.view(-1, last_hidden_state.shape[-1])
    logits, indices = out_emb_vector.forward(last_hidden_state.detach().cpu().numpy())

    logits, indices = torch.from_numpy(logits).to(torch.float32).to(device), torch.from_numpy(indices).to(torch.int64).to(device)

    out = torch.full((last_hidden_state.shape[0], vocab_size), -10000.0, dtype=torch.float32, device=x.device)
    out.scatter_(-1, indices, logits)
    out = out.view((orig_shape[0], orig_shape[1], vocab_size))

    # torch.cuda.synchronize()
    return out


model_vector_index(model_inputs).shape

torch.Size([1, 5])


torch.Size([1, 5, 50257])

In [59]:
k = 100
topk_ref = torch.topk(model_ref(model_inputs), k=k)
topk_index = torch.topk(model_vector_index(model_inputs), k=k)

common = torch.isin(topk_ref[1], topk_index[1])

print(topk_ref)
print(topk_index)
print(common.sum().item(), topk_ref[1].numel())

torch.return_types.topk(
values=tensor([[[-32.3920, -32.5917, -33.2036, -33.8366, -33.8767, -33.9498, -33.9652,
          -34.2312, -34.3258, -34.6203, -34.6742, -34.7565, -34.8086, -34.8105,
          -34.9256, -34.9388, -35.0174, -35.2291, -35.2363, -35.3051, -35.3266,
          -35.3627, -35.3825, -35.4458, -35.4827, -35.4879, -35.5665, -35.5951,
          -35.6385, -35.6445, -35.6953, -35.7356, -35.7832, -35.8885, -35.9249,
          -35.9285, -35.9415, -36.0249, -36.0916, -36.0979, -36.1494, -36.1663,
          -36.1847, -36.1910, -36.1964, -36.1981, -36.2357, -36.2589, -36.3537,
          -36.3674, -36.3738, -36.4240, -36.4263, -36.4391, -36.4423, -36.4580,
          -36.4771, -36.4842, -36.5401, -36.5597, -36.6037, -36.6086, -36.6176,
          -36.6682, -36.6686, -36.6808, -36.6907, -36.6995, -36.7062, -36.7072,
          -36.7504, -36.7513, -36.7717, -36.7810, -36.7811, -36.7819, -36.8051,
          -36.8231, -36.8242, -36.8436, -36.8499, -36.9246, -36.9914, -36.9958,
        

In [60]:
import timeit

forward_time = timeit.timeit(lambda: model_vector_index(model_inputs), number=50)
forward_ref_time = timeit.timeit(lambda: model_ref(model_inputs), number=50)

print(f"Vector: {forward_time:.6f} seconds")
print(f"Ref: {forward_ref_time:.6f} seconds")
print(f"Speedup: {forward_ref_time / forward_time:.4f}")

Vector: 0.919128 seconds
Ref: 0.440307 seconds
Speedup: 0.4790


In [61]:
models = {
    "hf-vector": LMWrapper(tokenizer, model_vector_index, device="cuda"),
    # "hf-ref": LMWrapper(tokenizer, model_ref, device="cuda"),
}

for name, wrapper in models.items():
    task_manager = lm_eval.tasks.TaskManager()
    model_result = lm_eval.simple_evaluate(
        model=wrapper,
        # "hellaswag", "piqa", "arc_easy", "winogrande", "lambada_openai"
        tasks=["arc_easy"],
        task_manager=task_manager,
    )

    print(name)
    print(model_result["results"])

INFO:lm-eval:The tag 'arc_ca' is already registered as a group, this tag will not be registered. This may affect tasks you want to call.
INFO:lm-eval:The tag 'arc_ca' is already registered as a group, this tag will not be registered. This may affect tasks you want to call.
INFO:lm-eval:Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
INFO:lm-eval:Using pre-initialized model
INFO:lm-eval:Building contexts for arc_easy on rank 0...
100%|██████████| 2376/2376 [00:03<00:00, 722.96it/s]
INFO:lm-eval:Running loglikelihood requests
100%|██████████| 9501/9501 [08:46<00:00, 18.05it/s]


hf-vector
{'arc_easy': {'alias': 'arc_easy', 'acc,none': 0.36363636363636365, 'acc_stderr,none': 0.009870849346011757, 'acc_norm,none': 0.3371212121212121, 'acc_norm_stderr,none': 0.009700146509130083}}


In [44]:
# VEC: {'arc_easy': {'alias': 'arc_easy', 'acc,none': 0.2680976430976431, 'acc_stderr,none': 0.009089526578213707, 'acc_norm,none': 0.2680976430976431, 'acc_norm_stderr,none': 0.009089526578213707}}
# REF: {'arc_easy': {'alias': 'arc_easy', 'acc,none': 0.43813131313131315, 'acc_stderr,none': 0.010180937100600076, 'acc_norm,none': 0.3947811447811448, 'acc_norm_stderr,none': 0.010030038935883607}}
# VEC k=500, ef=200, ef_construction=200: {'arc_easy': {'alias': 'arc_easy', 'acc,none': 0.36363636363636365, 'acc_stderr,none': 0.009870849346011757, 'acc_norm,none': 0.3371212121212121, 'acc_norm_stderr,none': 0.009700146509130083}}