In [10]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")
model.to(device)

def test_entailment(text1, text2):
    batch = tokenizer(text1, text2, return_tensors='pt').to(model.device)
    with torch.no_grad():
        logits = model(**batch).logits
        proba = torch.softmax(logits, -1)
    return proba.cpu().numpy()[0, model.config.label2id['ENTAILMENT']]

def test_equivalence(text1, text2):
    return max(test_entailment(text1, text2), test_entailment(text2, text1))

print(test_equivalence("I'm a good person", "I'm not a good person"))  # 2.0751484e-07
print(test_equivalence("I'm a good person", "You are a good person"))  # 0.49342492
print(test_equivalence("I'm a good person", "I'm not a bad person"))   # 0.94236994

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


0.00070681685
0.747932
0.9860415


In [11]:
sentence1 = '(1) Any dispute between the Contracting Parties concerning the interpretation or application of this Agreement shall, whenever possible, be settled through diplomatic channels.'
sentence2 = '(1) Any dispute between the Contracting Parties concerning the interpretation or application of this Agreement shall, whenever possible, be settled amicably through consultations.'

print(test_equivalence(sentence1, sentence2))  # 0.9

0.6988367
