# Probe Training Tutorial

We use transformer lens to load models, embeddings etc. 
We use regex and nltk to get tokens matching natural language english words which start with particular letters.
We train linear probes on a balanced dataset, and track binary classification quality. 

For now, this is almost all "under the hood".

To do:
- Find optimal hyper-parameters. 
- Explore internal probes. 
- Understand the directions in more detail.

In [2]:
import sys 
sys.path.append("..")

from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

OFFICIAL_MODEL_NAMES[1:10]
model = HookedTransformer.from_pretrained("EleutherAI/gpt-j-6B")

pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
from src.probe_training import all_probe_training_runner

vocab = model.tokenizer.get_vocab()
probe_weights_tensor = all_probe_training_runner(
    embeddings=model.W_E.detach(),
    vocab=vocab,
    alphabet="CAR",
    criteria_mode="starts",
    probe_type="linear",
    num_epochs=4,
    batch_size=32,
    learning_rate=0.005,
    train_test_split=0.95,
    rebalance=True,
    use_wandb=False,
)

# To Visualize Predictions, we can do something like this.

In [None]:
from src.dataset import get_letter_dataset
import nltk
nltk.download('words')

train_loader, test_loader = get_letter_dataset(
    criterion="starts",
    target="A",
    embeddings=model.W_E,
    vocab=vocab,
    batch_size=16,
    rebalance=True,
)

for batch in train_loader:
    print(batch[0][0])
    print(batch[0][1].shape)
    print(batch[1])
    # get proportions of positive and negative labels
    print(sum(batch[1]) / len(batch[1]))
    break

import plotly.express as px
import pandas as pd


def plot_batch_predictions(words, embeddings, probe):
    predictions = probe(embeddings).squeeze()
    predictions = predictions.squeeze()

    # make a dataframe
    df = pd.DataFrame(
        {
            "words": words,
            "predictions": predictions.cpu().detach().numpy(),
            "labels": batch[1].cpu().detach().numpy(),
        }
    )
    # convert labels to bool
    df["labels"] = df["labels"].astype(bool)
    # sort by predictions
    df = df.sort_values(by=["predictions"], ascending=True)

    fig = px.bar(
        df,
        y="words",
        x="predictions",
        color="labels",
        title="Probe weights for letter A",
        labels={
            "x": "Word",
            "y": "Probe Logit",
            "labels": "Starts with A",
        },
        text_auto="0.2f",
        text="words",
        height=800,
        # width=1200,
        template="plotly_white",
        category_orders={"index": df.index[::-1]},
    )
    # update the hover template to make the word much larger
    fig.update_traces(hovertemplate="<b>%{y}</b><br><br>%{x}")

    return fig


reverse_vocab = {v: k for k, v in vocab.items()}
words = [reverse_vocab[i.item()].strip("Ä ") for i in batch[0][0]]
plot_batch_predictions(words, batch[0][1], probe_weights_tensor["A"])

ModuleNotFoundError: No module named 'nltk'