# Better not Bigger

Case study from snorkel.

## Dataset

IMDb dataset from Hugging Face [datasets](https://huggingface.co/datasets/imdb)

## Setup

In [1]:
# !pip install snorkel transformers datasets -q

## Loading saved data

In [2]:
import glob
import pandas as pd
import numpy as np

import datasets
from datasets import concatenate_datasets, load_dataset, Dataset

datasets.logging.set_verbosity_error()

In [3]:
zero_shot_models = ["facebook/bart-large-mnli", 
                    "joeddav/xlm-roberta-large-xnli", 
                    "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", 
                    "BaptisteDoyen/camembert-base-xnli"]

In [4]:
train_dfs = []
valid_dfs = []
for model in zero_shot_models:

    train_dfs.append(load_dataset('json', 
                                  data_files=f'imdb_data/'
                                  f'{model}/train.json', 
                                  split='train').to_pandas())
    valid_dfs.append(load_dataset('json', 
                                  data_files=f'imdb_data/'
                                  f'{model}/test.json', 
                                  split='train').to_pandas())

In [5]:
train_ds = train_dfs[0].copy()

for df in train_dfs[1:]:
    train_ds = train_ds.merge(right=df, on=['text', 'label'])

valid_ds = valid_dfs[0].copy()

for df in valid_dfs[1:]:
    valid_ds = valid_ds.merge(right=df, on=['text', 'label'])

In [6]:
train_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,BaptisteDoyen/camembert-base-xnli
0,There is no relation at all between Fortier an...,1,-1,-1,-1,-1
1,This movie is a great. The plot is very true t...,1,1,-1,-1,-1
2,"George P. Cosmatos' ""Rambo: First Blood Part I...",0,1,-1,-1,-1
3,In the process of trying to establish the audi...,1,-1,-1,-1,-1
4,"Yeh, I know -- you're quivering with excitemen...",0,-1,-1,-1,-1


In [7]:
train_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 26544 entries, 0 to 26543
Data columns (total 6 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     26544 non-null  object
 1   label                                    26544 non-null  int64 
 2   facebook/bart-large-mnli                 26544 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           26544 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  26544 non-null  int64 
 5   BaptisteDoyen/camembert-base-xnli        26544 non-null  int64 
dtypes: int64(5), object(1)
memory usage: 1.4+ MB


In [8]:
valid_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,BaptisteDoyen/camembert-base-xnli
0,<br /><br />When I unsuspectedly rented A Thou...,1,1,-1,-1,-1
1,This is the latest entry in the long series of...,1,-1,1,-1,-1
2,This movie was so frustrating. Everything seem...,0,-1,-1,0,-1
3,"I was truly and wonderfully surprised at ""O' B...",1,1,-1,-1,-1
4,This movie spends most of its time preaching t...,0,-1,-1,-1,-1


In [9]:
valid_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 28710 entries, 0 to 28709
Data columns (total 6 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     28710 non-null  object
 1   label                                    28710 non-null  int64 
 2   facebook/bart-large-mnli                 28710 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           28710 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  28710 non-null  int64 
 5   BaptisteDoyen/camembert-base-xnli        28710 non-null  int64 
dtypes: int64(5), object(1)
memory usage: 1.5+ MB


## Analysis

In [10]:
L_train = train_ds[zero_shot_models].to_numpy()
L_valid = valid_ds[zero_shot_models].to_numpy()

In [11]:
from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import LabelModel
from snorkel.labeling import labeling_function

In [12]:
@labeling_function()
def bart_large():
    pass

@labeling_function()
def xlm_roberta():
    pass

@labeling_function()
def mdeberta():
    pass

@labeling_function()
def camembert():
    pass

In [13]:
lfs = [bart_large, xlm_roberta, mdeberta, camembert]

LFAnalysis(L_train, lfs=lfs).lf_summary()

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts
bart_large,0,"[0, 1]",0.28918,0.068603,0.006744
xlm_roberta,1,"[0, 1]",0.05764,0.040763,0.006668
mdeberta,2,"[0, 1]",0.139956,0.048636,0.023357
camembert,3,"[0, 1]",0.096444,0.048712,0.015747


In [14]:
LFAnalysis(L_valid, lfs=lfs).lf_summary(valid_ds.label.values)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
bart_large,0,"[0, 1]",0.272971,0.069801,0.007036,6681,1156,0.852495
xlm_roberta,1,"[0, 1]",0.055556,0.038906,0.006409,1092,503,0.684639
mdeberta,2,"[0, 1]",0.141344,0.050923,0.025775,3757,301,0.925826
camembert,3,"[0, 1]",0.098433,0.050017,0.018008,1297,1529,0.458953


## Label Model

In [15]:
label_model = LabelModel(cardinality=2, verbose=True).to('cuda')
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=42)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/500 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.061]
INFO:root:[100 epochs]: TRAIN:[loss=0.001]
 31%|███       | 155/500 [00:00<00:00, 1544.90epoch/s]INFO:root:[200 epochs]: TRAIN:[loss=0.001]
INFO:root:[300 epochs]: TRAIN:[loss=0.000]
 78%|███████▊  | 388/500 [00:00<00:00, 2004.45epoch/s]INFO:root:[400 epochs]: TRAIN:[loss=0.000]
100%|██████████| 500/500 [00:00<00:00, 2012.61epoch/s]
INFO:root:Finished Training


In [16]:
valid_gold = valid_ds.label.to_numpy()

In [17]:
label_model.score(L_valid, valid_gold)



{'accuracy': 0.8345207998161343}

In [18]:
train_ds['weak_labels'] = label_model.predict(L_train)
valid_ds['weak_labels'] = label_model.predict(L_valid)

In [19]:
train_ds.weak_labels.value_counts()

-1    14038
 1     8076
 0     4430
Name: weak_labels, dtype: int64

In [20]:
valid_ds.weak_labels.value_counts()

-1    15657
 1     8209
 0     4844
Name: weak_labels, dtype: int64

In [21]:
train_ds = train_ds[train_ds['weak_labels'] != -1]
valid_ds = valid_ds[valid_ds['weak_labels'] != -1]

In [22]:
train_ds.weak_labels.value_counts()

1    8076
0    4430
Name: weak_labels, dtype: int64

In [23]:
valid_ds.weak_labels.value_counts()

1    8209
0    4844
Name: weak_labels, dtype: int64

In [24]:
train_ds[['text', 'label', 'weak_labels']].to_json('./train.json')
valid_ds[['text', 'label', 'weak_labels']].to_json('./valid.json')