<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [12]</a>'.</span>

# Joint multitask

The point of this notebook is not to do anything useful, but to show what is possible with relatively low effort using the *transformer_heads* library. In this example, we will train four heads on a transformer model while using qlora to finetune the transformer block weights. The first head will be hooked at layer 9 (-4) and predict the sentiment of imdb reviews (text classification). The second head will be hooked at the last layer (-1 or 12) and does causal language modelling on imdb reviews. The third head will be hooked at layer 6 (-7) and will learn to count the number of occurences of each letter of the alphabet occuring in imdb reviews (Text-level regression). The final head will be hooked at layer 4 (-9) and will predict how many tokens will follow before the review ends for each token in imdb reviews (Token-level regression). The final head will also be a small mlp instead of a linear head.

All heads and the qlora parameters will be trained jointly (multi-task learning).

In [1]:
from transformer_heads import create_headed_qlora, load_lora_with_heads
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import (
    evaluate_head_wise,
)
import torch
import pandas as pd

In [2]:
model_path = "gpt2"
train_epochs = 1
eval_epochs = 1
logging_steps = 100

In [3]:
# Parameters
model_path = "mistralai/Mistral-7B-v0.1"
train_epochs = 1
eval_epochs = 1
logging_steps = 100


In [4]:
model_params = get_model_params(model_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

{'model_class': <class 'transformers.models.mistral.modeling_mistral.MistralForCausalLM'>, 'hidden_size': 4096, 'vocab_size': 32000}


Define the various different heads. Given the differences in loss functions and magnitudes of label data in the dataset, it is important to weigh the losses of each head so that training is given similar importance for all of them.

In [5]:
head_configs = [
    HeadConfig(
        name=f"sentiment_head",
        layer_hook=-4,
        in_size=hidden_size,
        output_activation="linear",
        pred_for_sequence=True,
        loss_fct="cross_entropy",
        num_outputs=2,
        loss_weight=2.0,
    ),
    HeadConfig(
        name=f"causal_lm",
        layer_hook=-1,
        in_size=hidden_size,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
        loss_weight=1.0,
    ),
    HeadConfig(
        name=f"alphabet_regression",
        layer_hook=-7,
        in_size=hidden_size,
        output_activation="linear",
        is_causal_lm=False,
        pred_for_sequence=True,
        loss_fct="mse",
        num_outputs=26,  # 26 letters in the alphabet
        is_regression=True,
        loss_weight=0.002,
    ),
    HeadConfig(
        name=f"num_tokens_regression",
        layer_hook=-7,
        hidden_size=128,  # MLP hidden size
        num_layers=3,  # 2 hidden layers in MLP
        in_size=hidden_size,
        output_activation="linear",
        is_causal_lm=False,
        pred_for_sequence=False,
        loss_fct="mse",
        num_outputs=1,
        is_regression=True,
        loss_weight=0.0002,
    ),
    HeadConfig(
        name=f"lm_head",  # Let's also keep the original lm head for comparison
        layer_hook=-1,
        in_size=hidden_size,
        output_activation="linear",
        is_causal_lm=True,
        pred_for_sequence=False,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        trainable=False,  # Keep it in it's pretrained state
    ),
]

In [6]:
dd = load_dataset("imdb")

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def processing_function(examples):
    out = tokenizer(examples["text"], padding=False, truncation=True)
    out["sentiment_head"] = examples["label"]
    out["causal_lm"] = out["lm_head"] = out["input_ids"].copy()
    out["num_tokens_regression"] = [
        list(map(float, range(len(ids) - 1, -1, -1))) for ids in out["input_ids"]
    ]
    out["alphabet_regression"] = [
        [
            float(text.count(x) + text.count(x.upper()))
            for x in "abcdefghijklmnopqrstuvwxyz"
        ]
        for text in examples["text"]
    ]
    return out


for split in dd.keys():
    dd[split] = dd[split].filter(function=lambda example: len(example["text"]) > 10)
    dd[split] = dd[split].shuffle()
    dd[split] = dd[split].map(processing_function, batched=True)

dd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask"] + [x.name for x in head_configs],
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns(["text", "label"])

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [8]:
dd["test"]

Dataset({
    features: ['input_ids', 'attention_mask', 'sentiment_head', 'causal_lm', 'lm_head', 'num_tokens_regression', 'alphabet_regression'],
    num_rows: 25000
})

Setting *target_modules=None* in the qlora config will make *create_headed_qlora* create LoRA modules for all linear layers in the transformer.

In [9]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)
lora_config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=None,
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
model = create_headed_qlora(
    base_model_class=model_class,
    model_name=model_path,
    quantization_config=quantization_config,
    lora_config=lora_config,
    head_configs=head_configs,
    fully_trained_heads=True,
    device_map={"": torch.cuda.current_device()},
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of TransformerWithHeads were not initialized from the model checkpoint at mistralai/Mistral-7B-v0.1 and are newly initialized: ['heads.alphabet_regression.lins.0.weight', 'heads.causal_lm.lins.0.weight', 'heads.num_tokens_regression.lins.0.bias', 'heads.num_tokens_regression.lins.0.weight', 'heads.num_tokens_regression.lins.1.bias', 'heads.num_tokens_regression.lins.1.weight', 'heads.num_tokens_regression.lins.2.weight', 'heads.sentiment_head.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
print_trainable_parameters(model)

all params: 3967684992 || trainable params: 215613824 || trainable%: 5.434247538167465
params by dtype: defaultdict(<class 'int'>, {torch.float32: 478024064, torch.uint8: 3489660928})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 215613824})


