In [7]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "3"

import torch

# from trans import DecoderTransformer
# from utils import generate_data, TextReversalDataset

from modeling_gpt_neox import GPTNeoXForCausalLM
from transformers import (
    AutoTokenizer, AutoConfig
)


In [None]:
device = "cuda:1"

pretrained_model_path = "/data/home/Model/Pythia/pythia-70m"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)

config = AutoConfig.from_pretrained(pretrained_model_path)
config.use_hook = False
model = GPTNeoXForCausalLM.from_pretrained(pretrained_model_path, config=config).to(device)

print(sum([p.numel() for p in model.parameters()]))
def count_mlp_params(model):
    n = 0
    for layer in model.gpt_neox.layers:
        n += sum([p.numel() for p in layer.mlp.dense_4h_to_h.parameters()])
    return n
print(count_mlp_params(model))


In [7]:
# from datasets import load_dataset

# ds = load_dataset("/home/Dataset/ptb_text_only", trust_remote_code=True)
# ds

In [9]:
from torch.utils.data import DataLoader, TensorDataset

def collate_fn(batch):
    return torch.stack([x[0] for x in batch])

def get_dataloader(tokenized_split, batch_size, shuffle):
    ds = TensorDataset(tokenized_split['input_ids'])
    loader = DataLoader(ds, batch_size, shuffle, collate_fn=collate_fn)
    return loader


tokenized_ds = torch.load("tokenized_ds.pt")


## LiSSA

In [9]:
# from utils import get_tokenized_dataset, get_dataloader

# tokenized_train_ds = get_tokenized_dataset(ds['train'], tokenizer, config.max_seq_len)
# train_loader = get_dataloader(tokenized_train_ds, 8, True)


In [None]:
for n, p in model.named_parameters():
    if "dense_h_to_4h" in n:
        print(n)

In [11]:
from torch.func import functional_call, jvp, grad
import torch.nn.functional as F
import einops

# device = "cuda:0"
# model = DecoderTransformer(
#     d_model=config.d_model,
#     n_heads=config.n_heads,
#     d_mlp=config.d_mlp,
#     n_layers=config.n_layers,
#     vocab_size=config.vocab_size,
#     max_seq_len=config.max_seq_len,
#     device=device,
#     use_hook=False,
# ).to(device)
# model.load_state_dict(torch.load("model.pth", map_location=device))
params = dict(model.named_parameters())

ids = tokenized_ds['test']['input_ids'][0].to(device)
targets = ids[1:]
ids = ids[:-1].unsqueeze(0)


def loss_with_logits(logits, targets):
    return F.cross_entropy(logits, targets)

def hvp_fn(loss_fn, params, v):
    return jvp(grad(loss_fn), (params,), (v,))[1]


In [12]:
def out_f_blocks(mlp_params, i):
    for n, p in mlp_params.items():
        params[n] = p
    
    out = functional_call(model, params, (i,)).logits
    out = einops.rearrange(out, "b s v -> (b s) v")
    return out


def loss_f_blocks(mlp_params, i, t):
    out = out_f_blocks(mlp_params, i)
    loss = F.cross_entropy(out, t)
    return loss


def ravel_named_params(d: dict[str, torch.Tensor]):
    names = []
    sizes = []
    tmp = []
    for n, p in d.items():
        names.append(n)
        sizes.append(p.shape)
        tmp.append(p.t().reshape(-1))
    def unravel(params_flattened: torch.Tensor):
        tmp = {}
        pointer = 0
        for n, s in zip(names, sizes):
            np = s[0]*s[1] if len(s) == 2 else s[0]
            s = (s[1], s[0]) if len(s) == 2 else s
            tmp[n] = params_flattened[pointer:pointer+np].view(s).t().detach()
            pointer += np
        return tmp
    return torch.cat(tmp).detach(), unravel


mlp_params = {}
for n, p in params.items():
    if "dense_h_to_4h" in n:
        mlp_params[n] = p


grads = grad(loss_f_blocks)(mlp_params, ids, targets)
f = lambda p: out_f_blocks(p, ids)
Jv = jvp(f, (mlp_params,), (grads,))[1]


In [None]:
from typing import Callable
from tqdm import tqdm

from torch.func import jvp, vjp
from torch.utils.data import DataLoader


def sample_labels(logits: torch.Tensor):
    probs = torch.nn.functional.softmax(logits, dim=-1)
    if len(probs.shape) > 1:
        return probs.multinomial(num_samples=1, replacement=True)[:, 0]
    else:
        return probs.multinomial(num_samples=1, replacement=True)[0]


def gnhvp_on_sample(f, L):
    def gnhvp_step(primals, tangents, inputs):
        f_ = lambda p: f(p, inputs)
        z, R_z = jvp(f_, (primals,), (tangents,))

        sampled_labels = sample_labels(z)
        L_ = lambda y: L(y, sampled_labels)

        R_gz = hvp_fn(L_, z, R_z)
        _, f_vjp = vjp(f_, primals)
        return f_vjp(R_gz)[0]
    return gnhvp_step


