In [51]:
from datasets import load_from_disk
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    set_seed
)
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from src.readers.collators import SimpleCollator
from torch.utils.data import DataLoader
from torch.optim import AdamW 
import lightning as pl


class PeftCALMT5(pl.LightningModule):

    def __init__(self, model_alias: str, tokenizer_alias: str):

        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_alias)

        self.peft_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM, 
            inference_mode=False, 
            target_modules=["q", "k", "v"],
            r=8, 
            lora_alpha=32, 
            lora_dropout=0.5
        )

        model = AutoModelForSeq2SeqLM.from_pretrained(model_alias)
        self.model = get_peft_model(model, self.peft_config)
        

    def training_step(self, batch, batch_idx): 
        outputs = self.model.forward(**batch, return_dict=True)
        loss = outputs["loss"]  
        
        self.log("train_loss", loss,  prog_bar=True, on_step=True, on_epoch=True)     
        return loss
    
    def validation_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs["loss"]  
        
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) 
        
    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=5e-4)
        return optimizer

tokenizer_alias = "google/flan-t5-base"
#model_alias = "google/flan-t5-base"
#model_alias = "models/pretraining/flan-t5-concept"

ablations_map = [
    "models/pretraining/flan-t5-synthetic-qa-ablation-causal",
    "models/pretraining/flan-t5-synthetic-qa-ablation-contributory",
    "models/pretraining/flan-t5-synthetic-qa-ablation-impact",
    "models/pretraining/flan-t5-synthetic-qa-ablation-intent",
    "models/pretraining/flan-t5-synthetic-qa-ablation-temporal"

]


model_alias = ablations_map[]


pl.seed_everything(42)
set_seed(42)
model = PeftCALMT5(model_alias, tokenizer_alias)
tokenizer = model.tokenizer

Seed set to 42


# Pretrain

In [None]:
from src.readers.collators import ConceptCollator, SimpleCollator, RandomMLMCollator

synthetic_collator = SimpleCollator(tokenizer, {"input_column": "input", "output_column": "output"})
mlm_collator = RandomMLMCollator(tokenizer)
concept_collator = ConceptCollator(tokenizer)
pretrain_ds = load_from_disk("data/pretraining/synthetic-qa")

pretrain_train = pretrain_ds["train"]
pretrain_train_dl = DataLoader(
    pretrain_train,
    batch_size=32, 
    pin_memory=True,
    shuffle=True,
    num_workers=4,
    collate_fn=synthetic_collator
)

pretrain_eval = pretrain_ds["test"]
pretrain_eval_dl = DataLoader(
    pretrain_train,
    batch_size=32, 
    pin_memory=True,
    shuffle=False,
    num_workers=4,
    collate_fn=synthetic_collator
)

In [None]:
trainer = pl.Trainer(
  max_epochs=3,
  devices=1, 
  accelerator="gpu",
)

trainer.fit(model, pretrain_train_dl, pretrain_eval_dl)

In [None]:
peft_model = model.model
merged = peft_model.merge_and_unload()
merged.save_pretrained("models/pretraining/flan-t5-concept-3-epoch")

# Task Finetuning

In [52]:
collator = SimpleCollator(tokenizer, {"input_column": "input", "output_column": "output_text"})
ds = load_from_disk("data/calm-bench/datasets/copa")
train_ds = ds.filter(lambda x: x['split'] == 'train')
train_dl = DataLoader(
    train_ds, 
    batch_size=32, 
    pin_memory=True, 
    shuffle=True,
    num_workers=4, 
    collate_fn=collator
)

val_ds = ds.filter(lambda x: x['split'] == 'val')
val_dl = DataLoader(
    val_ds, 
    batch_size=16, 
    pin_memory=True, 
    shuffle=False,
    num_workers=4, 
    collate_fn=collator
)


test_ds = ds.filter(lambda x: x['split'] == 'test')
test_dl = DataLoader(
    test_ds, 
    batch_size=32, 
    pin_memory=True, 
    shuffle=False,
    num_workers=4, 
    
    collate_fn=collator 
)

test_df = test_ds.to_pandas()

In [53]:
trainer = pl.Trainer(
  max_epochs=3,
  devices=1, 
  accelerator="gpu",
)

trainer.fit(model, train_dl, val_dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                  | Params
------------------------------------------------
0 | model | PeftModelForSeq2SeqLM | 248 M 
------------------------------------------------
1.3 M     Trainable params
247 M     Non-trainable params
248 M     Total params
995.620   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/ubuntu/anaconda3/envs/nlp-env/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [54]:
def get_choice(input_string):
    input_string = input_string.strip().lower()
    
    
    if input_string.startswith('a') or input_string.startswith('a)'):
        return 'a'
    elif input_string.startswith('b') or input_string.startswith('b)'):
        return 'b'
    elif input_string.startswith('c)') or input_string.startswith('c'):
        return 'c'
    elif input_string.startswith('d)') or input_string.startswith('d'):         
        return 'd'
    else:
        return 'a'


# Eval Wiqa

In [None]:
predictor = model.model
predictor.eval()
predictor = predictor.to("cuda")

In [None]:
import pandas as pd 
import tqdm.notebook as tqdm
from ast import literal_eval
import numpy as np 
import re

test = pd.read_csv("data/calm-bench/tasks/wiqa-updated-test2.csv")


rows = []
for i, row in tqdm.tqdm(test.iterrows()):

    try:
        steps = literal_eval(row["question_para_step"])
        context = [". ".join(tup) for tup in zip(np.arange(1, len(steps)+1).astype(str), steps) ]
        context = " ".join(context)

        question = row["question"]
        if question[-1] != "?":
            question += "?"
        
        input_ = f"context: {context}\nquestion: {question}\noptions: a) more b) less c) no effect".lower()
        output_text = f"a) more" if row["answer_label"] == "more" else f"b) less" if row["answer_label"] == "less" else f"c) no effect"
        output_label = "a" if row["updated_answer"] == "more" else "b" if row["updated_answer"] == "less" else "c"


        ii = tokenizer.encode(input_, return_tensors="pt").to("cuda")
        pred = predictor.generate(input_ids = ii)
        pred = tokenizer.batch_decode(pred, skip_special_tokens=True)[0]


        pred_label = get_choice(pred)

        r = row.to_dict()
        r["pred"] = pred
        r["pred_label"] = pred_label
        r["is_correct"] = int(pred_label == output_label)

        rows.append(r)
    except:
        pass

p = pd.DataFrame(rows)
p["is_correct"].mean() 

# Eval

In [55]:
predictor = model.model
predictor.eval()
predictor = predictor.to("cuda")

all_preds = []
for batch in test_dl:

    batch = {k: v.to("cuda") for k, v in batch.items()}

    preds = predictor.generate(**batch)
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    all_preds.extend(preds)



In [56]:
test_df["preds"] = all_preds
test_df["choice"] = test_df["preds"].apply(get_choice)
test_df["is_correct"] = test_df.apply(lambda x: int(x["preds"] == x["output_text"]), axis=1)

test_df["is_correct"].mean()

0.81