In [11]:
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
        "causal_lm": -100,
        "lm_head": -100,
        "num_tokens_regression": 0,
    }
)

<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [12]:
print(evaluate_head_wise(model, dd["test"], collator, epochs=eval_epochs))

Evaluating:   0%|          | 0/3125 [00:00<?, ?it/s]

Evaluating:   0%|          | 1/3125 [00:04<3:54:04,  4.50s/it]

Evaluating:   0%|          | 2/3125 [00:10<4:49:00,  5.55s/it]

Evaluating:   0%|          | 3/3125 [00:15<4:38:54,  5.36s/it]

Evaluating:   0%|          | 4/3125 [00:19<4:05:01,  4.71s/it]

Evaluating:   0%|          | 5/3125 [00:29<5:33:08,  6.41s/it]

Evaluating:   0%|          | 6/3125 [00:35<5:34:51,  6.44s/it]

Evaluating:   0%|          | 7/3125 [00:45<6:41:06,  7.72s/it]

Evaluating:   0%|          | 8/3125 [00:50<5:48:16,  6.70s/it]

Evaluating:   0%|          | 9/3125 [01:02<7:07:57,  8.24s/it]

Evaluating:   0%|          | 10/3125 [01:10<7:10:06,  8.28s/it]

Evaluating:   0%|          | 11/3125 [01:16<6:30:26,  7.52s/it]

Evaluating:   0%|          | 12/3125 [01:23<6:31:21,  7.54s/it]

Evaluating:   0%|          | 13/3125 [01:27<5:30:37,  6.37s/it]

Evaluating:   0%|          | 14/3125 [01:30<4:40:29,  5.41s/it]

Evaluating:   0%|          | 15/3125 [01:35<4:24:34,  5.10s/it]

Evaluating:   1%|          | 16/3125 [01:44<5:30:46,  6.38s/it]

Evaluating:   1%|          | 17/3125 [01:48<4:56:04,  5.72s/it]

Evaluating:   1%|          | 18/3125 [01:51<4:14:54,  4.92s/it]

Evaluating:   1%|          | 19/3125 [01:58<4:50:53,  5.62s/it]

Evaluating:   1%|          | 20/3125 [02:04<4:53:44,  5.68s/it]

Evaluating:   1%|          | 21/3125 [02:12<5:23:26,  6.25s/it]

Evaluating:   1%|          | 22/3125 [02:17<4:59:59,  5.80s/it]

Evaluating:   1%|          | 23/3125 [02:24<5:23:27,  6.26s/it]

Evaluating:   1%|          | 24/3125 [02:28<4:50:36,  5.62s/it]

Evaluating:   1%|          | 25/3125 [02:33<4:41:17,  5.44s/it]

Evaluating:   1%|          | 26/3125 [02:39<4:52:41,  5.67s/it]

Evaluating:   1%|          | 27/3125 [02:43<4:27:47,  5.19s/it]

Evaluating:   1%|          | 28/3125 [02:49<4:30:29,  5.24s/it]

Evaluating:   1%|          | 29/3125 [02:53<4:15:04,  4.94s/it]

Evaluating:   1%|          | 30/3125 [02:58<4:12:17,  4.89s/it]

Evaluating:   1%|          | 31/3125 [03:03<4:14:22,  4.93s/it]

Evaluating:   1%|          | 32/3125 [03:06<3:44:58,  4.36s/it]

Evaluating:   1%|          | 33/3125 [03:11<4:00:01,  4.66s/it]

Evaluating:   1%|          | 34/3125 [03:15<3:53:25,  4.53s/it]

