# **Glue Benchmark**

Using trainer from transformers. 

# Libraries 

In [None]:
import os
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding, AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification
from datasets import load_dataset
import gc

import wandb
import pandas as pd

wandb.init(mode="disabled", allow_val_change=True)

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = "1"
os.environ["WANDB_MODE"] = "disabled"

# Variables

In [None]:
TASKS = [
        "rte",
        "sst2",
        "mrpc",
        "stsb", 
        "cola",
        "wnli",
        "mnli_matched",
        "mnli_mismatched",
         "ax",
         "qnli",
         "qqp",
         ]

FEATURES = {
    "cola": ("sentence",),
    "sst2": ("sentence",),
    "mrpc": ("sentence1", "sentence2"),
    "stsb": ("sentence1", "sentence2"),
    "rte": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
    "mnli": ("premise", "hypothesis"),
    "mnli_matched": ("premise", "hypothesis"),
    "mnli_mismatched": ("premise", "hypothesis"),
    "ax": ("premise", "hypothesis"),
    "qnli": ("question", "question"),
    "qqp": ("question1", "question2"),
}

filenames = {
    "cola": "CoLA.tsv",
    "sst2": "SST-2.tsv",
    "mrpc": "MRPC.tsv",
    "qqp": "QQP.tsv",
    "stsb": "STS-B.tsv",
    "mnli_matched": "MNLI-m.tsv",
    "mnli_mismatched": "MNLI-mm.tsv",
    "qnli": "QNLI.tsv",
    "rte": "RTE.tsv",
    "wnli": "WNLI.tsv",
    "ax": "AX.tsv",
}

labelnames = {
    "mnli_matched": ["entailment", "neutral", "contradiction"],
    "mnli_mismatched": ["entailment", "neutral", "contradiction"],
    "ax": ["entailment", "neutral", "contradiction"],
    "qnli": ["entailment", "not_entailment"],
    "rte": ["entailment", "not_entailment"],
}

# Functions

In [None]:
def get_datasets(task_name):
  if task_name in ("ax", "mnli_matched", "mnli_mismatched"):
    dataset = load_dataset("glue", "mnli")
    train_ds, validation_ds = dataset["train"], dataset["validation_matched"]

    if task_name == "ax":
      dataset = load_dataset("glue", "ax")
      test_ds = dataset["test"]

    elif task_name == "mnli_matched":
      dataset = load_dataset("glue", "mnli_matched")
      validation_ds, test_ds = dataset["validation"], dataset["test"]

    elif task_name == "mnli_mismatched":
      dataset = load_dataset("glue", "mnli_mismatched")
      validation_ds, test_ds = dataset["validation"], dataset["test"]

  elif task_name != "mnli":
    dataset = load_dataset("glue", task_name)
    train_ds, validation_ds, test_ds = dataset["train"], dataset["validation"], dataset["test"]

  return train_ds, validation_ds, test_ds

def preprocess_input_features(sample, task_name):
    if len(FEATURES[task_name]) == 1:
        input_text = sample[FEATURES[task_name][0]]
    elif len(FEATURES[task_name]) == 2:
        input_text = f"task_name {FEATURES[task_name][0]}: {sample[FEATURES[task_name][0]]}, {FEATURES[task_name][0]}: {sample[FEATURES[task_name][1]]}"

    return {'input_text' : input_text}

def preprocess_function(examples, task_name):
    max_length = 512
    tokenized = tokenizer(examples['input_text'], truncation=TRUNCATION, padding=PADDING, max_length=max_length, return_tensors="pt")
    return tokenized

def get_num_labels(task_name):
    if task_name in ("ax", "mnli_matched", "mnli_mismatched"):
        return 3
    elif task_name == "stsb":
        return 1
    else:
        return 2

def finetune(task_name, train_ds, validation_ds):
    print("\n ▶ Starting Finetuning... \n")

    # Tokenize datasets
    print("\n\t  ▶ Preparing datasets... \n")
    train_ds = train_ds.map(lambda x: preprocess_input_features(x, task_name), batched=False)
    validation_ds = validation_ds.map(lambda x: preprocess_input_features(x, task_name), batched=False)
    train_ds = train_ds.map(lambda x: preprocess_function(x, task_name), batched=True)
    validation_ds = validation_ds.map(lambda x: preprocess_function(x, task_name), batched=True)
    print("\n\t  ▶ End of Preparing datasets. \n")

    # Format datasets for PyTorch
    train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    validation_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

    # Training Arguments
    training_args = TrainingArguments(
        output_dir=f'./results/{task_name}',
        evaluation_strategy="no",
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        logging_dir='./logs',
        learning_rate=2e-5,
        max_steps=200,
    )

    # Device setup
    device = torch.device("cuda" if CUDA else "cpu")
    model.to(device)

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=validation_ds,
        data_collator=data_collator
    )

    # Start training
    print("\n\t ▶ Starting Training \n")
    trainer.train()
    print("\n\t ▶ Finished  Training \n")

    # Save the model
    model_output_dir = f'./results/{task_name}/model'
    trainer.save_model(model_output_dir)

    return trainer

def get_predictions(task_name, test_ds, trainer):
    print("\n\t  ▶ Starting Test Predictions \n")
    submission_directory = "glue_submissions"
    if not os.path.exists(submission_directory):
        os.makedirs(submission_directory)

    filename = os.path.join(submission_directory, filenames[task_name])
    labelname = labelnames.get(task_name)

    # Tokenize test dataset
    test_ds = test_ds.map(lambda x: preprocess_input_features(x, task_name), batched=False)

    test_ds = test_ds.map(lambda x: preprocess_function(x, task_name), batched=True)
    test_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # Get predictions
    predictions = trainer.predict(test_ds)
    pred_labels = torch.argmax(torch.tensor(predictions.predictions), dim=-1)

    predictions_list = []
    for pred in pred_labels:
        if labelname:
            pred_label = labelname[int(pred)]
            predictions_list.append(pred_label)

        elif task_name == "stsb":
            pred_label = min(max(pred.item(), 0), 5)  # For tasks like STSB
            pred_label = f"{pred_label:.3f}"
            predictions_list.append(pred_label)
        else:
            predictions_list.append(int(pred))

    # Save predictions to file
    output = pd.DataFrame({
        'index': range(len(predictions_list)),
        'prediction': predictions_list
    })

    output.to_csv(filename, index=False, sep="\t")
    print("\n\t  ▶ End of Generating Predictions \n")


# Main Program

In [None]:
# Define Constants
PADDING = 'max_length'
TRUNCATION = 'longest_first'
CUDA = torch.cuda.is_available()  # Whether to use GPU or CPU

wandb.init(mode="disabled")
CUDA = torch.cuda.is_available()

model_name = "princeton-nlp/mabel-bert-base-uncased"
# model_name = "bert-base-uncased"

if __name__ == "__main__":
    torch.cuda.empty_cache()
    gc.collect()

    for task_name in TASKS:
        print(f"\n ▶ FINETUNING FOR TASK: {task_name} \n")

        num_labels = get_num_labels(task_name)
        
        model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        tokenizer = BertTokenizer.from_pretrained(model_name)

        data_collator = DataCollatorWithPadding(tokenizer)

        train_ds, validation_ds, test_ds = get_datasets(task_name)

        trainer = finetune(task_name, train_ds, validation_ds)

        get_predictions(task_name, test_ds, trainer)

        torch.cuda.empty_cache()
        gc.collect()