# Imports

In [1]:
# Standard library imports
import importlib
import gc
import copy

# Third-party imports
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import einops
import matplotlib.pyplot as plt
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer
from datasets import load_dataset
from transformer_lens.utils import tokenize_and_concatenate

# Local imports
import toy_models.xornet
import toy_models.tms
import toy_models.train
import toy_models.transformer_wrapper
import eigenestimation_algorithm.train
import eigenestimation_algorithm.eigenestimation
import eigenestimation_algorithm.evaluation

# Reload modules for interactive sessions
importlib.reload(toy_models.xornet)
importlib.reload(toy_models.tms)
importlib.reload(toy_models.train)
importlib.reload(toy_models.transformer_wrapper)
importlib.reload(eigenestimation_algorithm.train)
importlib.reload(eigenestimation_algorithm.eigenestimation)
importlib.reload(eigenestimation_algorithm.evaluation)

# Specific imports from local modules
from toy_models.xornet import XORNet, GenerateXORData
from toy_models.tms import Autoencoder, GenerateTMSData
from toy_models.train import TrainModel
from toy_models.transformer_wrapper import TransformerWrapper, DeleteParams, KLDivergenceLoss
from eigenestimation_algorithm.eigenestimation import EigenEstimation
from eigenestimation_algorithm.train import TrainEigenEstimation
from eigenestimation_algorithm.evaluation import (
    PrintFeatureVals,
    ActivatingExamples,
    PrintFeatureValsTransformer,
    PrintActivatingExamplesTransformer,
)

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)


  from .autonotebook import tqdm as notebook_tqdm


# Toy Models

Setup toy models, generate/pull data, and train.

## XORNet

In [None]:
X_xornet, Y_xornet, dataloader_xornet = GenerateXORData(n_repeats=100, batch_size=24)
model_xornet = XORNet()

_, _, _ =TrainModel(
    model=model_xornet,
    criterion=nn.MSELoss(),
    learning_rate=.01,
    dataloader=dataloader_xornet,
    n_epochs=1000
)

## TMS

In [None]:
#@title Train TMS
n_features = 5
hidden_dim = 2
n_datapoints = 1024
sparsity = .075

batch_size = 12
learning_rate = .1
n_epochs = 1000

X_tms, Y_tms, dataloader_TMS = GenerateTMSData(
    num_features=n_features, num_datapoints=n_datapoints, sparsity=sparsity, batch_size=batch_size)
tms_model = Autoencoder(n_features, hidden_dim)
_, _, _ = TrainModel(tms_model, nn.MSELoss(), learning_rate, dataloader_TMS, n_epochs=n_epochs)


# Plot TMS representations.
en = copy.deepcopy(tms_model.W).detach().cpu().numpy()

for i in range(en.shape[1]):
  plt.plot([0, en[0,i]], [0,en[1,i]], 'b-')
plt.show()

## 2-layer transformer




In [2]:
# @title Import pretrained gpt2 (2 layers)
# Disable fused kernels (FlashAttention and memory-efficient attention)
# We have to disable this to compute second-order gradients on transformer models.
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

# Ensure the math kernel is enabled (it is True by default)
torch.backends.cuda.enable_math_sdp(True)

# Load in a 2-L GPT2.
config = GPT2Config.from_pretrained('gpt2', n_layer=2)
gpt2 = GPT2Model.from_pretrained('gpt2', config=config)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",)
tokenizer.pad_token = tokenizer.eos_token
transformer_model = TransformerWrapper(gpt2, tokenizer)

# Make the eigenestimation a little smaller but only looking at a subset of the parameters.
# Pick a random subset of tensors to include in paramters, and turn the rest into frozen buffers.
params_to_delete = [name for name, param in transformer_model.named_parameters()]
params_to_delete = [p for p in params_to_delete if p!='transformer.h.1.ln_1.weight']

# Delete 3/4 of the parameters.
#for p in (params_to_delete[::20]):
#  params_to_delete.remove(p)

DeleteParams(transformer_model, params_to_delete)

print(sum([p.numel() for p in transformer_model.parameters()]))
for n,_ in transformer_model.named_parameters(): print(n)

768
transformer.h.1.ln_1.weight


