In [1]:
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset

In [2]:
if torch.cuda.is_available():
    device = torch.cuda.current_device()
print(device)

0


## super simple linear probe

In [3]:
class ProbingDataset(Dataset):
    def __init__(self, act, y):
        assert len(act) == len(y)
        print(f"dataset: {len(act)} pairs loaded...")
        self.act = act
        self.y = y
        print("y:", np.unique(y, return_counts=True))
        
    def __len__(self, ):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.tensor(self.act[idx]), torch.tensor(self.y[idx]).long()

In [4]:
LAYER = 8

act = np.load('63k_X_alllayers.npy')
labels = np.load('63k_Y.npy')

act = act[LAYER, :, :]

print(f"Loaded act: {act.shape}")
print(f"Loaded labels: {labels.shape}")

probing_dataset = ProbingDataset(act, labels)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
probe_train_dataset, probe_test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])
print(f"split into [test/train], [{test_size}/{train_size}]")

Loaded act: (63820, 768)
Loaded labels: (63820,)
dataset: 63820 pairs loaded...
y: (array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([6343, 6350, 6445, 6423, 6355, 6327, 6444, 6378, 6414, 6341]))
split into [test/train], [12764/51056]


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class LinearProbe(nn.Module):
    def __init__(self, num_input_features, num_classes):
        super(LinearProbe, self).__init__()
        self.linear = nn.Linear(num_input_features, num_classes)
    
    def forward(self, x):
        return self.linear(x)


probe = LinearProbe(768, 10)
probe = probe.to(device)

config = {
    'learning_rate': 0.001,
    'weight_decay': 1e-3,
    'batch_size': 1024,
    'num_epochs': 50,
}

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(probe.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

dataloader = DataLoader(probe_train_dataset, batch_size=config['batch_size'], shuffle=True)

# training loop
bar = tqdm(range(config['num_epochs']))
for epoch in bar:
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()

        # Forward pass
        outputs = probe(inputs)

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # train accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    bar.set_description(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.6f}, Acc: {correct/total:.6f}')


Epoch 50, Loss: 0.746705, Acc: 0.678667: 100%|██████████| 50/50 [00:42<00:00,  1.18it/s]


### test accuracy

In [10]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(probe_test_dataset, batch_size=config['batch_size'], shuffle=False)

total = 0
correct = 0

y_pred = []

probe.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = probe(inputs)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        y_pred.append(predicted.cpu().numpy())

print(f'Test Accuracy: {correct/total:.5f}')

y_pred = np.concatenate(y_pred)

100%|██████████| 13/13 [00:00<00:00, 67.23it/s]

Test Accuracy: 0.62191





In [14]:
from sklearn.metrics import classification_report

y_full = np.load('63k_Y.npy')
print(classification_report(y_full[probe_test_dataset.indices], y_pred, digits=3))

              precision    recall  f1-score   support

         0.0      0.985     0.992     0.989      1279
         1.0      0.931     0.924     0.927      1287
         2.0      0.777     0.810     0.794      1266
         3.0      0.628     0.650     0.639      1293
         4.0      0.543     0.516     0.529      1290
         5.0      0.485     0.351     0.407      1269
         6.0      0.385     0.602     0.469      1264
         7.0      0.411     0.344     0.374      1329
         8.0      0.437     0.266     0.330      1258
         9.0      0.617     0.773     0.686      1229

    accuracy                          0.622     12764
   macro avg      0.620     0.623     0.615     12764
weighted avg      0.620     0.622     0.614     12764