def create_gnhvp_estimator(gnhvp_step_fn, parameters, data_loader, device):
    ids = next(iter(data_loader))
    ids = ids.to(device)
    
    def compute_fn(vec):
        return gnhvp_step_fn(parameters, vec, ids)
    return compute_fn


def lissa_ignhvp(mvp: Callable,
          vec,
          n_iters: int,
          damping: float,
          alpha: float):
    
    vec_flattened, unravel_fn = ravel_named_params(vec)
    ihvp = vec_flattened.clone().detach()

    logs = []
    for i in tqdm(range(n_iters)):
        Ap = mvp(unravel_fn(ihvp))
        Ap = ravel_named_params(Ap)[0].detach()
        ihvp_new = vec_flattened + (1-damping*alpha)*ihvp - alpha*Ap
        ihvp_update = torch.linalg.vector_norm(ihvp_new-ihvp)
        logs.append(ihvp_update.item())

        ihvp = ihvp_new

    return ihvp, logs


n_iters = 10
damping = 1e-4
alpha = 1/40
batch_size = 32

train_loader = get_dataloader(tokenized_ds['train'], batch_size, True)
gnhvp_step_fn = gnhvp_on_sample(out_f_blocks, loss_with_logits)
gnhvp_estimator = create_gnhvp_estimator(gnhvp_step_fn, mlp_params, train_loader, device)
ihvp_lissa, logs = lissa_ignhvp(gnhvp_estimator, grads, n_iters, damping, alpha)
print(logs[-1])

%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(range(len(logs)), logs)


In [None]:
from scipy.stats import pearsonr


pearsonr(ihvp_lissa.cpu(), ravel_named_params(grads)[0].cpu())


## EK-FAC

In [None]:
device = "cuda:0"
# model = DecoderTransformer(
#     d_model=d_model,
#     n_heads=n_heads,
#     d_mlp=d_mlp,
#     n_layers=n_layers,
#     vocab_size=vocab_size,
#     max_seq_len=max_seq_len,
#     device=device,
#     use_hook=True,
# ).to(device)
# model.load_state_dict(torch.load("model.pth", map_location=device))

config.use_hook = False
model = GPTNeoXForCausalLM.from_pretrained(pretrained_model_path, config=config).to(device)

In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

from tqdm import tqdm
%matplotlib inline
import matplotlib.pyplot as plt

from utils import generate_data, TextReversalDataset


def sample_labels(logits: torch.Tensor):
    probs = torch.nn.functional.softmax(logits, dim=-1)
    if len(probs.shape) > 1:
        return probs.multinomial(num_samples=1, replacement=True)[:, 0]
    else:
        return probs.multinomial(num_samples=1, replacement=True)[0]


def ekfac_fit_covariance(model: DecoderTransformer,
                         device: str,
                         dataloader: DataLoader,
                         n_iters: int):
    n_blocks = len(model.blocks)
    batch_size = dataloader.batch_size
    
    A = [] # Covariance of inputs
    G = [] # Covariance of preactivation pseudo-gradients
    for block in model.blocks:
        mlp = block.mlp
        out_dim, in_dim = mlp.get_dims()
        A.append(torch.zeros((in_dim+1, in_dim+1), device=device))
        G.append(torch.zeros((out_dim, out_dim), device=device))

    logs = [
        {
            "A": [],
            "G": [],
        } for _ in range(n_blocks)
    ]
    for i in tqdm(range(n_iters), desc="Fitting covariance matrices A&G"):
        img, lbl = next(iter(dataloader))
        img = img.to(device)

        model.zero_grad()
        outputs = model(img)
        outputs = einops.rearrange(outputs, "b s v -> (b s) v")
        sampled_labels = sample_labels(outputs)
        loss = F.cross_entropy(outputs, sampled_labels, reduction="sum")
        loss.backward()

        for block_idx, block in enumerate(model.blocks):
            mlp = block.mlp
            inputs_ = mlp.get_a_l_minus_1()
            d_s_l = mlp.get_d_s_l()
            seq_len = inputs_.shape[1]
            for j in range(batch_size):
                for k in range(seq_len):
                    ak = inputs_[j, k:k+1, :]
                    d_s_l_k = d_s_l[j, k:k+1, :]
                    A[block_idx] += ak.T @ ak
                    G[block_idx] += d_s_l_k.T @ d_s_l_k
            
            logs[block_idx]["A"].append(torch.linalg.norm(A[block_idx]/((i+1)*batch_size)).detach().cpu())
            logs[block_idx]["G"].append(torch.linalg.norm(G[block_idx]/((i+1)*batch_size)).detach().cpu())
    
    plt.figure()
    fig, axes = plt.subplots(2, n_blocks, figsize=(12,8))
    for i in range(n_blocks):
        axes[0, i].plot(range(n_iters), logs[i]["A"])
        axes[0, i].set_title(f"A[{i}]")
        axes[1, i].plot(range(n_iters), logs[i]["G"])
        axes[1, i].set_title(f"G[{i}]")

    QA = []
    QG = []
    for block_idx in range(n_blocks):
        A[block_idx] /= n_iters*batch_size
        G[block_idx] /= n_iters*batch_size

        _, qa = torch.linalg.eigh(A[block_idx])
        _, qg = torch.linalg.eigh(G[block_idx])
        QA.append(qa.detach())
        QG.append(qg.detach())
    return QA, QG