In [3]:
# Load in data.
imdb_dataset = load_dataset("imdb", split="test[:1%]")
X_transformer= tokenize_and_concatenate(imdb_dataset, tokenizer, max_length = 24, add_bos_token=False)['tokens']
transformer_dataloader = DataLoader(X_transformer, batch_size=24, shuffle=True,
                                    generator=torch.Generator(device='cuda'))

# Eigenestimation

# Tests on Toy Models


## Xornet

In [None]:
n_u_vectors = 3
batch_size = 12
lambda_penalty = 1
repeats = 24
n_epochs = 100
learning_rate = .01

torch.cuda.empty_cache()
gc.collect()
eigenmodel_xornet = EigenEstimation(model_xornet, nn.MSELoss, n_u_vectors)

dataloader_xornet_eigen = DataLoader(
    einops.repeat(X_xornet, 's f -> (s r) f', r=repeats), batch_size=batch_size, shuffle=True,
    generator=torch.Generator(device=device))

TrainEigenEstimation(eigenmodel_xornet, dataloader_xornet_eigen, learning_rate, n_epochs, lambda_penalty)

# Clear cuda cache
torch.cuda.empty_cache()
gc.collect()


In [None]:
4*10*10*700

In [None]:
# Look at features
PrintFeatureVals(X_xornet, eigenmodel_xornet)

for f_idx in range(eigenmodel_xornet.n_u_vectors):
  sample, val = ActivatingExamples(X_xornet, eigenmodel_xornet, f_idx, 3)
  print('feature', f_idx)
  for s, v in zip(sample, val):
    print(s, '->', v)

## TMS

In [None]:
#@title Train Eigenmodel
n_u_vectors = 5
batch_size = 24
lambda_penalty = 1
n_epochs = 100
learning_rate = .01


dataloader = DataLoader(X_tms, batch_size=batch_size, shuffle=True,
                               generator=torch.Generator(device='cuda'))
eigenmodel_tms = EigenEstimation(tms_model, nn.MSELoss, n_u_vectors)
TrainEigenEstimation(eigenmodel_tms, dataloader, learning_rate, n_epochs, lambda_penalty)

In [None]:
#@title Look at features
X = X_tms[:10]
PrintFeatureVals(X_tms[:10], eigenmodel_tms)

for f_idx in range(eigenmodel_tms.n_u_vectors):
  sample, val = ActivatingExamples(X_tms[:100,], eigenmodel_tms, f_idx, 3)
  print('feature', f_idx)
  for s, v in zip(sample, val):
    print(s.round(3), '->', v.round(3))

## 2L Transformer

In [None]:
#@title Train Eigenmodel
#model = transformer_model
n_u_vectors = 10
batch_size = 4
lambda_penalty = 1
n_epochs = 10
learning_rate = .01

torch.cuda.empty_cache()
gc.collect()

transformer_dataloader = DataLoader(X_transformer[:100,:4], batch_size=batch_size, shuffle=True,
                                    generator=torch.Generator(device='cuda'))
eigenmodel_transformer = EigenEstimation(transformer_model, KLDivergenceLoss, n_u_vectors)
#TrainEigenEstimation(eigenmodel_transformer, transformer_dataloader, learning_rate, n_epochs, lambda_penalty)

# SCRATCH

In [None]:
eigenmodel_transformer

In [None]:
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.func import vmap

# Assume model, x, targets, and loss_fn are defined
# model = ...
# x = ...
# targets = ...
# loss_fn = ...

def get_flat_params(model):
    return parameters_to_vector(model.parameters())

def set_flat_params(model, flat_params):
    vector_to_parameters(flat_params, model.parameters())

# Get flat parameters w0
w0 = get_flat_params(model).detach()
w0.requires_grad_(True)

def f(w):
    set_flat_params(model, w)
    outputs = model(x)
    loss = loss_fn(outputs, targets)
    return loss

def hvp_func(u_i):
    hvp = torch.autograd.functional.hvp(f, w0, u_i)[1]
    return hvp

# Prepare U tensor
k_vectors = eigenmodel.n_u_vectors  # Number of u vectors
n_params = w0.numel()               # Number of parameters
# Assuming eigenmodel._parameters contains the u vectors
U = torch.cat([param.view(k_vectors, -1) for param in eigenmodel._parameters.values()], dim=1)