Evaluating:   1%|          | 35/3125 [03:21<4:05:54,  4.77s/it]

Evaluating:   1%|          | 36/3125 [03:25<3:52:27,  4.52s/it]

Evaluating:   1%|          | 37/3125 [03:32<4:40:05,  5.44s/it]

Evaluating:   1%|          | 38/3125 [03:37<4:29:15,  5.23s/it]

Evaluating:   1%|          | 39/3125 [03:44<4:52:57,  5.70s/it]

Evaluating:   1%|▏         | 40/3125 [03:51<5:20:15,  6.23s/it]

Evaluating:   1%|▏         | 41/3125 [03:58<5:27:09,  6.36s/it]

Evaluating:   1%|▏         | 42/3125 [04:03<5:14:15,  6.12s/it]

Evaluating:   1%|▏         | 43/3125 [04:09<5:10:01,  6.04s/it]

Evaluating:   1%|▏         | 44/3125 [04:12<4:11:27,  4.90s/it]

Evaluating:   1%|▏         | 45/3125 [04:15<3:46:27,  4.41s/it]

Evaluating:   1%|▏         | 46/3125 [04:20<3:53:06,  4.54s/it]

Evaluating:   2%|▏         | 47/3125 [04:24<3:52:37,  4.53s/it]

Evaluating:   2%|▏         | 48/3125 [04:31<4:28:05,  5.23s/it]

Evaluating:   2%|▏         | 49/3125 [04:42<5:56:54,  6.96s/it]

Evaluating:   2%|▏         | 50/3125 [04:48<5:42:00,  6.67s/it]

Evaluating:   2%|▏         | 51/3125 [04:57<6:16:31,  7.35s/it]

Evaluating:   2%|▏         | 52/3125 [05:02<5:44:28,  6.73s/it]

Evaluating:   2%|▏         | 53/3125 [05:04<4:32:58,  5.33s/it]

Evaluating:   2%|▏         | 54/3125 [05:09<4:30:43,  5.29s/it]

Evaluating:   2%|▏         | 55/3125 [05:17<4:59:46,  5.86s/it]

Evaluating:   2%|▏         | 56/3125 [05:23<5:02:31,  5.91s/it]

Evaluating:   2%|▏         | 57/3125 [05:30<5:22:44,  6.31s/it]

Evaluating:   2%|▏         | 58/3125 [05:34<4:44:11,  5.56s/it]

Evaluating:   2%|▏         | 59/3125 [05:38<4:29:47,  5.28s/it]

Evaluating:   2%|▏         | 60/3125 [05:43<4:12:06,  4.94s/it]

Evaluating:   2%|▏         | 61/3125 [05:52<5:26:25,  6.39s/it]

Evaluating:   2%|▏         | 62/3125 [05:58<5:14:09,  6.15s/it]

Evaluating:   2%|▏         | 63/3125 [06:00<4:13:31,  4.97s/it]

Evaluating:   2%|▏         | 64/3125 [06:03<3:46:45,  4.44s/it]

Evaluating:   2%|▏         | 65/3125 [06:10<4:19:29,  5.09s/it]

Evaluating:   2%|▏         | 66/3125 [06:17<4:42:56,  5.55s/it]

Evaluating:   2%|▏         | 67/3125 [06:19<3:58:18,  4.68s/it]

Evaluating:   2%|▏         | 68/3125 [06:24<4:02:20,  4.76s/it]

Evaluating:   2%|▏         | 69/3125 [06:29<3:56:49,  4.65s/it]

Evaluating:   2%|▏         | 70/3125 [06:34<4:07:21,  4.86s/it]

Evaluating:   2%|▏         | 71/3125 [06:42<4:54:24,  5.78s/it]

Evaluating:   2%|▏         | 72/3125 [06:47<4:42:35,  5.55s/it]

Evaluating:   2%|▏         | 73/3125 [06:56<5:40:25,  6.69s/it]

Evaluating:   2%|▏         | 74/3125 [07:01<5:13:14,  6.16s/it]

Evaluating:   2%|▏         | 75/3125 [07:08<5:20:42,  6.31s/it]

Evaluating:   2%|▏         | 76/3125 [07:14<5:23:31,  6.37s/it]

Evaluating:   2%|▏         | 77/3125 [07:18<4:41:26,  5.54s/it]

Evaluating:   2%|▏         | 78/3125 [07:23<4:28:25,  5.29s/it]

Evaluating:   3%|▎         | 79/3125 [07:30<5:04:15,  5.99s/it]

Evaluating:   3%|▎         | 80/3125 [07:33<4:13:13,  4.99s/it]

