# An In-depth Evaluation of Approaches to Text Classification (IDEATC)

## IV. Zero-shot Experiments

_This notebook is used for experimenting with DeBERTa for zero-shot classification._

### Libraries

In [1]:
# standard library
import os
from pathlib import Path

# deep learning
import datasets
import pandas as pd

# local packages
import src

# other settings
LOAD_PATH_DATASET = Path(os.pardir, 'data', 'processed')
SAVE_PATH_RESULTS = Path(os.pardir, 'data', 'results')
CHECKPOINTS = [
    'cross-encoder/nli-deberta-v3-xsmall',
    'cross-encoder/nli-deberta-v3-small',
    'cross-encoder/nli-deberta-v3-base',
]

### I. Zero-shot Baselines

In [3]:
for checkpoint in CHECKPOINTS:
    for path in LOAD_PATH_DATASET.glob('*processed*'):
        print(path.name)
        dataset = datasets.load_from_disk(path)
        labels = dataset['test'].features['label'].names
        model_size = checkpoint.split('-')[-1]  # i.e., 'xsmall', 'small' or 'base'
        src.experiments.zeroshot.run_experiment(
            dataset_dict=dataset,
            candidate_labels=labels,
            checkpoint=checkpoint,
            device='mps',
            # hypothesis_template='This example is {}',
            progress_bar=True,
            experiment_id=f'deberta_v3_{model_size}_zeroshot',
            save_path=SAVE_PATH_RESULTS.joinpath(path.name),
        )
print('Done!')

20_newsgroups_processed


100%|██████████| 7532/7532 [54:52<00:00,  2.29it/s]  


ag_news_processed


100%|██████████| 7600/7600 [10:14<00:00, 12.36it/s] 


web_of_science_processed
yelp_polarity_processed


100%|██████████| 1000/1000 [13:52<00:00,  1.20it/s]


dynabench_dynasent_processed


100%|██████████| 720/720 [02:09<00:00,  5.57it/s]


imdb_processed


100%|██████████| 1000/1000 [05:45<00:00,  2.89it/s]


setfit_sst5_processed


100%|██████████| 2210/2210 [04:44<00:00,  7.77it/s]


dbpedia_14_processed


100%|██████████| 7000/7000 [14:33<00:00,  8.02it/s] 


rotten_tomatoes_processed


100%|██████████| 1066/1066 [00:48<00:00, 21.93it/s]


yelp_review_full_processed


100%|██████████| 2500/2500 [28:46<00:00,  1.45it/s] 


20_newsgroups_processed


 84%|████████▍ | 6353/7532 [1:30:19<07:49,  2.51it/s]  

## II. Prompting Experiments

In [14]:
SAVE_PATH_RESULTS = Path(os.pardir, 'data', 'prompting')
dataset_name = 'dynabench_dynasent_processed'
dataset = datasets.load_from_disk(LOAD_PATH_DATASET.joinpath(dataset_name))
labels = dataset['test'].features['label'].names
print('Labels:', labels)
dataset

Labels: ['Negative', 'Neutral', 'Positive']


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'text_clean'],
        num_rows: 13065
    })
    validation: Dataset({
        features: ['text', 'label', 'text_clean'],
        num_rows: 720
    })
    test: Dataset({
        features: ['text', 'label', 'text_clean'],
        num_rows: 720
    })
})

In [None]:
prompts = [
    'This example is {}.',
    '{}',
    'This example expresses a {} sentiment.',
    'This example expresses a {} feeling.',
    'This example expresses a {} attitude.',
    'This example expresses a {} opinion.',
]

for checkpoint in CHECKPOINTS:
    model_size = checkpoint.split('-')[-1]
    for idx, prompt in enumerate(prompts):
        src.experiments.zeroshot.run_experiment(
            dataset_dict=dataset,
            candidate_labels=labels,
            checkpoint=checkpoint,
            device='mps',
            hypothesis_template=prompt,
            progress_bar=True,
            experiment_id=f'deberta_v3_{model_size}_zeroshot_prompt_{idx}',
            save_path=SAVE_PATH_RESULTS.joinpath(dataset_name),
        )
print('Done!')