n_iters = 5000
damping = 1e-4
batch_size = 1

train_data, train_targets, test_data, test_targets = generate_data(0, trainset_size+testset_size, testset_size, max_seq_len)
trainset = TextReversalDataset(train_data, train_targets, max_seq_len)
testset = TextReversalDataset(test_data, test_targets, max_seq_len)
train_loader = DataLoader(trainset, batch_size, True)

QA, QG = ekfac_fit_covariance(model, device, train_loader, n_iters)


In [21]:
def vectorize(x: torch.Tensor):
    """
    The vectorization process stacks columns of a matrix into a single vector.

    For a matrix x of shape (M, N),
    `vectorize(x) = [x[:,0]^T, x[:, 1]^T, ..., x[:, N-1]^T]^T`.
    """
    # rows, cols = x.shape
    # ans = torch.cat([x[:, i] for i in range(cols)], dim=0)
    return x.t().reshape(-1)

def unvectorize(x: torch.Tensor, rows, cols):
    # ans = torch.stack([x[i*rows:i*rows+rows] for i in range(cols)], dim=1)
    return x.reshape(cols, rows).t()

In [None]:
import torch.nn.functional as F

def ekfac_fit_diagonal(model: DecoderTransformer,
                       device: str,
                       dataloader: DataLoader,
                       n_iters: int,
                       QA: list[torch.Tensor],
                       QG: list[torch.Tensor]):
    n_blocks = len(model.blocks)
    batch_size = dataloader.batch_size

    Lambda = []
    for block in model.blocks:
        mlp = block.mlp
        out_dim, in_dim = mlp.get_dims()
        Lambda.append(torch.zeros(((in_dim+1)*out_dim), device=device))
    
    for i in tqdm(range(n_iters), desc="Fitting diagonal"):
        img, lbl = next(iter(dataloader))
        img = img.to(device)

        model.zero_grad()
        outputs = model(img)
        outputs = einops.rearrange(outputs, "b s v -> (b s) v")
        sampled_labels = sample_labels(outputs)
        loss = F.cross_entropy(outputs, sampled_labels, reduction="sum")
        loss.backward()

        for block_idx, block in enumerate(model.blocks):
            mlp = block.mlp
            dw = mlp.get_d_w_l()
            result = QG[block_idx].T @ dw @ QA[block_idx]
            result = vectorize(result)
            Lambda[block_idx] += result.pow(2)

    for i in range(n_blocks):
        Lambda[i] /= n_iters*batch_size
    return Lambda

train_data, train_targets, test_data, test_targets = generate_data(0, trainset_size+testset_size, testset_size, max_seq_len)
trainset = TextReversalDataset(train_data, train_targets, max_seq_len)
testset = TextReversalDataset(test_data, test_targets, max_seq_len)
train_loader = DataLoader(trainset, batch_size, True)

Lambda = ekfac_fit_diagonal(model, device, train_loader, n_iters, QA, QG)


In [25]:
torch.save(QA, "QA.pt")
torch.save(QG, "QG.pt")
torch.save(Lambda, "Lambda.pt")

In [26]:
QA = torch.load("QA.pt", map_location=device)
QG = torch.load("QG.pt", map_location=device)
Lambda = torch.load("Lambda.pt", map_location=device)

In [38]:
def example_grads(net,
                  inputs: torch.Tensor,
                  targets: torch.Tensor):
    net.zero_grad()
    outputs = net(inputs)
    outputs = einops.rearrange(outputs, "b s v -> (b s) v")
    loss = F.cross_entropy(outputs, targets)
    loss.backward()

    grads = []
    for block in net.blocks:
        mlp = block.mlp
        grads.append(mlp.get_d_w_l())
    return grads


def ekfac_ihvp_single_block(qa: torch.Tensor,
                            qg: torch.Tensor,
                            diagonal: torch.Tensor,
                            damping: float,
                            v: torch.Tensor):
    qg_v_qa = qg.T @ v @ qa
    diagonal += damping
    diagonal = unvectorize(diagonal, v.shape[0], v.shape[1])
    result = qg_v_qa / diagonal
    ihvp = qg @ result @ qa.T
    return ihvp