Evaluating:   3%|▎         | 81/3125 [07:37<4:05:47,  4.84s/it]

Evaluating:   3%|▎         | 82/3125 [07:40<3:37:08,  4.28s/it]

Evaluating:   3%|▎         | 83/3125 [07:47<4:06:48,  4.87s/it]

Evaluating:   3%|▎         | 84/3125 [07:59<6:00:55,  7.12s/it]

Evaluating:   3%|▎         | 85/3125 [08:02<5:01:51,  5.96s/it]

Evaluating:   3%|▎         | 86/3125 [08:07<4:51:13,  5.75s/it]

Evaluating:   3%|▎         | 87/3125 [08:13<4:50:12,  5.73s/it]

Evaluating:   3%|▎         | 88/3125 [08:20<5:03:48,  6.00s/it]

Evaluating:   3%|▎         | 89/3125 [08:27<5:27:57,  6.48s/it]

Evaluating:   3%|▎         | 90/3125 [08:35<5:44:30,  6.81s/it]

Evaluating:   3%|▎         | 91/3125 [08:38<4:44:47,  5.63s/it]

Evaluating:   3%|▎         | 92/3125 [08:42<4:16:04,  5.07s/it]

Evaluating:   3%|▎         | 93/3125 [08:45<3:56:00,  4.67s/it]

Evaluating:   3%|▎         | 94/3125 [08:50<4:01:09,  4.77s/it]

Evaluating:   3%|▎         | 95/3125 [08:54<3:44:04,  4.44s/it]

Evaluating:   3%|▎         | 96/3125 [08:58<3:42:03,  4.40s/it]

Evaluating:   3%|▎         | 97/3125 [09:01<3:11:54,  3.80s/it]

Evaluating:   3%|▎         | 98/3125 [09:12<5:11:28,  6.17s/it]

Evaluating:   3%|▎         | 99/3125 [09:17<4:40:58,  5.57s/it]

Evaluating:   3%|▎         | 100/3125 [09:20<4:05:27,  4.87s/it]

Evaluating:   3%|▎         | 101/3125 [09:23<3:37:49,  4.32s/it]

Evaluating:   3%|▎         | 102/3125 [09:30<4:27:26,  5.31s/it]

Evaluating:   3%|▎         | 103/3125 [09:37<4:42:39,  5.61s/it]

Evaluating:   3%|▎         | 104/3125 [09:40<4:10:28,  4.97s/it]

Evaluating:   3%|▎         | 105/3125 [09:46<4:18:02,  5.13s/it]

Evaluating:   3%|▎         | 106/3125 [09:50<4:07:31,  4.92s/it]

Evaluating:   3%|▎         | 107/3125 [09:57<4:42:26,  5.62s/it]

Evaluating:   3%|▎         | 108/3125 [10:04<4:52:06,  5.81s/it]

Evaluating:   3%|▎         | 109/3125 [10:07<4:18:39,  5.15s/it]

Evaluating:   4%|▎         | 110/3125 [10:12<4:12:46,  5.03s/it]

Evaluating:   4%|▎         | 111/3125 [10:17<4:17:11,  5.12s/it]

Evaluating:   4%|▎         | 112/3125 [10:25<4:55:10,  5.88s/it]

Evaluating:   4%|▎         | 113/3125 [10:32<5:05:38,  6.09s/it]

Evaluating:   4%|▎         | 114/3125 [10:42<6:09:43,  7.37s/it]

Evaluating:   4%|▎         | 115/3125 [10:51<6:33:55,  7.85s/it]

Evaluating:   4%|▎         | 116/3125 [10:53<5:12:53,  6.24s/it]

Evaluating:   4%|▎         | 117/3125 [11:00<5:15:50,  6.30s/it]

Evaluating:   4%|▍         | 118/3125 [11:05<4:54:28,  5.88s/it]

Evaluating:   4%|▍         | 119/3125 [11:09<4:28:33,  5.36s/it]

Evaluating:   4%|▍         | 120/3125 [11:13<4:02:40,  4.85s/it]

Evaluating:   4%|▍         | 121/3125 [11:17<3:54:46,  4.69s/it]

Evaluating:   4%|▍         | 122/3125 [11:20<3:30:00,  4.20s/it]

Evaluating:   4%|▍         | 123/3125 [11:25<3:47:12,  4.54s/it]

Evaluating:   4%|▍         | 124/3125 [11:30<3:54:32,  4.69s/it]

Evaluating:   4%|▍         | 125/3125 [11:36<4:13:18,  5.07s/it]

Evaluating:   4%|▍         | 126/3125 [11:39<3:42:51,  4.46s/it]

