In [None]:
import pandas as pd
import numpy as np
import logging
import tqdm
import time
from pathlib import Path

import torch
from transformers import BartForSequenceClassification, BartTokenizer, BartConfig
from sklearn.metrics import classification_report, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# internal libraries
from ressources import target_to_label

# set a seed value
torch.manual_seed(555)

logging.basicConfig(level=logging.WARNING)

In [None]:
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
now = time.time()

label_to_target = {v: k for k, v in target_to_label.items()}

with open(results_dir / Path(f"zs-nli-{now}.csv"), "w") as f:
    f.write(
        "text,target1,target2,exec_time,"
        + ",".join([k for k in target_to_label.keys()])
        + "\n"
    )

In [None]:
targets = [k for k in target_to_label.keys()]
targets.append("0")

In [None]:
config = BartConfig.from_pretrained("valhalla/distilbart-mnli-12-9")
tokenizer = BartTokenizer.from_pretrained("valhalla/distilbart-mnli-12-9")
model = BartForSequenceClassification.from_pretrained("valhalla/distilbart-mnli-12-9")
model.eval()

In [None]:
# data = pd.read_csv("osdg-data.csv")

# df = data[(data["sdg"] == 12) & (data["label_osdg"] == "accepted")]

# text = df["text"].iloc[0]
# print(text)

In [None]:
def predict(premise, hypothesis):
    # run through model pre-trained on MNLI
    input_ids = tokenizer.encode(
        premise, hypothesis, truncation=True, return_tensors="pt"
    )
    with torch.no_grad():
        logits = model(input_ids)[0]

    # we throw away "neutral" (dim 1) and take the probability of
    # "entailment" (2) as the probability of the label being true
    entail_contradiction_logits = logits[:, [0, 2]]

    probs = entail_contradiction_logits.softmax(dim=1)
    true_prob = probs[:, 1].item() * 100
    # logging.info(f"Probability that '{hypothesis}' is true: {true_prob:0.2f}%")

    return true_prob

In [None]:
labels = list(target_to_label.values())

# df true
df_true = pd.read_csv("data/afd_targets_odd_12_15_16.csv")


y_true = []
y_pred = []

pbar = tqdm.tqdm(150)

for index, row in df_true.iterrows():  # text

    results = {**{"exec_time": 0}, **{k: "" for k in target_to_label.keys()}}

    start_time = time.time()

    text = row["text"]
    target_max = None
    maximum = -1
    for label in labels:
        # Build hypothesis
        hypothesis = "The context is " + label

        # Run prediction
        true_prob = predict(text, hypothesis)

        target = label_to_target[label]
        results[target] = true_prob

        if true_prob >= maximum:
            maximum = true_prob
            target_max = target

    total_time = time.time() - start_time
    # logging.info(f"Total prediction time : {total_time:0.2f}s")

    results["exec_time"] = total_time

    # add maximum and true label
    y_pred.append(target_max)
    y_true.append(row["target1"])

    with open(results_dir / Path(f"zs-nli-{now}.csv"), "a") as f:
        exec_time = results["exec_time"]
        new_line = (
            f'"{text}",{row["target1"]},{row["target2"]},'
            + ",".join([f"{v:.2f}" for v in results.values()])
            + "\n"
        )
        f.write(new_line)

    del results
    pbar.update(1)

In [None]:
print(classification_report(y_true, y_pred, labels=targets))

fig, ax = plt.subplots(figsize=(20, 15))
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, labels=targets, ax=ax)