def ekfac_ihvp(QA: list[torch.Tensor],
               QG: list[torch.Tensor],
               Lambda: list[torch.Tensor],
               damping: float,
               vec: list[torch.Tensor]):
    ihvps = []
    for qa, qg, diagonal, v in zip(QA, QG, Lambda, vec):
        ihvp = ekfac_ihvp_single_block(qa, qg, diagonal, damping, v)
        ihvp = vectorize(ihvp)
        ihvps.append(ihvp)
    return torch.cat(ihvps)


img, lbl = trainset[0]
img, lbl = img.to(device), lbl.to(device)
vec = example_grads(model, img, lbl)
ihvp_ekfac = ekfac_ihvp(QA, QG, Lambda, damping, vec)


In [None]:
from scipy.stats import pearsonr


pearsonr(ihvp_lissa.cpu(), ihvp_ekfac.cpu())


In [None]:
from scipy.stats import pearsonr


pearsonr(ihvp_lissa.cpu(), ravel_named_params(grads)[0].cpu())


## CG

In [13]:
from torch.func import functional_call, jvp, grad
import torch.nn.functional as F
import einops

params = dict(model.named_parameters())


def out_f_all(params, i):
    out = functional_call(model, params, (i,)).logits
    out = einops.rearrange(out, "b s v -> (b s) v")
    return out


def loss_with_logits(logits, targets):
    return F.cross_entropy(logits, targets)

def hvp_fn(loss_fn, params, v):
    return jvp(grad(loss_fn), (params,), (v,))[1]


In [None]:
from typing import Callable
from tqdm import tqdm

from torch.func import jvp, grad, vjp


def sample_labels(logits: torch.Tensor):
    probs = torch.nn.functional.softmax(logits, dim=-1)
    if len(probs.shape) > 1:
        return probs.multinomial(num_samples=1, replacement=True)[:, 0]
    else:
        return probs.multinomial(num_samples=1, replacement=True)[0]


def gnhvp_on_sample(f, L):
    def gnhvp_step(primals, tangents, batch):
        f_ = lambda p: f(p, batch)
        z, R_z = jvp(f_, (primals,), (tangents,))

        sampled_labels = sample_labels(z)
        L_ = lambda y: L(y, sampled_labels)

        R_gz = hvp_fn(L_, z, R_z)
        _, f_vjp = vjp(f_, primals)
        return f_vjp(R_gz)[0]
    return gnhvp_step


def create_gnhvp_estimator(gnhvp_step_fn, parameters, data_loader, device):
    ids = next(iter(data_loader))
    ids = ids.to(device)
    
    def compute_fn(vec):
        return gnhvp_step_fn(parameters, vec, ids)
    return compute_fn


def conjugate_gradient(mvp_fn, b, damping, max_iter):
    b_flattened, unravel_fn = ravel_named_params(b)

    x = torch.zeros_like(b_flattened)
    r = b_flattened.clone().detach()
    p = r.clone().detach()
    rdotr = r.dot(r)

    logs = []
    for i in tqdm(range(max_iter)):
        Ap = mvp_fn(unravel_fn(p))
        Ap = ravel_named_params(Ap)[0]
        Ap += damping * p
        v = rdotr / (p.dot(Ap)+1e-11)
        x_new = x - v * p
        logs.append(torch.linalg.vector_norm(x - x_new).item())
        x = x_new
        r -= v * Ap
        newrdotr = r.dot(r)
        mu = newrdotr / rdotr
        p = r + mu * p
        rdotr = newrdotr
    return x, logs


n_iters = 100
damping = 1e-4
batch_size = 32

train_loader = get_dataloader(tokenized_ds['train'], batch_size, True)
# gnhvp_step_fn = gnhvp_on_sample(out_f_blocks, loss_with_logits)
# gnhvp_estimator = create_gnhvp_estimator(gnhvp_step_fn, mlp_params, train_loader, device)

gnhvp_step_fn = gnhvp_on_sample(out_f_all, loss_with_logits)
gnhvp_estimator = create_gnhvp_estimator(gnhvp_step_fn, params, train_loader, device)

ihvp_cg, logs = conjugate_gradient(gnhvp_estimator, grads, damping, n_iters)
print(logs[-1])

%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(range(len(logs)), logs)


In [None]:
from scipy.stats import pearsonr


pearsonr(ihvp_lissa.cpu(), ihvp_cg.cpu())


In [None]:
from scipy.stats import pearsonr


pearsonr(ihvp_ekfac.cpu(), ihvp_cg.cpu())


In [None]:
from scipy.stats import pearsonr


pearsonr(ravel_named_params(grads)[0].cpu(), ihvp_cg.cpu())


## TRAK

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "3"