100%|██████████| 720/720 [00:27<00:00, 25.82it/s]
100%|██████████| 720/720 [00:27<00:00, 26.32it/s]
100%|██████████| 720/720 [00:26<00:00, 27.01it/s]
100%|██████████| 720/720 [00:26<00:00, 26.95it/s]
100%|██████████| 720/720 [00:27<00:00, 26.25it/s]
100%|██████████| 720/720 [00:28<00:00, 25.44it/s]
100%|██████████| 720/720 [00:16<00:00, 43.36it/s]
100%|██████████| 720/720 [00:16<00:00, 43.83it/s]
100%|██████████| 720/720 [00:16<00:00, 43.71it/s]
100%|██████████| 720/720 [00:16<00:00, 43.81it/s]
100%|██████████| 720/720 [00:16<00:00, 44.37it/s]
100%|██████████| 720/720 [00:16<00:00, 44.71it/s]
100%|██████████| 720/720 [00:28<00:00, 25.15it/s]
100%|██████████| 720/720 [00:28<00:00, 25.50it/s]
100%|██████████| 720/720 [00:28<00:00, 25.19it/s]
100%|██████████| 720/720 [00:28<00:00, 25.23it/s]
 48%|████▊     | 348/720 [00:14<00:15, 23.74it/s]

In [2]:
SAVE_PATH_RESULTS = Path(os.pardir, 'data', 'prompting')
dataset_name = 'ag_news_processed'
dataset = datasets.load_from_disk(LOAD_PATH_DATASET.joinpath(dataset_name))
labels = dataset['test'].features['label'].names
print('Labels:', labels)
dataset

Labels: ['World', 'Sports', 'Business', 'Sci/Tech']


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'text_clean'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label', 'text_clean'],
        num_rows: 7600
    })
})

In [19]:
prompts2labels = [
    ('This example is {}.', labels),
    ('{}', labels),
    ('This example is about {}.', labels),
    ('This main topic of this text is {}.', labels),
    ('This example is {}.', ['World News', 'Sports', 'Business, Company, Entrepreneurship', 'Technology']),
]

for checkpoint in CHECKPOINTS:
    model_size = checkpoint.split('-')[-1]
    for idx, (prompt, labels) in enumerate(prompts2labels):
        src.experiments.zeroshot.run_experiment(
            dataset_dict=dataset,
            candidate_labels=labels,
            checkpoint=checkpoint,
            device='mps',
            hypothesis_template=prompt,
            progress_bar=True,
            experiment_id=f'deberta_v3_{model_size}_zeroshot_prompt_{idx}',
            save_path=SAVE_PATH_RESULTS.joinpath(dataset_name),
        )
print('Done!')

100%|██████████| 7600/7600 [05:49<00:00, 21.74it/s]
100%|██████████| 7600/7600 [05:00<00:00, 25.33it/s]
100%|██████████| 7600/7600 [04:52<00:00, 26.02it/s]
100%|██████████| 7600/7600 [04:51<00:00, 26.11it/s]
100%|██████████| 7600/7600 [04:49<00:00, 26.25it/s]
100%|██████████| 7600/7600 [04:22<00:00, 28.97it/s]
100%|██████████| 7600/7600 [03:09<00:00, 40.17it/s]
100%|██████████| 7600/7600 [03:00<00:00, 42.16it/s]
100%|██████████| 7600/7600 [02:55<00:00, 43.20it/s]
100%|██████████| 7600/7600 [02:52<00:00, 44.05it/s]
100%|██████████| 7600/7600 [04:46<00:00, 26.56it/s]
100%|██████████| 7600/7600 [04:42<00:00, 26.87it/s]
100%|██████████| 7600/7600 [04:44<00:00, 26.75it/s]
100%|██████████| 7600/7600 [04:41<00:00, 27.01it/s]
100%|██████████| 7600/7600 [04:40<00:00, 27.09it/s]


Done!


## III. Sanity Check

In [5]:
src.experiments.utils.show_best_results(SAVE_PATH_RESULTS)

Unnamed: 0,rotten_tomatoes,imdb,yelp_polarity,yelp_review_full,setfit_sst5,dynabench_dynasent,ag_news,20_newsgroups,dbpedia_14,web_of_science
dummy_classifier,47.84,48.9,46.7,20.4,20.1,32.69,25.32,5.34,7.22,0.72
complement_naive_bayes,76.54,85.6,86.4,49.34,36.01,52.0,86.74,75.1,94.3,68.6
sgd_classifier,75.64,88.4,91.8,56.84,36.27,51.67,87.96,75.76,97.31,75.98
fasttext,78.26,87.9,94.9,62.58,41.31,54.49,92.03,78.47,98.59,64.49
cnn,74.41,85.19,94.4,60.01,33.61,53.23,89.59,68.56,97.61,69.56
deberta_v3_small_finetuned,88.74,96.5,97.4,66.84,55.11,67.42,93.78,82.9,99.02,79.76
deberta_v3_small_zeroshot,75.89,79.08,86.24,34.98,34.2,42.48,54.26,42.78,59.2,57.3
deberta_v3_xsmall_zeroshot,76.17,82.53,88.89,35.18,35.73,43.99,64.29,45.09,64.5,62.25
deberta_v3_base_zeroshot,75.14,85.09,89.29,31.15,30.89,42.62,61.58,47.24,49.25,46.99