# Compute HVPs
HVPs = vmap(hvp_func)(U)

# HVPs is of shape (k_vectors, n_params)
print("Hessian-Matrix Product H @ U:")
print(HVPs)


In [None]:
# @title Import pretrained gpt2 (2 layers)
# Disable fused kernels (FlashAttention and memory-efficient attention)
# We have to disable this to compute second-order gradients on transformer models.
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

# Ensure the math kernel is enabled (it is True by default)
torch.backends.cuda.enable_math_sdp(True)

# Load in a 2-L GPT2.
config = GPT2Config.from_pretrained('gpt2', n_layer=2)
gpt2 = GPT2Model.from_pretrained('gpt2', config=config)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",)
tokenizer.pad_token = tokenizer.eos_token
transformer_model = TransformerWrapper(gpt2, tokenizer)

# Make the eigenestimation a little smaller but only looking at a subset of the parameters.
# Pick a random subset of tensors to include in paramters, and turn the rest into frozen buffers.
params_to_delete = [name for name, param in transformer_model.named_parameters()]
params_to_delete = [p for p in params_to_delete if p!='transformer.h.1.ln_1.weight']

# Delete 3/4 of the parameters.
#for p in (params_to_delete[::20]):
#  params_to_delete.remove(p)

DeleteParams(transformer_model, params_to_delete)

print(sum([p.numel() for p in transformer_model.parameters()]))
for n,_ in transformer_model.named_parameters(): print(n)

In [None]:
from torch.func import functional_call, vmap, grad

def compute_loss(params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)
    loss = eigenmodel_transformer.compute_loss()
    predictions = functional_call(model, (params, buffers), (batch,))
    loss = loss_fn(predictions, targets)
    return loss

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, x, y)

print(ft_per_sample_grads)

In [247]:
transformer_model

TransformerWrapper(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-1): 2 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

In [348]:
from torch import autograd
import time
#model = nn.Linear(10, 20)


torch.cuda.reset_max_memory_allocated()
gc.collect()
# Run your code

#x = torch.randn(8, 10) # 10 elements long

model = transformer_model
#x = X_transformer[[1],:4]
x.requires_grad=False

t  = time.time()
vecs = tuple([einops.repeat(torch.randn_like(p, requires_grad=True), '...->k ...', k=k) for p in model.parameters()])

x = X_transformer[1,:]
vec = tuple(k[0] for k in vecs)
for _ in range(10):

    out_model = model(x)
    loss = nn.MSELoss(reduction='none')(out_model, out_model.detach()).mean(dim=-1)

    dims = ' '.join([f'd{i}' for i in range(len(loss.shape))])

    out = loss.flatten()#einops.rearrange(loss, f'{dims} -> ({dims})') # Flatten.
    outputs = [i for i in out]
    print(time.time()-t, 'timex!')

    grads = autograd.grad(outputs=outputs, inputs=model.parameters(), create_graph=True, grad_outputs = [i for i in torch.eye(len(outputs))], is_grads_batched=True) # 160 x 20 x 10 - (batch x outputs) x (params)
    print(time.time()-t, 'timey!')

    #print(len(grads), grads[0].shape, 'grads')
    #print(time.time()-t, 'time0!')

    p=sum([einops.einsum(g, v, 'o ... , ... -> o') for g,v in zip(grads, vec)]) # (batch x outputs) x k.
    #print(p.shape, 'p')

    p_dims = ' '
    p_reshape = [i for i in einops.rearrange(p, 'o->(o)')] # (batch k)
    #print(len(p_reshape))
    grads2 = autograd.grad(p_reshape, vec, create_graph=True, grad_outputs = [i for i in torch.eye(len(p_reshape))], is_grads_batched=True) # for each grad - (batch x outputs x k) x (params)
    print(time.time()-t, 'time1!')
    #grads2 = autograd.grad(p_reshape, vecs, create_graph=True, grad_outputs = [i for i in torch.eye(len(p_reshape))], is_grads_batched=True) # for each grad - (batch x outputs x k) x (params)
    print(time.time()-t, 'time2!!')
    print(len(grads2), grads2[0].shape, 'grads2')

    #dims_grad2 =  [' '.join([f'd{i}' for idd in range(len(g.shape[1:]))]) for g in grads2]
    #grads2_v = tuple((g**2).sum(dim=-1) for g in grads2)
    grads2_v = tuple((g**2).sum(dim=-1) for g in grads2)
    #print(time.time()-t, 'time3!!')

    print(len(grads2_v), grads2_v[0].shape, 'grads2_v')

    #grads2_v = sum([einops.einsum(g,v, 'batch outputs k ... , k ... -> batch outputs k') for g,v,d in zip(grads2_rearranged, vecs, dims)])


    peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
    

