<a href="https://colab.research.google.com/github/martinloretzzz/vector-index-layer/blob/main/GPTVectorIndexLmEval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install lm_eval
!pip install hnswlib

Collecting hnswlib
  Using cached hnswlib-0.8.0.tar.gz (36 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: hnswlib
  Building wheel for hnswlib (pyproject.toml) ... [?25l[?25hdone
  Created wheel for hnswlib: filename=hnswlib-0.8.0-cp310-cp310-linux_x86_64.whl size=2364561 sha256=21985e88c4b45dd9091c5aa3512fd352942a1de3b9d5b0ac1b8bac128248c2ec
  Stored in directory: /root/.cache/pip/wheels/af/a9/3e/3e5d59ee41664eb31a4e6de67d1846f86d16d93c45f277c4e7
Successfully built hnswlib
Installing collected packages: hnswlib
Successfully installed hnswlib-0.8.0


In [2]:
import lm_eval
import torch
import torch.nn.functional as F
from lm_eval.api.model import LM
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GPT2Model
import hnswlib

# lm_eval --model hf --model_args pretrained=openai-community/gpt2 --tasks hellaswag --device cuda:0 --batch_size 8

In [3]:
class LMWrapper(LM):
    def __init__(self, tokenizer, model, device='cuda'):
        super().__init__()
        self.device = device
        self.tokenizer = tokenizer
        self.model = model

    def loglikelihood(self, requests, disable_tqdm: bool = False):
        return [self.calculate_loglikelihood(req.arguments[0], req.arguments[1]) for req in tqdm(requests, disable=disable_tqdm)]

    def calculate_loglikelihood(self, context, continuation):
        input_text = context + continuation
        inputs = self.tokenizer.encode(input_text)
        inputs = torch.tensor(inputs, dtype=torch.long).unsqueeze(0).to(self.device)
        model_inputs = inputs[:, :-1]

        continuation_ids = self.tokenizer.encode(continuation)
        continuation_ids = torch.tensor(continuation_ids, dtype=torch.long).unsqueeze(0).to(self.device)

        with torch.no_grad():
            logits = self.model(model_inputs)

        continuation_length = continuation_ids.size(-1)
        continuation_logits = logits[:, -continuation_length:]
        loss = F.cross_entropy(continuation_logits.view(-1, continuation_logits.size(-1)), continuation_ids.view(-1), reduction="sum")
        log_likelihood = -loss.item()

        greedy_ids = torch.argmax(continuation_logits, dim=-1)
        is_greedy = torch.equal(greedy_ids, continuation_ids)

        return (log_likelihood, is_greedy)

    def generate_until(self, requests):
        raise NotImplementedError()

    def loglikelihood_rolling(self, requests):
        raise NotImplementedError()

In [4]:
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
model_no_head = GPT2Model.from_pretrained("gpt2").to(device)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [5]:
class HNSWIndexEmbedding():
    def __init__(self, weight, k, M=32, ef=100, ef_construction=100):
        self.k = k
        xd, d = weight.shape
        self.index = hnswlib.Index(space='ip', dim=d)
        self.index.init_index(max_elements=xd, ef_construction=ef_construction, M=M)
        self.index.set_ef(ef)
        self.index.add_items(weight.numpy())

    def forward(self, x):
        indices, distances = self.index.knn_query(x.detach().cpu().numpy(), k=self.k)
        return torch.from_numpy(1 - distances).to(torch.float32).to(x.device), torch.from_numpy(indices).to(torch.int64).to(x.device)

In [6]:
out_emb_weight = model.transformer.wte.weight.detach().clone().cpu()
out_emb_vector = HNSWIndexEmbedding(out_emb_weight, k=50, M=32, ef=100, ef_construction=250)

In [7]:
inputs = tokenizer.encode("Hello World how are you?")
inputs = torch.tensor(inputs, dtype=torch.long).unsqueeze(0).to(device)
model_inputs = inputs[:, :-1]

print(model_inputs.shape)

k = 50
nan_value = -10000.0

def mask_non_topk(x, k, fill):
    _, indices = torch.topk(x, k, dim=-1)
    mask = torch.full_like(x, fill)
    return mask.scatter(-1, indices, x.gather(-1, indices))


def model_ref(x):
    logits = model(x).logits
    # torch.cuda.synchronize()
    return mask_non_topk(logits, k=k, fill=nan_value)


def model_vector_index(x):
    vocab_size = out_emb_weight.shape[0]
    last_hidden_state = model_no_head(x).last_hidden_state

    orig_shape = last_hidden_state.shape
    last_hidden_state = last_hidden_state.view(-1, last_hidden_state.shape[-1])
    logits, indices = out_emb_vector.forward(last_hidden_state)

    out = torch.full((last_hidden_state.shape[0], vocab_size), nan_value, dtype=torch.float32, device=x.device)
    out.scatter_(-1, indices, logits)
    out = out.view((orig_shape[0], orig_shape[1], vocab_size))

    # torch.cuda.synchronize()
    return out


model_vector_index(model_inputs).shape

torch.Size([1, 5])


torch.Size([1, 5, 50257])

In [8]:
import timeit

forward_time = timeit.timeit(lambda: model_vector_index(model_inputs), number=50)
forward_ref_time = timeit.timeit(lambda: model_ref(model_inputs), number=50)

print(f"Vector: {forward_time:.6f} seconds")
print(f"Ref: {forward_ref_time:.6f} seconds")
print(f"Speedup: {forward_ref_time / forward_time:.4f}")

Vector: 0.545084 seconds
Ref: 0.438792 seconds
Speedup: 0.8050


In [9]:
models = {
    "VEC": LMWrapper(tokenizer, model_vector_index, device="cuda"),
    "REF": LMWrapper(tokenizer, model_ref, device="cuda"),
}

for name, wrapper in models.items():
    task_manager = lm_eval.tasks.TaskManager()
    model_result = lm_eval.simple_evaluate(
        model=wrapper,
        # "hellaswag", "piqa", "arc_easy", "winogrande", "lambada_openai"
        tasks=["arc_easy"],
        task_manager=task_manager,
    )

    print(f"{name}:", model_result["results"])

INFO:lm-eval:The tag 'arc_ca' is already registered as a group, this tag will not be registered. This may affect tasks you want to call.
INFO:lm-eval:The tag 'arc_ca' is already registered as a group, this tag will not be registered. This may affect tasks you want to call.
INFO:lm-eval:Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
INFO:lm-eval:Using pre-initialized model


README.md:   0%|          | 0.00/9.00k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/331k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/346k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/86.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2251 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2376 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/570 [00:00<?, ? examples/s]

INFO:lm-eval:Building contexts for arc_easy on rank 0...
100%|██████████| 2376/2376 [00:06<00:00, 389.53it/s]
INFO:lm-eval:Running loglikelihood requests
100%|██████████| 9501/9501 [03:37<00:00, 43.62it/s]


VEC: {'arc_easy': {'alias': 'arc_easy', 'acc,none': 0.3345959595959596, 'acc_stderr,none': 0.00968213772432791, 'acc_norm,none': 0.31691919191919193, 'acc_norm_stderr,none': 0.009547254611446386}}


INFO:lm-eval:The tag 'arc_ca' is already registered as a group, this tag will not be registered. This may affect tasks you want to call.
INFO:lm-eval:The tag 'arc_ca' is already registered as a group, this tag will not be registered. This may affect tasks you want to call.
INFO:lm-eval:Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
INFO:lm-eval:Using pre-initialized model
INFO:lm-eval:Building contexts for arc_easy on rank 0...
100%|██████████| 2376/2376 [00:02<00:00, 827.53it/s] 
INFO:lm-eval:Running loglikelihood requests
100%|██████████| 9501/9501 [01:35<00:00, 99.04it/s] 


REF: {'arc_easy': {'alias': 'arc_easy', 'acc,none': 0.3463804713804714, 'acc_stderr,none': 0.009763542075695733, 'acc_norm,none': 0.3324915824915825, 'acc_norm_stderr,none': 0.009666892606130122}}


In [10]:
# REF k=inf: {'alias': 'arc_easy', 'acc,none': 0.43813131313131315, 'acc_stderr,none': 0.010180937100600076, 'acc_norm,none': 0.3947811447811448, 'acc_norm_stderr,none': 0.010030038935883607}
# VEC k=500, ef=200, ef_construction=200: {'alias': 'arc_easy', 'acc,none': 0.36363636363636365, 'acc_stderr,none': 0.009870849346011757, 'acc_norm,none': 0.3371212121212121, 'acc_norm_stderr,none': 0.009700146509130083}

# REF k=100: {'alias': 'arc_easy', 'acc,none': 0.35984848484848486, 'acc_stderr,none': 0.009848484848484846, 'acc_norm,none': 0.3442760942760943, 'acc_norm_stderr,none': 0.009749495321590819}
# VEC k=100, M=32, k=100, ef=100, ef_construction=200: {'alias': 'arc_easy', 'acc,none': 0.33375420875420875, 'acc_stderr,none': 0.00967606568357547, 'acc_norm,none': 0.3181818181818182, 'acc_norm_stderr,none': 0.009557408782506372}
# VEC kx=100, k=150, M=32, k=100, ef=100, ef_construction=200: {'alias': 'arc_easy', 'acc,none': 0.34385521885521886, 'acc_stderr,none': 0.009746660584852448, 'acc_norm,none': 0.3253367003367003, 'acc_norm_stderr,none': 0.0096134277089962}
# VEC kx=100, k=500, M=32, k=100, ef=100, ef_construction=200: {'alias': 'arc_easy', 'acc,none': 0.3531144781144781, 'acc_stderr,none': 0.009807078935467612, 'acc_norm,none': 0.3354377104377104, 'acc_norm_stderr,none': 0.009688175165829607}
# VEC k=100, M=32, k=100, ef=200, ef_construction=200: {'alias': 'arc_easy', 'acc,none': 0.3493265993265993, 'acc_stderr,none': 0.009782853449399291, 'acc_norm,none': 0.33080808080808083, 'acc_norm_stderr,none': 0.009654540125986119}
# VEC k=100, M=32, ef=100, ef_construction=300: {'alias': 'arc_easy', 'acc,none': 0.3392255892255892, 'acc_stderr,none': 0.009714917207765848, 'acc_norm,none': 0.32112794612794615, 'acc_norm_stderr,none': 0.009580787536986797}
# VEC k=100, M=32, ef=200, ef_construction=300: {'alias': 'arc_easy', 'acc,none': 0.3514309764309764, 'acc_stderr,none': 0.00979639558281772, 'acc_norm,none': 0.33375420875420875, 'acc_norm_stderr,none': 0.009676065683575473}

# REF k=50:                                    {'alias': 'arc_easy', 'acc,none': 0.3463804713804714, 'acc_stderr,none': 0.009763542075695733, 'acc_norm,none': 0.3324915824915825, 'acc_norm_stderr,none': 0.009666892606130122}
# VEC k=50, M=32, ef=100, ef_construction=250: {'alias': 'arc_easy', 'acc,none': 0.3341750841750842, 'acc_stderr,none': 0.009679106032919065, 'acc_norm,none': 0.31776094276094274, 'acc_norm_stderr,none': 0.009554033064443064}
# VEC k=50, M=32, ef=150, ef_construction=250: {'alias': 'arc_easy', 'acc,none': 0.33964646464646464, 'acc_stderr,none': 0.009717845628687471, 'acc_norm,none': 0.3253367003367003, 'acc_norm_stderr,none': 0.009613427708996196}
# VEC k=50, M=32, ef=200, ef_construction=250: {'alias': 'arc_easy', 'acc,none': 0.33964646464646464, 'acc_stderr,none': 0.009717845628687471, 'acc_norm,none': 0.3265993265993266, 'acc_norm_stderr,none': 0.009623047038267647}
# VEC k=50, M=24, ef=100, ef_construction=250: {'alias': 'arc_easy', 'acc,none': 0.33291245791245794, 'acc_stderr,none': 0.009669958978395335, 'acc_norm,none': 0.3181818181818182, 'acc_norm_stderr,none': 0.009557408782506372}
# VEC k=50, M=16, ef=100, ef_construction=250: {'alias': 'arc_easy', 'acc,none': 0.3228114478114478, 'acc_stderr,none': 0.009593950220366743, 'acc_norm,none': 0.31776094276094274, 'acc_norm_stderr,none': 0.009554033064443064}
# VEC k=50, M=32, ef=100, ef_construction=500: {'alias': 'arc_easy', 'acc,none': 0.3341750841750842, 'acc_stderr,none': 0.009679106032919067, 'acc_norm,none': 0.3228114478114478, 'acc_norm_stderr,none': 0.00959395022036675}