Evaluating:   4%|▍         | 127/3125 [11:45<3:59:26,  4.79s/it]

Evaluating:   4%|▍         | 128/3125 [11:47<3:21:35,  4.04s/it]

Evaluating:   4%|▍         | 129/3125 [11:50<2:57:19,  3.55s/it]

Evaluating:   4%|▍         | 130/3125 [11:54<3:05:03,  3.71s/it]

Evaluating:   4%|▍         | 131/3125 [12:00<3:47:05,  4.55s/it]

Evaluating:   4%|▍         | 132/3125 [12:04<3:30:35,  4.22s/it]

Evaluating:   4%|▍         | 133/3125 [12:07<3:12:55,  3.87s/it]

Evaluating:   4%|▍         | 134/3125 [12:13<3:43:48,  4.49s/it]

Evaluating:   4%|▍         | 135/3125 [12:15<3:16:02,  3.93s/it]

Evaluating:   4%|▍         | 136/3125 [12:21<3:46:00,  4.54s/it]

Evaluating:   4%|▍         | 137/3125 [12:25<3:41:55,  4.46s/it]

Evaluating:   4%|▍         | 138/3125 [12:33<4:29:23,  5.41s/it]

Evaluating:   4%|▍         | 139/3125 [12:43<5:36:06,  6.75s/it]

Evaluating:   4%|▍         | 140/3125 [12:48<5:12:40,  6.28s/it]

Evaluating:   5%|▍         | 141/3125 [12:52<4:32:35,  5.48s/it]

Evaluating:   5%|▍         | 142/3125 [12:57<4:23:33,  5.30s/it]

Evaluating:   5%|▍         | 143/3125 [13:00<3:57:51,  4.79s/it]

Evaluating:   5%|▍         | 144/3125 [13:03<3:34:32,  4.32s/it]

Evaluating:   5%|▍         | 145/3125 [13:08<3:42:56,  4.49s/it]

Evaluating:   5%|▍         | 146/3125 [13:12<3:35:08,  4.33s/it]

Evaluating:   5%|▍         | 147/3125 [13:16<3:24:22,  4.12s/it]

Evaluating:   5%|▍         | 148/3125 [13:21<3:37:24,  4.38s/it]

Evaluating:   5%|▍         | 149/3125 [13:29<4:25:20,  5.35s/it]

Evaluating:   5%|▍         | 150/3125 [13:33<4:18:18,  5.21s/it]

Evaluating:   5%|▍         | 151/3125 [13:41<4:49:40,  5.84s/it]

Evaluating:   5%|▍         | 152/3125 [13:45<4:24:46,  5.34s/it]

Evaluating:   5%|▍         | 153/3125 [13:57<6:01:42,  7.30s/it]

Evaluating:   5%|▍         | 154/3125 [14:01<5:16:57,  6.40s/it]

Evaluating:   5%|▍         | 155/3125 [14:04<4:27:18,  5.40s/it]

Evaluating:   5%|▍         | 156/3125 [14:09<4:21:16,  5.28s/it]

Evaluating:   5%|▌         | 157/3125 [14:12<3:50:43,  4.66s/it]

Evaluating:   5%|▌         | 158/3125 [14:16<3:39:08,  4.43s/it]

Evaluating:   5%|▌         | 159/3125 [14:20<3:30:50,  4.27s/it]

Evaluating:   5%|▌         | 160/3125 [14:25<3:46:55,  4.59s/it]

Evaluating:   5%|▌         | 161/3125 [14:28<3:14:26,  3.94s/it]

Evaluating:   5%|▌         | 162/3125 [14:31<2:58:29,  3.61s/it]

Evaluating:   5%|▌         | 163/3125 [14:37<3:39:35,  4.45s/it]

Evaluating:   5%|▌         | 164/3125 [14:41<3:34:51,  4.35s/it]

Evaluating:   5%|▌         | 165/3125 [14:47<3:48:47,  4.64s/it]

Evaluating:   5%|▌         | 166/3125 [14:49<3:13:31,  3.92s/it]

Evaluating:   5%|▌         | 167/3125 [14:55<3:53:03,  4.73s/it]

Evaluating:   5%|▌         | 168/3125 [15:02<4:17:40,  5.23s/it]

Evaluating:   5%|▌         | 169/3125 [15:06<3:56:24,  4.80s/it]

Evaluating:   5%|▌         | 170/3125 [15:10<3:44:21,  4.56s/it]

Evaluating:   5%|▌         | 171/3125 [15:17<4:20:39,  5.29s/it]

Evaluating:   6%|▌         | 172/3125 [15:19<3:43:34,  4.54s/it]