# device = "cuda:3"
device = "cuda"
pretrained_model_path = "/data/home/Model/Pythia/pythia-70m"

In [None]:
import torch
from torch import nn

from modeling_gpt_neox import GPTNeoXForCausalLM
from transformers import (
    AutoTokenizer, AutoConfig
)


class CustomModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.model.eval()
    
    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask)


tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, padding_side='left')
tokenizer.pad_token_id = tokenizer.eos_token_id

config = AutoConfig.from_pretrained("config.json")
config.use_hook = False
_model = GPTNeoXForCausalLM(config=config)
# _model.load_state_dict(torch.load("model.ckpt", map_location=device))
_model = _model.to(device)
model = CustomModel(_model)


In [3]:
from datasets import load_dataset


ds = load_dataset("/home/Dataset/ptb_text_only", trust_remote_code=True)


In [None]:
import pickle


with open("choices_candidates.pkl", "rb") as fp:
    choices_candidates = pickle.load(fp)

with open("choices_queries.pkl", "rb") as fp:
    choices_queries = pickle.load(fp)


import numpy as np
choices_candidates = np.array(choices_candidates)
choices_candidates.size, len(set(choices_candidates.ravel().tolist()))

In [5]:
ds_candidates = ds['train'].select(choices_candidates[0])
ds_queries = ds['test'].select([choices_queries[0]])

In [6]:
from torch.utils.data import DataLoader


max_length = 512

def tokenize_and_pad(examples):
    return tokenizer(examples['sentence'], max_length=max_length, truncation=True, padding='max_length', return_tensors='pt')


ds['train'] = ds_candidates
ds['test'] = ds_queries
tokenized_ds = ds.map(tokenize_and_pad, batched=True, remove_columns=['sentence'])


In [None]:
from typing import Iterable

from torch import Tensor
import torch.nn.functional as F

from trak import TRAKer
from trak.modelout_functions import AbstractModelOutput

import utils


class CustomModelOutput(AbstractModelOutput):
    @staticmethod
    def get_output(model, weights, buffers, input_ids, attention_mask, label):
        # kw_inputs = {
        #     "input_ids": input_ids,
        #     "attention_mask": attention_mask,
        # }
        input_ids = input_ids.unsqueeze(0)
        attention_mask = attention_mask.unsqueeze(0)
        outputs = torch.func.functional_call(model, (weights, buffers), args=(input_ids, attention_mask))
        logits = outputs.logits
        logits = logits.reshape(-1, logits.size(-1))
        label = label.reshape(-1)
        loss = F.cross_entropy(logits, label, reduction="sum")
        return loss

    @staticmethod
    def get_out_to_loss_grad(model, weights, buffers, batch: Iterable[Tensor]) -> Tensor:
        input_ids, attention_mask, labels = batch
        # kw_inputs = {
        #     "input_ids": input_ids,
        #     "attention_mask": attention_mask,
        # }
        outputs = torch.func.functional_call(model, (weights, buffers), args=(input_ids, attention_mask))
        logits = outputs.logits
        batch_size = logits.size(0)
        logits = logits.reshape(-1, logits.size(-1))
        labels = labels.reshape(-1)
        out_grads = torch.func.grad(lambda logits, labels: F.cross_entropy(logits, labels, reduction='sum'))(logits, labels)
        # out_grads = torch.sum(out_grads, dim=1, keepdim=True).clone().detach()
        out_grads = out_grads.reshape(batch_size, -1)
        out_grads = torch.sum(out_grads, dim=1, keepdim=True)
        # print(out_grads.shape)
        return out_grads


traker = TRAKer(model=model,
                task=CustomModelOutput,
                train_set_size=len(ds['train']),
                device=device,
                proj_dim=4096,
                lambda_reg=1e-2)


In [8]:
from tqdm import tqdm
import os


ckpt_dir = "checkpoints/"
n_ckpts = 10

batch_size = 48

def collate_fn(batch):
    return (
        torch.stack(([torch.tensor(x['input_ids'][:-1]) for x in batch])),
        torch.stack(([torch.tensor(x['attention_mask'][:-1]) for x in batch])),
        torch.stack(([torch.tensor(x['input_ids'][1:]) for x in batch])),
    )

train_loader = DataLoader(tokenized_ds['train'], batch_size, collate_fn=collate_fn)


In [None]:
import time


time_featurize = 0
for model_id in tqdm(range(n_ckpts), desc="Checkpoints"):
    ckpt = torch.load(os.path.join(ckpt_dir, f"{model_id}.ckpt"), map_location=device)
    model.model.load_state_dict(ckpt)
    traker.load_checkpoint(model.state_dict(), model_id=model_id)

    start = time.time()
    for batch in tqdm(train_loader, desc=f"Ckpt [{model_id}]"):
        batch = [x.to(device) for x in batch]
        traker.featurize(batch=batch, num_samples=batch[0].shape[0])
    end = time.time()
    time_featurize += end-start

