- https://huggingface.co/docs/peft/main/en/quicktour
- Prompting
  - Prompt tuning : https://huggingface.co/docs/peft/main/en/task_guides/clm-prompt-tuning
  - Prefix tuning : https://huggingface.co/docs/peft/main/en/task_guides/seq2seq-prefix-tuning
  - P-tuning : https://huggingface.co/docs/peft/main/en/task_guides/ptuning-seq-classification

# Prompt tuning

In [159]:
import warnings
warnings.filterwarnings("ignore")

In [118]:
from transformers import (
    AutoModelForCausalLM, 
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer, 
    LlamaForCausalLM, 
    LlamaTokenizer, 
    default_data_collator,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer
)
from peft import (
    get_peft_config, 
    get_peft_model, 
    PromptTuningInit, 
    PromptTuningConfig,
    PrefixTuningConfig,
    PromptEncoderConfig,
    LoraConfig, 
    TaskType,
    PeftType
)
import torch
from datasets import load_dataset
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
import evaluate

## * Config

In [160]:
device = "cuda:3"
my_root_directory=""
root_path = f"{my_root_directory}/llm/model_output/"

# https://huggingface.co/bigscience/bloomz-560m
model_name_or_path = "bigscience/bloomz-560m"
# model_name_or_path = f"{my_root_directory}/llm/llama/llama-2-7b-chat-hf"

In [162]:
# peft_config = LoraConfig(
#     task_type=TaskType.CAUSAL_LM,
#     inference_mode=False, 
#     r=8, 
#     lora_alpha=32, 
#     lora_dropout=0.1
# )

# For the best results, 
# the prompt_tuning_init_text should have the same number of tokens that should be predicted.
# To do this, 
# you can set num_virtual_tokens to the number of tokens of the prompt_tuning_init_text :
prompt_tuning_init_text="Classify if the tweet is a complaint or not:"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=len(tokenizer(prompt_tuning_init_text)["input_ids"]),
    prompt_tuning_init_text=prompt_tuning_init_text,
    tokenizer_name_or_path=model_name_or_path
)

## * Dataset

In [4]:
dataset_name = "twitter_complaints"
checkpoint_name = f"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt".replace(
    "/", "_"
)
text_column = "Tweet text"
label_column = "text_label"
max_length = 64
lr = 3e-2
num_epochs = 50
batch_size = 8

In [5]:
# https://huggingface.co/datasets/ought/raft
dataset = load_dataset("ought/raft", dataset_name)

Found cached dataset raft (/home/irteam/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84)
100%|██████████| 2/2 [00:00<00:00, 307.16it/s]


In [6]:
dataset.column_names

{'train': ['Tweet text', 'ID', 'Label'], 'test': ['Tweet text', 'ID', 'Label']}

In [7]:
dataset["train"]

Dataset({
    features: ['Tweet text', 'ID', 'Label'],
    num_rows: 50
})

In [8]:
dataset["test"]

Dataset({
    features: ['Tweet text', 'ID', 'Label'],
    num_rows: 3399
})

In [9]:
dataset["train"][0]

{'Tweet text': '@HMRCcustomers No this is my first job', 'ID': 0, 'Label': 2}

In [10]:
classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names]
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["Label"]]},
    batched=True,
    num_proc=1
)

Loading cached processed dataset at /home/irteam/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-083b7bfcf2ab7905.arrow
Loading cached processed dataset at /home/irteam/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-18b457bd4e4c4ea0.arrow


In [11]:
dataset["train"][0]

{'Tweet text': '@HMRCcustomers No this is my first job',
 'ID': 0,
 'Label': 2,
 'text_label': 'no complaint'}

In [161]:
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
# tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)

In [13]:
print(tokenizer.pad_token_id)
print(tokenizer.eos_token_id)

3
2


In [14]:
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    
target_max_length = max([len(tokenizer(class_label)["input_ids"]) for class_label in classes])
print(target_max_length)

3


In [15]:
tokenizer(classes[0])

{'input_ids': [3074, 4762, 60943], 'attention_mask': [1, 1, 1]}