Evaluating:   6%|▌         | 173/3125 [15:23<3:23:56,  4.15s/it]

Evaluating:   6%|▌         | 174/3125 [15:26<3:15:35,  3.98s/it]

Evaluating:   6%|▌         | 175/3125 [15:30<3:13:53,  3.94s/it]

Evaluating:   6%|▌         | 176/3125 [15:36<3:42:56,  4.54s/it]

Evaluating:   6%|▌         | 177/3125 [15:44<4:28:05,  5.46s/it]

Evaluating:   6%|▌         | 178/3125 [15:53<5:26:17,  6.64s/it]

Evaluating:   6%|▌         | 179/3125 [16:04<6:25:52,  7.86s/it]

Evaluating:   6%|▌         | 180/3125 [16:10<5:55:41,  7.25s/it]

Evaluating:   6%|▌         | 181/3125 [16:15<5:36:23,  6.86s/it]

Evaluating:   6%|▌         | 182/3125 [16:18<4:37:34,  5.66s/it]

Evaluating:   6%|▌         | 183/3125 [16:28<5:29:21,  6.72s/it]

Evaluating:   6%|▌         | 184/3125 [16:35<5:47:04,  7.08s/it]

Evaluating:   6%|▌         | 185/3125 [16:40<5:15:38,  6.44s/it]

Evaluating:   6%|▌         | 186/3125 [16:44<4:28:15,  5.48s/it]

Evaluating:   6%|▌         | 187/3125 [16:53<5:26:49,  6.67s/it]

Evaluating:   6%|▌         | 188/3125 [16:57<4:49:51,  5.92s/it]

Evaluating:   6%|▌         | 189/3125 [17:07<5:40:57,  6.97s/it]

Evaluating:   6%|▌         | 190/3125 [17:11<4:57:11,  6.08s/it]

Evaluating:   6%|▌         | 191/3125 [17:16<4:46:15,  5.85s/it]

Evaluating:   6%|▌         | 192/3125 [17:26<5:46:42,  7.09s/it]

Evaluating:   6%|▌         | 193/3125 [17:33<5:38:02,  6.92s/it]

Evaluating:   6%|▌         | 194/3125 [17:38<5:19:51,  6.55s/it]

Evaluating:   6%|▌         | 195/3125 [17:42<4:41:29,  5.76s/it]

Evaluating:   6%|▋         | 196/3125 [17:51<5:22:16,  6.60s/it]

Evaluating:   6%|▋         | 197/3125 [18:03<6:41:28,  8.23s/it]

Evaluating:   6%|▋         | 198/3125 [18:11<6:37:17,  8.14s/it]

Evaluating:   6%|▋         | 199/3125 [18:15<5:44:59,  7.07s/it]

Evaluating:   6%|▋         | 200/3125 [18:18<4:37:26,  5.69s/it]

Evaluating:   6%|▋         | 201/3125 [18:21<4:00:16,  4.93s/it]

Evaluating:   6%|▋         | 202/3125 [18:25<3:52:56,  4.78s/it]

Evaluating:   6%|▋         | 203/3125 [18:33<4:32:08,  5.59s/it]

Evaluating:   7%|▋         | 204/3125 [18:41<5:09:05,  6.35s/it]

Evaluating:   7%|▋         | 205/3125 [18:47<5:03:30,  6.24s/it]

Evaluating:   7%|▋         | 206/3125 [18:56<5:48:50,  7.17s/it]

Evaluating:   7%|▋         | 207/3125 [19:01<5:20:01,  6.58s/it]

Evaluating:   7%|▋         | 208/3125 [19:04<4:24:49,  5.45s/it]

Evaluating:   7%|▋         | 209/3125 [19:08<4:07:52,  5.10s/it]

Evaluating:   7%|▋         | 210/3125 [19:20<5:41:29,  7.03s/it]

Evaluating:   7%|▋         | 211/3125 [19:24<5:02:48,  6.23s/it]

Evaluating:   7%|▋         | 212/3125 [19:33<5:34:01,  6.88s/it]

Evaluating:   7%|▋         | 213/3125 [19:36<4:38:11,  5.73s/it]

Evaluating:   7%|▋         | 214/3125 [19:42<4:49:23,  5.96s/it]

Evaluating:   7%|▋         | 215/3125 [19:47<4:27:07,  5.51s/it]

Evaluating:   7%|▋         | 216/3125 [19:56<5:20:35,  6.61s/it]

Evaluating:   7%|▋         | 217/3125 [19:59<4:34:30,  5.66s/it]

Evaluating:   7%|▋         | 218/3125 [20:05<4:29:52,  5.57s/it]

