# Better not Bigger

Case study from snorkel.

## Dataset

LEDGAR from [LEX GLUE](https://huggingface.co/datasets/lex_glue)

## Setup

In [161]:
!pip install snorkel transformers datasets -q

[0m

## Reading Dataset

### Loading saved data

In [157]:
import glob
import pandas as pd
import numpy as np

import datasets
from datasets import concatenate_datasets, load_dataset, Dataset

datasets.logging.set_verbosity_error()

In [131]:
zero_shot_models = ["facebook/bart-large-mnli", 
                    "joeddav/xlm-roberta-large-xnli", 
                    "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", 
                    "BaptisteDoyen/camembert-base-xnli"]

batch_sizes = list(range(0, 5500+1, 500))
inds = len(batch_sizes) - 1
batch_sizes = [(batch_sizes[i], batch_sizes[i + 1]) for i in range(inds)]
batch_sizes

[(0, 500),
 (500, 1000),
 (1000, 1500),
 (1500, 2000),
 (2000, 2500),
 (2500, 3000),
 (3000, 3500),
 (3500, 4000),
 (4000, 4500),
 (4500, 5000),
 (5000, 5500)]

In [132]:
train_dfs = []
valid_dfs = []
for model in zero_shot_models:
    
    train_dfs_ = []
    valid_dfs_ = []

    for batch in batch_sizes:
        train_dfs_.append(load_dataset('json', 
                                       data_files=f'data/'
                                       f'{model}/train_{batch[0]}_{batch[1]}.json', 
                                       split='train').to_pandas())
        valid_dfs_.append(load_dataset('json', 
                                       data_files=f'data/'
                                       f'{model}/train_{batch[0]}_{batch[1]}.json', 
                                       split='train').to_pandas())
    train_dfs.append(pd.concat(train_dfs_, axis=0))
    valid_dfs.append(pd.concat(valid_dfs_, axis=0))

In [136]:
train_ds = train_dfs[0].copy()

for df in train_dfs[1:]:
    train_ds = train_ds.merge(right=df, on=['text', 'label'])

valid_ds = valid_dfs[0].copy()

for df in valid_dfs[1:]:
    valid_ds = valid_ds.merge(right=df, on=['text', 'label'])

In [137]:
train_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,BaptisteDoyen/camembert-base-xnli
0,Except as otherwise set forth in this Debentur...,97,-1,66,-1,-1
1,No ERISA Event has occurred or is reasonably e...,39,-1,90,-1,-1
2,This Amendment may be executed by one or more ...,26,-1,2,-1,-1
3,"From time to time, as and when required by the...",45,-1,-1,-1,-1
4,"Commencing March 7, 2016 and during the Employ...",11,-1,-1,-1,-1


In [138]:
valid_ds.head()

Unnamed: 0,text,label,facebook/bart-large-mnli,joeddav/xlm-roberta-large-xnli,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,BaptisteDoyen/camembert-base-xnli
0,Except as otherwise set forth in this Debentur...,97,-1,66,-1,-1
1,No ERISA Event has occurred or is reasonably e...,39,-1,90,-1,-1
2,This Amendment may be executed by one or more ...,26,-1,2,-1,-1
3,"From time to time, as and when required by the...",45,-1,-1,-1,-1
4,"Commencing March 7, 2016 and during the Employ...",11,-1,-1,-1,-1


### Load test data

In [147]:
from datasets import load_dataset

test_ds = load_dataset('lex_glue', name='ledgar', split='test').to_pandas()

test_ds.head()

Unnamed: 0,text,label
0,Executive agrees to be employed with the Compa...,35
1,Participant agrees that in the event of a brea...,75
2,"For purposes of this Amendment, all terms used...",55
3,"So long as this as this Note is outstanding, u...",16
4,"As of the Closing Date, Schedule 5.12 sets for...",83


In [148]:
labels = test_ds.label.values

print(f'Number of labels = {len(labels)}')

Number of labels = 10000


In [149]:
int2str = {i: j for i, j in enumerate(labels)}
str2int = {j: i for i, j in enumerate(labels)}

## Label Model

In [156]:
L_train = train_ds[zero_shot_models].to_numpy()
L_valid = valid_ds[zero_shot_models].to_numpy()

L_valid[:10]

array([[-1, 66, -1, -1],
       [-1, 90, -1, -1],
       [-1,  2, -1, -1],
       [-1, -1, -1, -1],
       [-1, -1, -1, -1],
       [-1,  1, -1, -1],
       [-1, -1, -1, -1],
       [-1, -1, -1, -1],
       [-1, -1, 98, -1],
       [-1,  1,  1, -1]])

In [162]:
from snorkel.labeling.model import LabelModel

label_model = LabelModel(cardinality=100, verbose=True).to('cuda')
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=42)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/500 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.030]
 19%|█▉        | 97/500 [00:00<00:01, 244.16epoch/s]INFO:root:[100 epochs]: TRAIN:[loss=0.022]
 35%|███▌      | 175/500 [00:00<00:01, 248.73epoch/s]INFO:root:[200 epochs]: TRAIN:[loss=0.016]
 56%|█████▋    | 282/500 [00:01<00:00, 255.24epoch/s]INFO:root:[300 epochs]: TRAIN:[loss=0.011]
 77%|███████▋  | 387/500 [00:01<00:00, 242.11epoch/s]INFO:root:[400 epochs]: TRAIN:[loss=0.008]
100%|██████████| 500/500 [00:01<00:00, 251.50epoch/s]
INFO:root:Finished Training


In [163]:
valid_gold = valid_ds.label.to_numpy()

In [165]:
label_model.score(L_valid, valid_gold)



{'accuracy': 0.36574420344053854}

## Training on ML Model