import logging
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
    print(logger.name)
    logger.setLevel(logging.DEBUG)

start = time.time()
traker.finalize_features()
end = time.time()
time_featurize += end-start
print(f"Time on featurizing: {time_featurize:.3f} s")


In [None]:
test_loader = DataLoader(tokenized_ds['test'], batch_size, shuffle=False, collate_fn=collate_fn)

time_score = 0
for model_id in range(n_ckpts):
    ckpt = torch.load(os.path.join(ckpt_dir, f"{model_id}.pkl"), map_location=device)
    model.model.load_state_dict(ckpt)

    start = time.time()
    traker.start_scoring_checkpoint(exp_name="test",
                                    checkpoint=model.state_dict(),
                                    model_id=model_id,
                                    num_targets=len(ds['test']))

    for batch in tqdm(test_loader):
        batch = [x.to(device) for x in batch]
        traker.score(batch, num_samples=batch[0].shape[0])
    end = time.time()
    time_score += end-start

start = time.time()
scores = traker.finalize_scores(exp_name="test")
end = time.time()
time_score += end-start
print(f"Time on scoring: {end-start:.3f} s")


In [None]:
import os
os.makedirs("infls/trak", exist_ok=True)

n_queries = len(choices_queries)
n_candidates = len(choices_candidates[0])

for i in range(n_queries):
    infls = scores[:, i].tolist()
    with open(f"infls/trak/{i}.pkl", "wb") as fp:
        pickle.dump(infls, fp)


## MHA/MLP parameter count

In [None]:
from modeling_gpt_neox import GPTNeoXForCausalLM
from transformers import (
    AutoTokenizer, AutoConfig
)
import torch
config = AutoConfig.from_pretrained("config.json")
config.use_hook = False
model = GPTNeoXForCausalLM(config=config)
model.load_state_dict(torch.load("model.ckpt", map_location="cpu"))

In [None]:
total = 0
n_mha_params = 0
n_mlp_params = 0
for n, p in model.named_parameters():
    print(n)
    total += p.numel()
    if "query_key_value" in n:
        n_mha_params += p.numel()
    elif "dense" in n:
        n_mlp_params += p.numel()
total, n_mha_params, n_mlp_params, n_mha_params/(n_mha_params+n_mlp_params), n_mlp_params/(n_mha_params+n_mlp_params)

## Sample influence estimation targets and candidates

In [3]:
import numpy as np
# from utils import generate_data, TextReversalDataset
from utils import get_tokenized_dataset, get_dataloader
from datasets import load_dataset


def sample_uniform(seed: int, n_total: int, n_choices: int):
    np.random.seed(seed)
    choices = np.random.choice(np.arange(n_total), n_choices, replace=False)
    return choices.tolist()

n_queries = 10
n_candidates = 500
seed = 42

# trainset_size = 100_000
# testset_size = 1_000

ds = load_dataset("/home/Dataset/ptb_text_only", trust_remote_code=True)
# train_data, train_targets, test_data, test_targets = generate_data(0, trainset_size+testset_size, testset_size, max_seq_len)
# trainset = TextReversalDataset(train_data, train_targets, max_seq_len)
# testset = TextReversalDataset(test_data, test_targets, max_seq_len)

choices_queries = sample_uniform(seed, len(ds['test']), n_queries)

choices_candidates = []
for i in range(n_queries):
    choices_candidates.append(sample_uniform(seed+1+i, len(ds['train']), n_candidates))


In [None]:
ds['test']

In [4]:
import pickle

with open("choices_candidates.pkl", "wb") as fp:
    pickle.dump(choices_candidates, fp)

with open("choices_queries.pkl", "wb") as fp:
    pickle.dump(choices_queries, fp)


## Influence correlation

In [None]:
import os
import pickle

def read_infls(infl_dir: str):
    infls = []
    for f in sorted(os.listdir(infl_dir)):
        with open(os.path.join(infl_dir, f), "rb") as fp:
            infl = pickle.load(fp)
        infls.append(infl)
    return infls


infls = {}
algorithms = sorted(os.listdir("infls"))
for alg in algorithms:
    infls[alg] = read_infls(f"infls/{alg}")
    print(f"{alg}: {len(infls[alg])}")


In [40]:
import numpy as np

def fix_tensor_type(infls: dict[str, list]):
    infls_fixed = []
    for infl in infls:
        if str(infl[0].__class__) == "<class 'torch.Tensor'>":
            infls_fixed.append(np.array([i.detach().cpu() for i in infl]))
        else:
            infls_fixed.append(np.array(infl))
    return infls_fixed

for alg in algorithms:
    infls[alg] = fix_tensor_type(infls[alg])


