<a href="https://colab.research.google.com/github/kumar-abhishek/handson-ml2/blob/master/zero_shot_learning_with_nli.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"]="0" 

In [10]:
!pip3 install -U ktrain

Collecting ktrain
  Using cached https://files.pythonhosted.org/packages/12/11/49bdde1b08a210365c04367e84d3fb1489db6f4b262358a10719962891a2/ktrain-0.16.3.tar.gz
Collecting tensorflow==2.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/85/d4/c0cd1057b331bc38b65478302114194bd8e1b9c2bbc06e300935c0e93d90/tensorflow-2.1.0-cp36-cp36m-manylinux2010_x86_64.whl (421.8MB)
[K     |████████████████████████████████| 421.8MB 32kB/s 
[?25hCollecting scikit-learn==0.21.3
[?25l  Downloading https://files.pythonhosted.org/packages/a0/c5/d2238762d780dde84a20b8c761f563fe882b88c5a5fb03c056547c442a19/scikit_learn-0.21.3-cp36-cp36m-manylinux1_x86_64.whl (6.7MB)
[K     |████████████████████████████████| 6.7MB 45.0MB/s 
Collecting keras_bert>=0.81.0
  Downloading https://files.pythonhosted.org/packages/ec/08/bffa03eb899b20bfb60553e4503f8bac00b83d415bc6ead08f6b447e8aaa/keras-bert-0.84.0.tar.gz
Collecting langdetect
[?25l  Downloading https://files.pythonhosted.org/packages/56/a3/8407c1e62d59

# Zero Shot Learning Using Natural Language Inference

In this notebook, we will demonstrate **zero-shot** topic classification.  **Zero-Shot Learning (ZSL)** is being able to solve a task despite not having received any training examples of that task.  The `ZeroShotClassifier` class in *ktrain* can be used to perform topic classification with no training examples.  The technique is based on **Natural Language Inference (or NLI)** as described in [this interesting blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html) by Joe Davison.

## STEP 1: Setup the Zero Shot Classifier and Describe Topics

We first instantiate the zero-shot-classifier and then describe the topic labels for our classifier with strings.

In [11]:
from ktrain import text 

In [12]:
zsl = text.ZeroShotClassifier()
topic_strings=['politics', 'elections', 'sports', 'films', 'television']

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=908.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1629486723.0, style=ProgressStyle(descr…




## STEP 2: Predict

There is no training involved here, as we are using **zero-shot-learning**.  We will simply supply the document that is being classified and the `topic_strings` defined earlier. The `predict` method uses Natural Language Inference (NLI) to infer the topic probabilities.

In [17]:
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
L = zsl.predict(doc, topic_strings=topic_strings, include_labels=True)
#print(L)
print(max(L, key=lambda item: item[1])[0])


elections


As you can see, our model correctly assigned the highest probabilities to `politics` and `elections`, as the text supplied pertains to both these topics.

Let's try some other examples.
#### document about `television`

In [12]:
doc = 'What is your favorite sitcom of all time?'
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('politics', 0.00011597130651352927),
 ('elections', 0.00015142725897021592),
 ('sports', 0.00011554655065992847),
 ('films', 0.03586330637335777),
 ('television', 0.9755581617355347)]

#### document about both `politics` and `television`

In [13]:
doc = """
President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised 
the administration's response to the coronavirus pandemic as a \"great success story\" on Wednesday -- 
less than a day after the number of confirmed coronavirus cases in the United States topped 1 million. 
Kushner painted a rosy picture for \"Fox and Friends\" Wednesday morning, 
saying that \"the federal government rose to the challenge and 
this is a great success story and I think that that's really what needs to be told.\"
"""
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('politics', 0.8382051587104797),
 ('elections', 0.009549472481012344),
 ('sports', 0.003681211732327938),
 ('films', 0.04510315880179405),
 ('television', 0.9293774366378784)]

#### document about `sports`, `television`, and `film`

In [14]:
doc = "The Last Dance is a 2020 American basketball documentary miniseries co-produced by ESPN Films and Netflix."
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('politics', 0.0003102553600911051),
 ('elections', 0.00048394937766715884),
 ('sports', 0.9848700761795044),
 ('films', 0.9717175364494324),
 ('television', 0.9505333304405212)]

In [15]:
#emotion_seeds = list(["happy", "sad", "surprise", "disgust", "anger", "fear", "shame"])
topic_strings=["happy", "sad", "surprise", "disgust", "anger", "fear", "shame"]

In [17]:
doc = "He felt guilty as he thought of Maeve's sweet face , and embarrassed that he should be so powerfully attracted to a woman dedicated to God"
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('happy', 0.00045107901678420603),
 ('sad', 0.7498875856399536),
 ('surprise', 0.342402845621109),
 ('disgust', 0.0009430521167814732),
 ('anger', 0.0019389173248782754),
 ('fear', 0.04364996775984764),
 ('shame', 0.9690190553665161)]

In [22]:
doc="David is concerned, at the length of time he says it took for an ambulance to arrive. He was happy in the end"
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('happy', 0.052014824002981186),
 ('sad', 0.0016243212157860398),
 ('surprise', 0.0031807036139070988),
 ('disgust', 0.0025987790431827307),
 ('anger', 0.02350490354001522),
 ('fear', 0.7439681887626648),
 ('shame', 0.025683118030428886)]