## Fine Tuning
Had to switch from my lab VM to Google Collab for this since I needed a GPU.

In [None]:
!pip install transformers datasets evaluate transformers[torch] py7zr peft wandb

### Full fine-tuning for summarization

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import wandb

In [None]:
## Load model and tokenizer
BASE_MODEL = "facebook/bart-large-cnn"

base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)

#### Load dataset

In [None]:
dataset = load_dataset("knkarthick/samsum")

## Clean dataset
dataset = dataset.remove_columns(['id'])
dataset = dataset.filter(lambda example: example['dialogue'] is not None)

## Shrink dataset for training
PERCENT = 0.3

dataset['train'] = dataset['train'].shuffle(seed=42).select(range(int(len(dataset['train'])*PERCENT)))
dataset['test'] = dataset['test'].shuffle(seed=37).select(range(int(len(dataset['test'])*PERCENT)))
dataset['validation'] = dataset['validation'].shuffle(seed=4).select(range(int(len(dataset['validation'])*PERCENT)))

dataset

#### Test summarization of base model

In [None]:
SAMPLE_DATA = dataset['test'][0]

def generate_summary(input, model, tokenizer, isPeft=False):
    sample = input['dialogue']
    label = input['summary']

    prompt = f"""
    Summarize the following conversation.

    {sample}

    Summary:
    """
    input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
    output = model.generate(input_ids=input_ids["input_ids"], max_new_tokens=200)
    output = tokenizer.decode(output[0], skip_special_tokens=True)

    print("Sample")
    print(sample)
    print("----------------------------------------")
    print("Model Generated Summary")
    print(output)
    print("Correct Summary")
    print(label)

In [None]:
generate_summary(SAMPLE_DATA, base_model, base_tokenizer)

#### Prepare the dataset

In [None]:
def tokenize_inputs(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '

    # Tokenize inputs
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    model_inputs = tokenizer(prompt, padding="max_length", max_length=200, truncation=True)

    # Tokenize labels
    labels = tokenizer(example["summary"], padding="max_length", max_length=200, truncation=True)

    labels["input_ids"] = [
        [(label if label != tokenizer.pad_token_id else -100) for label in label_seq]
        for label_seq in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenizer = base_tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenized_dataset = dataset.map(tokenize_inputs, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['dialogue', 'summary'])
tokenized_dataset

#### Start training

In [None]:
from huggingface_hub import notebook_login
notebook_login()

HF_USER = "shayharding"
FT_MODEL = "bart-samsum-finetuned"

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./" + FT_MODEL,
    hub_model_id=HF_USER + "/" + FT_MODEL,
    learning_rate=5e-6,
    num_train_epochs=1,
    weight_decay=0.01,
    auto_find_batch_size=True,
    eval_strategy="epoch",
    logging_steps=10
)

trainer = Trainer(
    model=base_model,
    processing_class=base_tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"]
)

wandb.init(project=FT_MODEL)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

#### Test the full fine-tuned model

In [None]:
ft_tokenizer = AutoTokenizer.from_pretrained(HF_USER + "/" + FT_MODEL)
ft_model = AutoModelForSeq2SeqLM.from_pretrained(HF_USER + "/" + FT_MODEL)

generate_summary(SAMPLE_DATA, ft_model, ft_tokenizer)

### Create PEFT model using LoRA

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

peft_model = get_peft_model(ft_model, lora_config)

PEFT_MODEL = "bart-samsum-peft"

In [None]:
training_args = TrainingArguments(
    output_dir="./" + PEFT_MODEL,
    hub_model_id=HF_USER + "/" + PEFT_MODEL,
    learning_rate=5e-6,
    num_train_epochs=2,
    weight_decay=0.01,
    auto_find_batch_size=True,
    eval_strategy="epoch",
    logging_steps=10
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"]
)

wandb.init(project=PEFT_MODEL)

In [None]:
peft_model.print_trainable_parameters()

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

#### Test the PEFT model

In [None]:
from peft import PeftModel

loaded_peft_model = PeftModel.from_pretrained(ft_model, HF_USER + "/" + PEFT_MODEL, is_trainable=False)

In [None]:
generate_summary(SAMPLE_DATA, loaded_peft_model, ft_tokenizer)