This notebook was created by Sylvain Gugger

In [149]:
from fastai.text import *

# What is the BLEU metric?

The BLEU metric has been introduced in [this article](https://www.aclweb.org/anthology/P02-1040) to come with some kind of way to evaluate the performance of translation models. It's based on the precision you hit with n-grams in your prediction compared to your target. Let's see this on an example. Imagine you have the target sentence
```
the cat is walking in the garden
```
and your model gives you the following output
```
the cat is running in the fields
```
We are going to compute the precision, which is the number of correctly predicted n-grams divided by the number of predicted n-grams for n going from 1 to 4.

For 1-grams (or tokens, more simply), we have predicted 5 correct words out of 7, so we get a precision of 5/7. Note that the order doesn't matter, for instance predicting
```
she read the book because she was interested in world history
```
instead of
```
she was interested in world history because she read the book
```
would give a precision of 1 for the 1-grams.

For 2-grams, in the first example, we have 3 correct 2-grams ("the cat", "cat is" and "in the") out of 6, so a precision of 3/6. In the second example, the precision for 2-grams is 9/10.

For 3-grams, in the first example, we have only 1 correct 3-gram ("the cat is") out of 5, so a precision of 1/5. In the second example, the precision for 3-grams is 6/9.

For 4-grams, in the first example, we don't have any 4-gram that is correct, so the precision is 0. In the second example, it's 4/8.

There is one big drawback in just taking the raw precision: very often a seq2seq model will predict the same easy word. If take the prediction
```
the the the the the the the the
```
for our first example, it would score a precision of 1. for 1-grams (because `the` is in the target sentence, so all the words are considered correct). To avoid that, we put a maximum for a given words to the number of times it can be considered correct, which is the number of times that word is in the target sentence. So in our example, only 2 of the 7 `the` are considered correct and this clamped precision gives us 2/7 for 1-grams.

One thing to note is that when we deal with a whole corpus, we take the sum of all the corrects over all the sentences then divide by the sum of all the predicted over all the sentences (we don't avarage precisions over each sentence).

To compute the BLEU score, the final formula is then
```
BLEU = length_penalty * ((p1 * p2 * p3 * p4) ** 0.25)
```
which is the geometric average of `p1`, `p2`, `p3` and `p4` (our n-gram precision scores) multiplied by a penalty given for the length: we penalize longer predictions with the precision scores, but short ones get it easier, especially if they only contain correct words. So we apply the following penalty:
```
length_penalty = 1 if len(pred) >= len(targ) else (1 - exp(-len(targ)/len(pred)))

NOte there should be a negative sign before the exponent.
```

And if we are taking the BLEU score of a whole corpus, we use the sum of the lengths of predicted sentences and the sum of the lengths of predicted targets.

# Let's code this

There is an implementation of BLEU in nltk, but the problem is that it's designed to support lists of tokenized texts, and therefore is very slow (5 hours announced on the validation set of the translation notebook for the average of BLEU scores of sentences). We have numericalized text, so it's easier to reimplement this and only deal with integers.

Specifically we are going to use the Counter class, which is going to count the number of instances of each n-gram in the predicted sentence and the target one:

## 1. Bleu score for unigrams 

### Suppose we have two lists of integers that represent word sequences

In [150]:
targ = [1,2,3,4,5,1,2]
pred = [1,2,3,7,5,1,1]

### Make a Counter object for each list, which is essentially a dictionary of items and number of occurrences

In [151]:
cnt_pred,cnt_targ = Counter(pred),Counter(targ)

In [152]:
cnt_pred,cnt_targ

(Counter({1: 3, 2: 1, 3: 1, 7: 1, 5: 1}),
 Counter({1: 2, 2: 2, 3: 1, 4: 1, 5: 1}))

#### The Bleu score is the number of corrects (the number of words in `pred` that are in `targ`) with a cap that is the number of times they appear in `targ`.

In [153]:
corrects = sum([min(c, cnt_targ[g]) for g,c in cnt_pred.items()])
corrects

5

In [154]:
cnt_pred.items()

dict_items([(1, 3), (2, 1), (3, 1), (7, 1), (5, 1)])

In [155]:
cnt_pred.values()

dict_values([3, 1, 1, 1, 1])

This works for unigrams, which are represented as ints. 
But it won't work for an ngram of more than one word, which is represented as a `list of ints`, since a `Counter` requires the objects inside to be `hashable`. 

## 2. BLEU score for ngrams
Here we develop machinery needed to calculate a BLEU score for ngrams.

### 2.1 Define a class that `hashes` an ngram, i.e. maps the ngram to a (hopefully) unique integer

In [156]:
# NGram class contains methods __eq__ and __hash__
# Takes inputs ngram and max_n
# ngram is the vector embedding of a sequence of words
class NGram():
    def __init__(self, ngram, max_n=5000): 
        self.ngram,self.max_n = ngram,max_n
        
    # Test ngram equivalence
    def __eq__(self, other):
    # Input: two ngrams
    # Output: boolean indicator. True if the ngrams are the same, False if they are not
        if len(self.ngram) != len(other.ngram): 
            return False
        equivalence_indicator = np.all(np.array(self.ngram) == np.array(other.ngram))
        return equivalence_indicator
    
    # Generate an integer hash for each ngram
    def __hash__(self):
        # suppose the ngram is "I see"
        #      then enumerate(self.ngram) returns [ (0,stoi["I"]), (1,stoi["see"]) ]
        #      then output is [ stoi["I"]*5000**0, stoi["see"]*5000**1 ]
        hash_value = int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))
        return hash_value

