In [1]:
import os, dill
from dev.constants import gdrive_path

import numpy as np
import pandas as pd
import torch as t
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import Subset, Dataset, DataLoader
from sklearn.linear_model import SGDClassifier

from jaxtyping import Float
from tqdm.notebook import tqdm, trange

In [2]:
class IPCCDataset(Dataset):

    def __init__(self, device: str="cuda"):
        self.device = device
        # all prompts used - corresponds to activations harvested
        self.files = os.listdir(f"{gdrive_path}/ipcc/long/prompts")
        # we index over all harvested activations as one dataset
        self.file_ranges = []
        # we will cache all activations and labels for faster loading
        self.cache, self.labels = {}, {}
        ix = 0
        for file in tqdm(self.files, desc="caching data"):
            labels = pd.read_json(f"{gdrive_path}/ipcc/long/prompts/{file}", orient="records", lines=True)
            path = f"{gdrive_path}/ipcc/long/activations"
            f = file.replace('prompts', 'PART')
            choice1 = t.load(f"{path}/{f.replace('.jsonl', '_CHOICE1.pt')}", pickle_module=dill).to(self.device)
            choice2 = t.load(f"{path}/{f.replace('.jsonl', '_CHOICE2.pt')}", pickle_module=dill).to(self.device)
            data = t.concat([choice1-choice2, choice2-choice1])
            labels = t.from_numpy(labels["P(S1)"].values).to(data.dtype).to(device)
            labels = t.concat([labels, 1-labels])
            self.cache[file] = data
            self.labels[file] = labels
            self.file_ranges.append((ix, ix+data.size(0)))
            ix += data.size(0)

    def __len__(self):
        return self.file_ranges[-1][1]

    def __getitem__(self, ix):
        file_ix = next(i for i, (start, end) in enumerate(self.file_ranges) if start <= ix < end)
        file = self.files[file_ix]
        data_ix = ix - self.file_ranges[file_ix][0]
        data = self.cache[file][data_ix]
        label = self.labels[file][data_ix]
        return data, label

In [3]:
dataset = IPCCDataset()
# train/val/test split
trainval_split, valtest_split = int(0.7*len(dataset)), int(0.85*len(dataset))
# shuffle indices and create subsets
perm = t.randperm(len(dataset))
train_dataset = Subset(dataset, perm[:trainval_split])
val_dataset = Subset(dataset, perm[trainval_split:valtest_split])
test_dataset = Subset(dataset, perm[valtest_split:])
# dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

caching data:   0%|          | 0/29 [00:00<?, ?it/s]

In [4]:
lr = SGDClassifier(
    loss="log_loss",
    learning_rate="constant",
    eta0=1e-2
)

n_epoch = 10
desc = "epoch 1"
first = True
for epoch in range(n_epoch):
    bar = tqdm(train_loader)
    bar.set_description(desc)
    for batch, labels in bar:
        batch, labels = batch.cpu().numpy(), labels.cpu().numpy() > 0.5
        if first:
            first = False
            lr.partial_fit(batch, labels, classes=[False, True])
        else: lr.partial_fit(batch, labels)

    scores = []
    for batch, labels in val_loader:
        batch, labels = batch.cpu().numpy(), labels.cpu().numpy() > 0.5
        score = lr.score(batch, labels)
        scores.append(score)
    score = sum(scores) / len(scores)
    desc = f"epoch {epoch+2} ({round(score, 3)})"
scores = []
for batch, labels in test_loader:
    batch, labels = batch.cpu().numpy(), labels.cpu().numpy() > 0.5
    score = lr.score(batch, labels)
    scores.append(score)
score = sum(scores) / len(scores)
print(f"test accuracy: {round(score, 3)}")

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

test accuracy: 0.734


In [10]:
from sklearn.metrics import confusion_matrix


y_pred, y_true = [], []
for batch, labels in test_loader:
    batch, labels = batch.cpu().numpy(), labels.cpu().numpy() > 0.5
    y_pred.append(lr.predict(batch))
    y_true.append(labels)
y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)
cm = confusion_matrix(y_true, y_pred, normalize="true")
cm

In [5]:
lr = SGDClassifier(
    loss="log_loss",
    learning_rate="constant",
    eta0=1e-3
)

n_epoch = 10
desc = "epoch 1"
first = True
for epoch in range(n_epoch):
    bar = tqdm(train_loader)
    bar.set_description(desc)
    for batch, labels in bar:
        batch, labels = batch.cpu().numpy(), labels.cpu().numpy() * 2
        if first:
            first = False
            lr.partial_fit(batch, labels, classes=[0, 1, 2])
        else: lr.partial_fit(batch, labels)

    scores = []
    for batch, labels in val_loader:
        batch, labels = batch.cpu().numpy(), labels.cpu().numpy() * 2
        score = lr.score(batch, labels)
        scores.append(score)
    score = sum(scores) / len(scores)
    desc = f"epoch {epoch+2} ({round(score, 3)})"
scores = []
for batch, labels in test_loader:
    batch, labels = batch.cpu().numpy(), labels.cpu().numpy() * 2
    score = lr.score(batch, labels)
    scores.append(score)
score = sum(scores) / len(scores)
print(f"test accuracy: {round(score, 3)}")

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

  0%|          | 0/11814 [00:00<?, ?it/s]

test accuracy: 0.72