print(peak_memory, 'memory')    
print(time.time()-t, 'time!')

0.004847526550292969 timex!
0.010789155960083008 timey!
0.012555360794067383 time1!
0.01257777214050293 time2!!
1 torch.Size([24, 768]) grads2
1 torch.Size([24]) grads2_v
0.016208648681640625 timex!
0.021465778350830078 timey!
0.023014307022094727 time1!
0.02303600311279297 time2!!
1 torch.Size([24, 768]) grads2
1 torch.Size([24]) grads2_v
0.026692867279052734 timex!
0.03182673454284668 timey!
0.03330254554748535 time1!
0.03332662582397461 time2!!
1 torch.Size([24, 768]) grads2
1 torch.Size([24]) grads2_v
0.03707098960876465 timex!
0.04202914237976074 timey!
0.04360079765319824 time1!
0.043624162673950195 time2!!
1 torch.Size([24, 768]) grads2
1 torch.Size([24]) grads2_v
0.047182559967041016 timex!
0.05252408981323242 timey!
0.05401182174682617 time1!
0.05403304100036621 time2!!
1 torch.Size([24, 768]) grads2
1 torch.Size([24]) grads2_v
0.057550907135009766 timex!
0.06273174285888672 timey!
0.06419849395751953 time1!
0.06426095962524414 time2!!
1 torch.Size([24, 768]) grads2
1 torch.Si

In [345]:
out_model.shape

torch.Size([4, 768])

In [316]:
x.shape

torch.Size([10, 24])

In [314]:
print(len(grads), grads[0].shape, 'grads')

1 torch.Size([240, 768]) grads


In [210]:
 f'(b o k {dims}) -> b o k {dims}'

'(b o k d0 d1) -> b o k d0 d1'

In [196]:
vecs[0].shape

torch.Size([7, 20, 10])

In [32]:
from torch import autograd
linear = nn.Linear(10, 20)

x = torch.randn(1, 10)
out = [i for i in (linear(x)**2)[0]]


flat_params = torch.cat([p.view(-1) for p in linear.parameters()])# flatten.requires_grad=True
v = torch.randn([10, flat_params.shape[0]], requires_grad=True)

p=einops.einsum(out, v, )

p.backward(create_graph=True)

print(v.grad)

linear.zero_grad()

grads = autograd.grad(out, linear.parameters(), create_graph=True)
grads

NameError: name 'p' is not defined

In [30]:
flat_params.shape[0]

220

In [17]:
out

