In [1]:
import random
from datasets import load_dataset
from transformers import set_seed

total_samples = 10000
set_seed(42)


def get_sst2(split: str):
    examples = load_dataset("sst2")[split]
    result_examples = []
    for example in examples:
        result_examples.append({})

        result_examples[-1]["input"] = example["sentence"].strip()
        result_examples[-1]["output"] = "positive" if example["label"] == 1 else "negative"

        result_examples[-1]["input"] += "</s>"
        result_examples[-1]["output"] += "</s>"

    random.shuffle(result_examples)
    result_examples = result_examples[:total_samples // 2]

    return result_examples


def get_mnli(split: str):
    examples = load_dataset("multi_nli")[split]
    result_examples = []
    for example in examples:
        result_examples.append({})

        result_examples[-1]["input"] = example["premise"].strip() + " " + example["hypothesis"].strip()

        if example["label"] == 0:
            result_examples[-1]["output"] = "entailment"
        elif example["label"] == 1:
            result_examples[-1]["output"] = "neutral"
        else:
            result_examples[-1]["output"] = "contradiction"

        result_examples[-1]["input"] += "</s>"
        result_examples[-1]["output"] += "</s>"

    random.shuffle(result_examples)
    result_examples = result_examples[:total_samples // 2]

    return result_examples

In [12]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from peft import get_peft_model, MultitaskPromptTuningConfig, TaskType

model_name = "google/flan-t5-base"

peft_config = MultitaskPromptTuningConfig(
    tokenizer_name_or_path=model_name,
    num_tasks=2,
    task_type=TaskType.SEQ_2_SEQ_LM,
    prompt_tuning_init="TEXT",
    num_virtual_tokens=50,
    num_transformer_submodules=1,
    prompt_tuning_init_text="classify the following into either positive or negative, or entailment, neutral or contradiction:"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = get_peft_model(model, peft_config)

bin /Users/mayankmishra/miniconda3/envs/ai/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cpu.so


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [19]:
from typing import Tuple
from torch.utils.data import Dataset, DataLoader
import torch

class MyDataset(Dataset):
    def __init__(self, split: str) -> None:
        super().__init__()
        if split == "train":
            self.examples = get_sst2(split) + get_mnli(split)
            random.shuffle(self.examples)
        else:
            self.examples = get_sst2(split)

    def __getitem__(self, index) -> dict:
        return self.examples[index]

    def __len__(self) -> int:
        return len(self.examples)

    @classmethod
    def collate_fn(cls, batch: dict) -> Tuple[torch.Tensor, torch.Tensor]:
        input = [i["input"] for i in batch]
        input = tokenizer(input, add_special_tokens=False, return_tensors="pt", padding=True)

        output = [i["output"] for i in batch]
        output = tokenizer(output, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
        output[output == tokenizer.pad_token_id] = -100

        return {
            "input_ids": input.input_ids,
            "attention_mask": input.attention_mask,
            "labels": output,
        }


train = DataLoader(MyDataset("train"), shuffle=True, batch_size=8, collate_fn=MyDataset.collate_fn)
val = DataLoader(MyDataset("validation"), shuffle=False, batch_size=8, collate_fn=MyDataset.collate_fn)

Found cached dataset sst2 (/Users/mayankmishra/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset multi_nli (/Users/mayankmishra/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset sst2 (/Users/mayankmishra/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)


  0%|          | 0/3 [00:00<?, ?it/s]

In [1]:
from torch.optim.adamw import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=1e-3)

n = 1000
step = 0
train_ = tqdm(train)

for batch in train_:
    step += 1
    loss = model(**batch)
    loss.backward()
    optimizer.step()
    train_.set_postfix(f"train loss = {loss}")

    if step % 1000 == 0:
        loss = 0
        for batch in val:
            loss += model(**batch)
        print("val loss =", loss / len(val))

model.save_pretrained("a.pt")

NameError: name 'model' is not defined