Evaluating:   7%|▋         | 219/3125 [20:07<3:41:48,  4.58s/it]

Evaluating:   7%|▋         | 220/3125 [20:16<4:51:53,  6.03s/it]

Evaluating:   7%|▋         | 221/3125 [20:25<5:28:50,  6.79s/it]

Evaluating:   7%|▋         | 222/3125 [20:30<4:56:31,  6.13s/it]

Evaluating:   7%|▋         | 223/3125 [20:39<5:48:12,  7.20s/it]

Evaluating:   7%|▋         | 224/3125 [20:45<5:19:46,  6.61s/it]

Evaluating:   7%|▋         | 225/3125 [20:50<4:57:23,  6.15s/it]

Evaluating:   7%|▋         | 226/3125 [20:56<4:57:34,  6.16s/it]

Evaluating:   7%|▋         | 227/3125 [20:59<4:09:43,  5.17s/it]

Evaluating:   7%|▋         | 228/3125 [21:03<4:03:46,  5.05s/it]

Evaluating:   7%|▋         | 229/3125 [21:09<4:04:52,  5.07s/it]

Evaluating:   7%|▋         | 230/3125 [21:12<3:43:18,  4.63s/it]

Evaluating:   7%|▋         | 231/3125 [21:18<3:57:55,  4.93s/it]

Evaluating:   7%|▋         | 232/3125 [21:23<4:03:14,  5.04s/it]

Evaluating:   7%|▋         | 233/3125 [21:29<4:14:47,  5.29s/it]

Evaluating:   7%|▋         | 234/3125 [21:37<4:53:07,  6.08s/it]

Evaluating:   8%|▊         | 235/3125 [21:42<4:40:57,  5.83s/it]

Evaluating:   8%|▊         | 236/3125 [21:48<4:44:44,  5.91s/it]

Evaluating:   8%|▊         | 237/3125 [21:57<5:28:13,  6.82s/it]

Evaluating:   8%|▊         | 238/3125 [22:02<4:58:17,  6.20s/it]

Evaluating:   8%|▊         | 239/3125 [22:12<5:54:54,  7.38s/it]

Evaluating:   8%|▊         | 240/3125 [22:18<5:28:22,  6.83s/it]

Evaluating:   8%|▊         | 241/3125 [22:24<5:17:41,  6.61s/it]

Evaluating:   8%|▊         | 242/3125 [22:32<5:43:01,  7.14s/it]

Evaluating:   8%|▊         | 243/3125 [22:37<5:05:43,  6.36s/it]

Evaluating:   8%|▊         | 244/3125 [22:40<4:17:50,  5.37s/it]

Evaluating:   8%|▊         | 245/3125 [22:45<4:15:25,  5.32s/it]

Evaluating:   8%|▊         | 246/3125 [22:49<4:00:10,  5.01s/it]

Evaluating:   8%|▊         | 247/3125 [22:55<4:08:58,  5.19s/it]

Evaluating:   8%|▊         | 248/3125 [22:58<3:40:11,  4.59s/it]

Evaluating:   8%|▊         | 249/3125 [23:03<3:46:00,  4.71s/it]

Evaluating:   8%|▊         | 250/3125 [23:11<4:38:37,  5.81s/it]

Evaluating:   8%|▊         | 251/3125 [23:16<4:25:43,  5.55s/it]

Evaluating:   8%|▊         | 252/3125 [23:25<5:14:08,  6.56s/it]

Evaluating:   8%|▊         | 253/3125 [23:28<4:15:33,  5.34s/it]

Evaluating:   8%|▊         | 254/3125 [23:32<3:58:32,  4.99s/it]

Evaluating:   8%|▊         | 255/3125 [23:36<3:39:41,  4.59s/it]

Evaluating:   8%|▊         | 256/3125 [23:38<3:08:18,  3.94s/it]

Evaluating:   8%|▊         | 257/3125 [23:44<3:39:12,  4.59s/it]

Evaluating:   8%|▊         | 258/3125 [23:48<3:33:19,  4.46s/it]

Evaluating:   8%|▊         | 259/3125 [23:52<3:21:36,  4.22s/it]

Evaluating:   8%|▊         | 260/3125 [23:57<3:37:42,  4.56s/it]

Evaluating:   8%|▊         | 261/3125 [24:00<3:15:25,  4.09s/it]

Evaluating:   8%|▊         | 262/3125 [24:08<4:07:42,  5.19s/it]

Evaluating:   8%|▊         | 263/3125 [24:14<4:21:58,  5.49s/it]

