<a href="https://colab.research.google.com/github/goddoe/hacking-llms-for-low-res-settings/blob/main/p_tuning_qa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import argparse
import os

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    TaskType,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
)

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from tqdm import tqdm

In [None]:
batch_size = 1
model_name_or_path = "EleutherAI/polyglot-ko-1.3b"
peft_type = PeftType.PREFIX_TUNING
device = "cuda"
num_epochs = 5

dataset_name = "heegyu/korquad-chat-v1"
max_length = 2048

In [None]:
peft_config = PromptEncoderConfig(task_type=TaskType.CAUSAL_LM,
                                  num_virtual_tokens=20,
                                  encoder_hidden_size=128)
lr = 1e-2

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [None]:
def tokenize_function(examples):
    # max_length=None => use the model max length (it's actually the default)
    outputs = tokenizer(examples["text"], truncation=True, max_length=None)
    return outputs

def collate_fn(examples):
    return tokenizer.pad(examples, padding="longest", return_tensors="pt")

dataset = load_dataset(dataset_name)


tokenized_dataset = dataset.map(tokenize_function,
                                batched=True,
                                remove_columns=["source", "text"])

td = tokenized_dataset['train'].train_test_split(train_size=0.8)

train_dataloader = DataLoader(td['train'],
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=collate_fn)
eval_dataloader = DataLoader(td['test'],
                             batch_size=batch_size,
                             collate_fn=collate_fn)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model

In [None]:
def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"], weight_decay=0.1):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]
    

def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather(outputs.loss))
    loss = torch.mean(torch.stack(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [None]:
optimizer = AdamW(get_grouped_params(model), lr=lr)

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,  # 0.06*(len(train_dataloader) * num_epochs),
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [None]:
model.to(device)

best_model_path = "./outputs/best_p_tuning_model"
min_valid_ppl = 9999999.

for epoch in range(num_epochs):
    model.train()
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch.to(device)
        outputs = model(batch['input_ids'], labels=batch['input_ids'])
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()


    eval_loss, perplexity = evaluate()
    eval_metric = {"loss/eval": eval_loss, "perplexity": perplexity}

    print(f"epoch {epoch}:", eval_metric)
    if eval_metric['perplexity'] <= min_valid_ppl:
        best_acc = eval_metric['perplexity']
        model.save_pretrained(best_model_path)
        tokenizer.save_pretrained(best_model_path)

# Load and Inference

In [None]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

In [None]:
config = PeftConfig.from_pretrained(best_model_path)
inference_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
inference_model = PeftModel.from_pretrained(inference_model, best_model_path)
inference_model.eval()

In [None]:
generator = pipeline("text-generation",
                     model=inference_model,
                     tokenizer=tokenizer,
                     device=0)

In [None]:
prompt = "<sys>1839년 바그너는 괴테의 파우스트을 처음 읽고 그 내용에 마음이 끌려 이를 소재로 해서 하나의 교향곡을 쓰려는 뜻을 갖는다.\n<usr>"

In [None]:
bot_text = generator(f"{prompt} 바그너가 1839년에 파우스트를 소재로 한 교향곡 작곡을 시작했다는데, 왜 이 소재에 마음이 끌렸을까?\n<bot>",
                     max_new_tokens=128, 
                     return_full_text=False)

In [None]:
print(bot_text)

In [None]:
print(bot_text[0]['generated_text'].split("<usr>")[0])

In [None]:
dataset['train']['text'][0]