In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformer_lens import HookedTransformer
from interpolated_ffn import ModelWithBilinearLayer, load_layer

In [2]:
device = "cuda"
dtype = torch.bfloat16
model_name = "gemma-2-2b"
model_pretrained = HookedTransformer.from_pretrained_no_processing(model_name, device = device, dtype=dtype)
layer = 18

# Bilinear layer with original weights 
model_bilinear = ModelWithBilinearLayer(model_pretrained, layer)

# Bilinear layers with different finetuning strategies
model_bilinear_logit_mse = ModelWithBilinearLayer(model_pretrained, layer)
load_layer(model_bilinear_logit_mse, "layer-18-step-20000-logit-mse.pt")
model_bilinear_output_mse = ModelWithBilinearLayer(model_pretrained, layer)
load_layer(model_bilinear_output_mse, "layer-18-step-20000-output-mse.pt")
model_bilinear_ce = ModelWithBilinearLayer(model_pretrained, layer)
load_layer(model_bilinear_ce, "layer-18-step-20000-ce.pt")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model gemma-2-2b into HookedTransformer


  model.ffn.load_state_dict(torch.load(name))


In [3]:
data = np.memmap('data_train.bin', dtype=np.uint32, mode='r')

def sample_batch(size, seq_len):
    indices = torch.randint(len(data) - seq_len - 1, (size,))
    xs = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in indices])
    ys = torch.stack([torch.from_numpy(data[i+1:i+seq_len+1].astype(np.int64)) for i in indices])
    return xs.to(device), ys.to(device)

In [4]:
def eval_model(model):
    torch.manual_seed(12345)
    seq_len = 1024
    batch_size = 2
    batches = 500
    loss = 0.0
    with torch.no_grad():
        for i in range(batches):
            tokens, next_tokens = sample_batch(batch_size, seq_len)
            logits = model(tokens)
            loss += F.cross_entropy(logits.view(-1, logits.size(-1)), next_tokens.view(-1)).item()
    return loss / batches

print("Pretrained:", eval_model(model_pretrained))
print("Bilinear untuned:", eval_model(model_bilinear))
print("Bilinear logit reconstruction:", eval_model(model_bilinear_logit_mse))
print("Bilinear layer output reconstruction:", eval_model(model_bilinear_output_mse))
print("Bilinear cross-entropy on next token prediction:", eval_model(model_bilinear_ce))

Pretrained: 3.656546875
Bilinear untuned: 3.829
Bilinear logit reconstruction: 3.689296875
Bilinear layer output reconstruction: 3.733953125
Bilinear cross-entropy on next token prediction: 3.1923125
