In [1]:
from transformer_heads import load_headed
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel
)
from transformer_heads.util.helpers import DataCollatorWithPadding
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import evaluate_head_wise,get_top_n_preds
import torch

In [2]:
# model_class, model_path = MistralForCausalLM, "mistralai/Mistral-7B-v0.1", 4096, 32000
model_class, model_path, hidden_size, vocab_size = GPT2LMHeadModel, "gpt2", 768, 50257

In [3]:
head_configs = [
    HeadConfig(
        name=f"wikitext_head_{i}",
        layer_hook=-i,
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
    ) for i in range(1,6)
]

In [4]:
dd = load_dataset("wikitext", "wikitext-2-v1")

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding="max_length", truncation=True)
    for hc in head_configs:
        out[hc.name] = out["input_ids"].copy()
    return out


for split in dd.keys():
    dd[split] = dd[split].map(tokenize_function, batched=True)
dd.set_format(
    type="torch", columns=["input_ids", "attention_mask"]+[x.name for x in head_configs]
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns("text")

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [7]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = load_headed(
    model_class,
    model_path,
    head_configs=head_configs,
    quantization_config=quantization_config,
    device_map={"": torch.cuda.current_device()},
)

Some weights of TransformerWithHeads were not initialized from the model checkpoint at gpt2 and are newly initialized: ['heads.wikitext_head_1.lins.0.weight', 'heads.wikitext_head_2.lins.0.weight', 'heads.wikitext_head_3.lins.0.weight', 'heads.wikitext_head_4.lins.0.weight', 'heads.wikitext_head_5.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
print_trainable_parameters(model)

all params: 274959360 || trainable params: 192986880 || trainable%: 70.18741969722362
params by dtype: defaultdict(<class 'int'>, {torch.float32: 232492032, torch.uint8: 42467328})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 192986880})


In [9]:
dd["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'wikitext_head_1', 'wikitext_head_2', 'wikitext_head_3', 'wikitext_head_4', 'wikitext_head_5'],
    num_rows: 36718
})

In [10]:
print(get_top_n_preds(5,model,"The historical significance of",tokenizer))

{'wikitext_head_1': ['dule', 'usp', ' Mars', ' TY', '),'], 'wikitext_head_2': ['iona', ' Ben', ' Farm', 'fits', 'FTWARE'], 'wikitext_head_3': ['trump', ' slit', 'abuse', ' frightened', ' Madonna'], 'wikitext_head_4': ['Cos', 'aniel', '644', 'Rect', '\x1d'], 'wikitext_head_5': [' SpaceX', 'TT', ' AUTHOR', ' orbiting', 'rait']}


In [11]:
args = TrainingArguments(
    output_dir="linear_probe_test",
    learning_rate=0.0002,
    num_train_epochs=0.05,  # To speed things up set to 0.1, set to 1 for better performance
    logging_steps=20,
    do_eval=False,
    remove_unused_columns=False,
)
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
        **{key:-100 for key in head_configs}
    }
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

NameError: name 'heads_configs' is not defined

In [None]:
print(evaluate_head_wise(model, dd["validation"], collator, epochs=0.1))

In [None]:
print(get_top_n_preds(5,model,"The historical significance of",tokenizer))