In [1]:
from transformer_heads import load_headed
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorWithPadding,
)
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
import torch

In [2]:
model_class, model_path = MistralForCausalLM, "mistralai/Mistral-7B-v0.1"

In [3]:
heads_configs = [
    HeadConfig(
        name="wikitext_head",
        layer_hook=-4,  # Hook to layer [-4] (Drop 3 layers from the end)
        in_size=4096,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=32000,
        is_regression=False,
        output_bias=False,
    )
]

In [4]:
dd = load_dataset("wikitext", "wikitext-2-v1")

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding="max_length", truncation=True)
    out[heads_configs[0].name] = out["input_ids"].copy()
    return out


for split in dd.keys():
    dd[split] = dd[split].map(tokenize_function, batched=True)
dd.set_format(
    type="torch", columns=["input_ids", "attention_mask", heads_configs[0].name]
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns("text")

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [6]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = load_headed(
    model_class,
    model_path,
    head_configs=heads_configs,
    quantization_config=quantization_config,
    device_map={"": torch.cuda.current_device()},
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of TransformerWithHeads were not initialized from the model checkpoint at mistralai/Mistral-7B-v0.1 and are newly initialized: ['heads.wikitext_head.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
print_trainable_parameters(model)

all params: 3752071168 || trainable params: 131072000 || trainable%: 3.493323930469754
params by dtype: defaultdict(<class 'int'>, {torch.float32: 262410240, torch.uint8: 3489660928})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 131072000})


In [8]:
dd["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'wikitext_head'],
    num_rows: 36718
})

In [9]:
args = TrainingArguments(
    output_dir="linear_probe_test",
    learning_rate=0.0002,
    num_train_epochs=0.1,
    logging_steps=5,
    do_eval=False,
    remove_unused_columns=False,
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    eval_dataset=dd["validation"],
    data_collator=DataCollatorWithPadding(tokenizer),
)
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mykeller[0m ([33mchm-hci[0m). Use [1m`wandb login --relogin`[0m to force relogin


[{'input_ids': tensor([    1, 28705,   327,   327,   327,  9444, 19886,   327,   327,   327,
        28705,    13]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'wikitext_head': tensor([    1, 28705,   327,   327,   327,  9444, 19886,   327,   327,   327,
        28705,    13])}, {'input_ids': tensor([1]), 'attention_mask': tensor([1]), 'wikitext_head': tensor([1])}, {'input_ids': tensor([    1, 28705,   415,  5189, 13316,   884,   604,   354,   272, 28705,
            0, 28705, 10746,  5212, 28705, 28770, 28783, 28705,     0, 28705,
         8996,   304, 28705, 28740,  4949,   802, 28733, 28818, 10435,  5582,
          304, 28705, 28740, 28750, 28782, 28705,     0, 28705,   304, 28705,
        28750,  4949,   802, 28733, 28818,  8544,  3693, 17084,   842,   330,
         7681,   604,   354,   272,  5567, 10746,   659, 28705, 28740, 28734,
         5582,   304, 28705, 28740, 28781, 17084,  1200,   354,   264,  3487,
         3102,   302, 28705, 28781, 28774,  5582, 

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`wikitext_head` in this case) have excessive nesting (inputs type `list` where type `int` is expected).