# Better not Bigger

Case study from snorkel.

## Dataset

IMDb dataset from Hugging Face [datasets](https://huggingface.co/datasets/imdb)

## Setup

In [1]:
# !pip install snorkel transformers datasets -q

## Loading saved data

In [2]:
import glob
import pandas as pd
import numpy as np

import datasets
from datasets import concatenate_datasets, load_dataset, Dataset

datasets.logging.set_verbosity_error()

In [3]:
zero_shot_models = ["facebook/bart-large-mnli", 
                    "joeddav/xlm-roberta-large-xnli", 
                    "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"] 
                    # "BaptisteDoyen/camembert-base-xnli"]

In [4]:
train_dfs = []
valid_dfs = []
for model in zero_shot_models:

    train_dfs.append(load_dataset('json', 
                                  data_files=f'imdb_data/'
                                  f'{model}/train.json', 
                                  split='train').to_pandas())
    valid_dfs.append(load_dataset('json', 
                                  data_files=f'imdb_data/'
                                  f'{model}/test.json', 
                                  split='train').to_pandas())

Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-c83fb4cd57a82a57/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-c83fb4cd57a82a57/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.
Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-75fb20e8657c5de6/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-75fb20e8657c5de6/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.
Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-a02a1540b531ea93/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-a02a1540b531ea93/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.
Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-5532053e7fda0d26/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-5532053e7fda0d26/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.
Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-120494a2b1882b8c/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-120494a2b1882b8c/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.
Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-8ad9280891ccde40/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-8ad9280891ccde40/0.0.0/a3e658c4731e59120d44081ac10bf85dc7e1388126b92338344ce9661907f253. Subsequent calls will reuse this data.


In [5]:
train_ds = train_dfs[0].copy()

for df in train_dfs[1:]:
    train_ds = train_ds.merge(right=df, on=['text', 'label'])

valid_ds = valid_dfs[0].copy()

for df in valid_dfs[1:]:
    valid_ds = valid_ds.merge(right=df, on=['text', 'label'])

In [6]:
train_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli
0,There is no relation at all between Fortier an...,1,1,1,1
1,This movie is a great. The plot is very true t...,1,1,1,1
2,"George P. Cosmatos' ""Rambo: First Blood Part I...",0,1,1,1
3,In the process of trying to establish the audi...,1,1,1,0
4,"Yeh, I know -- you're quivering with excitemen...",0,1,1,0


In [7]:
train_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 25624 entries, 0 to 25623
Data columns (total 5 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     25624 non-null  object
 1   label                                    25624 non-null  int64 
 2   facebook/bart-large-mnli                 25624 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           25624 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  25624 non-null  int64 
dtypes: int64(4), object(1)
memory usage: 1.2+ MB


In [8]:
valid_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli
0,<br /><br />When I unsuspectedly rented A Thou...,1,1,1,0
1,This is the latest entry in the long series of...,1,1,1,0
2,This movie was so frustrating. Everything seem...,0,1,1,0
3,"I was truly and wonderfully surprised at ""O' B...",1,1,1,1
4,This movie spends most of its time preaching t...,0,1,1,0


In [9]:
valid_ds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 26368 entries, 0 to 26367
Data columns (total 5 columns):
 #   Column                                   Non-Null Count  Dtype 
---  ------                                   --------------  ----- 
 0   text                                     26368 non-null  object
 1   label                                    26368 non-null  int64 
 2   facebook/bart-large-mnli                 26368 non-null  int64 
 3   joeddav/xlm-roberta-large-xnli           26368 non-null  int64 
 4   MoritzLaurer/mDeBERTa-v3-base-mnli-xnli  26368 non-null  int64 
dtypes: int64(4), object(1)
memory usage: 1.2+ MB


## Analysis

In [10]:
L_train = train_ds[zero_shot_models].to_numpy()
L_valid = valid_ds[zero_shot_models].to_numpy()

In [11]:
from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import LabelModel
from snorkel.labeling import labeling_function

In [12]:
@labeling_function()
def bart_large():
    pass

@labeling_function()
def xlm_roberta():
    pass

@labeling_function()
def mdeberta():
    pass

# @labeling_function()
# def camembert():
#     pass

In [13]:
lfs = [bart_large, xlm_roberta, mdeberta]

LFAnalysis(L_train, lfs=lfs).lf_summary()

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts
bart_large,0,"[0, 1]",1.0,1.0,0.688885
xlm_roberta,1,"[0, 1]",1.0,1.0,0.688885
mdeberta,2,"[0, 1]",1.0,1.0,0.688885


In [14]:
LFAnalysis(L_valid, lfs=lfs).lf_summary(valid_ds.label.values)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
bart_large,0,"[0, 1]",1.0,1.0,0.67745,16243,10125,0.616012
xlm_roberta,1,"[0, 1]",1.0,1.0,0.67745,15485,10883,0.587265
mdeberta,2,"[0, 1]",1.0,1.0,0.67745,19204,7164,0.728307


## Label Model

In [15]:
label_model = LabelModel(cardinality=2, verbose=True).to('cuda')
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=42)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/500 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=3.556]
 19%|█▉        | 94/500 [00:00<00:00, 937.33epoch/s]INFO:root:[100 epochs]: TRAIN:[loss=0.002]
INFO:root:[200 epochs]: TRAIN:[loss=0.000]
INFO:root:[300 epochs]: TRAIN:[loss=0.000]
 66%|██████▋   | 332/500 [00:00<00:00, 1780.42epoch/s]INFO:root:[400 epochs]: TRAIN:[loss=0.000]
100%|██████████| 500/500 [00:00<00:00, 1851.48epoch/s]
INFO:root:Finished Training


In [16]:
valid_gold = valid_ds.label.to_numpy()

In [17]:
label_model.score(L_valid, valid_gold)



{'accuracy': 0.6753640776699029}

In [18]:
train_ds['weak_labels'] = label_model.predict(L_train)
valid_ds['weak_labels'] = label_model.predict(L_valid)

In [19]:
train_ds.weak_labels.value_counts()

1    18673
0     6951
Name: weak_labels, dtype: int64

In [20]:
valid_ds.weak_labels.value_counts()

1    19278
0     7090
Name: weak_labels, dtype: int64

In [21]:
train_ds[['text', 'label', 'weak_labels']].to_json('./train.json')
valid_ds[['text', 'label', 'weak_labels']].to_json('./valid.json')