In [1]:
import transformer_lens
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn as nn
from sklearn.decomposition import PCA
import numpy as np

import torch.nn.functional as F

from tqdm import tqdm
from scipy.stats import spearmanr

In [2]:
# Load datasets
train_set = load_dataset("sentence-transformers/stsb", split="train")
test_set  = load_dataset("sentence-transformers/stsb", split="test")

first_train  = torch.load('gpt2_medium_train_acts_1.npy')
second_train = torch.load('gpt2_medium_train_acts_2.npy')

first_test = torch.load('gpt2_medium_test_acts_1.npy')
second_test = torch.load('gpt2_medium_test_acts_2.npy')

train_scores = torch.Tensor(train_set['score'])
test_scores = torch.Tensor(test_set['score'])

In [31]:
class SiameseNetwork(nn.Module):
    def __init__(self, d_in=1024, d_hidden=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.LayerNorm(d_hidden),
            nn.GELU(),
            nn.Linear(d_hidden, d_hidden),
            nn.GELU(),
            nn.Linear(d_hidden, d_hidden)
        )

    def forward(self, x1, x2):
        h1 = self.mlp(x1)
        h2 = self.mlp(x2)
        return F.cosine_similarity(h1, h2, dim=-1)

class LayerwiseSiameseNetworks(nn.Module):
    def __init__(self, n_layers=24, d_in=1024, d_hidden=256):
        super().__init__()
        self.n_layers = n_layers
        self.layer_nets = nn.ModuleList([SiameseNetwork(d_in, d_hidden) for _ in range(n_layers)])

    def forward(self, x1, x2):
        similarities = []
        for layer in range(self.n_layers):
            net_at_layer = self.layer_nets[layer].cuda()
            # Get the activations at the given layer
            x1_l = x1[:, layer].cuda()
            x2_l = x2[:, layer].cuda()

            sim_at_layer = net_at_layer(x1_l, x2_l)
            # print(sim_at_layer.shape)
            similarities.append(sim_at_layer)

        return torch.stack(similarities, dim=-1)

class SimilarityLearner(nn.Module):
    def __init__(self, n_layers=24, d_in=1024, d_hidden=256, d_out_hidden=512):
        super().__init__()
        # This is going to output a set of cosine sims. The tensor is of shape
        # [bn, n_layers]
        self.layer_nets = LayerwiseSiameseNetworks(n_layers, d_in, d_hidden)

        self.mlp = nn.Sequential(
            nn.Linear(n_layers, d_out_hidden),
            nn.LayerNorm(d_out_hidden),
            nn.GELU(),
            nn.Linear(d_out_hidden, 1),
            nn.Hardsigmoid()
        )

    def forward(self, x1, x2):
        layerwise_similarities = self.layer_nets(x1, x2)
        similarity = self.mlp(layerwise_similarities)

        return similarity.squeeze()

In [32]:
similarity_learner = SimilarityLearner(d_hidden=512).to('cuda:0')

In [33]:
def train_epoch(model, first_acts, second_acts, scores, optimizer, batch_size=32):
    model.train()
    total_loss = 0
    n_batches = first_acts.shape[0] // batch_size
    
    for batch in range(n_batches):
        start_idx = batch * batch_size
        end_idx = (batch + 1) * batch_size
        
        x1_batch = first_acts[start_idx:end_idx].cuda()
        x2_batch = second_acts[start_idx:end_idx].cuda()
        scores_batch = scores[start_idx:end_idx].cuda()
        
        optimizer.zero_grad()
        
        sims = model(x1_batch, x2_batch)
        
        loss = F.mse_loss(sims, scores_batch)        
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / n_batches

optimizer = torch.optim.AdamW(similarity_learner.parameters(), lr=3e-5)
for epoch in tqdm(range(30)):
    # Training
    loss = train_epoch(similarity_learner, first_train, second_train, train_scores, optimizer)
    
    # Evaluation
    similarity_learner.eval()
    with torch.no_grad():
        # Print progress every 10 epochs
        if (epoch + 1) % 10 == 0:
            # Get training correlations
            train_preds = similarity_learner(first_train.cuda(), second_train.cuda())
            train_corr = torch.corrcoef(torch.stack([train_preds.cpu(), train_scores]))[0,1]
            
            # Get test correlations
            test_preds = similarity_learner(first_test.cuda(), second_test.cuda())
            test_corr = torch.corrcoef(torch.stack([test_preds.cpu(), test_scores]))[0,1]

            print(f"Epoch {epoch+1}")
            print(f"Loss: {loss:.4f}")
            print(f"Train correlation: {train_corr:.4f}")
            print(f"Test correlation: {test_corr:.4f}\n")

 33%|██████████████████████████████████████████████████████████▋                                                                                                                     | 10/30 [00:59<01:56,  5.84s/it]

Epoch 10
Loss: 0.0142
Train correlation: 0.9028
Test correlation: 0.6378



 67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                          | 20/30 [01:58<00:58,  5.81s/it]

Epoch 20
Loss: 0.0054
Train correlation: 0.9539
Test correlation: 0.6737



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [02:57<00:00,  5.92s/it]

Epoch 30
Loss: 0.0031
Train correlation: 0.9646
Test correlation: 0.6816






In [34]:
test_preds = similarity_learner.cuda()(first_test.cuda(), second_test.cuda()).detach().to('cpu')
                          
print(torch.corrcoef(torch.stack([test_preds, test_scores])))
print(spearmanr(test_preds, test_scores))

train_preds = similarity_learner.cuda()(first_train.cuda(), second_train.cuda()).detach().to('cpu')

print(torch.corrcoef(torch.stack([train_preds, train_scores])))
print(spearmanr(train_preds, train_scores))

tensor([[1.0000, 0.6816],
        [0.6816, 1.0000]])
SignificanceResult(statistic=0.6650154085482199, pvalue=8.614542964874223e-177)
tensor([[1.0000, 0.9646],
        [0.9646, 1.0000]])
SignificanceResult(statistic=0.9622781159031129, pvalue=0.0)
