In [1]:
import transformer_lens
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn as nn
from sklearn.decomposition import PCA
import numpy as np

import torch.nn.functional as F

from tqdm import tqdm
from scipy.stats import spearmanr

import random

from copy import deepcopy

In [2]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [3]:
# Load datasets
train_set = load_dataset("sentence-transformers/stsb", split="train")
test_set  = load_dataset("sentence-transformers/stsb", split="test")

first_train  = torch.load('gpt2_medium_train_acts_1.npy')
second_train = torch.load('gpt2_medium_train_acts_2.npy')

first_test = torch.load('gpt2_medium_test_acts_1.npy')
second_test = torch.load('gpt2_medium_test_acts_2.npy')

train_scores = torch.Tensor(train_set['score'])
test_scores = torch.Tensor(test_set['score'])

In [4]:
class SiameseNetwork(nn.Module):
    def __init__(self, d_in=1024, d_hidden=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.LayerNorm(d_hidden),
            nn.GELU(),
            nn.Linear(d_hidden, d_hidden),
            nn.GELU(),
            nn.Linear(d_hidden, d_hidden)
        )

    def forward(self, x1, x2):
        h1 = self.mlp(x1)
        h2 = self.mlp(x2)
        return F.cosine_similarity(h1, h2, dim=-1)

class LayerwiseSiameseNetworks(nn.Module):
    def __init__(self, n_layers=24, d_in=1024, d_hidden=256):
        super().__init__()
        self.n_layers = n_layers
        self.layer_nets = nn.ModuleList([SiameseNetwork(d_in, d_hidden) for _ in range(n_layers)])

    def forward(self, x1, x2):
        similarities = []
        for layer in range(self.n_layers):
            net_at_layer = self.layer_nets[layer].cuda()
            # Get the activations at the given layer
            x1_l = x1[:, layer].cuda()
            x2_l = x2[:, layer].cuda()

            sim_at_layer = net_at_layer(x1_l, x2_l)
            # print(sim_at_layer.shape)
            similarities.append(sim_at_layer)

        return torch.stack(similarities, dim=-1)

class SimilarityLearner(nn.Module):
    def __init__(self, n_layers=24, d_in=1024, d_hidden=256, d_out_hidden=512):
        super().__init__()
        # This is going to output a set of cosine sims. The tensor is of shape
        # [bn, n_layers]
        self.layer_nets = LayerwiseSiameseNetworks(n_layers, d_in, d_hidden)

        self.mlp = nn.Sequential(
            nn.Linear(n_layers, d_out_hidden),
            nn.LayerNorm(d_out_hidden),
            nn.GELU(),
            nn.Linear(d_out_hidden, 1),
            nn.Hardsigmoid()
        )

    def forward(self, x1, x2):
        layerwise_similarities = self.layer_nets(x1, x2)
        similarity = self.mlp(layerwise_similarities)

        return similarity.squeeze()

In [5]:
similarity_learner = SimilarityLearner(d_hidden=512).to('cuda:0')

In [6]:
def train_epoch(model, first_acts, second_acts, scores, optimizer, batch_size=32):
    model.train()
    total_loss = 0
    n_batches = first_acts.shape[0] // batch_size
    
    for batch in range(n_batches):
        start_idx = batch * batch_size
        end_idx = (batch + 1) * batch_size
        
        x1_batch = first_acts[start_idx:end_idx].cuda()
        x2_batch = second_acts[start_idx:end_idx].cuda()
        scores_batch = scores[start_idx:end_idx].cuda()
        
        optimizer.zero_grad()
        
        sims = model(x1_batch, x2_batch)
        
        loss = F.mse_loss(sims, scores_batch)        
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / n_batches

