## Initialize PolyModel

In [None]:
import os
import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    default_data_collator,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from datasets import load_dataset, concatenate_datasets
from peft import PolyConfig, get_peft_model, TaskType, PeftModel, PeftConfig

os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_name_or_path = "google/flan-t5-large"

batch_size = 16
lr = 5e-5
num_epochs = 5

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    trust_remote_code=True
)
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name_or_path,
    trust_remote_code=True
)

In [None]:
peft_config = PolyConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    poly_type="poly",
    r=8,        		# rank of lora in poly
    n_tasks=4,  		# # of tasks
    n_skills=2, 		# # of skills (loras)
    n_splits=4, 		# # of heads
)

model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()
model

## Prepare datasets

For this example, we selected four `SuperGLUE` benchmark datasets: `boolq`, `multirc`, `rte`, and `wic`, each with a training set of 1,000 examples and an evaluation set of 100 examples.

In [None]:
# boolq
boolq_dataset = (
    load_dataset("super_glue", "boolq")
    .map(
        lambda x: {
            "input": f"task_boolq {x['passage']}\nQuestion: {x['question']}\nA. Yes\nB. No\nAnswer:",
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x['label'])],
            "task_name": "boolq",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("boolq example: ")
print(boolq_dataset['train'][0])

# multirc
multirc_dataset = (
    load_dataset("super_glue", "multirc")
    .map(
        lambda x: {
            "input": (
                f"task_multirc {x['paragraph']}\nQuestion: {x['question']}\nAnswer: {x['answer']}\nIs it"
                 " true?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x['label'])],
            "task_name": "multirc",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("multirc example: ")
print(multirc_dataset['train'][0])

# rte
rte_dataset = (
    load_dataset("super_glue", "rte")
    .map(
        lambda x: {
            "input": (
                f"task_rte {x['premise']}\n{x['hypothesis']}\nIs the sentence below entailed by the"
                " sentence above?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - entailment
            # 1 - not_entailment
            "output": ["A", "B"][int(x['label'])],
            "task_name": "rte",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("rte example: ")
print(rte_dataset['train'][0])

# wic
wic_dataset = (
    load_dataset("super_glue", "wic")
    .map(
        lambda x: {
            "input": (
                f"task_wic Sentence 1: {x['sentence1']}\nSentence 2: {x['sentence2']}\nAre '{x['word']}'"
                " in the above two sentences the same?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x['label'])],
            "task_name": "wic",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("wic example: ")
print(wic_dataset['train'][0])

In [None]:
# define a task2id map
TASK2ID = {
    "boolq": 0,
    "multirc": 1,
    "rte": 2,
    "wic": 3,
}

def tokenize(examples):
    inputs = examples["input"]
    targets = examples["output"]
    features = tokenizer(inputs, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    features["labels"] = labels
    features["task_ids"] = [TASK2ID[t] for t in examples["task_name"]]
    return features

In [None]:
def get_superglue_dataset(
    split="train",
    n_samples=500,
):
    ds = concatenate_datasets(
        [
            boolq_dataset[split].shuffle().select(range(n_samples)),
            multirc_dataset[split].shuffle().select(range(n_samples)),
            rte_dataset[split].shuffle().select(range(n_samples)),
            wic_dataset[split].shuffle().select(range(n_samples)),
        ]
    )
    ds = ds.map(
        tokenize,
        batched=True,
        remove_columns=["input", "output", "task_name"],
        load_from_cache_file=False,
    )
    return ds

In [None]:
superglue_train_dataset = get_superglue_dataset(split="train", n_samples=1000)
superglue_eval_dataset = get_superglue_dataset(split="test", n_samples=100)

## Train and evaluate

In [None]:
# training and evaluation

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = [ [ i for i in seq if i != -100 ] for seq in preds ]
    labels = [ [ i for i in seq if i != -100 ] for seq in labels ]
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    correct = 0
    total = 0
    for pred, true in zip(preds, labels):
        if pred.strip() == true.strip():
            correct += 1
        total += 1
    accuracy = correct / total
    return {"accuracy": accuracy}


training_args = Seq2SeqTrainingArguments(
    "output",
    per_device_train_batch_size=batch_size,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="no",
    report_to=[],
    predict_with_generate=True,
    generation_max_length=2,
    remove_unused_columns=False,
)
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=superglue_train_dataset,
    eval_dataset=superglue_eval_dataset,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()

In [None]:
# saving model
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id, safe_serialization=False)

In [None]:
ckpt = f"{peft_model_id}/adapter_model.bin"
!du -h $ckpt

## Load and infer

In [None]:
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

In [None]:
model.eval()
i = 200
task_ids = torch.LongTensor([TASK2ID["rte"]])
inputs = tokenizer(rte_dataset["validation"]["input"][i], return_tensors="pt")
print(rte_dataset["validation"]["input"][i])
print(rte_dataset["validation"]["output"][i])
print(inputs)

with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"], task_ids=task_ids, max_new_tokens=2)
    print(outputs)
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True))