In [None]:
import torch
from transformers import BartForSequenceClassification, BartTokenizer

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

print(DEVICE)


class BartZeroShot:
    def __init__(self):
        self.nli_model = BartForSequenceClassification.from_pretrained(
            "facebook/bart-large-mnli"
        ).to(DEVICE)
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")

    def predict(self, sentence, label):
        x = self.tokenizer.encode(
            sentence,
            f"This example is {label}",
            return_tensors="pt",
            truncation="only_first",
        )
        logits = self.nli_model(x.to(DEVICE))[0]

        entail_contradiction_logits = logits[:, [0, 2]]
        probs = entail_contradiction_logits.softmax(1)
        prob_label_is_true = probs[:, 1].item()
        return prob_label_is_true

In [None]:
bz = BartZeroShot()

In [None]:
bz.predict("I really really hate my life", "positive")

In [None]:
bz.predict("I really really love my life", "positive")