## attempt using modifed linear probe and training loop from othello_world

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F


# taken from othello_world repo
class LinearProbe(nn.Module):
    # num task is just for folding multiple probes into one
    # probe class is number of possible classes to predict
    def __init__(self, device, probe_class, num_task, input_dim=512):
        super().__init__()
        self.input_dim = input_dim
        self.probe_class = probe_class
        self.num_task = num_task
        self.proj = nn.Linear(self.input_dim, self.probe_class * self.num_task, bias=True)
        self.apply(self._init_weights)
        self.to(device)
        
    def forward(self, act, y=None):
        # [B, f], [B, #task]
        logits = self.proj(act).reshape(-1, self.num_task, self.probe_class)  # [B, #task, C]
        if y is None:
            return logits, None
        else:
            targets = y.to(torch.long)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
            return logits, loss
    
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
                if pn.endswith('bias'):
                    # biases of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        # no_decay.add('pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )
        print("Decayed:", decay)
        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0)
        return optimizer, scheduler

In [2]:
"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""
import os
import logging

from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader

logger = logging.getLogger(__name__)

class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9 # (at what point we reach 10% of original LR)
    # checkpoint settings
    ckpt_path = None
    num_workers = 0 # for DataLoader

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

class Trainer:
    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config

        # take over whatever gpus are on the system
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        if not os.path.exists(self.config.ckpt_path):
            os.makedirs(self.config.ckpt_path)
        torch.save(raw_model.state_dict(), os.path.join(self.config.ckpt_path, "checkpoint.ckpt"))

    def train(self, prt=True):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer, scheduler = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset
            loader = DataLoader(data, shuffle=True, pin_memory=True,
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

            losses = []
            pbar = tqdm(enumerate(loader), total=len(loader), disable=not prt) if is_train else enumerate(loader)
            for it, (x, y) in pbar:
                x = x.to(self.device)  # [B, f]
                y = y.to(self.device)  # [B, #task] 

                with torch.set_grad_enabled(is_train):
                    logits, loss = model(x, y)
                    loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
                    losses.append(loss.item())

                    y_hat = torch.argmax(logits, dim=-1, keepdim=False)  # [B, #task]
                    hits = y_hat == y  # [B, #task]

                if is_train:
                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()
                    mean_loss = float(np.mean(losses))
                    # mean_acc = np.sum(hits_epoch).item() / np.sum(totals_epoch).item()
                    mean_acc = hits.sum().item() / hits.numel()
                    lr = optimizer.param_groups[0]['lr']
                    pbar.set_description(f"epoch {epoch+1}: train loss {mean_loss:.5f}; lr {lr:.2e}; train acc {mean_acc*100:.2f}%")
                    
            if not is_train:
                test_loss = float(np.mean(losses))
                scheduler.step(test_loss)
                # test_acc = np.sum(hits_epoch).item() / np.sum(totals_epoch).item()
                test_acc = hits.sum().item() / hits.numel()
                if prt: 
                    logger.info(f"test loss {test_loss:.5f}; test acc {test_acc*100:.2f}%")
                return test_loss

        best_loss = float('inf')
        self.tokens = 0  # counter used for learning rate decay
        
        for epoch in range(config.max_epochs):
            run_epoch('train')
            if self.test_dataset is not None:
                test_loss = run_epoch('test')
                if test_loss < best_loss:
                    best_loss = test_loss
                    self.save_checkpoint()

In [3]:
import os

from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset

In [4]:
class ProbingDataset(Dataset):
    def __init__(self, act, y):
        assert len(act) == len(y)
        print(f"dataset: {len(act)} pairs loaded...")
        self.act = act
        self.y = y
        print("y:", np.unique(y, return_counts=True))
        
    def __len__(self, ):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.tensor(self.act[idx]), torch.tensor(self.y[idx])

In [33]:
LAYER = 2

act = np.load('63k_X_alllayers.npy')
labels = np.load('63k_Y.npy')

act = act[LAYER, :, :]

print(f"Loaded act: {act.shape}")
print(f"Loaded labels: {labels.shape}")

probing_dataset = ProbingDataset(act, labels)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
probe_train_dataset, probe_test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])

Loaded act: (63820, 768)
Loaded labels: (63820,)
dataset: 63820 pairs loaded...
y: (array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([6343, 6350, 6445, 6423, 6355, 6327, 6444, 6378, 6414, 6341]))


In [34]:
if torch.cuda.is_available():
    device = torch.cuda.current_device()