In [16]:
def preprocess_function(examples):
    batch_size = len(examples[text_column])
    inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]]
    targets = [str(x) for x in examples[label_column]]
    model_inputs = tokenizer(inputs)
    labels = tokenizer(targets)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i] + [tokenizer.pad_token_id]
        # print(i, sample_input_ids, label_input_ids)
        model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
        labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
        model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])
    # print(model_inputs)
    # print(labels)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i]
        model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
            max_length - len(sample_input_ids)
        ) + sample_input_ids
        model_inputs["attention_mask"][i] = [0] * (
            max_length - len(sample_input_ids)
        ) + model_inputs["attention_mask"][i]
        labels["input_ids"][i] = [-100] * (
            max_length - len(sample_input_ids)
        ) + label_input_ids
        # if i in [0, 1]:
        #     print(model_inputs["input_ids"][i])
        #     print(model_inputs["attention_mask"][i])
        #     print(labels["input_ids"][i])
        model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
        model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
        labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [17]:
temp_input_data = preprocess_function(dataset["train"])

In [18]:
dataset["train"].column_names

['Tweet text', 'ID', 'Label', 'text_label']

In [19]:
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset"
)

                                                                                          

## * DataLoader

In [20]:
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["test"]

In [21]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 50
})

In [22]:
train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(
    eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)

## * Train

In [23]:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
# model = LlamaForCausalLM.from_pretrained(model_name_or_path)

In [24]:
# model과 config를 wrapping
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 8192 || all params: 559222784 || trainable%: 0.0014648902430985358


In [25]:
# Before prompt tuning

inputs = tokenizer(
    f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"} Label : ',
    return_tensors="pt",
)

model.to(device)

with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    # generate 호출
    outputs = model.generate(
        input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, eos_token_id=3
    )
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

['Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this? Label :  NoThe present invention relates to a method of']


In [26]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs)
)

In [27]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        # 굳이 items?
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
    
    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )
    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"epoch: {epoch}, train_ppl: {train_ppl}, train_epoch_loss: {train_epoch_loss}, eval_ppl: {eval_ppl}, eval_epoch_loss: {eval_epoch_loss}")

100%|██████████| 7/7 [00:01<00:00,  6.90it/s]
100%|██████████| 425/425 [00:29<00:00, 14.25it/s]


epoch: 0, train_ppl: 3.380147851886592e+17, train_epoch_loss: 40.36186599731445, eval_ppl: 16679.0, eval_epoch_loss: 9.721905708312988


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.24it/s]


epoch: 1, train_ppl: 538404.0, train_epoch_loss: 13.196364402770996, eval_ppl: 13381.1865234375, eval_epoch_loss: 9.501605033874512


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.25it/s]


epoch: 2, train_ppl: 284034.09375, train_epoch_loss: 12.556849479675293, eval_ppl: 9123.810546875, eval_epoch_loss: 9.118642807006836


100%|██████████| 7/7 [00:00<00:00,  7.18it/s]
100%|██████████| 425/425 [00:29<00:00, 14.26it/s]


epoch: 3, train_ppl: 121827.3984375, train_epoch_loss: 11.710360527038574, eval_ppl: 5693.92333984375, eval_epoch_loss: 8.647154808044434


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.25it/s]


epoch: 4, train_ppl: 21584.353515625, train_epoch_loss: 9.979723930358887, eval_ppl: 3951.930908203125, eval_epoch_loss: 8.281959533691406


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.25it/s]


epoch: 5, train_ppl: 4086.43310546875, train_epoch_loss: 8.315427780151367, eval_ppl: 6859.73193359375, eval_epoch_loss: 8.833423614501953


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.25it/s]


epoch: 6, train_ppl: 878.7350463867188, train_epoch_loss: 6.7784833908081055, eval_ppl: 13165.3671875, eval_epoch_loss: 9.485344886779785


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.24it/s]


epoch: 7, train_ppl: 261.6167907714844, train_epoch_loss: 5.566880702972412, eval_ppl: 19448.931640625, eval_epoch_loss: 9.875547409057617


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.24it/s]


epoch: 8, train_ppl: 147.39532470703125, train_epoch_loss: 4.9931182861328125, eval_ppl: 17187.388671875, eval_epoch_loss: 9.751931190490723


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 9, train_ppl: 104.95393371582031, train_epoch_loss: 4.653521537780762, eval_ppl: 26680.7109375, eval_epoch_loss: 10.191696166992188


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 10, train_ppl: 81.39569854736328, train_epoch_loss: 4.399322509765625, eval_ppl: 33229.58203125, eval_epoch_loss: 10.411195755004883


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.25it/s]


