# Replication Pipeline for BART results

This notebook contains the code to reproduce the BART results presented in our paper. To run this code, follow the instructions below. 

In [None]:
# Install packages (only required if not already installed)
# !pip install transformers
# !pip install torch
# !pip install torchmetrics

In [None]:
import transformers
import torch
import numpy as np
import pandas  as pd
from transformers import pipeline
import torch
from tqdm import tqdm

from src.finetuning import compute_and_print_metrics_for_dataset_b

# ************************************************ #
# TODO: Choose dataset for reproduction from the four case studies below
# ************************************************ #

# case study 1
DATASET = "01-nyt-sentiment"
zero_shot_labels = ["negative sentiment", "positive sentiment"]
dataset_sentences = f"./data/{DATASET}/all-x.csv"

# ************************************************ #

# case study 2
# DATASET = "02-twitter-stance"
# zero_shot_labels = ["negative attitudinal stance towards", "positive attitudinal stance towards"]
# dataset_sentences = f"./data/{DATASET}/all-x.csv"

# ************************************************ #

# case study 3
# DATASET = "03-emotion-angry"
# zero_shot_labels = ["Angry", "Non-Angry"]
# dataset_sentences = f"./data/{DATASET}/all-x.csv"
# dataset_sentences = f"./data/{DATASET}/all-x-translated.csv"

# ************************************************ #

# case study 4
# DATASET = "04-brexit-stance"
# zero_shot_labels = ["Neutral towards Leave demands", "Pro-Leave demands", "Very Pro-Leave demands"]
# dataset_sentences = f"./data/{DATASET}/all-x.csv"

# ************************************************ #

MODEL_NAME = "facebook/bart-large-mnli"

def label2idx(label_Name):
    return zero_shot_labels.index(label_Name)

dataset_labels = f"./data/{DATASET}/all-y.csv"

all_x = np.squeeze(np.array(pd.read_csv(dataset_sentences, header=None, sep='\t\t')))
all_y = np.squeeze(np.array(pd.read_csv(dataset_labels, dtype=np.float32, header=None)))

pipe = pipeline("zero-shot-classification", model=MODEL_NAME, device=torch.cuda.current_device())

n_samples = all_x.shape[0]
all_y_pred = np.zeros(n_samples)

for idx in tqdm(range(n_samples)):
    
    curr_x = all_x[idx]

    output = pipe(curr_x, zero_shot_labels)
        
    curr_y_pred = label2idx(output.get("labels")[0])

    all_y_pred[idx] = curr_y_pred

compute_and_print_metrics_for_dataset_b(all_y, [ all_y_pred ], None, "", False, True)