# Classify contradictory sentences using an LLM
We want to classify contradictory sentences, as per the [Contradictory, My dear Watson](https://www.kaggle.com/competitions/contradictory-my-dear-watson/overview) challenge on Kaggle.

Here we preprocess and split the data, storing it as Artifacts and Tables in Weights&Biases.




In [1]:
# imports
import os
from pathlib import Path
import warnings
import pandas as pd
import numpy as np
import kaggle
import wandb
import tqdm

import torch
from sklearn.model_selection import StratifiedKFold

warnings.filterwarnings('ignore')

# Constants
DATA_PATH = "data"
WANDB_PROJECT = "contradictory"
RAW_DATA_AT = "contra_raw"
PROCESSED_DATA_AT = "contra_split"

Found MPS, may not work on some torch ops!


## Preprocess the data
First step is to download and preprocess the data.

In [3]:
# load the data, downloading if it doesnt already exist

!pushd data && kaggle competitions download -c contradictory-my-dear-watson \
&& unzip -o contradictory-my-dear-watson && popd
TRAIN_PATH = os.path.join(DATA_PATH, "train.csv")
TEST_PATH = os.path.join(DATA_PATH, "test.csv")
SUBMISSION_PATH = os.path.join(DATA_PATH, "sample_submission.csv")

contradictory-my-dear-watson.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  contradictory-my-dear-watson.zip
  inflating: sample_submission.csv   
  inflating: test.csv                
  inflating: train.csv               


In [4]:
raw_data_df = pd.read_csv(TRAIN_PATH)
submit_df = pd.read_csv(TEST_PATH)
sample_submission_df = pd.read_csv(SUBMISSION_PATH)  # just contains 'id' and 'prediction' label, example format for submitting

id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
label2id = {v:k for k,v in id2label.items()}

print("raw data shape:", raw_data_df.shape)
print("Submit shape:", submit_df.shape)

raw data shape: (12120, 6)
Submit shape: (5195, 5)


In [5]:
raw_data_df["label_str"] = raw_data_df.label.map(lambda x: id2label[int(x)])
raw_data_df.head()

Unnamed: 0,id,premise,hypothesis,lang_abv,language,label,label_str
0,5130fd2cb5,and these comments were considered in formulat...,The rules developed in the interim were put to...,en,English,0,entailment
1,5b72532a0b,These are issues that we wrestle with in pract...,Practice groups are not permitted to work on t...,en,English,2,contradiction
2,3931fbe82a,Des petites choses comme celles-là font une di...,J'essayais d'accomplir quelque chose.,fr,French,0,entailment
3,5622f0c60b,you know they can't really defend themselves l...,They can't defend themselves because of their ...,en,English,0,entailment
4,86aaa48b45,ในการเล่นบทบาทสมมุติก็เช่นกัน โอกาสที่จะได้แสด...,เด็กสามารถเห็นได้ว่าชาติพันธุ์แตกต่างกันอย่างไร,th,Thai,1,neutral


In [6]:
SEED = 98765

def seed_everything(seed):
    os.environ['PYTHONHASHSEED']=str(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    
seed_everything(SEED)

In [7]:
#peek at a premise/hypothesis pair and their label
print(f"Premise: {raw_data_df['premise'].values[0]}")
print(f"Hypothesis: {raw_data_df['hypothesis'].values[0]}")
print(f"Label: {raw_data_df['label'].values[0]}")

Premise: and these comments were considered in formulating the interim rules.
Hypothesis: The rules developed in the interim were put together with these comments in mind.
Label: 0


In [8]:
# create wandb Table
table = wandb.Table(dataframe=raw_data_df)


In [9]:
# start W&B run and put tables into new Artifact
run = wandb.init(project=WANDB_PROJECT, entity=None, job_type="upload")
raw_data_at = wandb.Artifact(RAW_DATA_AT, type="raw_data")

[34m[1mwandb[0m: Currently logged in as: [33mmpesavento[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
# upload table
raw_data_at.add(table, "eda_table")
run.log_artifact(raw_data_at)
run.finish()

In [11]:
raw_data_df.language.value_counts()

language
English       6870
Chinese        411
Arabic         401
French         390
Swahili        385
Urdu           381
Vietnamese     379
Russian        376
Hindi          374
Greek          372
Thai           371
Spanish        366
Turkish        351
German         351
Bulgarian      342
Name: count, dtype: int64

In [12]:
print("train data lang count:", raw_data_df.language.nunique())

train data lang count: 15


In [13]:
# are there any languages in the test set that arent in the training set?
language_set_diff = set(raw_data_df.language.unique()) - set(submit_df.language.unique())
print("Languages in raw/train not in submit/test:", language_set_diff)

Languages in raw/train not in submit/test: set()


# Create Train/Validate/Test data split

In [14]:
run = wandb.init(project=WANDB_PROJECT, entity=None, job_type="data_split")
raw_data_at = run.use_artifact(f'{RAW_DATA_AT}:latest')
path = Path(raw_data_at.download())

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016752521533392913, max=1.0…

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [15]:
orig_eda_table = raw_data_at.get("eda_table")

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [16]:
df = pd.DataFrame()
df = raw_data_df[['id', 'premise', 'hypothesis', 'lang_abv', 'label']].copy()
df['fold'] = -1

y = df['label']

In [17]:
cv = StratifiedKFold(n_splits=10)
for i, (train_idxs, test_idxs) in enumerate(cv.split(df, y)):
    df.loc[test_idxs, ['fold']] = i

In [18]:
df['Stage'] = 'train'
df.loc[df.fold == 0, ['Stage']] = 'test'
df.loc[df.fold == 1, ['Stage']] = 'valid'
del df['fold']
df.Stage.value_counts()

Stage
train    9696
test     1212
valid    1212
Name: count, dtype: int64

In [19]:
df.to_csv(os.path.join(DATA_PATH, 'data_split.csv'), index=False)

In [20]:
# create & push artifacts to WandB
processed_data_at = wandb.Artifact(PROCESSED_DATA_AT, type="split_data")
processed_data_at.add_file(os.path.join(DATA_PATH, 'data_split.csv'))

<wandb.sdk.artifacts.artifact_manifest_entry.ArtifactManifestEntry at 0x187e308d0>

Totally unnecessary to do the WandB join on a Table, but it's a good practice and avoids duplicating the existing data in the artifact.

In [21]:
df.columns

Index(['id', 'premise', 'hypothesis', 'lang_abv', 'label', 'Stage'], dtype='object')

In [22]:
orig_eda_table = raw_data_at.get("eda_table")
data_split_table = wandb.Table(dataframe=df[['id', 'Stage']])
join_table = wandb.JoinedTable(orig_eda_table, data_split_table, "id")
processed_data_at.add(join_table, "eda_table_data_split")

[34m[1mwandb[0m:   1 of 1 files downloaded.  


<wandb.sdk.artifacts.artifact_manifest_entry.ArtifactManifestEntry at 0x1882280d0>

In [23]:
run.log_artifact(processed_data_at)
run.finish()