epoch: 11, train_ppl: 67.08550262451172, train_epoch_loss: 4.205967903137207, eval_ppl: 48844.83203125, eval_epoch_loss: 10.796403884887695


100%|██████████| 7/7 [00:00<00:00,  7.31it/s]
100%|██████████| 425/425 [00:29<00:00, 14.25it/s]


epoch: 12, train_ppl: 53.68946075439453, train_epoch_loss: 3.9832167625427246, eval_ppl: 59047.09765625, eval_epoch_loss: 10.986090660095215


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.24it/s]


epoch: 13, train_ppl: 43.468894958496094, train_epoch_loss: 3.772045612335205, eval_ppl: 83257.5234375, eval_epoch_loss: 11.329693794250488


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.20it/s]


epoch: 14, train_ppl: 38.822174072265625, train_epoch_loss: 3.658991575241089, eval_ppl: 86719.3125, eval_epoch_loss: 11.370431900024414


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.24it/s]


epoch: 15, train_ppl: 29.82989501953125, train_epoch_loss: 3.3955111503601074, eval_ppl: 121658.234375, eval_epoch_loss: 11.70897102355957


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.26it/s]


epoch: 16, train_ppl: 23.94160270690918, train_epoch_loss: 3.1756176948547363, eval_ppl: 286042.65625, eval_epoch_loss: 12.563896179199219


100%|██████████| 7/7 [00:00<00:00,  7.18it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 17, train_ppl: 18.19940948486328, train_epoch_loss: 2.9013891220092773, eval_ppl: 554802.25, eval_epoch_loss: 13.226366996765137


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.27it/s]


epoch: 18, train_ppl: 14.414392471313477, train_epoch_loss: 2.668227195739746, eval_ppl: 711650.1875, eval_epoch_loss: 13.475341796875


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.24it/s]


epoch: 19, train_ppl: 11.717874526977539, train_epoch_loss: 2.4611153602600098, eval_ppl: 1778658.25, eval_epoch_loss: 14.391369819641113


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 20, train_ppl: 11.812820434570312, train_epoch_loss: 2.4691853523254395, eval_ppl: 825788.5, eval_epoch_loss: 13.624094009399414


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 21, train_ppl: 8.01941204071045, train_epoch_loss: 2.081865072250366, eval_ppl: 1308490.625, eval_epoch_loss: 14.08438491821289


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.22it/s]


epoch: 22, train_ppl: 6.380003452301025, train_epoch_loss: 1.8531686067581177, eval_ppl: 1455277.5, eval_epoch_loss: 14.190707206726074


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.22it/s]


epoch: 23, train_ppl: 5.108733177185059, train_epoch_loss: 1.6309514045715332, eval_ppl: 1851730.375, eval_epoch_loss: 14.431631088256836


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.22it/s]


epoch: 24, train_ppl: 3.9735682010650635, train_epoch_loss: 1.3796645402908325, eval_ppl: 2006133.375, eval_epoch_loss: 14.511719703674316


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.24it/s]


epoch: 25, train_ppl: 3.627457618713379, train_epoch_loss: 1.288532018661499, eval_ppl: 963477.1875, eval_epoch_loss: 13.778304100036621


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 26, train_ppl: 2.8793156147003174, train_epoch_loss: 1.0575525760650635, eval_ppl: 2760612.75, eval_epoch_loss: 14.830963134765625


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.22it/s]


epoch: 27, train_ppl: 2.9725444316864014, train_epoch_loss: 1.0894182920455933, eval_ppl: 1459396.875, eval_epoch_loss: 14.193533897399902


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 28, train_ppl: 2.3122549057006836, train_epoch_loss: 0.8382232189178467, eval_ppl: 1980808.875, eval_epoch_loss: 14.499015808105469


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.20it/s]


epoch: 29, train_ppl: 1.8611103296279907, train_epoch_loss: 0.6211733222007751, eval_ppl: 1943938.375, eval_epoch_loss: 14.480226516723633


100%|██████████| 7/7 [00:00<00:00,  7.24it/s]
100%|██████████| 425/425 [00:29<00:00, 14.18it/s]


epoch: 30, train_ppl: 1.8047007322311401, train_epoch_loss: 0.5903947353363037, eval_ppl: 1572628.625, eval_epoch_loss: 14.268259048461914


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 31, train_ppl: 1.675833821296692, train_epoch_loss: 0.5163108110427856, eval_ppl: 1104459.25, eval_epoch_loss: 13.91486644744873


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.20it/s]


