In [16]:
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch 
import transformers

from bs4 import BeautifulSoup
from lightning.pytorch.loggers import WandbLogger
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from torch import nn
from torch.utils.data import DataLoader,Dataset,RandomSampler, SequentialSampler

from typing import Tuple, Dict, Any


In [25]:
SEED = 42
DATASET = "boolq"
MODEL_NAME = "answerdotai/ModernBERT-base"
MINIBATCH_SIZE = 64
N_EPOCHS = 50
TEST_VAL_SET_SIZE = 0.15

df = pd.read_csv("datasets/csv/all_data.csv", low_memory=False, index_col=0)
# df = df.loc[df["dataset"] == DATASET]

models_to_remove = set(["llama_03B_32", "llama_70B_31"])
cols = set(df.columns.tolist())

df = df[list(cols - models_to_remove)]
display(df["dataset"].unique())
display(df.columns)

array(['mmlu_high_school_geography', 'mmlu_high_school_macroeconomics',
       'mmlu_philosophy', 'mmlu_professional_law', 'mmlu_world_religions',
       'mmlu_college_biology', 'mmlu_global_facts',
       'mmlu_high_school_statistics', 'mmlu_clinical_knowledge',
       'mmlu_high_school_psychology', 'mmlu_moral_disputes',
       'mmlu_professional_medicine', 'mmlu_us_foreign_policy',
       'mmlu_prehistory', 'mmlu_college_physics', 'boolq',
       'mmlu_machine_learning', 'mmlu_high_school_mathematics',
       'mmlu_college_computer_science', 'mmlu_logical_fallacies',
       'mmlu_professional_accounting',
       'mmlu_high_school_government_and_politics',
       'mmlu_high_school_world_history', 'mmlu_college_medicine',
       'mmlu_electrical_engineering', 'mmlu_high_school_chemistry',
       'mmlu_high_school_european_history', 'mmlu_conceptual_physics',
       'mmlu_college_chemistry', 'mmlu_human_sexuality',
       'mmlu_moral_scenarios', 'logiqa2', 'mmlu_econometrics',
       '

Index(['correct_response', 'dataset', 'choices', 'subject', 'input_text',
       'llama_08B_31', 'llama_01B_32', 'label', 'llama_70B_33'],
      dtype='object')

In [3]:
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModel.from_pretrained(MODEL_NAME)

In [4]:
def pre_process(text):
    text = BeautifulSoup(text).get_text()
    # fetch alphabetic characters
    text = re.sub("[^a-zA-Z]", " ", text)
    # convert text to lower case
    text = text.lower()
    # split text into words to remove whitespaces
    tokens = text.split()
    return " ".join(tokens)

In [5]:
x_train, x_test, y_train, y_test = train_test_split(df["input_text"], df["label"], test_size=TEST_VAL_SET_SIZE, random_state=SEED, shuffle=True)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=TEST_VAL_SET_SIZE, random_state=SEED, shuffle=False)

x_train = x_train.tolist()
y_train = y_train.tolist()

x_val = x_val.tolist()
y_val = y_val.tolist()

x_test = x_test.tolist()
y_test = y_test.tolist()

# Convert string elements to list
y_train = [eval(i) for i in y_train]
y_val = [eval(i) for i in y_val]
y_test = [eval(i) for i in y_test]

In [6]:
display(type(y_test[0]))

list

In [7]:
class ClassificationDataset(Dataset):

    def __init__(self, x: list, y: list, tokenizer):
        self.x = x
        self.y = y
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        inputs = self.tokenizer.encode_plus(
            self.x[idx],
            None,
            add_special_tokens=True,
            max_length=512,
            padding="max_length",
            return_token_type_ids=False,
            return_attention_mask=True,
            truncation=True,
            return_tensors="pt"
        )
               
        return {
            "input_ids": inputs["input_ids"].flatten() ,
            "attention_mask": inputs["attention_mask"].flatten(),
            "label": torch.tensor(self.y[idx], dtype=torch.float) 
        }    

# trainset = ClassificationDataset(x_train, y_train, tokenizer)
# display(next(iter(trainset)))

In [8]:
class MESSLightningDataloader(pl.LightningDataModule):

    def __init__(self, x_train: list, y_train: list, x_val: list, y_val: list, x_test: list, y_test: list, tokenizer, batch_size: int = 64):
        super().__init__()
        self.x_train = x_train
        self.y_train = y_train
        self.x_val = x_val
        self.y_val = y_val
        self.x_test = x_test
        self.y_test = y_test

        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.trainset = ClassificationDataset(x=self.x_train, y=self.y_train, tokenizer=self.tokenizer)
        self.valset = ClassificationDataset(x=self.x_val, y=self.y_val, tokenizer=self.tokenizer)
        self.testset = ClassificationDataset(x=self.x_test, y=self.y_test, tokenizer=self.tokenizer)
    
    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.valset, batch_size=self.batch_size, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.testset, batch_size=self.batch_size, shuffle=True)
    

