In [None]:
import pandas as pd
import numpy as np
import logging

import time
from pathlib import Path

import torch
from transformers import BartForSequenceClassification, BartTokenizer, BartConfig

# internal libraries
from ressources import target_to_label

# set a seed value
torch.manual_seed(555)

logging.basicConfig(level=logging.WARNING)

In [None]:
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
now = time.time()

label_to_target = {v: k for k, v in target_to_label.items()}

with open(results_dir / Path(f"{now}.csv"), "w") as f:
    f.write("text,exec_time,12.1,12.2,12.3,12.4,12.5,12.6,12.7,12.8,12.a,12.b,12.c\n")

In [None]:
config = BartConfig.from_pretrained("valhalla/distilbart-mnli-12-9")
tokenizer = BartTokenizer.from_pretrained("valhalla/distilbart-mnli-12-9")
model = BartForSequenceClassification.from_pretrained("valhalla/distilbart-mnli-12-9")

In [None]:
data = pd.read_csv("osdg-data.csv")

df = data[(data["sdg"] == 12) & (data["label_osdg"] == "accepted")]

text = df["text"].iloc[0]
print(text)

In [None]:
def predict(premise, hypothesis):
    # run through model pre-trained on MNLI
    input_ids = tokenizer.encode(premise, hypothesis, truncation=True, return_tensors="pt")
    logits = model(input_ids)[0]

    # we throw away "neutral" (dim 1) and take the probability of
    # "entailment" (2) as the probability of the label being true
    entail_contradiction_logits = logits[:, [0, 2]]

    probs = entail_contradiction_logits.softmax(dim=1)
    true_prob = probs[:, 1].item() * 100
    #logging.info(f"Probability that '{hypothesis}' is true: {true_prob:0.2f}%")

    return true_prob

In [None]:
labels = list(target_to_label.values())

for text in df["text"]:

    results = {**{"text": [], "exec_time": []}, **{k: [] for k in target_to_label.keys()}}

    start_time = time.time()
    results["text"].append(text)

    for label in labels:
        # Build hypothesis
        hypothesis = "The context is " + label

        # Run prediction
        true_prob = predict(text, hypothesis)

        target_id = label_to_target[label]
        results[target_id].append(true_prob)

    total_time = time.time() - start_time
    #logging.info(f"Total prediction time : {total_time:0.2f}s")

    results["exec_time"].append(total_time)

    with open(results_dir / Path(f"{now}.csv"), "a") as f:
        for i in range(len(results["exec_time"])):
            text = results["text"][i]
            exec_time = results["exec_time"][i]
            new_line = (
                f'"{text}",'
                + ",".join([f"{v[i]:.2f}" for k, v in results.items() if k != "text"])
                + "\n"
            )
            f.write(new_line)

    del results