# Linear probing of sentiment classification in a transformer trained on causal_lm
In this notebook we'll try to find out if and how a transformer trained to do causal_lm process information about the sentiment of a sentence. We'll use the imdb dataset for this.

In [1]:
from transformer_heads import load_headed
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import (
    evaluate_head_wise,
    get_top_n_preds,
    get_some_preds,
)
import torch
import pandas as pd

In [None]:
model_path = "gpt2"
train_epochs = 1
eval_epochs = 1
logging_steps = 100
full_finetuning = False
num_heads = 6

In [None]:
model_params = get_model_params(model_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

We are doing text classification, so we have to set pred_for_sequence to True for this task. In the imdb dataset, we only have two labels, 0 for negative and 1 for positive. So we have to set num_outputs to 2.

In [3]:
head_configs = [
    HeadConfig(
        name=f"imdb_head_{(1+(i-1)*2)}",
        layer_hook=-(1 + (i - 1) * 2),
        in_size=hidden_size,
        output_activation="linear",
        pred_for_sequence=True,
        loss_fct="cross_entropy",
        num_outputs=2,
    )
    for i in range(1, num_heads + 2)
]

In [4]:
dd = load_dataset("imdb")

In [5]:
dd["test"][0]

{'text': 'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\'t match the background, and painfully one-dimensional characters cannot be overcome with a \'sci-fi\' setting. (I\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\'s not. It\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\'s rubbish as 

In the *tokenize_function*, we set the *label* entry in the dataset for each of our heads.

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding=False, truncation=True)
    for hc in head_configs:
        out[hc.name] = examples["label"]
    return out


for split in dd.keys():
    dd[split] = dd[split].filter(function=lambda example: len(example["text"]) > 10)
    dd[split] = dd[split].shuffle()
    dd[split] = dd[split].map(tokenize_function, batched=True)

dd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask"] + [x.name for x in head_configs],
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns(["text", "label"])

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [7]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = load_headed(
    model_class,
    model_path,
    head_configs=head_configs,
    quantization_config=None if full_finetuning else quantization_config,
    freeze_base_model=not full_finetuning,
    device_map={"": torch.cuda.current_device()},
)

Some weights of TransformerWithHeads were not initialized from the model checkpoint at gpt2 and are newly initialized: ['heads.imdb_head_1.lins.0.weight', 'heads.imdb_head_11.lins.0.weight', 'heads.imdb_head_13.lins.0.weight', 'heads.imdb_head_3.lins.0.weight', 'heads.imdb_head_5.lins.0.weight', 'heads.imdb_head_7.lins.0.weight', 'heads.imdb_head_9.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
print_trainable_parameters(model)

all params: 81983232 || trainable params: 10752 || trainable%: 0.01311487695435086
params by dtype: defaultdict(<class 'int'>, {torch.float32: 39515904, torch.uint8: 42467328})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 10752})


Our heads are linear layers with only two outputs. Thus we have a very low amount of trainable parameters.

In [9]:
ins, preds, ground_truths = get_some_preds(
    model, dd["test"], tokenizer, n=5, classification=True
)
print(
    pd.DataFrame(
        list(zip(ins, preds["imdb_head_5"], ground_truths["imdb_head_5"])),
        columns=["review", "label", "ground_truth"],
    )
)

Predicting: 100%|██████████| 5/5 [00:00<00:00, 24.19it/s]

                                              review  label
0  If you repeat a lie enough number of times wil...      1
1  the first time I saw this movie, I just though...      0
2  Having spent all of her money caring for her t...      1
3  This story focuses on the birth defect known a...      1
4  I have viewed this cartoon as a child, a fathe...      1
5  'Sleight of Hand' is my favorite Rockford File...      1





Untrained heads give fairly random outputs.

In [10]:
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
    }
)

In [11]:
print(evaluate_head_wise(model, dd["test"], collator, epochs=eval_epochs))

Evaluating: 100%|██████████| 3125/3125 [05:20<00:00,  9.75it/s]

(7.46833487121582, {'imdb_head_1': 0.8291657885360718, 'imdb_head_3': 1.140816629960537, 'imdb_head_5': 1.822834258645177, 'imdb_head_7': 0.7859305122661591, 'imdb_head_9': 0.9702462933826447, 'imdb_head_11': 1.2253565692448616, 'imdb_head_13': 0.69398482006073})





In [12]:
args = TrainingArguments(
    output_dir="imdb_linear_probe",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,  # To speed things up set to 0.1, set to 1 for better performance
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mykeller[0m ([33mchm-hci[0m). Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
20,5.6135
40,5.5969
60,5.9386
80,5.4176
100,5.2001
120,5.4372
140,5.0923
160,5.0335
180,5.0779
200,5.0335




TrainOutput(global_step=3125, training_loss=4.580280744628906, metrics={'train_runtime': 1120.7626, 'train_samples_per_second': 22.306, 'train_steps_per_second': 2.788, 'total_flos': 8481206894985216.0, 'train_loss': 4.580280744628906, 'epoch': 1.0})

In [13]:
print(evaluate_head_wise(model, dd["test"], collator, epochs=eval_epochs))

Evaluating: 100%|██████████| 3125/3125 [05:35<00:00,  9.32it/s]

(4.326077914962768, {'imdb_head_1': 0.6165544888973236, 'imdb_head_3': 0.5710003126907348, 'imdb_head_5': 0.5809044548034668, 'imdb_head_7': 0.6037158410215377, 'imdb_head_9': 0.6250181088924408, 'imdb_head_11': 0.6354757823753356, 'imdb_head_13': 0.6934089240074157})





In [14]:
ins, preds, ground_truths = get_some_preds(
    model, dd["test"], tokenizer, n=5, classification=True
)
print(
    pd.DataFrame(
        list(zip(ins, preds["imdb_head_5"], ground_truths["imdb_head_5"])),
        columns=["review", "label", "ground_truth"],
    )
)

Predicting: 100%|██████████| 5/5 [00:00<00:00, 33.91it/s]

                                              review  label
0  If you repeat a lie enough number of times wil...      0
1  the first time I saw this movie, I just though...      0
2  Having spent all of her money caring for her t...      1
3  This story focuses on the birth defect known a...      1
4  I have viewed this cartoon as a child, a fathe...      0
5  'Sleight of Hand' is my favorite Rockford File...      1



