In [None]:
!pip install transformers[sentencepiece]

In [None]:
from torch.nn.functional import softmax
from transformers import MT5ForConditionalGeneration, MT5Tokenizer

model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)

In [None]:
sequence_to_classify = "Wen werden Sie bei der nächsten Wahl wählen? "
candidate_labels = ["spor", "ekonomi", "politika"]
hypothesis_template = "Dieses Beispiel ist {}."

In [None]:
ENTAILS_LABEL = "▁0"
NEUTRAL_LABEL = "▁1"
CONTRADICTS_LABEL = "▁2"
label_inds = tokenizer.convert_tokens_to_ids(
    [ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL]
)

In [None]:
def process_nli(premise, hypothesis):
    return f"xnli: premise: {premise} hypothesis: {hypothesis}"

In [None]:
pairs = [
    (sequence_to_classify, hypothesis_template.format(label))
    for label in candidate_labels
]
seqs = [
    process_nli(premise=premise, hypothesis=hypothesis) for premise, hypothesis in pairs
]

In [None]:
print(seqs)

In [None]:
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
out = model.generate(
    **inputs, output_scores=True, return_dict_in_generate=True, num_beams=1
)

In [None]:
scores = out.scores[0]
scores = scores[:, label_inds]

In [None]:
print(scores)

In [None]:
entailment_ind = 0
contradiction_ind = 2
entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]

In [None]:
entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)

In [None]:
print(entail_vs_contra_probas)

In [None]:
entail_scores = scores[:, entailment_ind]
entail_probas = softmax(entail_scores, dim=0)

In [None]:
print(entail_probas)

In [None]:
print(dict(zip(candidate_labels, entail_probas.tolist())))