#### A `collision` is said to occur if the hash function maps two *different* ngrams to the *same* integer. To avoid collisions, `max_n` should be a large number, such as the vocab size, which makes collisions highly unlikely. 

### 2.2 Define a function to gather all the possible ngrams of a sequence of words

In [157]:
# generate all the ngrams in a sequence of words
def get_grams(x, n, max_n=5000):
    # input:
    #   x is a sequence of words or their integer representations
    #   n is the length of the ngram: if n = 3 then we are processing a 3-gram
    # output:
    #   a list of all the ngrams that can be constructed from x
    return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]

### 2.3 Example

In [158]:
ngram_list = get_grams(["I","see","that","you","like","to","run","fast"],2,max_n=5000)
print(len(ngram_list))
ngram_list

7


[<__main__.NGram at 0x14a8cde33c8>,
 <__main__.NGram at 0x14a8cde3240>,
 <__main__.NGram at 0x14a8cde3438>,
 <__main__.NGram at 0x14a8cde39b0>,
 <__main__.NGram at 0x14a8cde3a58>,
 <__main__.NGram at 0x14a8cde3ac8>,
 <__main__.NGram at 0x14a8cde3b38>]

In [159]:
# ngram_list is a list object whose elements are ngrams
print(type(ngram_list))

# an element of ngram_list is an NGram object
print(type(ngram_list[0]))

<class 'list'>
<class '__main__.NGram'>


In [160]:
# An object of the NGram class has static properties to retrieve max_n and ngram
# There are also `dunder` methods, including 
#     __eq__ to test equivalence between two ngrams
#     __hash__ to compute the hash of the ngram
dir(ngram_list[0])

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'max_n',
 'ngram']

#### Let's see all the ngrams in ngram_list

In [161]:
[ngram_list[i].ngram for i in range(len(ngram_list))]

[['I', 'see'],
 ['see', 'that'],
 ['that', 'you'],
 ['you', 'like'],
 ['like', 'to'],
 ['to', 'run'],
 ['run', 'fast']]

#### Test the `__eq__` method

In [162]:
print(ngram_list[0].ngram.__eq__(ngram_list[1].ngram))
print(ngram_list[0].ngram.__eq__(ngram_list[0].ngram))

False
True


#### Test the __hash__ method
In order to `hash` each ngram to an integer, the input `x` to `get_grams()` must be a sequence of integer representations of a sequence of words, not the words themselves

In [163]:
# make all possible 3-grams from an input sequence
ngram_list = get_grams([1,5,6,10,2,1,10,6],n=3,max_n=5000)
ngram_list

[<__main__.NGram at 0x14a8cdeb5c0>,
 <__main__.NGram at 0x14a8cdeb630>,
 <__main__.NGram at 0x14a8cdeb710>,
 <__main__.NGram at 0x14a8cdeb518>,
 <__main__.NGram at 0x14a8cdeb0f0>,
 <__main__.NGram at 0x14a8cdeb358>]

#### Here are all the 3-grams

In [164]:
[ngram_list[i].ngram for i in range(len(ngram_list))]

[[1, 5, 6], [5, 6, 10], [6, 10, 2], [10, 2, 1], [2, 1, 10], [1, 10, 6]]

