# NLI for Zero Shot Classification Demo

First load the model. In theory any model pretrained for an NLI task with `"entailment"` and `"contradiction"` labels will work.  Search [HuggingFace Hub](https://huggingface.co/models?sort=downloads&search=nli) for more.

In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# load the model and tokenizer
MODEL = "cross-encoder/nli-distilroberta-base"
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
tokenizer = AutoTokenizer.from_pretrained(MODEL)

# grab the "entailment" (positive) and "contradiction" (negative) ids
# these correspond to the index of the logits in the model's output.
pos_id = model.config.label2id["entailment"]
neg_id = model.config.label2id["contradiction"]

In [2]:
def predict(context, candidates):
    """Given a context, print the score for each candidate answer"""
    
    for c in candidates:
        
        # tokenize context/candidate as a "text pair"
        inputs = tokenizer(context, c, return_tensors="pt")
        
        # run through the encoder/classifier model
        output = model(**inputs)
        
        # output.logits have the outputs from the classification head
        # take softmax over "entailment" and "contradiction" scores
        # `score` is the resulting "probability of entailment"
        score = output.logits[:,[pos_id, neg_id]].softmax(-1)[0][0]
        print(f"{c}: {score}")


In [3]:
predict("I backed up into a telephone pole.", ["Theft", "Collision Damage"])

Theft: 0.16164545714855194
Collision Damage: 0.379823237657547
