# Training a simple linear probe on a transformer model

In this introductory notebook, we will train a simple linear probe for a transformer model to check if causal language modelling understanding for the wikitext dataset is present in a transformer model even at some intermediate layer.

In [1]:
# Standard imports
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from datasets import load_dataset
from peft import LoraConfig
import torch

# Imports from the transformer_heads library
from transformer_heads import load_headed
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import evaluate_head_wise, get_top_n_preds

Set the model and it's model parameters. Default is GPT2, but you can also use Mistral 7b if you have enough GPU RAM (and are willing to wait longer for training to complete)

In [2]:
model_path = "gpt2"
train_epochs = 1
eval_epochs = 1
logging_steps = 100

In [3]:
# Parameters
model_path = "meta-llama/Llama-2-7b-hf"
train_epochs = 0.1
eval_epochs = 0.1
logging_steps = 40


In [4]:
model_params = get_model_params(model_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

{'model_class': <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>, 'hidden_size': 4096, 'vocab_size': 32000}


Start out by configuring the linear probing head. In this example we hook at layer -4. This is using the python indexing format. E.g. gpt-2 has 12 transformer blocks. Hooking at layer -4 means that the linear probe processes the hidden state after the 9th layer while hooking at layer -1 would mean processing the hidden state after the last (12th) transformer block.

Otherwise, we are setting *num_layers* to 1, to make sure that we are actually training a *linear* probe and not an mlp. With *is_causal_lm* we specify the type of task that the model is supposed to learn.

In [5]:
heads_configs = [
    HeadConfig(
        name="wikitext_head",
        layer_hook=-4,  # Hook to layer [-4] (Drop 3 layers from the end)
        in_size=hidden_size,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
    )
]

Now we load and format our dataset. We need to make sure that the dataset has labels stored for each head that we want to train. In case of the causal language modelling task, these labels are just copys of the input_ids.

In [6]:
dd = load_dataset("wikitext", "wikitext-2-v1")

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding=False)
    out[heads_configs[0].name] = out["input_ids"].copy()
    return out


for split in dd.keys():
    dd[split] = dd[split].map(tokenize_function, batched=True)
dd.set_format(
    type="torch", columns=["input_ids", "attention_mask", heads_configs[0].name]
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns("text")

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [8]:
dd["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'wikitext_head'],
    num_rows: 36718
})

Now it is time to load our model. The load_headed function of transformer_heads is great for loading frozen models with a linear probe. To save GPU memory, we will load the model in a quantized state. 

In [9]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = load_headed(
    model_class,
    model_path,
    head_configs=heads_configs,
    quantization_config=quantization_config,
    device_map={"": torch.cuda.current_device()},
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of TransformerWithHeads were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['heads.wikitext_head.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


That warning about weights not initialized from the model checkpoint is exactly what we want to see. We want a newly initialized linear probe that is not in the pretrained gpt2 checkpoint.

Let's check some data about our model using the convenience method *print_trainable_parameters*.

In [10]:
print_trainable_parameters(model)

all params: 3500412928 || trainable params: 131072000 || trainable%: 3.7444725149866662
params by dtype: defaultdict(<class 'int'>, {torch.float32: 262410240, torch.uint8: 3238002688})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 131072000})


Given that gpt2 is a fairly small model with a large vocab size, our single linear probe already has quite a lot of parameters compared to the rest of the model. Every parameter in the model that is not part of the linear probe is frozen (has requires_grad set to false).

Let's see how our linear probe performs before it is trained

In [11]:
print(
    get_top_n_preds(
        n=5, model=model, text="The historical significance of", tokenizer=tokenizer
    )
)

{'wikitext_head': ['TAC', '–£–∫—Ä–∞—ó', 'utf', 'Mik', 'Lond']}


As expected, this is pretty random.

Let's train the linear probe now using huggingfaces simple to use Trainer class. Note that we are using a custom collator here, to handle the labels under the heads_configs names correctly.

