# Transformer






In [60]:
import argparse
from pathlib import Path
import torch
from torch import nn
import torch.distributed as dist
from torch.utils.data import TensorDataset, DataLoader
import wandb  # Add Weights & Biases for tracking
import os
import sys
import transformer_lens
# Import AutoTokenizer
from transformers import AutoTokenizer
from datasets import load_dataset
from transformer_lens.utils import tokenize_and_concatenate
# import tokenize_and_concatenate




# Append module directory for imports
# Append module directory for imports
parent_dir = os.path.expanduser('../eigenestimation')
sys.path.append(parent_dir)

from eigenmodel.trainer import Trainer
from eigenmodel.eigenmodel import EigenModel
from utils.utils import TransformDataLoader, DeleteParams, RetrieveWandBArtifact
from utils.loss import KLDivergenceVectorLoss


from toy_models.transformer_wrapper import TransformerWrapper
# Ensure correct device usage
device = "cuda" if torch.cuda.is_available() else "cpu"

from cycling_utils import TimestampedTimer

timer = TimestampedTimer("Imported TimestampedTimer")
from utils.uniform_models import ZeroOutput

[ 2025-02-19 02:10:31 ] Imported TimestampedTimer                                                    0.000 ms,         0.00 s total


## Set up

In [None]:
# @title Import pretrained gpt2 (2 layers)
# Disable fused kernels (FlashAttention and memory-efficient attention)
# We have to disable this to compute second-order gradients on transformer models.
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

# Ensure the math kernel is enabled (it is True by default)
torch.backends.cuda.enable_math_sdp(True)
# Download tinystories-1M, not from hookedtransformer
# https://huggingface.co/datasets/roneneldan/TinyStories-1M

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

tinystories_1m = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-1M')

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token
transformer_model = TransformerWrapper(tinystories_1m, tokenizer, outputs_logits=False)


print( [(name, param.numel()) for name, param in transformer_model.named_parameters()])


# Make the eigenestimation a little smaller but only looking at a subset of the parameters.
# Pick a random subset of tensors to include in paramters, and turn the rest into frozen buffers.
params_to_delete = [name for name, param in transformer_model.named_parameters()]

print('transformer.lm_head.weight' in params_to_delete)

params_to_delete = [p for p in params_to_delete if 
   'transformer.transformer.h.3.mlp.c_fc.weight' not in p]

# Delete 3/4 of the parameters.
#for p in (params_to_delete[::20]):
#  params_to_delete.remove(p)

DeleteParams(transformer_model, params_to_delete)

print(sum([p.numel() for p in transformer_model.parameters()]))
for n,p in transformer_model.named_parameters(): print(n, p.shape, p.numel())

# Load in data.
dataset = load_dataset('roneneldan/TinyStories', split="validation[:1%]")
X_transformer = tokenize_and_concatenate(dataset, transformer_model.tokenizer, max_length = 8, add_bos_token=False)['tokens']


[('transformer.transformer.wte.weight', 3216448), ('transformer.transformer.wpe.weight', 131072), ('transformer.transformer.h.0.ln_1.weight', 64), ('transformer.transformer.h.0.ln_1.bias', 64), ('transformer.transformer.h.0.attn.attention.k_proj.weight', 4096), ('transformer.transformer.h.0.attn.attention.v_proj.weight', 4096), ('transformer.transformer.h.0.attn.attention.q_proj.weight', 4096), ('transformer.transformer.h.0.attn.attention.out_proj.weight', 4096), ('transformer.transformer.h.0.attn.attention.out_proj.bias', 64), ('transformer.transformer.h.0.ln_2.weight', 64), ('transformer.transformer.h.0.ln_2.bias', 64), ('transformer.transformer.h.0.mlp.c_fc.weight', 16384), ('transformer.transformer.h.0.mlp.c_fc.bias', 256), ('transformer.transformer.h.0.mlp.c_proj.weight', 16384), ('transformer.transformer.h.0.mlp.c_proj.bias', 64), ('transformer.transformer.h.1.ln_1.weight', 64), ('transformer.transformer.h.1.ln_1.bias', 64), ('transformer.transformer.h.1.attn.attention.k_proj.wei

In [163]:
transformer_model = TransformerWrapper(tinystories_1m, tokenizer, outputs_logits=False)


print( [(name, param.numel()) for name, param in transformer_model.named_parameters()])


# Make the eigenestimation a little smaller but only looking at a subset of the parameters.
# Pick a random subset of tensors to include in paramters, and turn the rest into frozen buffers.
params_to_delete = [name for name, param in transformer_model.named_parameters()]

print('transformer.lm_head.weight' in params_to_delete)
params_to_delete = [p for p in params_to_delete if 
   'transformer.transformer.h.3.mlp.c_fc.weight' not in p]
print('transformer.lm_head.weight' in params_to_delete)

DeleteParams(transformer_model, params_to_delete)
for n,p in transformer_model.named_parameters():
    print(n, p.shape, p.numel())

[('transformer.transformer.h.3.mlp.c_fc.weight', 16384), ('transformer.lm_head.weight', 3216448)]
True
True
transformer.transformer.h.3.mlp.c_fc.weight torch.Size([256, 64]) 16384


In [143]:
for n,p in transformer_model.named_parameters(): print(n, p.shape, p.numel())


transformer.transformer.h.3.mlp.c_fc.weight torch.Size([256, 64]) 16384
transformer.lm_head.weight torch.Size([50257, 64]) 3216448


In [147]:
params_to_delete = [name for name, param in transformer_model.named_parameters()]

params_to_delete = [p for p in params_to_delete if 
   'transformer.transformer.h.3.mlp.c_fc.weight' not in p]

'transformer.lm_head.weight' in params_to_delete

True

In [148]:
params_to_delete = [name for name, param in transformer_model.named_parameters()]
DeleteParams(transformer_model, params_to_delete)



In [149]:
'transformer.lm_head.weight' in transformer_model.named_parameters()

False

In [150]:
for n,p in transformer_model.named_parameters(): print(n, p.shape, p.numel())
print(sum([p.numel() for p in transformer_model.parameters()]))


0


In [66]:
train_dataset = X_transformer[::10,:8]
eval_dataset = X_transformer[1::10,:8]

In [134]:
a = model(X_transformer[:2,:8]).logits
a

tensor([[[ 7.7580,  2.3624, -7.0801,  ..., -8.2345,  1.1284,  1.7170],
         [-0.2748,  6.6603, -1.8956,  ..., -4.4409, -1.2141,  3.7374],
         [ 6.9880,  3.1286, -4.4858,  ..., -6.7440,  2.3033, -0.4636],
         ...,
         [ 7.2013, -1.0819, -7.8657,  ..., -6.3906, -5.0056, -2.4216],
         [10.4197,  0.0953, -5.4691,  ..., -8.8656, -4.5735, -2.2194],
         [ 1.0557,  1.6034, -6.9371,  ..., -7.0611,  5.1501,  0.2257]],

        [[ 6.8507,  3.1315, -8.0478,  ..., -8.5349, -3.7631,  3.2066],
         [ 2.7750,  4.4037, -4.5321,  ..., -7.9039, -5.9186,  3.0265],
         [-4.1229,  9.9840,  2.3348,  ..., -1.8931,  4.5458,  2.5225],
         ...,
         [14.9818,  4.9248, -8.1071,  ..., -8.0346, -2.0046, -3.2659],
         [ 3.9872,  1.0837, -8.7434,  ..., -7.4886, -3.3030, -1.9253],
         [ 3.8538, -2.9488, -5.6578,  ..., -5.1305, -5.4858, -0.1017]]],
       grad_fn=<UnsafeViewBackward0>)

## Eigenestimation

In [117]:
eigenmodel = EigenModel(transformer_model, ZeroOutput, KLDivergenceVectorLoss(), 10, 10)

for X_batch in DataLoader(train_dataset, batch_size=2):
    predictions = eigenmodel.model(X_batch)
    loss = eigenmodel.loss(predictions, predictions)
    grads = eigenmodel.compute_gradients(X_batch[:2,:2])


HERE!


AttributeError: 'CausalLMOutputWithPast' object has no attribute 'softmax'

In [118]:
train_dataset.shape

torch.Size([5854, 8])

In [119]:
loss.shape

torch.Size([80])

In [120]:
import einops
if True:
            jvp_flattened = einops.rearrange(jvp, '... -> (...)')
            # Get the nth highest value of jvp_flattened
            with torch.no_grad():
                top_k = .1
                sorted_values, _ = torch.sort(abs(jvp_flattened), descending=True)#
                nth_highest_value = sorted_values[round(top_k*len(jvp_flattened))]
                nth_lowest_value = sorted_values[-round(top_k*len(jvp_flattened))]
            jvp_topk = jvp*(abs(jvp)>=nth_highest_value).float()
            jvp_bottomk = jvp*(abs(jvp)<=nth_lowest_value).float()
            
            reconstruction = eigenmodel.reconstruct(jvp_topk)

compute_reconstruction_loss(reconstruction, gradients)

tensor(1., device='cuda:0', grad_fn=<DivBackward0>)

In [111]:
def compute_reconstruction_loss(reconstruction, gradients):
        #'''
        
        A_dot_A = (torch.concat([einops.rearrange(
            reconstruction[name] * reconstruction[name], 
            f'b ... -> b (...)')
                                 for name in gradients
            ], dim=-1).sum(dim=-1)) # batch x params#.mean()#sum(dim=0).mean()
                    
            
        B_dot_B = (torch.concat([einops.rearrange(
                    gradients[name] * gradients[name], 
                    f'b ... -> b (...)')
                for name in gradients
            ], dim=-1).sum(dim=-1)) # batch x params#.mean()#sum(dim=0).mean()(dim=0).mean()
            
        A_dot_B = (torch.concat([einops.rearrange(
                    reconstruction[name] * gradients[name], 
                    'b ... -> b (...)')
                for name in gradients
            ], dim=-1).sum(dim=-1)) # batch x params#.mean()#sum(dim=0).mean()(dim=0).mean()
            
        eps = 1e-10

        L2_error = (A_dot_A*A_dot_A - 2*A_dot_B*A_dot_B + B_dot_B*B_dot_B).sum()/(B_dot_B*B_dot_B + eps).sum()

        return L2_error #diff/baseline

In [11]:
top_texts = RetrieveWandBArtifact(project_path=f"brianna-chrisman-2024/Eigenestimation/eigenestimation_{run_name}", metric_name="TopActivatingTexts")
for feature_idx in top_texts:
    print(f'-----f{feature_idx}------')
    for val, _, _, text in top_texts[feature_idx]:
        print(f'{text}->{round(val, 5)}')

[34m[1mwandb[0m:   1 of 1 files downloaded.  


/root/workspace/eigenestimation/notebooks/artifacts/eigenestimation_tiny_stores_transformer_1m_TopActivatingTexts:v5
-----f0------
 with Max* and* have fun. They ran->0.0
 Max* and* Sue knew they had to help->0.0
 children in the classroom. The end*.*->0.0
-----f1------
 at mom. He looked* at* dad.->0.0
newline"Tom* and* Anna, you are->0.0
 Lily and Ben. They looked* at* Mom->0.0
-----f2------
newlinenewlineBut then Lily remembered what* her*->0.0
 and lots* of* fun.Once upon a->0.0
. What should they do?*newline*newline->0.0
-----f3------
 and their owners. *newline*newlineSuddenly->0.0
 pick them* and* hold them close to her->0.0
, he found a new truck* that*->0.0
-----f4------
.newlinenewlineWhen he got to* the*->0.0
 it. The ball goes faster* and* faster->0.0
, Jack went outside to play* and* saw->0.0
-----f5------
 mom bought him a trumpet*,* but he->0.0
. She says, "I* love* you->0.0
 talk. His mom found him* and* cried->0.0
-----f6------
 bear. They tell them that they* have*->0