#### Here are the hashes of the 3-grams

In [165]:
[ngram_list[i].__hash__() for i in range(len(ngram_list))]

[150025001, 250030005, 50050006, 25010010, 250005002, 150050001]

### 2.4 Compute the number of correctly predicted ngrams:

In [166]:
def get_correct_ngrams(pred, targ, n, max_n=5000):
    # inputs pred, targ are predicted and target word sequences
    pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)
    pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)
    return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)

#### example

In [167]:
get_correct_ngrams([1,5,4,10,1,1,10,6],[1,5,6,10,2,1,10,6],3)

(1, 6)

## 3. The BLEU metric over two sentences is implemented as follows:

Assume we include sequences of up to 4 words, i.e., 1-grams, 2-grams, 3-grams, and 4-grams.

Note that the BLEU metric imposes a penalty if the predicted sentence is `shorter` than the target sentence

In [168]:
def sentence_bleu(pred, targ, max_n=5000):
    
    # We count up to 4-grams
    correct_ngrams = [get_correct_ngrams(pred, targ, ngram_size, max_n=max_n) for ngram_size in range(1,4)]
    
    # list with the fraction of ngrams recovered, for sequence lengths of 1, 2, 3, and 4
    frac_recovered = [n_recovered/n_total for n_recovered,n_total in correct_ngrams]
    
    # compute the penalty if predicted sentence is *shorter* than target sentence
    len_penalty = exp(1 - len(targ)/len(pred)) if len(pred) < len(targ) else 1 # between 0 and 1
    
    # raw score is the product of fraction recovered for each sequence length, to the power of 
    raw_score = (frac_recovered[0]*frac_recovered[1]*frac_recovered[2]*frac_recovered[3])
    
    # bleu score is product of len_penalty and raw_score ** 0.25 
    bleu = len_penalty * raw_score**0.25 
    return bleu

## 4. The BLEU metric over a text corpus

In [171]:
def corpus_bleu(preds, targs, max_n=5000):
    
    # Analysis includes 1-grams, 2-grams, 3-grams, 4-grams
    
    # Initialize
    pred_len,targ_len,n_precs,counts = 0,0,[0]*4,[0]*4
    
    # Loop over the sentences in the corpus
    for pred,targ in zip(preds,targs):
        
        # increment length of corpus, in sentences
        pred_len += len(pred)
        targ_len += len(targ)
        
        # extract all ngrams of each allowed length from the current sentence
        for i in range(4):
            
            # where is the function ngram_corrects() defined?
            count,total = ngram_corrects(pred, targ, i+1, max_n=max_n)
            n_recovered[i] += count
            n_total[i] += total
            
    # list with fraction of ngrams recovered for each sentence in the corpus
    frac_recovered = [count/total for count,total in zip(n_recovered,n_total)]
    
    # compute a single length penalty for the whole corpus
    len_penalty = exp(1 - targ_len/pred_len) if pred_len < targ_len else 1
    
    # compute raw score
    raw_score = frac_recovered [0]*frac_recovered [1]*frac_recovered [2]*frac_recovered [3]
    
    # compute and return bleu metric
    bleu = len_penalty * raw_score ** 0.25
    return bleu

#### This takes 11s to run on our (Sylvain and Jeremy) validation set (instead of 5 hours), so we can even use it as a metric during training. 
#### We define a `Callback`to do this:

In [170]:
class CorpusBLEU(Callback):
    def __init__(self, vocab_sz):
        self.vocab_sz = vocab_sz
        self.name = 'bleu'
    
    def on_epoch_begin(self, **kwargs):
        self.pred_len,self.targ_len,self.n_precs,self.counts = 0,0,[0]*4,[0]*4
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        last_output = last_output.argmax(dim=-1)
        for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):
            self.pred_len += len(pred)
            self.targ_len += len(targ)
            for i in range(4):
                c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)
                self.n_precs[i] += c
                self.counts[i] += t
    
    def on_epoch_end(self, last_metrics, **kwargs):
        n_precs = [c/t for c,t in zip(n_precs,counts)]
        len_penalty = exp(1 - targ_len/pred_len) if pred_len < targ_len else 1
        bleu = len_penalty * ((n_precs[0]*n_precs[1]*n_precs[2]*n_precs[3]) ** 0.25)
        return add_metrics(last_metrics, bleu)