In [12]:
args = TrainingArguments(
    output_dir="linear_probe_test",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,  # Important to set to False, otherwise things will fail
)
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        heads_configs[0].name: -100,
        "attention_mask": 0,
    }
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mykeller[0m ([33mchm-hci[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: wandb version 0.16.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Tracking run with wandb version 0.16.3


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/raven/u/ykeller/transformer_heads/wandb/run-20240322_140058-trv2j07o[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33myouthful-cherry-189[0m


[34m[1mwandb[0m: ‚≠êÔ∏è View project at [34m[4mhttps://wandb.ai/chm-hci/huggingface[0m


[34m[1mwandb[0m: üöÄ View run at [34m[4mhttps://wandb.ai/chm-hci/huggingface/runs/trv2j07o[0m


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...




Step,Training Loss
40,10.2816
80,7.2887
120,6.33
160,5.6379
200,5.4088
240,5.1548
280,4.8757
320,4.9828
360,4.8183
400,4.6848


TrainOutput(global_step=459, training_loss=5.78458249958512, metrics={'train_runtime': 2347.1457, 'train_samples_per_second': 1.564, 'train_steps_per_second': 0.196, 'total_flos': 3.79067003362345e+16, 'train_loss': 5.78458249958512, 'epoch': 0.1})

So this is nice to see, the probe is learning something and the training loss decreases. But how about evaluation on the validation set?

In [13]:
print(evaluate_head_wise(model, dd["validation"], collator, epochs=eval_epochs))

Evaluating:   0%|          | 0/47.0 [00:00<?, ?it/s]



Evaluating:   2%|‚ñè         | 1/47.0 [00:01<01:06,  1.44s/it]

Evaluating:   4%|‚ñç         | 2/47.0 [00:02<00:46,  1.04s/it]

Evaluating:   6%|‚ñã         | 3/47.0 [00:03<00:53,  1.22s/it]

Evaluating:   9%|‚ñä         | 4/47.0 [00:04<00:43,  1.02s/it]

Evaluating:  11%|‚ñà         | 5/47.0 [00:05<00:44,  1.06s/it]

Evaluating:  13%|‚ñà‚ñé        | 6/47.0 [00:06<00:46,  1.13s/it]

Evaluating:  15%|‚ñà‚ñç        | 7/47.0 [00:08<00:58,  1.47s/it]

Evaluating:  17%|‚ñà‚ñã        | 8/47.0 [00:10<01:04,  1.65s/it]

Evaluating:  19%|‚ñà‚ñâ        | 9/47.0 [00:13<01:15,  1.99s/it]

Evaluating:  21%|‚ñà‚ñà‚ñè       | 10/47.0 [00:14<00:59,  1.62s/it]

Evaluating:  23%|‚ñà‚ñà‚ñé       | 11/47.0 [00:16<01:07,  1.89s/it]

Evaluating:  26%|‚ñà‚ñà‚ñå       | 12/47.0 [00:17<00:53,  1.52s/it]

Evaluating:  28%|‚ñà‚ñà‚ñä       | 13/47.0 [00:20<01:00,  1.78s/it]

Evaluating:  30%|‚ñà‚ñà‚ñâ       | 14/47.0 [00:21<00:55,  1.68s/it]

Evaluating:  32%|‚ñà‚ñà‚ñà‚ñè      | 15/47.0 [00:23<00:57,  1.78s/it]

Evaluating:  34%|‚ñà‚ñà‚ñà‚ñç      | 16/47.0 [00:26<01:05,  2.12s/it]

Evaluating:  36%|‚ñà‚ñà‚ñà‚ñå      | 17/47.0 [00:28<01:07,  2.24s/it]

Evaluating:  38%|‚ñà‚ñà‚ñà‚ñä      | 18/47.0 [00:30<01:03,  2.18s/it]

Evaluating:  40%|‚ñà‚ñà‚ñà‚ñà      | 19/47.0 [00:32<00:56,  2.01s/it]

Evaluating:  43%|‚ñà‚ñà‚ñà‚ñà‚ñé     | 20/47.0 [00:34<00:51,  1.90s/it]

Evaluating:  45%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 21/47.0 [00:35<00:42,  1.65s/it]

Evaluating:  47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 22/47.0 [00:35<00:32,  1.31s/it]

Evaluating:  49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 23/47.0 [00:37<00:34,  1.46s/it]

Evaluating:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 24/47.0 [00:39<00:36,  1.57s/it]

Evaluating:  53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 25/47.0 [00:41<00:36,  1.66s/it]

Evaluating:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 26/47.0 [00:43<00:35,  1.70s/it]

Evaluating:  57%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 27/47.0 [00:44<00:34,  1.73s/it]

Evaluating:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 28/47.0 [00:47<00:37,  1.97s/it]

Evaluating:  62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 29/47.0 [00:50<00:41,  2.31s/it]

Evaluating:  64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 30/47.0 [00:52<00:39,  2.33s/it]

Evaluating:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 31/47.0 [00:54<00:34,  2.13s/it]

Evaluating:  68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 32/47.0 [00:55<00:27,  1.83s/it]

Evaluating:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 33/47.0 [00:56<00:19,  1.43s/it]

Evaluating:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 34/47.0 [00:57<00:19,  1.50s/it]

Evaluating:  74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 35/47.0 [01:00<00:21,  1.81s/it]

Evaluating:  77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 36/47.0 [01:02<00:19,  1.81s/it]

Evaluating:  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 37/47.0 [01:02<00:14,  1.45s/it]

Evaluating:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 38/47.0 [01:03<00:10,  1.20s/it]

Evaluating:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 39/47.0 [01:03<00:08,  1.02s/it]

Evaluating:  85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 40/47.0 [01:05<00:07,  1.03s/it]

Evaluating:  87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 41/47.0 [01:05<00:05,  1.07it/s]

Evaluating:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 42/47.0 [01:06<00:03,  1.28it/s]

Evaluating:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 43/47.0 [01:06<00:02,  1.44it/s]

Evaluating:  94%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 44/47.0 [01:07<00:02,  1.15it/s]

Evaluating:  96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 45/47.0 [01:09<00:02,  1.06s/it]

Evaluating:  98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 46/47.0 [01:11<00:01,  1.24s/it]

Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 47/47.0 [01:12<00:00,  1.30s/it]

Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 47/47.0 [01:13<00:00,  1.57s/it]

(4.309780701994896, {'wikitext_head': 4.309780701994896})





Yep, the evaluation loss is similar to the training loss, indicating no overfitting. How do things look in a practical example?

In [14]:
print(get_top_n_preds(5, model, "The historical significance of", tokenizer))

{'wikitext_head': ['the', '', 'these', 'church', 'those']}


We see that the linear probe has learned to predict tokens that are pretty likely to follow that sentence.