In [2]:
from contextlib import contextmanager
import copy
from typing import Iterator

import tabulate
import torch
from torch import nn

import llminference as L
import modify_gptneox as h2o

torch.set_num_threads(32)

In [19]:
out = {}
adapter = L.Adapter.from_pretrained("EleutherAI/pythia-1.4b")
original_model = adapter.model
data = L.qa.SQuAD.data()
examples = [L.qa.add_few_shot_prompt(
    data[i], k=1, prompt_template=L.qa.get_default_prompt_template(adapter.model.config._name_or_path, shots=1))
    for i in range(200)]

Found cached dataset squad (/nethome/douglaso/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /nethome/douglaso/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-03400b5d173ab796.arrow
Loading cached shuffled indices for dataset at /nethome/douglaso/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-c374d8616db3e9e0.arrow


In [20]:
adapter.model = original_model
out["dense"] = list(L.qa.evaluate(adapter, examples, batch_size=5))

Evaluating EleutherAI/pythia-1.4b: 100%|██████████| 40/40 [32:20<00:00, 48.51s/it]


In [21]:
eviction_settings = L.eviction_attention.Settings(k=256, local_k=64, strategy="sum_weight")
adapter.model = L.eviction_attention.convert(original_model, eviction_settings)
out["eviction_ours"] = list(L.qa.evaluate(adapter, examples, batch_size=5))

Evaluating EleutherAI/pythia-1.4b: 100%|██████████| 40/40 [36:01<00:00, 54.04s/it]


In [22]:
@contextmanager
def h2o_generation_context(model: nn.Module) -> Iterator[nn.Module]:
    yield model
    for m in model.modules():
        if isinstance(m, h2o.GPTNeoXAttention_Mask):
            m._reset_masks()

approximate_sequence_length = 2048
adapter.model = copy.deepcopy(original_model)
adapter.model.config.heavy_ratio = (eviction_settings.k - eviction_settings.local_k) / approximate_sequence_length
adapter.model.config.recent_ratio = eviction_settings.local_k / approximate_sequence_length
h2o.convert_kvcache_gpt_neox_heavy_recent(adapter.model, adapter.model.config)
adapter.model.load_state_dict(original_model.state_dict(), strict=False)
adapter.model.generation_context = h2o_generation_context
out["eviction_h2o"] = list(L.qa.evaluate(adapter, examples, batch_size=1))

Evaluating EleutherAI/pythia-1.4b: 100%|██████████| 200/200 [33:16<00:00,  9.98s/it]


In [23]:
print("Overall exact-match results:")
for name, results in out.items():
    print(f"{name:>15}: {sum(x['match'] for x in results)}/{len(results)}")

Overall exact-match results:
          dense: 74/200
  eviction_ours: 60/200
   eviction_h2o: 57/200


In [24]:
chars = 20
different_from_dense = sum(
    b["output"][:chars] != o["output"][:chars] or b["output"][:chars] != h["output"][:chars]
    for b, o, h in zip(out["dense"], out["eviction_ours"], out["eviction_h2o"])
)
different_ours_h20 = sum(
    o["output"][:chars] != h["output"][:chars]
    for o, h in zip(out["eviction_ours"], out["eviction_h2o"])
)
print(f"{different_from_dense-different_ours_h20}/{different_from_dense} error match")

41/79 error match


In [25]:
chars = 20
print(tabulate.tabulate([
    dict(dense=repr(b["output"][:chars]),
         eviction_ours=repr(o["output"][:chars]),
         eviction_h2o=repr(h["output"][:chars]),
         ours_matches_h20=o["output"][:chars] == h["output"][:chars])
    for b, o, h in zip(out["dense"], out["eviction_ours"], out["eviction_h2o"])
], headers="keys", showindex=True))

     dense                    eviction_ours            eviction_h2o             ours_matches_h20
---  -----------------------  -----------------------  -----------------------  ------------------
  0  ' steam plants\nQuesti'  ' steam plants\nQuesti'  ' steam plants\nQuesti'  True
  1  ' 1 July 1851\nQuestio'  ' 1 July 1851\nQuestio'  ' 1 July 1851\nQuestio'  True
  2  ' English MPs\nQuestio'  ' English\nQuestion: M'  ' English MPs can onl'   False
  3  ' Nepali\nQuestion: Wh'  ' Nepali\nQuestion: Wh'  ' Nepali\nQuestion: Wh'  True
  4  ' Joseph Coulon de Ju'   ' Joseph Coulon de Vi'   ' Joseph Brant, a Fre'   False
  5  ' 5 committees\nQuesti'  ' 5\nQuestion: What is'  ' 5\nQuestion: What is'  True
  6  ' the Mission Council'   ' the Mission Council'   ' the Mission Council'   True
  7  ' He prayed, "Into yo'   ' He prayed, "Into yo'   ' He prayed to God to'   False
  8  ' Miller\nQuestion: Wh'  ' Miller\nQuestion: Wh'  ' Miller\nQuestion: Wh'  True
  9  ' The British Empire\n'  ' The 