In [2]:
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset

In [5]:
if torch.cuda.is_available():
    device = torch.cuda.current_device()
print(device)

0


## super simple linear probe

In [6]:
class ProbingDataset(Dataset):
    def __init__(self, act, y):
        assert len(act) == len(y)
        print(f"dataset: {len(act)} pairs loaded...")
        self.act = act
        self.y = y
        print("y:", np.unique(y, return_counts=True))
        
    def __len__(self, ):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.tensor(self.act[idx]), torch.tensor(self.y[idx]).long()

In [7]:
LAYER = 8

act = np.load('63k_X_alllayers.npy')
labels = np.load('63k_Y.npy')

act = act[LAYER, :, :]

print(f"Loaded act: {act.shape}")
print(f"Loaded labels: {labels.shape}")

probing_dataset = ProbingDataset(act, labels)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
probe_train_dataset, probe_test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])
print(f"split into [test/train], [{test_size}/{train_size}]")

Loaded act: (63820, 768)
Loaded labels: (63820,)
dataset: 63820 pairs loaded...
y: (array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([6343, 6350, 6445, 6423, 6355, 6327, 6444, 6378, 6414, 6341]))
split into [test/train], [12764/51056]


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class LinearProbe(nn.Module):
    def __init__(self, num_input_features, num_classes):
        super(LinearProbe, self).__init__()
        self.linear = nn.Linear(num_input_features, num_classes)
    
    def forward(self, x):
        return self.linear(x)


probe = LinearProbe(768, 10)

config = {
    'learning_rate': 0.001,
    'weight_decay': 1e-3,
    'batch_size': 1024,
    'num_epochs': 50,
}

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(probe.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

dataloader = DataLoader(probe_train_dataset, batch_size=config['batch_size'], shuffle=True)

# training loop
bar = tqdm(range(config['num_epochs']))
for epoch in bar:
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = probe(inputs)

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # train accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    bar.set_description(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.6f}, Acc: {correct/total:.6f}')


  from .autonotebook import tqdm as notebook_tqdm
Epoch 50, Loss: 0.745541, Acc: 0.679391: 100%|██████████| 50/50 [00:38<00:00,  1.29it/s]


### test accuracy

In [9]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(probe_test_dataset, batch_size=config['batch_size'], shuffle=False)

total = 0
correct = 0

y_pred = []

probe.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_dataloader):

        outputs = probe(inputs)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        y_pred.append(predicted.cpu().numpy())

print(f'Test Accuracy: {correct/total:.5f}')

y_pred = np.concatenate(y_pred)

100%|██████████| 13/13 [00:00<00:00, 74.86it/s]

Test Accuracy: 0.62159





In [10]:
from sklearn.metrics import classification_report

y_full = np.load('63k_Y.npy')
print(classification_report(y_full[probe_test_dataset.indices], y_pred))

              precision    recall  f1-score   support

         0.0       0.99      0.98      0.99      1256
         1.0       0.91      0.93      0.92      1274
         2.0       0.79      0.78      0.78      1259
         3.0       0.61      0.71      0.66      1289
         4.0       0.58      0.38      0.45      1284
         5.0       0.44      0.65      0.53      1238
         6.0       0.42      0.29      0.35      1334
         7.0       0.42      0.37      0.39      1259
         8.0       0.41      0.42      0.42      1293
         9.0       0.63      0.72      0.68      1278

    accuracy                           0.62     12764
   macro avg       0.62      0.62      0.62     12764
weighted avg       0.62      0.62      0.61     12764