best_score = 0
best_model = None
optimizer = torch.optim.AdamW(similarity_learner.parameters(), lr=3e-5)
for epoch in tqdm(range(100)):
    # Training
    loss = train_epoch(similarity_learner, first_train, second_train, train_scores, optimizer)
    
    # Evaluation
    similarity_learner.eval()
    with torch.no_grad():
        # Get test correlations
        test_preds = similarity_learner(first_test.cuda(), second_test.cuda())
        test_corr = torch.corrcoef(torch.stack([test_preds.cpu(), test_scores]))[0,1]

        if test_corr.item() > best_score:
            best_score = test_corr.item()
            best_model = deepcopy(similarity_learner)
        
        # Print progress every 10 epochs
        if (epoch + 1) % 10 == 0:
            # Get training correlations
            train_preds = similarity_learner(first_train.cuda(), second_train.cuda())
            train_corr = torch.corrcoef(torch.stack([train_preds.cpu(), train_scores]))[0,1]
            

            print(f"Epoch {epoch+1}")
            print(f"Loss: {loss:.4f}")
            print(f"Train correlation: {train_corr:.4f}")
            print(f"Test correlation: {test_corr:.4f}\n")

 10%|█████████████████▌                                                                                                                                                             | 10/100 [00:57<08:47,  5.86s/it]

Epoch 10
Loss: 0.0128
Train correlation: 0.9107
Test correlation: 0.6449



 20%|███████████████████████████████████                                                                                                                                            | 20/100 [02:00<08:08,  6.10s/it]

Epoch 20
Loss: 0.0065
Train correlation: 0.9297
Test correlation: 0.6664



 30%|████████████████████████████████████████████████████▌                                                                                                                          | 30/100 [02:57<06:37,  5.67s/it]

Epoch 30
Loss: 0.0023
Train correlation: 0.9318
Test correlation: 0.6707



 40%|██████████████████████████████████████████████████████████████████████                                                                                                         | 40/100 [03:56<06:20,  6.35s/it]

Epoch 40
Loss: 0.0022
Train correlation: 0.9578
Test correlation: 0.6799



 50%|███████████████████████████████████████████████████████████████████████████████████████▌                                                                                       | 50/100 [04:53<04:42,  5.66s/it]

Epoch 50
Loss: 0.0024
Train correlation: 0.9386
Test correlation: 0.6508



 60%|█████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                      | 60/100 [05:56<04:11,  6.30s/it]

Epoch 60
Loss: 0.0014
Train correlation: 0.9874
Test correlation: 0.7114



 70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                    | 70/100 [06:56<02:58,  5.94s/it]

Epoch 70
Loss: 0.0010
Train correlation: 0.9869
Test correlation: 0.6972



 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                   | 80/100 [07:58<02:06,  6.34s/it]

Epoch 80
Loss: 0.0013
Train correlation: 0.9692
Test correlation: 0.6773



 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                 | 90/100 [08:59<01:02,  6.23s/it]

Epoch 90
Loss: 0.0009
Train correlation: 0.9877
Test correlation: 0.7014



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [10:01<00:00,  6.02s/it]

Epoch 100
Loss: 0.0010
Train correlation: 0.9871
Test correlation: 0.7082






In [7]:
test_preds = best_model.cuda()(first_test.cuda(), second_test.cuda()).detach().to('cpu')
                          
print(torch.corrcoef(torch.stack([test_preds, test_scores])))
print(spearmanr(test_preds, test_scores))

train_preds = best_model.cuda()(first_train.cuda(), second_train.cuda()).detach().to('cpu')

print(torch.corrcoef(torch.stack([train_preds, train_scores])))
print(spearmanr(train_preds, train_scores))

tensor([[1.0000, 0.7163],
        [0.7163, 1.0000]])
SignificanceResult(statistic=0.6972712637102803, pvalue=2.3869221388570265e-201)
tensor([[1.0000, 0.9775],
        [0.9775, 1.0000]])
SignificanceResult(statistic=0.976859642602135, pvalue=0.0)


In [8]:
sum([w.numel() for w in similarity_learner.parameters()])

25241601