In [41]:
for i in range(len(infls['trak'])):
    infls['trak'][i] = infls['trak'][i][i*500:(i+1)*500]


### Pearson Correlation

In [42]:
import numpy as np
from scipy.stats import pearsonr

baseline = "cg-all"

results = {}
for alg in algorithms:
    if alg != baseline:
        ps = []
        for infls1, infls2 in zip(infls[alg], infls[baseline]):
            ps.append(pearsonr(infls1, infls2).statistic)
        std = np.std(ps)
        ps = np.mean(ps)
        
        results[alg] = ps, std


In [None]:
from tabulate import tabulate

table_content = [(n, p) for n, p in results.items()]
print(tabulate(table_content, headers=["Algorithm", "Pearson correlation"], tablefmt="pipe"))


In [None]:
results_sorted = dict(sorted(results.items(), key=lambda item: np.abs(item[1][0]), reverse=True))
table_content = [(n, p) for n, p in results_sorted.items()]
print(tabulate(table_content, headers=["Algorithm", "Pearson correlation"], tablefmt="pipe"))


In [None]:
from pytablewriter import MarkdownTableWriter
from pytablewriter.style import Cell, Style

results_sorted = dict(sorted(results.items(), key=lambda item: np.abs(item[1]), reverse=True))
table_content = [(n, f"{abs(p):.6f}") for n, p in results_sorted.items()]


def style_filter(cell: Cell, **kwargs):
    if cell.is_header_row():
        return None
    
    style = Style(align="center")

    if cell.row == 0 and cell.col == 1:
        style.update(font_weight="bold")
    elif cell.row == 1 and cell.col == 1:
        style.update(font_style="italic")
    
    return style


table_writer = MarkdownTableWriter(
    table_name="Scalability validation (MLP)",
    headers=["Algorithm", "Pearson correlation"],
    value_matrix=table_content,
    flavor="github",
)
table_writer.add_style_filter(style_filter)
table_writer.write_table()


### Spearman Correlation

In [45]:
import numpy as np
from scipy.stats import spearmanr

baseline = "cg-all"

results = {}
for alg in algorithms:
    if alg != baseline:
        ps = []
        for infls1, infls2 in zip(infls[alg], infls[baseline]):
            ps.append(spearmanr(infls1, infls2).statistic)
        std = np.std(ps)
        ps = np.mean(ps)
        results[alg] = ps, std


In [None]:
from tabulate import tabulate

table_content = [(n, p) for n, p in results.items()]
print(tabulate(table_content, headers=["Algorithm", "Spearman correlation"], tablefmt="pipe"))


In [None]:
results_sorted = dict(sorted(results.items(), key=lambda item: np.abs(item[1][0]), reverse=True))
table_content = [(n, p) for n, p in results_sorted.items()]
print(tabulate(table_content, headers=["Algorithm", "Spearman correlation"], tablefmt="pipe"))


In [None]:
from pytablewriter import MarkdownTableWriter
from pytablewriter.style import Cell, Style

results_sorted = dict(sorted(results.items(), key=lambda item: np.abs(item[1]), reverse=True))
table_content = [(n, f"{abs(p):.6f}") for n, p in results_sorted.items()]


def style_filter(cell: Cell, **kwargs):
    if cell.is_header_row():
        return None
    
    style = Style(align="center")

    if cell.row == 0 and cell.col == 1:
        style.update(font_weight="bold")
    elif cell.row == 1 and cell.col == 1:
        style.update(font_style="italic")
    
    return style


table_writer = MarkdownTableWriter(
    table_name="Scalability validation (MLP)",
    headers=["Algorithm", "Spearman correlation"],
    value_matrix=table_content,
    flavor="github",
)
table_writer.add_style_filter(style_filter)
table_writer.write_table()


In [None]:
with open("/home/Dataset/wikitext-103/wiki.valid.tokens", "r") as fp:
    valid = fp.read()
valid

## Plotting

### Pearson correlation

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import seaborn as sns
sns.reset_orig()
sns.set_theme(context='paper', style='whitegrid')

import numpy as np
import json

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


with open("results.json", "r") as fp:
    results = json.load(fp)

results = [r for r in results if "all" in r['alg'] or "ekfac" in r['alg']]

for i in range(len(results)):
    results[i]["overhead"] /= 3600
    results[i]["job"] /= 3600*10
    results[i]["time"] = results[i]["overhead"] + results[i]["job"]

fig = plt.figure()
axes = []
for rec in results:
    ax = plt.scatter(rec['time'], rec['pc'], color=rec['color'], marker=rec['marker'], s=64)
    axes.append(ax)
    if 'ekfac-all' == rec['alg'] or 'ekfac-mlp' == rec['alg'] or 'arnoldi' in rec['alg'] or 'trak' in rec['alg']:
        ax = plt.scatter(rec['job'], rec['pc'], color=rec['color'], marker=rec['marker'], s=64, facecolors='none')
        axes.append(ax)