epoch: 32, train_ppl: 1.6426419019699097, train_epoch_loss: 0.4963058531284332, eval_ppl: 1383934.375, eval_epoch_loss: 14.140440940856934


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 33, train_ppl: 1.6532964706420898, train_epoch_loss: 0.5027711391448975, eval_ppl: 1149193.25, eval_epoch_loss: 13.954570770263672


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 34, train_ppl: 1.5835336446762085, train_epoch_loss: 0.45965877175331116, eval_ppl: 911221.125, eval_epoch_loss: 13.722540855407715


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.22it/s]


epoch: 35, train_ppl: 1.4889373779296875, train_epoch_loss: 0.39806264638900757, eval_ppl: 1150890.0, eval_epoch_loss: 13.956046104431152


100%|██████████| 7/7 [00:00<00:00,  7.25it/s]
100%|██████████| 425/425 [00:30<00:00, 14.17it/s]


epoch: 36, train_ppl: 1.4481050968170166, train_epoch_loss: 0.37025585770606995, eval_ppl: 721718.5, eval_epoch_loss: 13.48939037322998


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.20it/s]


epoch: 37, train_ppl: 1.5175349712371826, train_epoch_loss: 0.41708728671073914, eval_ppl: 1019304.3125, eval_epoch_loss: 13.834630966186523


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.22it/s]


epoch: 38, train_ppl: 1.5402865409851074, train_epoch_loss: 0.43196845054626465, eval_ppl: 798447.5, eval_epoch_loss: 13.590424537658691


100%|██████████| 7/7 [00:00<00:00,  7.25it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 39, train_ppl: 1.3242266178131104, train_epoch_loss: 0.2808285355567932, eval_ppl: 862893.0, eval_epoch_loss: 13.668045997619629


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 40, train_ppl: 1.2781380414962769, train_epoch_loss: 0.24540437757968903, eval_ppl: 1085222.875, eval_epoch_loss: 13.897295951843262


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 41, train_ppl: 1.270923137664795, train_epoch_loss: 0.23974347114562988, eval_ppl: 1156074.5, eval_epoch_loss: 13.960540771484375


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.22it/s]


epoch: 42, train_ppl: 1.2374060153961182, train_epoch_loss: 0.21301725506782532, eval_ppl: 1096227.125, eval_epoch_loss: 13.907384872436523


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:30<00:00, 14.17it/s]


epoch: 43, train_ppl: 1.2273577451705933, train_epoch_loss: 0.20486363768577576, eval_ppl: 1292886.25, eval_epoch_loss: 14.0723876953125


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 44, train_ppl: 1.2194174528121948, train_epoch_loss: 0.1983732134103775, eval_ppl: 1372642.0, eval_epoch_loss: 14.132247924804688


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.19it/s]


epoch: 45, train_ppl: 1.1868008375167847, train_epoch_loss: 0.17126137018203735, eval_ppl: 1172581.25, eval_epoch_loss: 13.97471809387207


100%|██████████| 7/7 [00:00<00:00,  7.27it/s]
100%|██████████| 425/425 [00:29<00:00, 14.20it/s]


epoch: 46, train_ppl: 1.1878573894500732, train_epoch_loss: 0.17215116322040558, eval_ppl: 1166293.5, eval_epoch_loss: 13.969341278076172


100%|██████████| 7/7 [00:00<00:00,  7.29it/s]
100%|██████████| 425/425 [00:29<00:00, 14.21it/s]


epoch: 47, train_ppl: 1.1761916875839233, train_epoch_loss: 0.16228191554546356, eval_ppl: 1250548.75, eval_epoch_loss: 14.039093017578125


100%|██████████| 7/7 [00:00<00:00,  7.28it/s]
100%|██████████| 425/425 [00:29<00:00, 14.23it/s]


epoch: 48, train_ppl: 1.1723730564117432, train_epoch_loss: 0.15902996063232422, eval_ppl: 1266396.875, eval_epoch_loss: 14.05168628692627


100%|██████████| 7/7 [00:00<00:00,  7.26it/s]
100%|██████████| 425/425 [00:29<00:00, 14.19it/s]

epoch: 49, train_ppl: 1.1648969650268555, train_epoch_loss: 0.1526326984167099, eval_ppl: 1270319.5, eval_epoch_loss: 14.054779052734375





## * Inference

In [28]:
# inputs = tokenizer(
#     f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"} Label : ',
#     return_tensors="pt",
# )

