# MNLI

In [1]:
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification

### bert-base

In [2]:
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-MNLI")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-MNLI")
mnli = pipeline(
    "sentiment-analysis",
    tokenizer=tokenizer,
    model=model
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=630.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=48.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=437988463.0, style=ProgressStyle(descri…




### roberta-base

In [3]:
roberta_tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-MNLI")
roberta_model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-MNLI")
mnli_roberta = pipeline(
    "sentiment-analysis",
    tokenizer=roberta_tokenizer,
    model=roberta_model
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=678.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=150.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=25.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501006086.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at textattack/roberta-base-MNLI were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### distilbert-base

In [4]:
distilbert_tokenizer = AutoTokenizer.from_pretrained("textattack/distilbert-base-uncased-MNLI")
distilbert_model = AutoModelForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-MNLI")
mnli_distilbert = pipeline(
    "sentiment-analysis",
    tokenizer=distilbert_tokenizer,
    model=distilbert_model
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=639.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=48.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=267848226.0, style=ProgressStyle(descri…




In [5]:
def load_sentences(filename):
    '''
    params : name of file
    return : list of sentences
    '''
    f = open(filename)
    data = []
    for line in f.readlines():
        sents = line.strip().split('\t')
        data.append(sents[0].strip()+'[SEP]'+sents[1].strip())
    return data

In [6]:
def mnli_result(sents, outputs):
    id2label = {
        'LABEL_0':'contradiction',
        'LABEL_1':'neutral',
        'LABEL_2':'entailment'
    }
    for s, o in zip(sents, outputs):
        s = s.split('[SEP]')
        text = s[0]
        hypo = s[1]
        print(f"text : {s[0]}\nhypo : {s[1]}\n{id2label[o['label']]}({o['score']:.2f})\n")

## 1. bert-base-uncased

### test with a file

In [7]:
filename = 'mnli_sample.txt'

sents = load_sentences(filename)
outputs = mnli(sents)
mnli_result(sents, outputs)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)

text : This is a test sentence.
hypo : This is not a test sentence.
neutral(0.69)



### test with a sentence

In [8]:
text = "This is a test sentence."
hypo = "This is test."

sent = [text+'[SEP]'+hypo]
mnli_result(sent, mnli(sent))

text : This is a test sentence.
hypo : This is test.
neutral(0.72)



## 2. roberta-base

### test with a file

In [9]:
filename = 'mnli_sample.txt'

sents = load_sentences(filename)
outputs = mnli_roberta(sents)
mnli_result(sents, outputs)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)

text : This is a test sentence.
hypo : This is not a test sentence.
contradiction(1.00)



### test with a sentence

In [10]:
text = "This is a test sentence."
hypo = "This is test."

sent = [text+'[SEP]'+hypo]
mnli_result(sent, mnli_roberta(sent))

text : This is a test sentence.
hypo : This is test.
entailment(0.98)

