In [1]:
from contextlib import contextmanager
import copy
from typing import Iterator

import tabulate
from torch import nn

import llminference as L
import modify_gptneox as h2o

In [2]:
adapter = L.Adapter.from_pretrained("EleutherAI/pythia-1b")
original_model = adapter.model
data = L.qa.SQuAD.data()
examples = [L.qa.add_few_shot_prompt(data[i], k=1) for i in range(20)]
out = {}

Found cached dataset squad (/nethome/douglaso/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /nethome/douglaso/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-03400b5d173ab796.arrow
Loading cached shuffled indices for dataset at /nethome/douglaso/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-c374d8616db3e9e0.arrow


In [3]:
adapter.model = original_model
out["baseline"] = list(L.qa.evaluate(adapter, examples, batch_size=10, use_cache=False))

Evaluating EleutherAI/pythia-1b: 100%|██████████| 2/2 [02:00<00:00, 60.08s/it]


In [4]:
adapter.model = L.eviction_attention.convert_gptneox(
    original_model, L.eviction_attention.Settings(k=512, local_k=256),
)
out["ours"] = list(L.qa.evaluate(adapter, examples, batch_size=10, use_cache=False, generation_context=L.eviction_attention.generation_context))

Evaluating EleutherAI/pythia-1b: 100%|██████████| 2/2 [01:55<00:00, 57.75s/it]


In [5]:
@contextmanager
def h20_generation_context(model: nn.Module) -> Iterator[nn.Module]:
    yield model
    for m in model.modules():
        if isinstance(m, h2o.GPTNeoXAttention_Mask):
            m._reset_masks()

adapter.model = copy.deepcopy(original_model)
adapter.model.config.heavy_ratio = 512/2048
adapter.model.config.recent_ratio = 256/2048
h2o.convert_kvcache_gpt_neox_heavy_recent(adapter.model, adapter.model.config)
adapter.model.load_state_dict(original_model.state_dict(), strict=False)
out["h2o"] = list(L.qa.evaluate(adapter, examples, batch_size=1, use_cache=False, generation_context=h20_generation_context))

Evaluating EleutherAI/pythia-1b: 100%|██████████| 20/20 [01:49<00:00,  5.46s/it]


In [10]:
print("Overall exact-match results:")
for name, results in out.items():
    print(f"   {name}: {sum(x['match'] for x in results)}/{len(results)}")

Overall exact-match results:
   baseline: 8/20
   ours: 3/20
   h2o: 3/20


In [8]:
chars = 20
print(tabulate.tabulate([
    dict(baseline=repr(b["output"][:chars]),
         ours=repr(o["output"][:chars]),
         h2o=repr(h["output"][:chars]),
         ours_matches_h20=o["output"][:chars] == h["output"][:chars])
    for b, o, h in zip(out["baseline"], out["ours"], out["h2o"])
], headers="keys", showindex=True))

    baseline                 ours                     h2o                      ours_matches_h20
--  -----------------------  -----------------------  -----------------------  ------------------
 0  ' steam turbines\nQues'  ' coal, oil, natural '   ' coal, oil, natural '   True
 1  ' 1 July 1851\nQuestio'  ' 1851\nQuestion: What'  ' 1851\nQuestion: What'  True
 2  ' English MPs can onl'   ' yes\nQuestion: MPs r'  ' yes\nQuestion: MPs r'  True
 3  ' Nepalese\nQuestion: '  ' Nepalese\nQuestion: '  ' Nepalese\nQuestion: '  True
 4  ' The French killed m'   ' French scounting pa'   ' French scounting pa'   True
 5  ' Committees\nQuestion'  ' The City Council is'   ' The City Council is'   True
 6  ' The Mission Council'   ' The Mission Council'   ' The Mission Council'   True
 7  ' Chest pains\nQuestio'  ' 2 a.m\nQuestion: Wha'  ' 2 a.m\nQuestion: Wha'  True
 8  ' Newton\nQuestion: Wh'  ' Miller\nQuestion: Wh'  ' Newton\nQuestion: Wh'  False
 9  ' The United States\nQ'  ' The United States\