In [29]:
model.to(device)

with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    # generate 호출
    outputs = model.generate(
        input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, eos_token_id=3
    )
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

['Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this? Label : complaint']


# Prifix tuning

## * Config

In [31]:
# https://huggingface.co/google-t5/t5-large
# (Text-To-Text-Transfer Transformer - T5)
model_name_or_path = "t5-large"

text_column = "sentence"
label_column = "text_label"
max_length = 128
lr = 1e-2
num_epochs = 5
batch_size = 8

In [71]:
peft_config = PrefixTuningConfig(
    task_type=TaskType.SEQ_2_SEQ_LM, 
    inference_mode=False, 
    num_virtual_tokens=20
)

## * Dataset

In [None]:
dataset_name = "sentences_allagree"

In [33]:
# https://huggingface.co/datasets/takala/financial_phrasebank
dataset = load_dataset("financial_phrasebank", dataset_name)

Downloading builder script: 100%|██████████| 6.04k/6.04k [00:00<00:00, 7.82MB/s]
Downloading readme: 100%|██████████| 8.88k/8.88k [00:00<00:00, 8.35MB/s]


Downloading and preparing dataset financial_phrasebank/sentences_allagree to /home/irteam/.cache/huggingface/datasets/financial_phrasebank/sentences_allagree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141...


Downloading data: 100%|██████████| 682k/682k [00:00<00:00, 31.3MB/s]
                                                                                     

Dataset financial_phrasebank downloaded and prepared to /home/irteam/.cache/huggingface/datasets/financial_phrasebank/sentences_allagree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 315.62it/s]


In [37]:
# dataset에 train만 존재
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 2264
    })
})

In [39]:
dataset = dataset["train"].train_test_split(test_size=0.1)

In [40]:
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 2037
    })
    test: Dataset({
        features: ['sentence', 'label'],
        num_rows: 227
    })
})

In [41]:
dataset["validation"] = dataset["test"]

In [42]:
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 2037
    })
    test: Dataset({
        features: ['sentence', 'label'],
        num_rows: 227
    })
    validation: Dataset({
        features: ['sentence', 'label'],
        num_rows: 227
    })
})

In [43]:
del(dataset["test"])

In [44]:
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 2037
    })
    validation: Dataset({
        features: ['sentence', 'label'],
        num_rows: 227
    })
})

In [51]:
classes = dataset["train"].features["label"].names
classes

['negative', 'neutral', 'positive']

In [52]:
dataset["train"][0]

{'sentence': 'The company is studying the feasibility of focusing most of its processed meat production in the Vantaa facilities and the processing of fresh meat in the Forssa facilities .',
 'label': 1}

In [56]:
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1
)

                                                    

In [57]:
dataset["train"][0]

{'sentence': 'The company is studying the feasibility of focusing most of its processed meat production in the Vantaa facilities and the processing of fresh meat in the Forssa facilities .',
 'label': 1,
 'text_label': 'neutral'}

In [58]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

Downloading config.json: 100%|██████████| 1.21k/1.21k [00:00<00:00, 181kB/s]
Downloading spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 1.14MB/s]
Downloading tokenizer.json: 100%|██████████| 1.39M/1.39M [00:00<00:00, 2.66MB/s]
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [65]:
def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[label_column]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs

In [66]:
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset"
)

                                                                                           

In [68]:
processed_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2037
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 227
    })
})

In [69]:
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)

## * Training

In [72]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)

Downloading pytorch_model.bin: 100%|██████████| 2.95G/2.95G [05:51<00:00, 8.40MB/s]
Downloading generation_config.json: 100%|██████████| 147/147 [00:00<00:00, 13.2kB/s]


In [75]:
model.print_trainable_parameters()

trainable params: 983040 || all params: 738651136 || trainable%: 0.13308583065659835


In [77]:
inputs = tokenizer(
    "The Lithuanian beer market made up 14.41 million liters in January , a rise of 0.8 percent from the year-earlier figure , the Lithuanian Brewers ' Association reporting citing the results from its members .",
    return_tensors="pt",
)

In [78]:
model.to(device)

with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

['the Association of Lithuanian Brewers ']


In [76]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs)
)

In [80]:
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
    
    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )
    
    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = eval_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"epoch: {epoch}, train_ppl: {train_ppl:.4f}, train_epoch_loss: {train_epoch_loss:.4f}, eval_ppl: {eval_ppl:.4f}, eval_epoch_loss: {eval_epoch_loss:.4f}")

