# Probe Training Tutorial

We use transformer lens to load models, embeddings etc. 
We use regex and nltk to get tokens matching natural language english words which start with particular letters.
We train linear probes on a balanced dataset, and track binary classification quality. 

For now, this is almost all "under the hood".

To do:
- Find optimal hyper-parameters. 
- Explore internal probes. 
- Understand the directions in more detail.

In [None]:
import sys 
import torch
sys.path.append("..")

from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES

OFFICIAL_MODEL_NAMES[1:10]
model = HookedTransformer.from_pretrained("pythia-2.8b")

In [None]:
from src.probe_training import all_probe_training_runner
import nltk
nltk.download('words')

eff_embed = model.W_E + model.blocks[0].mlp(model.blocks[0].ln2(model.W_E[None] + model.blocks[0].attn.b_O))
vocab = model.tokenizer.get_vocab()
probe_weights_tensor = all_probe_training_runner(
    embeddings=eff_embed[0].detach(),
    vocab=vocab,
    criteria_mode="starts",
    probe_type="linear",
    num_epochs=30,
    batch_size=128,
    learning_rate=0.005,
    train_test_split=0.90,
    rebalance=True,
    use_wandb=True,
    device="mps"
)

# Visualize Predictions

In [57]:
import plotly.express as px

probe_weights = torch.stack([probe_weights_tensor[i].fc.weight for i in probe_weights_tensor.keys()])
probe_biases = torch.stack([probe_weights_tensor[i].fc.bias for i in probe_weights_tensor.keys()])

cat_eff_embed = eff_embed[0,model.to_single_token(" class")]
print(cat_eff_embed.shape)
cat_probe_activations = probe_weights @ cat_eff_embed + probe_biases
px.bar(x=probe_weights_tensor.keys(), y=cat_probe_activations.squeeze().detach().cpu().numpy())

torch.Size([2560])


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed