### Imports

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()
huggingface_login_key = os.getenv("HUGGINGFACE_LOGIN_KEY")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install -q pyarrow==12.0.0
!pip install -q -U accelerate
!pip install -q transformers
!pip install -q tdqm
!pip install -q torch
!pip install -q -U bitsandbytes
!pip install -q -U evaluate
!pip install -q cohere
!pip install -q huggingface_hub
!pip install -q trl

In [None]:
from huggingface_hub import login

login()

In [None]:
import torch

torch.device("cuda:0")

In [None]:
torch.cuda.empty_cache()

In [None]:
# torch.cuda.set_per_process_memory_fraction(0.8, 0)

In [None]:
# import os
# import sys
# from google.colab import userdata

# # ADD A PAT https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens
# os.environ['GITHUB_TOKEN'] = userdata.get('GITHUB_TOKEN')  # put the token in colab secret keys
# os.environ['REPOSITORY'] =  "github.com/jth500/maet-pln.git"
# !git clone https://${{GITHUB_TOKEN}}@${{REPOSITORY}}
# !mv maet-pln maet_pln # rename to remove the invalid dash
# sys.path.append("maet_pln/src")


In [None]:
from data_handler import BARTDatasetHandler
from model_builder import BARTModelBuilder
from tokenization import TokenizationHandler

### Process Data

In [None]:
bart_tk_handler = TokenizationHandler("facebook/bart-large")
bart_tk_handler.model_id

In [None]:
bart_tk_handler.create_tokenizer()
tokenizer = bart_tk_handler.tokenizer

In [None]:
tokenizer.eos_token

In [None]:
# SET UP DATASET
dataset_name = "EdinburghNLP/xsum"
data_handler = BARTDatasetHandler(dataset_name, tokenizer)
sft_train_data, rlaif_train_data, val_data = data_handler.process_data(input_label="document", target_label="summary")

In [None]:
sft_train_data

In [None]:
rlaif_train_data

In [None]:
val_data

In [None]:
# random test
import random
idx = random.randint(0, len(sft_train_data)-1)
tokenizer.decode(sft_train_data[idx]['labels'])

### SFT

In [None]:
from sft import SFT

In [None]:
# SET UP BASE MODEL
base_model_id = "facebook/bart-large"
model_builder = BARTModelBuilder(model_id=base_model_id, tokenizer=tokenizer, rlaif=False)
base_model = model_builder.base_model

In [None]:
from transformers import GenerationConfig

input_ids = torch.tensor(val_data["input_ids"][0]).unsqueeze(0).to('cuda')
# attention_mask = torch.tensor(val_data["attention_mask"][0]).unsqueeze(0).to('cuda')
generation_config = GenerationConfig(
    do_sample=True,
    temperature=0.8,
    top_p=0.3,
    num_beams=1,
    max_new_tokens=50,
    repetition_penalty=0.8
)
with torch.no_grad():
    generation_output = base_model.generate(
        input_ids=input_ids,
        # attention_mask=attention_mask,
        pad_token_id=base_model.config.pad_token_id,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
    )
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)

In [None]:
output

In [None]:
stf_trainer = SFT(
    save_dir="sft-bart-xsum-2303",
    tokenizer=tokenizer,
    base_model=model_builder.base_model,
    train_dataset=sft_train_data,
    train_epochs=0.2,
    )

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
stf_trainer.train_model()

In [None]:
stf_trainer.push_model_to_hub()

In [None]:
torch.cuda.empty_cache()

### RLAIF

In [None]:
from rlaif import RLAIF

sft_model = "ijwatson98/sft-bart-xsum-2303"
rlaif_model = "ijwatson98/rlaif-bart-xsum-2303"
rlaif_trainer = RLAIF(sft_model, tokenizer, rlaif_model, rlaif_train_data)

In [None]:
import warnings

# Filter warnings by message content
warnings.filterwarnings("ignore", message="A <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'> model is loaded from")
warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set.")
warnings.filterwarnings("ignore", message="Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.")
warnings.filterwarnings("ignore", message="A decoder-only architecture is being used, but right-padding was detected!")
warnings.filterwarnings("ignore", message="Your text contains a trailing whitespace, which has been trimmed to ensure high quality generations.")

In [None]:
kl, returns, advantages = rlaif_trainer.train_model(max_ppo_steps=10)

In [None]:
rlaif_trainer.push_model_to_hub()

### Inference

In [None]:
sft_model_id = "ijwatson98/sft-bart-xsum-2303"
sft_model_builder = BARTModelBuilder(model_id=sft_model_id, tokenizer=tokenizer)
sft_model = sft_model_builder.base_model

In [None]:
rlaif_model_id = "ijwatson98/rlaif-bart-xsum-2303"
rlaif_model_builder = BARTModelBuilder(model_id=rlaif_model_id, tokenizer=tokenizer)
rlaif_model = rlaif_model_builder.base_model

In [None]:
from inference import Inference

In [None]:
sft_inference = Inference(sft_model_builder.base_model, tokenizer, val_data)

In [None]:
rlaif_inference = Inference(rlaif_model_builder.base_model, tokenizer, val_data)

In [None]:
from transformers import GenerationConfig

input_ids = torch.tensor(val_data["input_ids"][2]).unsqueeze(0).to('cuda')
# attention_mask = torch.tensor(val_data["attention_mask"][2]).unsqueeze(0).to('cuda')
generation_config = GenerationConfig(
    do_sample=True,
    temperature=0.8,
    top_p=0.3,
    num_beams=1,
    max_new_tokens=50,
)
with torch.no_grad():
    generation_output = sft_model.generate(
        input_ids=input_ids,
        # attention_mask=attention_mask,
        pad_token_id=base_model.config.pad_token_id,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
    )
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)

In [None]:
val_data["input"][2]

In [None]:
output

In [None]:
posts, sft_model_summaries, true_summaries = sft_inference.sample_inference(sample_size=10)

In [None]:
_, rlaif_model_summaries, _ = rlaif_inference.sample_inference(sample_size=10)

In [None]:
true_summaries[5]

In [None]:
sft_model_summaries[5]

In [None]:
rlaif_model_summaries[5]