100%|██████████| 255/255 [00:44<00:00,  5.76it/s]
100%|██████████| 29/29 [00:02<00:00, 11.40it/s]


epoch: 0, train_ppl: 1.0080, train_epoch_loss: 0.0080, eval_ppl: 1.0730, eval_epoch_loss: 0.0705


100%|██████████| 255/255 [00:44<00:00,  5.76it/s]
100%|██████████| 29/29 [00:02<00:00, 11.30it/s]


epoch: 1, train_ppl: 1.0077, train_epoch_loss: 0.0076, eval_ppl: 1.0695, eval_epoch_loss: 0.0672


100%|██████████| 255/255 [00:44<00:00,  5.77it/s]
100%|██████████| 29/29 [00:02<00:00, 11.34it/s]


epoch: 2, train_ppl: 1.0044, train_epoch_loss: 0.0044, eval_ppl: 1.0391, eval_epoch_loss: 0.0384


100%|██████████| 255/255 [00:44<00:00,  5.69it/s]
100%|██████████| 29/29 [00:02<00:00, 11.31it/s]


epoch: 3, train_ppl: 1.0047, train_epoch_loss: 0.0047, eval_ppl: 1.0420, eval_epoch_loss: 0.0411


100%|██████████| 255/255 [00:44<00:00,  5.73it/s]
100%|██████████| 29/29 [00:02<00:00, 11.26it/s]

epoch: 4, train_ppl: 1.0053, train_epoch_loss: 0.0053, eval_ppl: 1.0473, eval_epoch_loss: 0.0462





In [82]:
correct = 0
total = 0
for pred, true in zip(eval_preds, dataset["validation"]["text_label"]):
    if pred.strip() == true.strip():
        correct += 1
    total += 1
accuracy = correct / total * 100
print(f"accuracy: {accuracy} % on the evaluation dataset")
print(f"{eval_preds[:10]}")
print(f"{dataset['validation']['text_label'][:10]}")

