### Side Notebook: Classification and Masked Language Models

This is inspired by and aligned with the paper: **“Exploiting Cloze Questions for Few Shot Text Classification and Natural Language Inference”**, by Timo Schick, Hinrich Schuetze (https://arxiv.org/pdf/2001.07676.pdf). See also Joachim\'s paper reading session on Thursday, 03/18, noon.

The idea is to use the intrinsic knowledge of pre-trained language models to 'prime' the classification task. This approach can improve classification performance when there are few training examples.

In [1]:
from transformers import BertTokenizer, TFBertForMaskedLM
import tensorflow as tf
import numpy as np

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = TFBertForMaskedLM.from_pretrained('bert-base-cased')

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/527M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFBertForMaskedLM.

All the layers of TFBertForMaskedLM were initialized from the model checkpoint at bert-base-cased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.


Let's first look at an example of how the masked language model works:

In [3]:
inputs = tokenizer("The capital of France is [MASK].", return_tensors="tf")
mask_position = np.where(inputs['input_ids'] == 103)

In [4]:
mask_position

(array([0]), array([6]))

In [7]:
np.argmax(model(inputs)[0][0, mask_position[1][0]])

2123

In [8]:
tokenizer.convert_ids_to_tokens(2123)

'Paris'

Cool! The language model figured that out.

Let's now look at three fake tweets:

In [9]:
tweets = ["This flick was really bad. Poor acting.", 
          "This joint was absolutely terrific. Great! Really nice flavors!", 
          "My meal at the place was terribly prepared. Way too spicey."]

**Questions**:
 * Can you classify these tweets without further information?
 * How would we generally use BERT for these types of classification tasks?
 * What is the challenge with this approach, particularly given there are few examples?
 * Is there maybe a way to use masked language models to help?
 
 
Obviously... one cannot simply say "I'm going to classify". Classify what?  You need to know what the domain/question really is. 'Is the tweet about restaurants or movies' would be quite different from a sentiment classification.

So... can we use masked LMs? 

Yes... the idea is to augment the actual text with a **probing question** that includes a **masked token that is to be predicted**. And you compare here the likelihood of **words that represent the classes**. I.e., for each task you need to define a (or multiple) probing question and the words that represent the classes.

Examples:

 *  Topic classification: **"This movie was really bad. Poor acting."** $\rightarrow$ **"This movie was really bad. Poor acting. *This sentence talks about [MASK]*"** where you want the LM to predict which word of your class representatives **'restaurants' or 'movies'** is more likely for the [MASK] token prediction.
 
 
  *  Sentiment classifications: **"This movie was really bad. Poor acting."** $\rightarrow$ **"This movie was really bad. Poor acting. *I had a [MASK] experience.*"** where you want the LM to predict which word of your class representatives **'good' or 'bad'** is more likely for the [MASK] token prediction.



Let's try easy examples for two classification tasks for our tweets:

1) Does the tweet refer to restaurants or movies   
2) Is the sentiment positive or negative

In [8]:
#test_phrase = " I had a [MASK] experience."
test_phrase = " The previous phrase talked about [MASK]."

#pair = ['good', 'bad']
pair = ['restaurants', 'movies']

tweets_pet = [x + test_phrase for x in tweets]

inputs = tokenizer(tweets_pet, padding=True, return_tensors="tf")

mask_positions = np.where(inputs['input_ids'] == 103)

out = model(inputs)


pair_tokens = tokenizer.convert_tokens_to_ids(pair)


for example_nr, example in enumerate(tweets_pet):
    print('"' + example + '"')
    print('Logits:')

    print('\t' + pair[0] + ': ', out[0][example_nr, mask_positions[1][example_nr]][pair_tokens[0]].numpy())
    print('\t' + pair[1] + ': ', out[0][example_nr, mask_positions[1][example_nr]][pair_tokens[1]].numpy())
    print()


"This flick was really bad. Poor acting. The previous phrase talked about [MASK]."
Logits:
	restaurants:  0.0086269155
	movies:  5.537849

"This joint was absolutely terrific. Great! Really nice flavors! The previous phrase talked about [MASK]."
Logits:
	restaurants:  3.5222921
	movies:  2.9667614

"My meal at the place was terribly prepared. Way too spicey. The previous phrase talked about [MASK]."
Logits:
	restaurants:  3.9967456
	movies:  2.4576795



Nice! Looks like we got perfect zero-shot performance by exploiting the knowledge of the masked language model! (I.e., the logit for the correct class was the higher one.)

Now we do the same for the sentiment classification. For that, we need to swap the probing question and the answer pair that represents the task:

In [9]:
test_phrase = " I had a [MASK] experience."
#test_phrase = " This sentence talks about [MASK]."

pair = ['good', 'bad']
#pair = ['restaurants', 'movies']

tweets_pet = [x + test_phrase for x in tweets]

inputs = tokenizer(tweets_pet, padding=True, return_tensors="tf")

mask_positions = np.where(inputs['input_ids'] == 103)

out = model(inputs)


pair_tokens = tokenizer.convert_tokens_to_ids(pair)


for example_nr, example in enumerate(tweets_pet):
    print('"' + example + '"')
    print('Logits:')

    print('\t' + pair[0] + ': ', out[0][example_nr, mask_positions[1][example_nr]][pair_tokens[0]].numpy())
    print('\t' + pair[1] + ': ', out[0][example_nr, mask_positions[1][example_nr]][pair_tokens[1]].numpy())
    print()


"This flick was really bad. Poor acting. I had a [MASK] experience."
Logits:
	good:  9.79245
	bad:  13.258743

"This joint was absolutely terrific. Great! Really nice flavors! I had a [MASK] experience."
Logits:
	good:  9.051356
	bad:  8.38979

"My meal at the place was terribly prepared. Way too spicey. I had a [MASK] experience."
Logits:
	good:  8.767116
	bad:  10.673337



And again are all of the classes correct! 

So for this simple setup we got zero-shot classification accuracy for both classification tasks by selecting a suitable probing question and a suitable word pair that represents the two classes.  

But be careful... things can be a lot less good even in situations where one would think it should be a no-brainer for the LM.

Let's play around a bit.

**Further Questions**:
 * Any surprises in the above outcomes? 
 * Are there other ways to encode the information + test phrase for BERT? 
 * Would this also work for more than 2 classes? What would change?
 * What can be difficulties with this approach?