Hypothesis: We might be representation bottlenecked and that might be contributing to our less than ideal performance. What if we start by pretraining our network to do autoencoding, chop off the decoder, and then do our metric learning on that

In [1]:
import transformer_lens
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn as nn
import numpy as np

import torch.nn.functional as F

from tqdm import tqdm

import random

In [2]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [3]:
from datasets import load_dataset

model = transformer_lens.HookedTransformer.from_pretrained("gpt2-medium")
ds = load_dataset("sentence-transformers/all-nli", "pair-class")
ds_train = ds['train']



Loaded pretrained model gpt2-medium into HookedTransformer


In [4]:
def get_acts(model, prompts):
    import torch
    from tqdm import tqdm
    # The number of layers our model has. GPT2-medium has 24
    layers = range(model.cfg.n_layers)

    # This is going to hold all of our activations. Notice the shape here: [n_prompts, n_layers, d_model]
    data = torch.zeros((len(prompts), len(layers), model.cfg.d_model))

    # For every prompt
    for i, prompt in tqdm(enumerate(prompts)):
        # Do a forward pass with the LLM on said prompt. This function lets us
        # cache the activations.
        _, activations = model.run_with_cache(prompt)

        # For every layer, go through and grab the activation we want at that layer
        # The "[0, -1]" there is just getting the first batch (we do one batch at a time, this
        # could probably be improved) and then the last token at that batch (the last token
        # in the residual stream probably (if some literature is correct) contains the "most
        # information". This is the last token /in the residual stream/, not like "dog" in
        # "John has a dog". We could experiment if this is the right place/token to try but
        # that's for another day
        for j in layers:
            # Store that activation!
            data[i, j] = activations[f'blocks.{j}.hook_resid_post'][0,-1]

    return data

In [5]:
# Let's just take the first 10k sentence pairs

idxs = random.sample(range(len(ds_train)), 10_000)

In [9]:
subset = ds_train.select(idxs)
premises = subset['premise']
hypotheses = subset['hypothesis']
labels = subset['label']

In [10]:
premise_acts = get_acts(model, premises)
hypothesis_acts = get_acts(model, hypotheses)

10000it [09:06, 18.29it/s]
10000it [09:07, 18.27it/s]


In [12]:
torch.save(premise_acts, "premise_acts.pt")
torch.save(hypothesis_acts, "hypothesis_acts.pt")
torch.save(labels, "labels.pt")