accuracy: 96.0352422907489 % on the evaluation dataset
['neutral', 'neutral', 'negative', 'positive', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'positive']
['neutral', 'neutral', 'positive', 'positive', 'neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'positive']


In [83]:
with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

['positive']


# P-tuning

## * Config

In [102]:
model_name_or_path = "bigscience/bloomz-560m"
task = "mrpc"
num_epochs = 20
lr = 1e-3
batch_size = 32

In [94]:
# https://huggingface.co/datasets/nyu-mll/glue
dataset = load_dataset("SetFit/mrpc")

Downloading readme: 100%|██████████| 316/316 [00:00<00:00, 856kB/s]


Downloading and preparing dataset json/SetFit--mrpc to /home/irteam/.cache/huggingface/datasets/SetFit___json/SetFit--mrpc-cf983d02a5b947c7/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]
Downloading data:   0%|          | 0.00/1.14M [00:00<?, ?B/s][A
Downloading data:   6%|▌         | 66.6k/1.14M [00:00<00:02, 406kB/s][A
Downloading data:  20%|█▉        | 224k/1.14M [00:00<00:01, 719kB/s] [A
Downloading data: 100%|██████████| 1.14M/1.14M [00:00<00:00, 2.17MB/s][A
Downloading data files:  33%|███▎      | 1/3 [00:01<00:02,  1.03s/it]
Downloading data:   0%|          | 0.00/127k [00:00<?, ?B/s][A
Downloading data: 100%|██████████| 127k/127k [00:00<00:00, 725kB/s] [A
Downloading data files:  67%|██████▋   | 2/3 [00:01<00:00,  1.26it/s]
Downloading data:   0%|          | 0.00/533k [00:00<?, ?B/s][A
Downloading data:  10%|▉         | 51.2k/533k [00:00<00:01, 302kB/s][A
Downloading data:  31%|███       | 165k/533k [00:00<00:00, 508kB/s] [A
Downloading data: 100%|██████████| 533k/533k [00:00<00:00, 1.01MB/s][A
Downloading data files: 100%|██████████| 3/3 [00:02<00:00,  1.14it/s]
Extracting data files: 100%|█

Dataset json downloaded and prepared to /home/irteam/.cache/huggingface/datasets/SetFit___json/SetFit--mrpc-cf983d02a5b947c7/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 541.57it/s]


## * Dataset

In [95]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text1', 'text2', 'label', 'idx', 'label_text'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['text1', 'text2', 'label', 'idx', 'label_text'],
        num_rows: 408
    })
    test: Dataset({
        features: ['text1', 'text2', 'label', 'idx', 'label_text'],
        num_rows: 1725
    })
})

In [96]:
dataset["train"][0]

{'text1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'text2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .',
 'label': 1,
 'idx': 0,
 'label_text': 'equivalent'}

In [124]:
# accuracy =  evaluate.load("accuracy")
# accuracy.compute(predictions=[0, 1, 1, 0], references=[0, 1, 0, 1])

Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 4.47MB/s]


{'accuracy': 0.5}

In [99]:
metric = evaluate.load("glue", task)

Downloading builder script: 100%|██████████| 5.75k/5.75k [00:00<00:00, 6.50MB/s]


In [101]:
import numpy as np

# Return dict or None:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [103]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_size="left")
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [112]:
def tokenize_function(examples):
    outputs = tokenizer(examples["text1"], examples["text2"], truncation=True, max_length=None)
    return outputs

In [113]:
tokenize_function(dataset["train"][0])

{'input_ids': [15144, 350, 3786, 41939, 3868, 44163, 630, 43944, 1683, 9487, 567, 368, 53134, 567, 630, 461, 188210, 4396, 656, 386, 3868, 27602, 503, 76110, 35414, 427, 6371, 661, 3804, 567, 368, 53134, 567, 630, 2883, 350, 3786, 41939, 3868, 44163, 461, 188210, 4396, 656, 386, 3868, 27602, 503], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [114]:
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text1", "text2", "idx"]
)

                                                                  

In [116]:
print(next(iter(tokenized_datasets["train"])))

{'label': 1, 'label_text': 'equivalent', 'input_ids': [15144, 350, 3786, 41939, 3868, 44163, 630, 43944, 1683, 9487, 567, 368, 53134, 567, 630, 461, 188210, 4396, 656, 386, 3868, 27602, 503, 76110, 35414, 427, 6371, 661, 3804, 567, 368, 53134, 567, 630, 2883, 350, 3786, 41939, 3868, 44163, 461, 188210, 4396, 656, 386, 3868, 27602, 503], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [119]:
# Pad the examples in the batches to the longest sequence in the batch:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="longest")

## * Train

In [153]:
# P-tuning uses a prompt encoder to optimize the prompt parameters:
#   num_virtual_tokens is the number of virtual tokens to use, or in other words, the prompt
peft_config = PromptEncoderConfig(
    task_type=TaskType.SEQ_CLS,
    num_virtual_tokens=20,
    encoder_hidden_size=128
)

In [154]:
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config).to(device)
model.print_trainable_parameters()

Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."


trainable params: 304384 || all params: 559518976 || trainable%: 0.054401014631539506


In [155]:
# Before P-tuning
classes = ["not equivalent", "equivalent"]

sentence1 = "Coast redwood trees are the tallest trees on the planet and can grow over 300 feet tall."
sentence2 = "The coast redwood trees, which can attain a height of over 300 feet, are the tallest trees on earth."

inputs = tokenizer(sentence1, sentence2, truncation=True, padding="longest", return_tensors="pt")

with torch.no_grad():
    inputs = {k: v.to(device) for k,v in inputs.items()}
    outputs = model(**inputs).logits
    print(outputs)
    
paraphrased_text = torch.softmax(outputs, dim=1).tolist()[0]
for i in range(len(classes)):
    print(f"{classes[i]}: {int(round(paraphrased_text[i] * 100))}%")

BloomForSequenceClassification will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`


tensor([[42.1092, -5.3616]], device='cuda:3')
not equivalent: 100%
equivalent: 0%


In [156]:
training_args = TrainingArguments(
    output_dir=root_path + "/bloomz-peft-p-tuning",
    learning_rate=1e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True
)

In [157]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [148]:
model = model.to(device)

In [151]:
# After P-tuing :
with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model(**inputs).logits
    print(outputs)
    
paraphrased_text = torch.softmax(outputs, dim=1).tolist()[0]
for i in range(len(classes)):
    print(f"{classes[i]}: {int(round(paraphrased_text[i] * 100))}%")

BloomForSequenceClassification will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`


tensor([[-13.7025, -11.5503]], device='cuda:3')
not equivalent: 10%
equivalent: 90%