[tensor(1.0687, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.1070, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.8654, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0577, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0269, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0672, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0040, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.1864, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0787, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.8409, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0864, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0186, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0904, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0019, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.0746, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.3169, device='cuda:0', grad_fn=<UnbindBackward0>),
 tensor(0.1385, device='

In [31]:
torch.cuda.empty_cache()

gc.collect()

#from torch.autograd
from torch.func import functional_call, vmap, grad

def myloss(X, w0):
    return eigenmodel_transformer.compute_loss(X, w0)




def myloss2(X, w0, u):

    grads = torch.func.jacrev(myloss, argnums=1)(X, w0)##torch.autograd.grad(loss, w0.values(), create_graph=True)
    grad_u = torch.cat([einops.einsum(g,u, 'batch p, p->batch') for g, u in zip(grads.values(), u.values())])
    # Flatten gradients and u vectors
#    grads_flat = torch.cat([g.view(-1) for g in grads.values()])
#    u_flat = torch.cat([u_i.view(-1) for u_i in u.values()])
    
    # Compute dot product
#    grad_u = torch.dot(grads_flat, u_flat)
    return grad_u

def hvp_inhouse(params, X, w0):
    u = {name:v[0] for name, v in params.items()}

    #loss = eigenmodel_transformer.compute_loss(X, w0).sum()  # Sum over batch

    # Compute gradient of loss w.r.t. parameters
    #grads = torch.func.grad(myloss, argnums=1)(X, w0)##torch.autograd.grad(loss, w0.values(), create_graph=True)
    
    #print(grads)
    # Flatten gradients and u vectors
    #grads_flat = torch.cat([g.view(-1) for g in grads])
    #u_flat = torch.cat([u_i.view(-1) for u_i in u.values()])
    
    #print(grads_flat.shape)
    #print(u_flat.shape)
    # Compute dot product
    #grad_u = torch.dot(grads_flat, u_flat)
    

    # Compute Hessian-vector product
    hvp = torch.func.grad(myloss2, argnums=1, is_grads_batched=True)(X, w0, u)

    # Flatten hvp and u vectors
    hvp_flat = torch.cat([h.view(-1) for h in hvp.values()])
    u_flat = torch.cat([u_i.view(-1) for u_i in u.values()])

    # Compute dot product
    dH_du = torch.dot(hvp_flat, u_flat)
    return dH_du


torch.cuda.reset_max_memory_allocated()
# Run your code
peak_memory = torch.cuda.max_memory_allocated() / 1024**2



#test = hvp_inhouse(X_transformer[:3,], eigenmodel_transformer._parameters)



# Compute HVPs for all vectors in U
peak_memory = torch.cuda.max_memory_allocated() / 1024**2

params = eigenmodel_transformer._parameters
w0 = eigenmodel_transformer.w0
HVPs = torch.vmap(hvp_inhouse, in_dims=(None, 0, None))(params, X_transformer[:3,:], w0)
print(peak_memory)




GradTrackingTensor(lvl=2, value=
    BatchedTensor(lvl=1, bdim=0, value=
        tensor([[-3.4078e-02, -2.3414e-01,  1.1105e+00,  2.0623e-01, -9.2142e-02,
                  3.2440e-01,  2.4454e-03,  1.2116e-05,  5.1514e-02,  2.8423e-01,
                 -3.2002e-01, -1.8888e+00, -3.6671e-11,  1.4318e+00, -1.5549e-01,
                 -1.2171e+00,  8.1368e-03,  6.8094e+00,  7.2591e-03,  1.2072e+00,
                  1.1051e+00,  4.1215e-02, -2.8333e-01, -1.7296e-09],
                [ 2.5315e-02, -1.8952e-10, -9.9086e-08, -7.0067e-02,  5.0342e-04,
                  7.4225e-04,  9.2458e-01,  5.8505e-03, -1.2408e+00, -3.8720e+00,
                  1.2356e+01,  8.9047e-06, -7.4990e-01, -1.7808e+00,  5.2451e-02,
                  2.3128e+00, -3.5312e-01,  1.1556e-01, -5.5430e-01,  7.9101e-01,
                  1.8138e-01,  2.9986e-01,  1.3191e+00, -9.7872e-01],
                [-9.1007e-02, -6.1624e-01,  2.2481e-02,  9.8138e-01, -2.8843e-02,
                 -9.4954e-02,  1.1853e-01,  2.512

RuntimeError: grad_and_value(f)(*args): Expected f(*args) to return a scalar Tensor, got tensor with 1 dims. Maybe you wanted to use the vjp or jacrev APIs instead?

In [35]:
torch.cuda.empty_cache()

gc.collect()

def hvp_inhouse(X, params, w0):
    u = {name:v[0] for name, v in params.items()}

    loss = eigenmodel_transformer.compute_loss(X, w0).sum()  # Sum over batch

    # Compute gradient of loss w.r.t. parameters
    grads = torch.autograd.grad(loss, w0.values(), create_graph=True)
    
    print(grads)
    # Flatten gradients and u vectors
    grads_flat = torch.cat([g.view(-1) for g in grads])
    u_flat = torch.cat([u_i.view(-1) for u_i in u.values()])
    
    print(grads_flat.shape)
    print(u_flat.shape)
    # Compute dot product
    grad_u = torch.dot(grads_flat, u_flat)
    

    # Compute Hessian-vector product
    hvp = torch.autograd.grad(grad_u, w0.values(), retain_graph=True)

    # Flatten hvp and u vectors
    hvp_flat = torch.cat([h.view(-1) for h in hvp])
    u_flat = torch.cat([u_i.view(-1) for u_i in u.values()])

    # Compute dot product
    dH_du = torch.dot(hvp_flat, u_flat)
    return dH_du


torch.cuda.reset_max_memory_allocated()
# Run your code


#test = hvp_inhouse(X_transformer[:3,], eigenmodel_transformer._parameters)
peak_memory = torch.cuda.max_memory_allocated() / 1024**2
print(peak_memory)

from torch.func import vmap

# Compute HVPs for all vectors in U
params = eigenmodel_transformer._parameters
w0 = eigenmodel_transformer.w0
HVPs = vmap(hvp_inhouse, in_dims=(0, None, None))(X_transformer[:3,:], params, w0)

NameError: name 'orch' is not defined

In [26]:
from torch.func import functional_call, vmap, grad


tensor([ 938.7214, 1162.7638,  806.2974], device='cuda:0',
       grad_fn=<MvBackward0>)

In [None]:
eigenmodel_transformer.compute_loss(X_transformer[:3,:], w0)

In [None]:
def attempt(self, x, w0, u):
    # Compute loss
    loss = eigenestimation_algorithm.compute_loss(x, w0).sum()  # Sum over batch

    # Compute gradient of loss w.r.t. parameters
    grads = torch.autograd.grad(loss, w0.values(), create_graph=True)

    # Flatten gradients and u vectors
    grads_flat = torch.cat([g.view(-1) for g in grads])
    u_flat = torch.cat([u_i.view(-1) for u_i in u.values()])

    # Compute dot product
    grad_u = torch.dot(grads_flat, u_flat)
    return grad_u


    # Compute grad along u
    grad_u = self.grad_along_u(x, self.w0, u)

    # Compute Hessian-vector product
    hvp = torch.autograd.grad(grad_u, self.w0.values(), retain_graph=True)

    # Flatten hvp and u vectors
    hvp_flat = torch.cat([h.view(-1) for h in hvp])
    u_flat = torch.cat([u_i.view(-1) for u_i in u.values()])

    # Compute dot product
    dH_du = torch.dot(hvp_flat, u_flat)
    return dH_du



In [None]:
def call(X, params):
    params_dict = {k:v for k,v in zip(params.keys(), param_tuple)}
    return functional_call(transformer_model, X[:2,:2], params_dict)
##eigenmodel_transformer(X_transformer[:2,:2])

In [None]:
from torch.func import jacrev, functional_call

param_tuple = tuple([v[0] for v in eigenmodel_transformer._parameters.values()])
v_tuple = tuple([torch.rand_like(v[0]) for v in eigenmodel_transformer._parameters.values()])

X = X_transformer[:2,:2]
def Call(*params):
    params_ordered_dict = eigenmodel_transformer._parameters
    param_tuple = params
    params_dict = {k:v for k,v in zip(params_ordered_dict.keys(), param_tuple)}
    out = functional_call(transformer_model, params_dict,  X)
    return out

torch.autograd.functional.hvp(Call, tuple(X.float()) + param_tuple, v=tuple(torch.zeros(X.shape))+v_tuple)

In [None]:
param_tuple

In [None]:
eigenmodel_transformer._parameters.keys()

In [None]:
tuple(X.float()) + param_tuple

In [None]:
tuple(torch.zeros(X.shape))+v_tuple

In [None]:
with torch.no_grad():
    for i in range(n_u_vectors):
        print(f'-----{i}-----')
        PrintActivatingExamplesTransformer(eigenmodel_transformer, X_transformer[::100,:4], 1,top_k=5, batch_size=4)

In [None]:
eigenmodel_transformer.double_grad_along_u(X_transformer[:2,], u=eigenmodel_transformer._parameters).shape


In [None]:
(list(eigenmodel_transformer._parameters.values())[0]**2).sum(dim=1)

In [None]:
with torch.no_grad():
    PrintFeatureValsTransformer(eigenmodel_transformer, X_transformer[::100,:10], 4, 1)