# Better not Bigger

Case study from snorkel.

## Dataset

IMDb dataset from Hugging Face [datasets](https://huggingface.co/datasets/imdb)

## Setup

In [1]:
!pip install snorkel transformers datasets -q

[0m

## Reading Dataset

### Loading saved data

In [2]:
import glob
import pandas as pd
import numpy as np

import datasets
from datasets import concatenate_datasets, load_dataset, Dataset

datasets.logging.set_verbosity_error()

In [3]:
zero_shot_models = ["facebook/bart-large-mnli", 
                    "joeddav/xlm-roberta-large-xnli", 
                    "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", 
                    "BaptisteDoyen/camembert-base-xnli"]

In [4]:
train_dfs = []
valid_dfs = []
for model in zero_shot_models:

    train_dfs.append(load_dataset('json', 
                                  data_files=f'imdb_data/'
                                  f'{model}/train.json', 
                                  split='train').to_pandas())
    valid_dfs.append(load_dataset('json', 
                                  data_files=f'imdb_data/'
                                  f'{model}/test.json', 
                                  split='train').to_pandas())

In [5]:
train_ds = train_dfs[0].copy()

for df in train_dfs[1:]:
    train_ds = train_ds.merge(right=df, on=['text', 'label'])

valid_ds = valid_dfs[0].copy()

for df in valid_dfs[1:]:
    valid_ds = valid_ds.merge(right=df, on=['text', 'label'])

In [6]:
train_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,BaptisteDoyen/camembert-base-xnli
0,There is no relation at all between Fortier an...,1,1,1,1,1
1,This movie is a great. The plot is very true t...,1,1,1,1,0
2,"George P. Cosmatos' ""Rambo: First Blood Part I...",0,1,1,1,1
3,In the process of trying to establish the audi...,1,1,1,0,0
4,"Yeh, I know -- you're quivering with excitemen...",0,1,1,0,1


In [7]:
train_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 26544 entries, 0 to 26543
Data columns (total 6 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     26544 non-null  object
 1   label                                    26544 non-null  int64 
 2   facebook/bart-large-mnli                 26544 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           26544 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  26544 non-null  int64 
 5   BaptisteDoyen/camembert-base-xnli        26544 non-null  int64 
dtypes: int64(5), object(1)
memory usage: 1.4+ MB


In [8]:
valid_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,BaptisteDoyen/camembert-base-xnli
0,<br /><br />When I unsuspectedly rented A Thou...,1,1,1,0,1
1,This is the latest entry in the long series of...,1,1,1,0,1
2,This movie was so frustrating. Everything seem...,0,1,1,0,1
3,"I was truly and wonderfully surprised at ""O' B...",1,1,1,1,1
4,This movie spends most of its time preaching t...,0,1,1,0,1


In [9]:
valid_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 28710 entries, 0 to 28709
Data columns (total 6 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     28710 non-null  object
 1   label                                    28710 non-null  int64 
 2   facebook/bart-large-mnli                 28710 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           28710 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  28710 non-null  int64 
 5   BaptisteDoyen/camembert-base-xnli        28710 non-null  int64 
dtypes: int64(5), object(1)
memory usage: 1.5+ MB


## Label Model

In [10]:
L_train = train_ds[zero_shot_models].to_numpy()
L_valid = valid_ds[zero_shot_models].to_numpy()

L_valid[:10]

array([[1, 1, 0, 1],
       [1, 1, 0, 1],
       [1, 1, 0, 1],
       [1, 1, 1, 1],
       [1, 1, 0, 1],
       [1, 1, 1, 1],
       [1, 1, 0, 1],
       [1, 1, 0, 1],
       [1, 1, 0, 1],
       [1, 1, 0, 1]])

In [11]:
from snorkel.labeling.model import LabelModel

label_model = LabelModel(cardinality=100, verbose=True).to('cuda')
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=42)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/500 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=9.234]
 15%|█▌        | 76/500 [00:00<00:01, 271.51epoch/s]INFO:root:[100 epochs]: TRAIN:[loss=7.238]
 34%|███▍      | 172/500 [00:00<00:01, 275.00epoch/s]INFO:root:[200 epochs]: TRAIN:[loss=2.490]
 58%|█████▊    | 289/500 [00:01<00:00, 284.03epoch/s]INFO:root:[300 epochs]: TRAIN:[loss=0.134]
 76%|███████▌  | 381/500 [00:01<00:00, 298.64epoch/s]INFO:root:[400 epochs]: TRAIN:[loss=0.020]
100%|██████████| 500/500 [00:01<00:00, 286.52epoch/s]
INFO:root:Finished Training


In [12]:
valid_gold = valid_ds.label.to_numpy()

In [13]:
label_model.score(L_valid, valid_gold)



{'accuracy': 0.8401598401598401}

In [14]:
train_ds['snorkel'] = label_model.predict(L_train)
valid_ds['snorkel'] = label_model.predict(L_valid)

## Training ML Model on gold data

In [15]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV


pipe = make_pipeline(TfidfVectorizer(min_df=5, norm=None), 
                     LogisticRegression(max_iter=1e4))
param_grid = {'logisticregression__C': [0.1]}
grid = GridSearchCV(pipe, param_grid, cv=5)
grid.fit(train_ds['text'], train_ds['label'])
print("Best cross-validation score: {:.2f}".format(grid.best_score_))

Best cross-validation score: 0.88


In [16]:
from sklearn.metrics import accuracy_score
print(f"Validation Accuracy:\n{accuracy_score(valid_ds['label'], grid.predict(valid_ds['text']))}")

Validation Accuracy:
0.8597701149425288


## Training ML Model on noisy labels

In [17]:
train_ds = train_ds[train_ds['snorkel'] != -1]
valid_ds = valid_ds[valid_ds['snorkel'] != -1]

In [18]:
train_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 15327 entries, 0 to 26542
Data columns (total 7 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     15327 non-null  object
 1   label                                    15327 non-null  int64 
 2   facebook/bart-large-mnli                 15327 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           15327 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  15327 non-null  int64 
 5   BaptisteDoyen/camembert-base-xnli        15327 non-null  int64 
 6   snorkel                                  15327 non-null  int64 
dtypes: int64(6), object(1)
memory usage: 957.9+ KB


In [19]:
valid_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 17017 entries, 3 to 28709
Data columns (total 7 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     17017 non-null  object
 1   label                                    17017 non-null  int64 
 2   facebook/bart-large-mnli                 17017 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           17017 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  17017 non-null  int64 
 5   BaptisteDoyen/camembert-base-xnli        17017 non-null  int64 
 6   snorkel                                  17017 non-null  int64 
dtypes: int64(6), object(1)
memory usage: 1.0+ MB


In [20]:
grid.fit(train_ds['text'], train_ds['snorkel'])
print("Best cross-validation score: {:.2f}".format(grid.best_score_))
print(f"Validation Accuracy:\n{accuracy_score(valid_ds['label'], grid.predict(valid_ds['text']))}")

Best cross-validation score: 0.78
Validation Accuracy:
0.8113651054827525
