# Better not Bigger

Case study from snorkel.

## Dataset

LEDGAR from [LEX GLUE](https://huggingface.co/datasets/lex_glue)

## Setup

In [None]:
!pip install snorkel transformers datasets -q

## Reading Dataset

In [71]:
from datasets import load_dataset

ledgar = load_dataset('lex_glue', name='ledgar')

ledgar



  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
})

In [72]:
ledgar['train'].features

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(num_classes=100, names=['Adjustments', 'Agreements', 'Amendments', 'Anti-Corruption Laws', 'Applicable Laws', 'Approvals', 'Arbitration', 'Assignments', 'Assigns', 'Authority', 'Authorizations', 'Base Salary', 'Benefits', 'Binding Effects', 'Books', 'Brokers', 'Capitalization', 'Change In Control', 'Closings', 'Compliance With Laws', 'Confidentiality', 'Consent To Jurisdiction', 'Consents', 'Construction', 'Cooperation', 'Costs', 'Counterparts', 'Death', 'Defined Terms', 'Definitions', 'Disability', 'Disclosures', 'Duties', 'Effective Dates', 'Effectiveness', 'Employment', 'Enforceability', 'Enforcements', 'Entire Agreements', 'Erisa', 'Existence', 'Expenses', 'Fees', 'Financial Statements', 'Forfeitures', 'Further Assurances', 'General', 'Governing Laws', 'Headings', 'Indemnifications', 'Indemnity', 'Insurances', 'Integration', 'Intellectual Property', 'Interests', 'Interpretations', 'Jurisdictions', 'Liens', 'Litigations',

In [73]:
labels = ledgar['train'].features['label'].names

print(f'Number of labels = {len(labels)}')

Number of labels = 100


In [74]:
int2str = {i: j for i, j in enumerate(labels)}
str2int = {j: i for i, j in enumerate(labels)}

In [75]:
example = ledgar['train'][1]

print(example['text'])
print(int2str[example['label']])

No ERISA Event has occurred or is reasonably expected to occur that, when taken together with all other such ERISA Events for which liability is reasonably expected to occur, could reasonably be expected to result in a Material Adverse Effect. Neither Borrower nor any ERISA Affiliate maintains or contributes to or has any obligation to maintain or contribute to any Multiemployer Plan or Plan, nor otherwise has any liability under Title IV of ERISA.
Erisa


In [104]:
# use much smaller dataset

ledgar['train'] = ledgar['train'].shuffle(seed=42).select(range(5))
ledgar['validation'] = ledgar['validation'].shuffle(seed=42).select(range(5))

ledgar

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 5
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 5
    })
})

## Dummy Labels

In [105]:
from transformers import pipeline

classifier = pipeline("zero-shot-classification",
                      model="typeform/distilbert-base-uncased-mnli")

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [106]:
label = classifier(example['text'], labels)['labels'][0]

In [107]:
str2int[label]

39

In [108]:
from snorkel.preprocess import preprocessor
from snorkel.labeling import labeling_function
from snorkel.labeling import PandasLFApplier

In [114]:
dbert_cls = pipeline("zero-shot-classification",
                      model="typeform/distilbert-base-uncased-mnli")

@preprocessor(memoize=True)
def get_label_dbert(example):
    output = dbert_cls(example['text'], labels)
    label = output['labels'][0]
    score = output['scores'][0]
    example.label_ = label
    example.score_ = score
    return example


@labeling_function(pre=[get_label_dbert])
def label_dbert(example):
    if example.score_ > 0.75:
        return str2int[example.label_]
    return -1

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [115]:
dbart_129_cls = pipeline("zero-shot-classification",
                      model="valhalla/distilbart-mnli-12-9")

@preprocessor(memoize=True)
def get_label_dbart_129(example):
    output = dbart_129_cls(example['text'], labels)
    label = output['labels'][0]
    score = output['scores'][0]
    example.label_ = label
    example.score_ = score
    return example


@labeling_function(pre=[get_label_dbart_129])
def label_dbart_129(example):
    if example.score_ > 0.75:
        return str2int[example.label_]
    return -1

In [116]:
dbart_121_cls = pipeline("zero-shot-classification",
                      model="valhalla/distilbart-mnli-12-1")

@preprocessor(memoize=True)
def get_label_dbart_121(example):
    output = dbart_121_cls(example['text'], labels)
    label = output['labels'][0]
    score = output['scores'][0]
    example.label_ = label
    example.score_ = score
    return example


@labeling_function(pre=[get_label_dbart_121])
def label_dbart_121(example):
    if example.score_ > 0.75:
        return str2int[example.label_]
    return -1

In [117]:
train_df = ledgar['train'].to_pandas()
valid_df = ledgar['validation'].to_pandas()

In [None]:
applier = PandasLFApplier([label_dbert, label_dbart_129, label_dbart_121])

L_train = applier.apply(train_df)
L_valid = applier.apply(valid_df)

L_train

100%|██████████| 5/5 [05:18<00:00, 63.69s/it]
100%|██████████| 5/5 [07:15<00:00, 114.30s/it]

In [65]:
from snorkel.labeling.model import LabelModel

label_model = LabelModel(cardinality=100, verbose=True)
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=42)

INFO:root:Computing O...
INFO:root:Estimating \mu...
  0%|          | 0/500 [00:00<?, ?epoch/s]INFO:root:[0 epochs]: TRAIN:[loss=0.953]
 20%|██        | 100/500 [00:00<00:01, 353.08epoch/s]INFO:root:[100 epochs]: TRAIN:[loss=0.842]
 36%|███▌      | 181/500 [00:00<00:00, 381.63epoch/s]INFO:root:[200 epochs]: TRAIN:[loss=0.744]
 53%|█████▎    | 263/500 [00:00<00:00, 395.99epoch/s]INFO:root:[300 epochs]: TRAIN:[loss=0.669]
 77%|███████▋  | 386/500 [00:01<00:00, 397.01epoch/s]INFO:root:[400 epochs]: TRAIN:[loss=0.616]
100%|██████████| 500/500 [00:01<00:00, 381.73epoch/s]
INFO:root:Finished Training


In [68]:
import numpy as np
gold = np.array(ledgar['validation'][:]['label'])

In [69]:
label_model.score(L_test, gold)



{'accuracy': 0.5}