groups = [4, 2, 2, 4, 1, 1, 2, 2]
axes_legend = []
p = 0
for group in groups:
    if group > 1:
        axes_legend.append(tuple(axes[p:p+group]))
    else:
        axes_legend.append(axes[p])
    p += group
plt.legend(axes_legend,
    [
    "CG-all",
    "Ours",
    "Ours-mlp",
    "LiSSA",
    "Arnoldi",
    "GDP",
    "CKA",
    "Arnoldi",
    "TRAK",
    ],
    handler_map={tuple: HandlerTuple(ndivide=None)},
    loc=9,
    bbox_to_anchor=(1.2, 0.663)
    )
plt.xlabel("Time (h)")
# plt.xlabel("Time (s)")
plt.xticks(np.arange(0, 21, 2))
plt.ylabel("Pearson correlation")
plt.ylim(top=1.02)
plt.title("Ground truth: CG (All parameters)")
plt.xscale('log')

import os
figures_dir = "paper-figures/"
os.makedirs(figures_dir, exist_ok=True)
figure_name = "pearson_corr_transformer_all.pdf"
plt.savefig(os.path.join(figures_dir, figure_name), format="pdf", bbox_inches='tight')


### Spearman correlation

In [None]:
import json
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import seaborn as sns
sns.reset_orig()
sns.set_theme(context='paper', style='whitegrid')
import numpy as np

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

with open("results.json", "r") as fp:
    results = json.load(fp)

results = [r for r in results if "all" in r['alg'] or "ekfac" in r['alg']]

for i in range(len(results)):
    results[i]["overhead"] /= 3600
    results[i]["job"] /= 3600*10
    results[i]["time"] = results[i]["overhead"] + results[i]["job"]

fig = plt.figure()
axes = []
for rec in results:
    ax = plt.scatter(rec['time'], rec['sc'], color=rec['color'], marker=rec['marker'], s=64)
    axes.append(ax)
    if 'ekfac-all' == rec['alg'] or 'ekfac-mlp' == rec['alg'] or 'arnoldi' in rec['alg'] or 'trak' in rec['alg']:
        ax = plt.scatter(rec['job'], rec['sc'], color=rec['color'], marker=rec['marker'], s=64, facecolors='none')
        axes.append(ax)


groups = [4, 2, 2, 4, 1, 1, 2, 2]
axes_legend = []
p = 0
for group in groups:
    if group > 1:
        axes_legend.append(tuple(axes[p:p+group]))
    else:
        axes_legend.append(axes[p])
    p += group
plt.legend(axes_legend,
    [
    "CG-all",
    "Ours",
    "Ours-mlp",
    "LiSSA",
    "GDP",
    "CKA",
    "Arnoldi",
    "TRAK",
    ],
    handler_map={tuple: HandlerTuple(ndivide=None)},
    loc=9,
    bbox_to_anchor=(1.2, 0.663)
    )
plt.xlabel("Time (h)")
# plt.xlabel("Time (s)")
plt.xticks(np.arange(0, 21, 2))
plt.ylabel("Spearman correlation\n($\\rightarrow$Better)")
plt.ylim(top=1.02)
plt.title("($\leftarrow$Better)")
plt.xscale('log')

import os
figures_dir = "paper-figures/"
os.makedirs(figures_dir, exist_ok=True)
figure_name = "spearman_corr_transformer_all.pdf"
plt.savefig(os.path.join(figures_dir, figure_name), format="pdf", bbox_inches='tight')


## Time trend with scale

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_orig()
sns.set_theme(context='paper', style='whitegrid')

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


n_params = [268800, 1861632, 10014720, 56494080, 108876800, 217845760]
overhead = [90.72, 93.4, 134, 820.82, 1376.82, 5201.11]
job = [7.62, 7.65, 9.46, 16.95, 47.54, 52.49]
pd_table = dict([(p, t) for p, t in zip(n_params, overhead)])

time = [o + j for o, j in zip(overhead, job)]

plt.plot(n_params, overhead, c='C0')
plt.plot(n_params, time, c='C1')
plt.scatter(n_params, overhead, c='C0')
plt.scatter(n_params, time, c='C1')

def line(x, a):
    return a * x

from scipy import optimize
parameters, covariance = optimize.curve_fit(line, n_params, overhead)
import numpy as np
x = np.arange(0, n_params[-1], step=1e4)
fitted = line(x, parameters[0])
plt.plot(x, fitted, 'r--')
plt.xlabel("Number of parameters")
plt.ylabel("Run time (s)")

import os
figures_dir = "paper-figures/"
os.makedirs(figures_dir, exist_ok=True)
figure_name = "infl-time.pdf"
plt.savefig(os.path.join(figures_dir, figure_name), format="pdf")
