# Better not Bigger

Case study from snorkel.

## Dataset

LEDGAR from [LEX GLUE](https://huggingface.co/datasets/lex_glue)

## Setup

In [1]:
!pip install snorkel transformers datasets -q

[0m

## Reading Dataset

In [2]:
from datasets import load_dataset

ledgar = load_dataset('lex_glue', name='ledgar')

ledgar

Downloading builder script:   0%|          | 0.00/6.52k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/5.38k [00:00<?, ?B/s]

Downloading and preparing dataset lex_glue/ledgar (download: 15.50 MiB, generated: 54.69 MiB, post-processed: Unknown size, total: 70.19 MiB) to /root/.cache/huggingface/datasets/lex_glue/ledgar/1.0.0/c3c0bd7433b636dc39ae49a84dc401190c73156617efc415b04e9835a93a7043...


Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset lex_glue downloaded and prepared to /root/.cache/huggingface/datasets/lex_glue/ledgar/1.0.0/c3c0bd7433b636dc39ae49a84dc401190c73156617efc415b04e9835a93a7043. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
})

In [3]:
ledgar['train'].features

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(num_classes=100, names=['Adjustments', 'Agreements', 'Amendments', 'Anti-Corruption Laws', 'Applicable Laws', 'Approvals', 'Arbitration', 'Assignments', 'Assigns', 'Authority', 'Authorizations', 'Base Salary', 'Benefits', 'Binding Effects', 'Books', 'Brokers', 'Capitalization', 'Change In Control', 'Closings', 'Compliance With Laws', 'Confidentiality', 'Consent To Jurisdiction', 'Consents', 'Construction', 'Cooperation', 'Costs', 'Counterparts', 'Death', 'Defined Terms', 'Definitions', 'Disability', 'Disclosures', 'Duties', 'Effective Dates', 'Effectiveness', 'Employment', 'Enforceability', 'Enforcements', 'Entire Agreements', 'Erisa', 'Existence', 'Expenses', 'Fees', 'Financial Statements', 'Forfeitures', 'Further Assurances', 'General', 'Governing Laws', 'Headings', 'Indemnifications', 'Indemnity', 'Insurances', 'Integration', 'Intellectual Property', 'Interests', 'Interpretations', 'Jurisdictions', 'Liens', 'Litigations',

In [4]:
labels = ledgar['train'].features['label'].names

print(f'Number of labels = {len(labels)}')

Number of labels = 100


In [5]:
int2str = {i: j for i, j in enumerate(labels)}
str2int = {j: i for i, j in enumerate(labels)}

In [6]:
example = ledgar['train'][1]

print(example['text'])
print(int2str[example['label']])

No ERISA Event has occurred or is reasonably expected to occur that, when taken together with all other such ERISA Events for which liability is reasonably expected to occur, could reasonably be expected to result in a Material Adverse Effect. Neither Borrower nor any ERISA Affiliate maintains or contributes to or has any obligation to maintain or contribute to any Multiemployer Plan or Plan, nor otherwise has any liability under Title IV of ERISA.
Erisa


In [7]:
# use much smaller dataset

ledgar['train'] = ledgar['train'].shuffle(seed=42).select(range(1000))
ledgar['validation'] = ledgar['validation'].shuffle(seed=42).select(range(200))

ledgar

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 200
    })
})

## Dummy Labels

In [8]:
from transformers import pipeline

classifier = pipeline("zero-shot-classification",
                      model="typeform/distilbert-base-uncased-mnli")

Downloading:   0%|          | 0.00/776 [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Downloading:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/258 [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [9]:
label = classifier(example['text'], labels)['labels'][0]

In [10]:
str2int[label]

39

In [11]:
from snorkel.preprocess import preprocessor
from snorkel.labeling import labeling_function
from snorkel.labeling import PandasLFApplier

In [12]:
dbert_cls = pipeline("zero-shot-classification",
                      model="typeform/distilbert-base-uncased-mnli", device=0)

@preprocessor(memoize=True)
def get_label_dbert(example):
    output = dbert_cls(example['text'], labels)
    label = output['labels'][0]
    score = output['scores'][0]
    example.label_ = label
    example.score_ = score
    return example


@labeling_function(pre=[get_label_dbert])
def label_dbert(example):
    if example.score_ > 0.5:
        return str2int[example.label_]
    else:
        return -1

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [13]:
dbart_129_cls = pipeline("zero-shot-classification",
                      model="valhalla/distilbart-mnli-12-9", device=0)

@preprocessor(memoize=True)
def get_label_dbart_129(example):
    output = dbart_129_cls(example['text'], labels)
    label = output['labels'][0]
    score = output['scores'][0]
    example.label_ = label
    example.score_ = score
    return example


@labeling_function(pre=[get_label_dbart_129])
def label_dbart_129(example):
    if example.score_ > 0.5:
        return str2int[example.label_]
    else:
        return -1

Downloading:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 15.74 GiB total capacity; 1.46 GiB already allocated; 17.56 MiB free; 1.48 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
dbart_121_cls = pipeline("zero-shot-classification",
                      model="valhalla/distilbart-mnli-12-1", device=0)

@preprocessor(memoize=True)
def get_label_dbart_121(example):
    output = dbart_121_cls(example['text'], labels)
    label = output['labels'][0]
    score = output['scores'][0]
    example.label_ = label
    example.score_ = score
    return example


@labeling_function(pre=[get_label_dbart_121])
def label_dbart_121(example):
    if example.score_ > 0.5:
        return str2int[example.label_]
    else:
        return -1

In [None]:
train_df = ledgar['train'].to_pandas()
valid_df = ledgar['validation'].to_pandas()

In [None]:
applier = PandasLFApplier([label_dbert, label_dbart_129, label_dbart_121])

L_train = applier.apply(train_df)
L_valid = applier.apply(valid_df)

L_valid[:10]

In [None]:
from snorkel.labeling.model import LabelModel

label_model = LabelModel(cardinality=100, verbose=True).to('cuda')
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=42)

In [None]:
import numpy as np
gold = np.array(ledgar['validation'][:]['label'])

In [None]:
label_model.score(L_valid, gold)