Evaluating:   8%|▊         | 264/3125 [24:18<3:59:03,  5.01s/it]

Evaluating:   8%|▊         | 265/3125 [24:25<4:30:53,  5.68s/it]

Evaluating:   9%|▊         | 266/3125 [24:28<3:44:00,  4.70s/it]

Evaluating:   9%|▊         | 267/3125 [24:30<3:07:54,  3.94s/it]

Evaluating:   9%|▊         | 268/3125 [24:35<3:29:49,  4.41s/it]

Evaluating:   9%|▊         | 269/3125 [24:39<3:15:41,  4.11s/it]

Evaluating:   9%|▊         | 270/3125 [24:42<2:57:01,  3.72s/it]

Evaluating:   9%|▊         | 271/3125 [24:46<3:04:46,  3.88s/it]

Evaluating:   9%|▊         | 272/3125 [24:51<3:19:14,  4.19s/it]

Evaluating:   9%|▊         | 273/3125 [24:57<3:46:18,  4.76s/it]

Evaluating:   9%|▉         | 274/3125 [25:07<4:59:05,  6.29s/it]

Evaluating:   9%|▉         | 275/3125 [25:09<4:06:18,  5.19s/it]

Evaluating:   9%|▉         | 276/3125 [25:13<3:43:23,  4.70s/it]

Evaluating:   9%|▉         | 277/3125 [25:25<5:32:44,  7.01s/it]

Evaluating:   9%|▉         | 278/3125 [25:32<5:26:19,  6.88s/it]

Evaluating:   9%|▉         | 279/3125 [25:37<5:01:17,  6.35s/it]

Evaluating:   9%|▉         | 280/3125 [25:42<4:35:37,  5.81s/it]

Evaluating:   9%|▉         | 281/3125 [25:47<4:34:07,  5.78s/it]

Evaluating:   9%|▉         | 282/3125 [25:50<3:55:01,  4.96s/it]

Evaluating:   9%|▉         | 283/3125 [25:57<4:13:21,  5.35s/it]

Evaluating:   9%|▉         | 284/3125 [26:02<4:15:52,  5.40s/it]

Evaluating:   9%|▉         | 285/3125 [26:12<5:19:19,  6.75s/it]

Evaluating:   9%|▉         | 286/3125 [26:19<5:20:23,  6.77s/it]

Evaluating:   9%|▉         | 287/3125 [26:28<5:57:43,  7.56s/it]

Evaluating:   9%|▉         | 288/3125 [26:35<5:45:00,  7.30s/it]

Evaluating:   9%|▉         | 289/3125 [26:42<5:35:08,  7.09s/it]

Evaluating:   9%|▉         | 290/3125 [26:46<4:56:08,  6.27s/it]

Evaluating:   9%|▉         | 291/3125 [26:52<4:51:23,  6.17s/it]

Evaluating:   9%|▉         | 292/3125 [26:55<4:03:59,  5.17s/it]

Evaluating:   9%|▉         | 293/3125 [26:59<3:55:47,  5.00s/it]

Evaluating:   9%|▉         | 293/3125 [27:24<4:24:58,  5.61s/it]




OutOfMemoryError: CUDA out of memory. Tried to allocate 3.01 GiB. GPU 0 has a total capacity of 39.39 GiB of which 255.06 MiB is free. Including non-PyTorch memory, this process has 39.13 GiB memory in use. Of the allocated memory 34.75 GiB is allocated by PyTorch, and 3.89 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
args = TrainingArguments(
    output_dir="imdb_linear_probe",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,  # To speed things up set to 0.1, set to 1 for better performance
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

In [None]:
evals = evaluate_head_wise(model, dd["test"], collator, epochs=eval_epochs)
print(evals)

# Saving and loading
Now let's how to save a complicated mulit-headed model to then load it again for inference. Saving is super easy. Just call save_pretrained and all trained parameters will be saved correctly. The saving will also work correctly for checkpoints created during training.

In [None]:
model.save_pretrained("qlora_multitask_imdb")
del model

While loading the model, we need to make sure to correctly attach and initialize all the heads, so that won't easily work with the huggingface api. Instead, *transformer_heads* provides the *load_lora_with_heads* function. Note that giving a quantization config is optional here. We could also give a different quantization config or none at all.

In [None]:
model = load_lora_with_heads(
    model_class,
    "qlora_multitask_imdb",
    quantization_config,
    device_map={"": torch.cuda.current_device()},
)

Let's now find out if the loaded model behaves the same as the saved model:

In [None]:
new_evals = evaluate_head_wise(model, dd["test"], collator, epochs=eval_epochs)
print(new_evals)