# Zero-shot

In [1]:
import torch

if torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

### Loading the dataset

In [2]:
from datasets import load_dataset

In [3]:
ds = load_dataset("paul-ww/ei-abstract-significance")

Found cached dataset parquet (/Users/paul/.cache/huggingface/datasets/paul-ww___parquet/paul-ww--ei-abstract-significance-1c087dddb8b05c98/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
class_labels = ds["train"].features["label"]
label2id = {name: class_labels.str2int(name) for name in class_labels.names}
id2label = {v: k for k, v in label2id.items()}

### Tracking using Weights&Biases

In [5]:
%env WANDB_LOG_MODEL='end'
%env WANDB_WATCH='all'

env: WANDB_LOG_MODEL='end'
env: WANDB_WATCH='all'


In [6]:
config = {
    # "model": "valhalla/distilbart-mnli-12-3",
    "model": "valhalla/distilbart-mnli-12-1",
    # "model": "facebook/bart-large-mnli",
    "candidate_labels": class_labels.names,
    "batch_size": 1,
    "seed": 42,
}

In [7]:
import wandb

wandb.login()
run = wandb.init(
    project="significance_classification", group="transformer_zero_shot", config=config
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpaul_ww[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
from transformers import set_seed

set_seed(wandb.config["seed"])

### Model Setup

In [9]:
# from transformers import AutoTokenizer

# tokenizer = AutoTokenizer.from_pretrained(wandb.config["model"], model_max_length=512)

In [10]:
from transformers import pipeline

pipe_zs = pipeline(
    "zero-shot-classification",
    model=wandb.config["model"],
    # tokenizer=tokenizer,
    candidate_labels=wandb.config["candidate_labels"],
    device=DEVICE,
)

In [11]:
from tqdm.auto import tqdm
from transformers.pipelines.pt_utils import KeyDataset
import math

predictions = []
for out in tqdm(
    pipe_zs(
        KeyDataset(ds["test"], "text"),
        batch_size=wandb.config["batch_size"],
        truncation="only_first",
        device=DEVICE,
        return_all_scores=True,
    ),
    desc="Running inference",
    total=math.ceil(len(ds["test"]) / wandb.config["batch_size"]),
):
    predictions.append(out)

Running inference:   0%|          | 0/123 [00:00<?, ?it/s]

#### Evaluation

In [12]:
import numpy as np

y_pred_proba = []
for p in predictions:
    idx_0 = p["labels"].index(class_labels.names[0])
    idx_1 = p["labels"].index(class_labels.names[1])
    score_0 = p["scores"][idx_0]
    score_1 = p["scores"][idx_1]
    y_pred_proba.append([score_0, score_1])

In [13]:
from classification.utils import log_metrics_to_wandb

log_metrics_to_wandb(
    y_pred_proba=np.array(y_pred_proba),
    y_true_num=ds["test"]["label"],
    id2label=id2label,
    labels=class_labels.names,
    run=run,
)