In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from modify_gptj import modify_gptj

tokenizer = AutoTokenizer.from_pretrained("NbAiLab/nb-gpt-j-6B")

cuda0 = "cuda:0" if torch.cuda_is_available() else "cpu"
cuda1 = "cuda:1" if torch.cuda_is_available() else "cpu"

normal_model = AutoModelForCausalLM.from_pretrained("NbAiLab/nb-gpt-j-6B").to(cuda0)
normal_model.eval()

pos_shift_model = AutoModelForCausalLM.from_pretrained("NbAiLab/nb-gpt-j-6B").to(cuda1)
modify_gptj(pos_shift_model)
pos_shift_model.eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
model.safetensors:  40%|███▉      | 9.57G/24.2G [12:42<18:51, 12.9MB/s]

In [None]:
from .cache_regimes import SlidingWindowCache, AttentionSinkCache

# Load the first 10,000 words of the report
with open("report.txt") as f:
    book_text = f.read()
n_words = 10_000
text = " ".join(book_text.split(" ")[:n_words])

# Instantiate different cache regimes
GPTJ_CTX_LEN = 2048
cache_size = GPTJ_CTX_LEN - 1

sliding_window = SlidingWindowCache(cache_size=cache_size)
attention_sink = AttentionSinkCache(start_size=4, recent_size=cache_size - 4)
attention_sink_no_start = AttentionSinkCache(start_size=0, recent_size=cache_size)

In [None]:
from .evaluation import evalute_on_text

torch.cuda.empty_cache()

nlls_sw = evalute_on_text(normal_model, tokenizer, text, sliding_window)
ppl_sw = torch.exp(nlls_sw.mean())

print(f"PPL for sliding window {ppl_sw:.2f}")

In [None]:
torch.cuda.empty_cache()

nlls_as = evalute_on_text(pos_shift_model, tokenizer, text, attention_sink)
ppl_as = torch.exp(nlls_as.mean())

print(f"PPL for attention sink {ppl_as}")

In [None]:
torch.cuda.empty_cache()

nlls_no_cache = evalute_on_text(normal_model, tokenizer, text)
ppl_no_cache = torch.exp(nlls_no_cache.mean())
print(f"PPL without cache {ppl_no_cache}")

In [None]:
torch.cuda.empty_cache()

nlls_sw_pos_shift = evalute_on_text(pos_shift_model, text, attention_sink_no_start)
ppl_sw_pos_shift = torch.exp(nlls_sw_pos_shift.mean())

print(f"PPL for sliding window, recomputing cache {ppl_sw_pos_shift}")

In [None]:
from .plots import graph_sliding_window_ppl

graph_sliding_window_ppl(nlls_sw_pos_shift[:5000], cache_regime = "Window Attention with Recomputation of Positional Embeddins", window_size = 16)

In [None]:
from .plots import graph_attentions

n_sentences = 100
n_tokens = 20

# Get the first n_sentences sentences with n_tokens tokens
sentences = list(
    filter(lambda s: len(tokenizer(s).input_ids) == n_tokens, book_text.split("."))
)[:n_sentences]

# Get attentions for each sentence
attns = []
for sentence in sentences:
    with torch.no_grad():
        input_ids = tokenizer(sentence, return_tensors="pt").input_ids.to(
            normal_model.device
        )
        outputs = normal_model(input_ids, output_attentions=True)
    attns.append(torch.stack(outputs.attentions, dim=0))

sentence_attns = torch.stack(attns).mean(0).squeeze()[:, 0, :, :].to("cpu")
graph_attentions(sentence_attns, head_idxs=[0, 1, 5, 16], n_cols=2)