# Data preparation

This notebook is used to prepare the data subset with BART predictions for the case. Download the data using this [link](https://drive.google.com/file/d/13mAaFqCrscUYkoITf4rZ6qG9ptAlIJVb/view?usp=sharing) and place it in the `data` directory.

In [None]:
import pandas as pd
from tqdm import tqdm
import torch
from transformers import BartForSequenceClassification, BartTokenizer

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
SEED = 42
NUM_SAMPLES = 20_000

In [None]:
class BartZeroShot:
    def __init__(self):
        self.nli_model = BartForSequenceClassification.from_pretrained(
            "facebook/bart-large-mnli"
        ).to(DEVICE)
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")

    def predict(self, sentence, label):
        x = self.tokenizer.encode(
            sentence,
            f"This example is {label}",
            return_tensors="pt",
            truncation="only_first",
        )
        logits = self.nli_model(x.to(DEVICE))[0]

        entail_contradiction_logits = logits[:, [0, 2]]
        probs = entail_contradiction_logits.softmax(1)
        prob_label_is_true = probs[:, 1].item()
        return prob_label_is_true

In [None]:
data_full = pd.read_csv("data/twitter_dataset_full.csv")

In [None]:
data_small = data_full.sample(n=NUM_SAMPLES, random_state=SEED)

In [None]:
# Make predictions using the BartZeroShot model
model = BartZeroShot()
tqdm.pandas()

data_small["bart_is_positive"] = data_small.progress_apply(
    lambda row: model.predict(row["message"], "positive"), axis=1
)

In [None]:
# Save as csv to data folder
data_small.to_csv("data/twitter_dataset_small_w_bart_preds.csv", index=False)