In [1]:
from nltk.lm import NgramCounter
from nltk import word_tokenize 
from nltk.util import ngrams
import nltk
from nltk.probability import FreqDist
from nltk.lm.api import LanguageModel
from nltk.corpus import gutenberg
from collections import Counter
from nltk.lm.preprocessing import padded_everygram_pipeline
from nltk.lm.preprocessing import flatten
from nltk.lm import Vocabulary
from itertools import chain
import math
import numpy as np



class MyStupidBackoff(LanguageModel):
    
    def __init__(self, context, fdist, alpha=0.4,  *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.context = context
        self.fdist = fdist

    def compute_ngram(self, order, text):
        token = word_tokenize(text)
        ngram_words = list(ngrams(token, order))
        return ngram_words

    def compute_ppl(self, model, data):
        highest_ngram = model.order
        scores = []
        for sentence in data:   
            ngrams, flat_text = padded_everygram_pipeline(highest_ngram, [sentence])
            ngrams = chain.from_iterable(ngrams)
            scores.extend([-1 * model.logscore(w[-1], w[0:-1]) for w in ngrams if len(w) == highest_ngram])
        return math.pow(2.0, np.asarray(scores).mean())

    def stupid_backoff(self, word, context=None):              
        padded_ngrams, flat_text = padded_everygram_pipeline(self.order, context)
        unigrams = self.compute_ngram(1, word)
        if not context:
            return len(unigrams) 
        if fdist.freq(word)<=0:
            return self.alpha*self.stupid_backoff(word, context[1:])
        else: 
            return fdist.freq(word)/NgramCounter(padded_ngrams).N()

    def unmasked_score(self, word, context=None):
        return self.stupid_backoff(word, context)

macbeth_sents = [[w.lower() for w in sent] for sent in gutenberg.sents('shakespeare-macbeth.txt')]
macbeth_words = flatten(macbeth_sents)
# Compute vocab 
lex = Vocabulary(macbeth_words, unk_cutoff=2)
# Handeling OOV
macbeth_oov_sents = [list(lex.lookup(sent)) for sent in macbeth_sents]
padded_ngrams_oov, flat_text_oov = padded_everygram_pipeline(2, macbeth_oov_sents)
# Train the model 
fdist = FreqDist(chain.from_iterable(macbeth_oov_sents))
stupid = MyStupidBackoff(macbeth_oov_sents, fdist, 0.4, 2)
stupid.fit(padded_ngrams_oov, flat_text_oov)

ngrms, flat_text = padded_everygram_pipeline(stupid.order, macbeth_sents)
ngrms = chain.from_iterable(ngrms)

print("\033[1mMyStupidBackoff StupidBackoff : Perplexity\033[0m")
print("{:.3}".format(stupid.perplexity([x for x in ngrms if len(x) == stupid.order])))

print("\033[1mMyStupidBackoff StupidBackoff : Manual Function Perplexity\033[0m")
print("{:.3}".format(stupid.compute_ppl(stupid, macbeth_sents)))

[1mMyStupidBackoff StupidBackoff : Perplexity[0m
1.04e+03
[1mMyStupidBackoff StupidBackoff : Manual Function Perplexity[0m
1.04e+03