print(device)

0


In [None]:
probe = LinearProbe(device, 10, 1, input_dim=act.shape[1])

max_epochs = 10
tconf = TrainerConfig(
    max_epochs=max_epochs, 
    batch_size=1024, 
    learning_rate=1e-3,
    betas=(.9, .999), 
    lr_decay=True, 
    warmup_tokens=len(probe_train_dataset)*5, 
    final_tokens=len(probe_test_dataset)*max_epochs,
    num_workers=0, 
    weight_decay=0., 
    ckpt_path=os.path.join(f"./ckpts/testprobe_layer{LAYER}")
)
trainer = Trainer(probe, probe_train_dataset, probe_test_dataset, tconf)
trainer.train(prt=True)
trainer.save_checkpoint()

## super simple linear probe

In [36]:
class ProbingDataset(Dataset):
    def __init__(self, act, y):
        assert len(act) == len(y)
        print(f"dataset: {len(act)} pairs loaded...")
        self.act = act
        self.y = y
        print("y:", np.unique(y, return_counts=True))
        
    def __len__(self, ):
        return len(self.y)
    
    def __getitem__(self, idx):
        return torch.tensor(self.act[idx]), torch.tensor(self.y[idx]).long()

In [37]:
LAYER = 3

act = np.load('63k_X_alllayers.npy')
labels = np.load('63k_Y.npy')

act = act[LAYER, :, :]

print(f"Loaded act: {act.shape}")
print(f"Loaded labels: {labels.shape}")

probing_dataset = ProbingDataset(act, labels)
train_size = int(0.8 * len(probing_dataset))
test_size = len(probing_dataset) - train_size
probe_train_dataset, probe_test_dataset = torch.utils.data.random_split(probing_dataset, [train_size, test_size])
print(f"split into [test/train], [{test_size}/{train_size}]")

Loaded act: (63820, 768)
Loaded labels: (63820,)
dataset: 63820 pairs loaded...
y: (array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([6343, 6350, 6445, 6423, 6355, 6327, 6444, 6378, 6414, 6341]))
split into [test/train], [12764/51056]


In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class LinearProbe(nn.Module):
    def __init__(self, num_input_features, num_classes):
        super(LinearProbe, self).__init__()
        self.linear = nn.Linear(num_input_features, num_classes)
    
    def forward(self, x):
        return self.linear(x)


probe = LinearProbe(768, 10)

config = {
    'learning_rate': 0.001,
    'weight_decay': 1e-3,
    'batch_size': 1024,
    'num_epochs': 50,
}

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(probe.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

dataloader = DataLoader(probe_train_dataset, batch_size=config['batch_size'], shuffle=True)

# training loop
bar = tqdm(range(config['num_epochs']))
for epoch in bar:
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = probe(inputs)

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # train accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    bar.set_description(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.6f}, Acc: {correct/total:.6f}')


Epoch 50, Loss: 0.685244, Acc: 0.739521: 100%|██████████| 50/50 [00:41<00:00,  1.22it/s]


### test accuracy

In [39]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(probe_test_dataset, batch_size=config['batch_size'], shuffle=False)

total = 0
correct = 0

y_pred = []

probe.eval()
with torch.no_grad():
    for inputs, labels in tqdm(test_dataloader):

        outputs = probe(inputs)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        y_pred.append(predicted.cpu().numpy())

print(f'Test Accuracy: {correct/total:.5f}')

y_pred = np.concatenate(y_pred)

100%|██████████| 13/13 [00:00<00:00, 38.37it/s]

Test Accuracy: 0.68482





In [40]:
from sklearn.metrics import classification_report

y_full = np.load('63k_Y.npy')
print(classification_report(y_full[probe_test_dataset.indices], y_pred))

              precision    recall  f1-score   support

         0.0       0.97      0.98      0.98      1227
         1.0       0.93      0.93      0.93      1269
         2.0       0.86      0.86      0.86      1350
         3.0       0.72      0.70      0.71      1279
         4.0       0.57      0.68      0.62      1271
         5.0       0.57      0.48      0.52      1288
         6.0       0.52      0.50      0.51      1303
         7.0       0.50      0.42      0.46      1251
         8.0       0.49      0.54      0.52      1290
         9.0       0.70      0.77      0.74      1236

    accuracy                           0.68     12764
   macro avg       0.68      0.69      0.68     12764
weighted avg       0.68      0.68      0.68     12764