In [9]:
def compute_scores(probs, labels, stage: str, class_indices: list, include_confusion_matrix: bool = True) -> Tuple[Dict, Any]:
    preds = torch.where(probs > 0.5, 1.0, 0.0)
    preds = preds.double()

    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()
    
    acc = accuracy_score(labels, preds, normalize=True)
    f1 = f1_score(labels, preds, average="weighted")
    prec = precision_score(labels, preds, average="weighted")
    reca = recall_score(labels, preds, average="weighted")

    confusion_mx = None
    if include_confusion_matrix:
        confusion_mx = confusion_matrix(labels, preds)
    
    return {f"{stage}/accuracy": acc, f"{stage}/f1": f1, f"{stage}/precision": prec, f"{stage}/recall": reca}, confusion_mx


class MESSPlusMLP(nn.Module):
    def __init__(self, hidden_size, num_labels: int, hidden_dim_shapes: list):
        super().__init__()

        layer_dims: list = hidden_dim_shapes
        layer_dims.insert(0, hidden_size)
        layer_dims.append(num_labels)

        self.mlp = nn.ModuleList([
            nn.Linear(hidden_size, layer_dims[1]),
            nn.ReLU()
        ])

        for idx, layer_dim in enumerate(layer_dims[2:]):
            linear_layers = [i for i in self.mlp if hasattr(i, "out_features")]
            self.mlp.append(
                nn.Linear(linear_layers[idx].out_features, layer_dim),
            )

            if not layer_dim == num_labels:
                self.mlp.append(
                    nn.ReLU()
                )

        self.mlp = nn.Sequential(*self.mlp)

    def forward(self, x):
        return self.mlp(x)


