In [1]:
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-33M', cache_dir="data/").to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M", cache_dir="data/")

In [4]:
dataset = datasets.load_from_disk("data/tinystories_tokenized")



In [5]:
starting_indexes = np.array([0] + list(np.cumsum(dataset["preds_len"])))
def loss_idx_to_dataset_idx(idx):
    """given an idx in range(0, 10658635), return
    a sample index in range(0, 20000) and pred-in-sample
    index in range(0, 1023). Note token-in-sample idx is
    exactly pred-in-sample + 1"""
    sample_index = np.searchsorted(starting_indexes, idx, side="right") - 1
    pred_in_sample_index = idx - starting_indexes[sample_index]
    return int(sample_index), int(pred_in_sample_index)

def get_context(idx):
    """given idx in range(0, 10658635), return dataset sample
    and predicted token index within sample, in range(1, 1024)."""
    sample_index, pred_index = loss_idx_to_dataset_idx(idx)
    return dataset[sample_index], pred_index+1

def print_context(idx):
    """
    given idx in range(0, 10658635), print prompt preceding the corresponding
    prediction, and highlight the predicted token.
    """
    sample, token_idx = get_context(idx)
    prompt = sample["split_by_token"][:token_idx]
    prompt = "".join(prompt)
    token = sample["split_by_token"][token_idx]
    print(prompt + "\033[41m" + token + "\033[0m")


In [6]:
losses = torch.load("data/losses.pt")

In [7]:
lowloss_idxs = (losses < 0.693).nonzero().flatten().tolist()

In [8]:
param_names = [n for n, _ in model.named_parameters()]

In [9]:
highsignal_names = [name for name in param_names if 
                        ('ln' not in name) and 
                        ('wte' not in name) and
                        ('wpe' not in name)]

In [10]:
def get_flattened_gradient(model, param_subset):
    grads = []
    for name, p in model.named_parameters():
        if name in param_subset:
            grads.append(p.grad)
    return torch.cat([g.flatten() for g in grads])

In [11]:
token_idxs = lowloss_idxs[::100][:10000]

In [12]:
len_g = sum(model.state_dict()[name].numel() for name in highsignal_names)
S = len(token_idxs)

In [13]:
block_len = 200
blocks = [token_idxs[i:min(len(token_idxs), i+block_len)] for i in range(0, len(token_idxs), block_len)]

In [14]:
C = torch.zeros((S, S), device=device)

In [15]:
iouter = 0
for iblock in tqdm(blocks, desc="outer loop"):
    Gi = torch.zeros((len(iblock), len_g), device=device)
    for i, idx in enumerate(iblock):
        model.zero_grad()
        document, l = get_context(idx)
        prompt = document['text']
        tokens = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
        logits = model(**tokens).logits
        targets = tokens.input_ids
        ls = torch.nn.functional.cross_entropy(logits[0, :-1, :], targets[0, 1:], reduction='none')
        ls_l = ls[l-1]
        ls_l.backward()
        g = get_flattened_gradient(model, highsignal_names)
        Gi[i] = g
    Gi = F.normalize(Gi, p=2, dim=1)
    j_index = blocks.index(iblock)
    jouter = sum(len(block) for block in blocks[:j_index])
    for jblock in tqdm(blocks[j_index:], leave=False, desc="inner loop", display=False):
        Gj = torch.zeros((len(jblock), len_g), device=device)
        for j, idx in enumerate(jblock):
            model.zero_grad()
            document, l = get_context(idx)
            prompt = document['text']
            tokens = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
            logits = model(**tokens).logits
            targets = tokens.input_ids
            ls = torch.nn.functional.cross_entropy(logits[0, :-1, :], targets[0, 1:], reduction='none')
            ls_l = ls[l-1]
            ls_l.backward()
            g = get_flattened_gradient(model, highsignal_names)
            Gj[j] = g
        Gj = F.normalize(Gj, p=2, dim=1)
        Cij = torch.matmul(Gi, Gj.T)
        C[iouter:iouter+len(iblock), jouter:jouter+len(jblock)] = Cij
        C[jouter:jouter+len(jblock), iouter:iouter+len(iblock)] = Cij.T
        jouter += len(jblock)
    iouter += len(iblock)

outer loop:   0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
torch.save((token_idxs, C), "data/C-2.pt")