# Contradictory sentences - baseline model
Create a baseline model for contradiction classification

Because this dataset is multi-lingual, we need to choose the best in class language model that is readily trainable (on kaggle TPUs?). One possibility is the [`XLM-RoBERTa`](https://huggingface.co/tomaarsen/span-marker-xlm-roberta-base-multinerd) model, but this model has fallen out of favor due to major tokenization limitations. The preferred model for multilanguage NER is this SpanMarker model using xlm-roberta-base as the underlying encoder, trained on the multinerd dataset: [`span-marker-xlm-roberta-base-multinerd`](https://huggingface.co/tomaarsen/span-marker-xlm-roberta-base-multinerd). The problem is that I wont want to be doing NER, I want to be doing sentence comparison.

A reasonable starting point is just the base [`XLM-RoBERTa`](https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/xlm-roberta) model.

This kaggle challenge was started as a reason to learn to use TPUs. You can use TPUs in PyTorch with the [`torch_xla`](https://pytorch.org/xla/release/2.0/index.html) package. See how to use it in this example kaggle code [here](https://www.kaggle.com/code/tanlikesmath/the-ultimate-pytorch-tpu-tutorial-jigsaw-xlm-r).  
For now, we will stick with CPU/GPU. Double check Apple silicon MPS devices [[ref]](https://developer.apple.com/metal/pytorch/).



In [1]:
# imports
import os
from pathlib import Path
import warnings
import time
from types import SimpleNamespace  # a wrapper around a datadict

import pandas as pd
import numpy as np
import kaggle
import wandb
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import lr_scheduler
from transformers import XLMRobertaForSequenceClassification
from transformers import (
    TrainingArguments, Trainer, DataCollatorWithPadding,
    XLMRobertaTokenizer, XLMRobertaModel, XLMRobertaConfig)
from datasets import Dataset, DatasetDict
import evaluate

from utils import *

warnings.filterwarnings('ignore')

In [2]:
# Constants
DATA_PATH = "data"
WANDB_PROJECT = "contradictory"
RAW_DATA_AT = "contra_raw"
PROCESSED_DATA_AT = "contra_split"

In [3]:
device = "cpu"
if torch.cuda.is_available():
    print("Found GPU: ", torch.cuda.device_count())
    device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    mps_device = torch.device("mps")
    print("Found MPS, may not work on some torch ops!" )
    device = "mps"

torch.device(device)

Found GPU:  1


device(type='cuda')

In [4]:
train_config = SimpleNamespace(
    framework="torch",
    batch_size=16,
    num_epochs=1,
    lr=1e-5,
    arch="xlm-roberta-base",
    seed=SEED,
    log_preds=True,
    classifier_dropout=0.0,
    id2label={0: "entailment", 1: "neutral", 2: "contradiction"},
    label2id={v:k for k,v in id2label.items()}
)

In [5]:

# load the HF accuracy fn
accuracy_fn = evaluate.load("accuracy")


In [6]:
tokenizer = XLMRobertaTokenizer.from_pretrained(train_config.arch)

In [7]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    

In [8]:
def download_data():
    processed_data_at = wandb.use_artifact(f'{PROCESSED_DATA_AT}:latest')
    processed_dataset_dir = Path(processed_data_at.download())
    return processed_dataset_dir

def get_df(processed_dataset_dir, is_test=False):
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')
    if not is_test:
        # drop test for now, split in valid & train
        df = df[df.Stage != 'test'].reset_index(drop=True)
        df['is_valid'] = df.Stage == 'valid'
    else:
        df = df[df.Stage == 'test'].reset_index(drop=True)
    return df


In [9]:
def tokenize_function_batch(examples):
    tokenized_examples = tokenizer(examples["premise"], examples["hypothesis"], 
                                   truncation=True, padding=True, return_tensors="pt",)
    return tokenized_examples

def get_data(df):
    """
    Load the data from df into a dataset
    This is a bit more important if we are loading images/labels
    """
    train_dataset = Dataset.from_pandas(df[df["is_valid"]!=True])
    valid_dataset = Dataset.from_pandas(df[df["is_valid"]])
    datasets = DatasetDict({"train": train_dataset, "validation": valid_dataset})
    tokenized_datasets = datasets.map(tokenize_function_batch, batched=True)
    return tokenized_datasets

In [10]:
def create_predictions_table(dataset, trainer, id2label):
    """Creates a wandb table with predictions and targets side by side"""
    predictions = trainer.predict(dataset, metric_key_prefix="validate")
    X_pred = np.argmax(predictions.predictions, axis=1)
    y_labels = predictions.label_ids
    if not np.array_equal(y_labels, [dataset[i]["label"] for i in range(len(dataset))]):
        raise ValueError("prediction labels do not match dataset labels")
    
    col_names = ["id", "premise", "hypothesis", "lang_abv", "label", "predict"]

    data_out = []
    for i, sample in tqdm(enumerate(dataset)):
        data_out.append({
            col:sample[col] for col in col_names[:-1]})
        data_out[-1][col_names[-1]] = X_pred[i]  # add the predict field
    data_df = pd.DataFrame.from_records(data_out)

    # add the positive match column, True if matched target label, false otherwise
    data_df["is_correct"] = (data_df["label"]==data_df["predict"]).astype(int)
    
    table = wandb.Table(data=data_df)
    wandb.log({"pred_table":table})
    return table

In [11]:
def log_final_metrics(trainer):
    scores = trainer.evaluate()
    for k,v in scores.items():
        wandb.summary[k] = v

In [12]:

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    ent_ix = np.where(labels==label2id["entailment"])[0]
    neut_ix = np.where(labels==label2id["neutral"])[0]
    contra_ix = np.where(labels==label2id["contradiction"])[0]
    metrics = {
        "accuracy": accuracy_fn.compute(
            predictions=predictions, references=labels)["accuracy"],
        "acc_entailment": accuracy_fn.compute(
            predictions=predictions[ent_ix], references=labels[ent_ix])["accuracy"],
        "acc_neutral": accuracy_fn.compute(
            predictions=predictions[neut_ix], references=labels[neut_ix])["accuracy"],
        "acc_contradiction": accuracy_fn.compute(
            predictions=predictions[contra_ix], references=labels[contra_ix])["accuracy"],
    }
    return metrics

In [13]:
def train(config):
    seed_everything(SEED)
    # init wandb
    run = wandb.init(project=WANDB_PROJECT, entity=None, job_type="training", config=config)

    processed_dataset_dir = download_data()
    df = get_df(processed_dataset_dir)
    tokenized_datasets = get_data(df)  # more space in this for hyperparameters

    config = wandb.config  # reload the instance config

    num_labels = len(np.unique(tokenized_datasets['train']["label"]))
    # fixed model arch for now
    xlm_roberta_config = XLMRobertaConfig.from_pretrained(config.arch)
    # set dropout prob
    # xlm_roberta_config.classifier_dropout = config.classifier_dropout

    # import pdb; pdb.set_trace()
    model = XLMRobertaForSequenceClassification.from_pretrained(config.arch, num_labels=num_labels)

    output_dir = os.path.join(DATA_PATH, f"contradiction-training-{str(int(time.time()))}")

    trainer_config = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=config.lr,
        num_train_epochs=config.num_epochs,
        weight_decay=0.01,
        logging_steps=1,
        report_to="wandb",  # enable logging to W&B
        # run_name=f"{MODEL_NAME}-baseline",  # name of the W&B run (optional)
    )

    # set up the trainer
    trainer = Trainer(
        model=model,
        args=trainer_config,
        train_dataset=tokenized_datasets['train'],
        eval_dataset=tokenized_datasets['validation'],
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        compute_metrics=compute_metrics,
    )
    
    # train it!
    model_trained = trainer.train()
    
    table = create_predictions_table(tokenized_datasets['validation'], trainer, id2label)
    
    
    wandb.finish()
    return trainer

## Run the training

In [14]:
trainer = train(train_config)

[34m[1mwandb[0m: Currently logged in as: [33mmpesavento[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   4 of 4 files downloaded.  


Map:   0%|          | 0/9696 [00:00<?, ? examples/s]

Map:   0%|          | 0/1212 [00:00<?, ? examples/s]

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bias', 'classifier.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Acc Entailment,Acc Neutral,Acc Contradiction
1,0.8647,0.972319,0.541254,0.543062,0.579897,0.502463


1212it [00:00, 4014.34it/s]


VBox(children=(Label(value='0.824 MB of 0.824 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/acc_contradiction,▁
eval/acc_entailment,▁
eval/acc_neutral,▁
eval/accuracy,▁
eval/loss,▁
eval/runtime,▁
eval/samples_per_second,▁
eval/steps_per_second,▁
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███

0,1
eval/acc_contradiction,0.50246
eval/acc_entailment,0.54306
eval/acc_neutral,0.5799
eval/accuracy,0.54125
eval/loss,0.97232
eval/runtime,8.4666
eval/samples_per_second,143.151
eval/steps_per_second,17.953
train/epoch,1.0
train/global_step,1212.0


In [15]:
run = wandb.init(project=WANDB_PROJECT, entity=None, job_type="evaluation", config=train_config)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668295099952955, max=1.0…

In [16]:
trainer.evaluate()

{'eval_loss': 0.9723190665245056,
 'eval_accuracy': 0.5412541254125413,
 'eval_acc_entailment': 0.5430622009569378,
 'eval_acc_neutral': 0.5798969072164949,
 'eval_acc_contradiction': 0.5024630541871922,
 'eval_runtime': 8.4409,
 'eval_samples_per_second': 143.586,
 'eval_steps_per_second': 18.008,
 'epoch': 1.0}