class MESSRouter(pl.LightningModule):
    
    def __init__(self, base_model, model_list: list, hidden_layer_shape: list, optim_name: str, n_classes=10, steps_per_epoch=None, n_epochs=3, lr=0.0001):
        super().__init__()
        
        self.backbone = base_model

        for param in self.backbone.parameters():
            self.backbone.requires_grad = False
        
        self.classifier = MESSPlusMLP(
            self.backbone.config.hidden_size,
            n_classes, 
            hidden_dim_shapes=hidden_layer_shape
        ) 
        
        self.steps_per_epoch = steps_per_epoch
        self.n_epochs = n_epochs
        self.lr = lr
        self.criterion = nn.BCELoss()

        self.model_list = model_list
        self.optim_name = optim_name

        # Metrics 
        self.metrics_df = pd.DataFrame()
        self.val_losses = []

        self.save_hyperparameters()
        
    def forward(self, input_ids, attn_mask):

        with torch.no_grad():
            out = self.backbone(input_ids=input_ids, attention_mask=attn_mask)

        # The output here is shaped as follows: (BATCH_SIZE, 768)
        out = self.mean_pooling(out.last_hidden_state, attn_mask)
        out = self.classifier(out)
        out = torch.sigmoid(out)
        return out

    @staticmethod
    def mean_pooling(model_output, attention_mask):
        # We need to make sure we only pool actual tokens and not padding.
        # Please see for reference: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
        return torch.sum(model_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        
        probs  = self(input_ids, attention_mask)    
        loss = self.criterion(probs, labels)
        self.log('train/loss',loss , prog_bar=True, logger=True)
        return {"loss": loss, "predictions": probs, "labels": labels}

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        
        probs = self(input_ids, attention_mask)
        loss = self.criterion(probs, labels)

        metrics, conf_mx = compute_scores(
            probs, 
            labels, 
            stage="val", 
            include_confusion_matrix=False, 
            class_indices=[]
        )

        metrics.update({
            "batch_idx": batch_idx, 
            "stage": "val"
        })

        self.metrics_df = pd.concat([self.metrics_df, pd.DataFrame(metrics, index=[0])])
        
        self.val_losses.append(loss)
        self.log("val_loss", loss, prog_bar=True, logger=False)
        return loss

    def on_validation_epoch_end(self):

        stage = "val"

        numeric_cols = self.metrics_df.columns.tolist()
        numeric_cols = [i for i in self.metrics_df.columns.tolist() if stage in i]
        epoch_metrics = self.metrics_df.loc[self.metrics_df["stage"] == "val", numeric_cols].mean()
        epoch_metrics = epoch_metrics.to_dict()
        epoch_metrics[f"{stage}/loss"] = np.mean([i.cpu().item() for i in self.val_losses])
    
        display(epoch_metrics)
        self.log_dict(epoch_metrics)
        
        self.metrics_df = pd.DataFrame()
        self.val_losses = []

    def test_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        
        outputs = self(input_ids,attention_mask)
        loss = self.criterion(outputs,labels)
        self.log('test_loss',loss , prog_bar=True,logger=True)
        
        return loss
    
    
    def configure_optimizers(self):
        if self.optim_name == "sgd":
            optimizer = torch.optim.SGD(self.parameters(), lr=self.lr, momentum=0.9)
        elif self.optim_name == "adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        else: 
            raise NotImplementedError("Optimizer not configured.")
        
        return optimizer

            
        # warmup_steps = self.steps_per_epoch // 3
        # total_steps = self.steps_per_epoch * self.n_epochs - warmup_steps

        # scheduler = transformers.get_linear_schedule_with_warmup(optimizer,warmup_steps,total_steps)

        return [optimizer]  # , [scheduler]

In [15]:
import yaml

sweep_config = {
    "method": "random", 
    "metric": {
        "name": "loss",
        "goal": "minimize"   
    },
    "parameters": {
        "optimizer": {
            "values": ["adam", "sgd"]
        },
        "hidden_layer_shape": {
            "values": [
                [128],
                [256],
                [512],
                [512, 256],
                [256, 128],
                [512, 256, 128]
            ]
        },
        "dropout": {"values": [0.3, 0.4, 0.5]},
        "epoch": {"values": [1, 5, 10, 25, 50]},
        "learning_rate": {
            # a flat distribution between 0 and 0.1
            "distribution": "uniform",
            "min": 0,
            "max": 0.1
          },
        "minibatch_size": {
            # integers between 32 and 256
            # with evenly-distributed logarithms 
            "distribution": "q_log_uniform_values",
            "q": 8,
            "min": 32,
            "max": 256,
          }
    },
    
}

with open('config/classifier_training/sweep_v1.yaml', 'w') as f:
    yaml.dump(sweep_config, f, default_flow_style=False)

In [None]:
def make_experiment(config: dict): 

    MODEL_LIST = [i for i in df.columns if "llama_" in i]
    NUM_MODELS = len(MODEL_LIST)
    
    MODEL = MESSRouter(
        base_model=model, 
        model_list=MODEL_LIST,
        n_classes=NUM_MODELS, 
        steps_per_epoch=len(x_train) / config.minibatch_size, 
        n_epochs=config.epoch, 
        lr=config.learning_rate, 
        hidden_layer_shape=config.hidden_layer_shape, 
        optim_name=config.optimizer
    )
    
    DATALOADER = MESSLightningDataloader(
        x_train=x_train, 
        y_train=y_train,
        x_val=x_val, 
        y_val=y_val,
        x_test=x_test,
        y_test=y_test, 
        tokenizer=tokenizer, 
        batch_size=config.minibatch_size
    )
    
    LOGGER = WandbLogger(
        project="mess-plus-classifier-hp-sweep-v01"
    )
    
    wandb.login()
    trainer = pl.Trainer(
        max_epochs = config.epoch, 
        accelerator="gpu",
        
    )
    trainer.fit(MODEL, DATALOADER)
    wandb.finish()