In [3]:
from nltk.corpus import reuters
from nltk import bigrams, trigrams
from collections import Counter, defaultdict

# Create a placeholder for model
model = defaultdict(lambda: defaultdict(lambda: 0))

# Count frequency of co-occurance  
for sentence in reuters.sents():
    for w1, w2, w3 in trigrams(sentence, pad_right=True, pad_left=True):
        model[(w1, w2)][w3] += 1

model

defaultdict(<function __main__.<lambda>()>,
            {(None,
              None): defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>, {'ASIAN': 4,
                          'They': 446,
                          'But': 1054,
                          'The': 8839,
                          'Unofficial': 1,
                          '"': 3589,
                          'In': 1380,
                          'Threat': 2,
                          'Taiwan': 38,
                          'Retaliation': 3,
                          'A': 764,
                          'Last': 202,
                          'Much': 8,
                          'He': 1586,
                          'Meanwhile': 41,
                          'Japan': 111,
                          'Deputy': 8,
                          'CHINA': 50,
                          'It': 1768,
                          'JAPAN': 164,
                          'MITI': 12,
                          'Nuclear': 1,
               

In [8]:
import numpy as np

class Trigram:
    
    def __init__(self, corpus: list):
        """ Initialized the class """
        self.corpus = corpus
        self._trigrams = self._get_trigrams()
        self.model = self._gen_model()
        
    def _get_trigrams(self):
        """ Gets the bigram count """
        trigram_count = defaultdict(lambda: defaultdict(lambda: 0))
        for tokens in self.corpus:
            for w1, w2, w3 in zip(tokens, tokens[1:], tokens[2:]):
                trigram_count[(w1, w2)][w3] += 1
        return trigram_count
    
    def _gen_model(self):
        """ Generates the trigram model """
        model = defaultdict(lambda: defaultdict(lambda: 0))
        for bigram, seq in self._trigrams.items():
            vals = np.array(list(seq.values()))
            count = float(np.sum(vals))
            vals = vals / count
            model[bigram] = dict(zip(seq.keys(), vals))
        return model


In [6]:
sents = list()
for s in reuters.sents():
    s.insert(0, '<s>')
    s.append('</s>')
    sents.append(s)
sents

[['<s>',
  'ASIAN',
  'EXPORTERS',
  'FEAR',
  'DAMAGE',
  'FROM',
  'U',
  '.',
  'S',
  '.-',
  'JAPAN',
  'RIFT',
  'Mounting',
  'trade',
  'friction',
  'between',
  'the',
  'U',
  '.',
  'S',
  '.',
  'And',
  'Japan',
  'has',
  'raised',
  'fears',
  'among',
  'many',
  'of',
  'Asia',
  "'",
  's',
  'exporting',
  'nations',
  'that',
  'the',
  'row',
  'could',
  'inflict',
  'far',
  '-',
  'reaching',
  'economic',
  'damage',
  ',',
  'businessmen',
  'and',
  'officials',
  'said',
  '.',
  '</s>'],
 ['<s>',
  'They',
  'told',
  'Reuter',
  'correspondents',
  'in',
  'Asian',
  'capitals',
  'a',
  'U',
  '.',
  'S',
  '.',
  'Move',
  'against',
  'Japan',
  'might',
  'boost',
  'protectionist',
  'sentiment',
  'in',
  'the',
  'U',
  '.',
  'S',
  '.',
  'And',
  'lead',
  'to',
  'curbs',
  'on',
  'American',
  'imports',
  'of',
  'their',
  'products',
  '.',
  '</s>'],
 ['<s>',
  'But',
  'some',
  'exporters',
  'said',
  'that',
  'while',
  'the',
  'con

In [9]:
trigram = Trigram(sents)

In [14]:
sorted(trigram.model[("the", "price")], key=trigram.model[("the", "price")].get, reverse=True)

['of',
 'it',
 'to',
 'for',
 '.',
 'at',
 'adjustment',
 'is',
 ',',
 'paid',
 'increases',
 'per',
 'the',
 'will',
 'cut',
 'cuts',
 '(',
 'differentials',
 'has',
 'stayed',
 'was',
 'freeze',
 'increase',
 'would',
 'yesterday',
 'effect',
 'used',
 'climate',
 'reductions',
 'limit',
 'now',
 'moved',
 'adjustments',
 'slumped',
 'move',
 'evolution',
 'went',
 'factor',
 'Royal',
 'again',
 'changes',
 'holds',
 'fall',
 '-',
 'from',
 'base',
 'on',
 'review',
 'while',
 'collapse',
 'being',
 'outlook',
 'rises',
 'drop',
 'guaranteed',
 ',"',
 'structure',
 'and',
 'could',
 'related',
 'hike',
 'we',
 'policy',
 'revision',
 'led',
 'action',
 'zone',
 'slump',
 'had',
 'difference',
 'in',
 'raise',
 'support',
 'gap',
 'projected',
 'approached',
 'instability']

In [13]:
trigram.model[("the", "price")]

{'yesterday': 0.004651162790697674,
 'of': 0.3209302325581395,
 'it': 0.05581395348837209,
 'effect': 0.004651162790697674,
 'cut': 0.009302325581395349,
 'for': 0.05116279069767442,
 'paid': 0.013953488372093023,
 'to': 0.05581395348837209,
 'increases': 0.013953488372093023,
 'used': 0.004651162790697674,
 'climate': 0.004651162790697674,
 '.': 0.023255813953488372,
 'cuts': 0.009302325581395349,
 'reductions': 0.004651162790697674,
 'limit': 0.004651162790697674,
 'now': 0.004651162790697674,
 'moved': 0.004651162790697674,
 'per': 0.013953488372093023,
 'adjustments': 0.004651162790697674,
 '(': 0.009302325581395349,
 'slumped': 0.004651162790697674,
 'is': 0.018604651162790697,
 'move': 0.004651162790697674,
 'evolution': 0.004651162790697674,
 'differentials': 0.009302325581395349,
 'went': 0.004651162790697674,
 'the': 0.013953488372093023,
 'factor': 0.004651162790697674,
 'Royal': 0.004651162790697674,
 ',': 0.018604651162790697,
 'again': 0.004651162790697674,
 'changes': 0.0

In [22]:
from nltk.corpus import gutenberg
from nltk.util import ngrams

gut_ngrams = (
    ngram for sent in gutenberg.sents() for ngram in ngrams(sent, 3,
    pad_left=True, pad_right=True, left_pad_symbol="<s>", right_pad_symbol="</s>"))

n = 0
for ngram in gut_ngrams:
    print(ngram)
    n += 1
    if n == 7:
        break

('<s>', '<s>', '[')
('<s>', '[', 'Emma')
('[', 'Emma', 'by')
('Emma', 'by', 'Jane')
('by', 'Jane', 'Austen')
('Jane', 'Austen', '1816')
('Austen', '1816', ']')


In [27]:
import math
import random
from collections import Counter, defaultdict

class KneserNeyLM:

    def __init__(self, highest_order, ngrams, start_pad_symbol='<s>',
            end_pad_symbol='</s>'):
        """
        Constructor for KneserNeyLM.
        Params:
            highest_order [int] The order of the language model.
            ngrams [list->tuple->string] Ngrams of the highest_order specified.
                Ngrams at beginning / end of sentences should be padded.
            start_pad_symbol [string] The symbol used to pad the beginning of
                sentences.
            end_pad_symbol [string] The symbol used to pad the beginning of
                sentences.
        """
        self.highest_order = highest_order
        self.start_pad_symbol = start_pad_symbol
        self.end_pad_symbol = end_pad_symbol
        self.lm = self.train(ngrams)

    def train(self, ngrams):
        """
        Train the language model on the given ngrams.
        Params:
            ngrams [list->tuple->string] Ngrams of the highest_order specified.
        """
        kgram_counts = self._calc_adj_counts(Counter(ngrams))
        probs = self._calc_probs(kgram_counts)
        return probs

    def highest_order_probs(self):
        return self.lm[0]

    def _calc_adj_counts(self, highest_order_counts):
        """
        Calculates the adjusted counts for all ngrams up to the highest order.
        Params:
            highest_order_counts [dict{tuple->string, int}] Counts of the highest
                order ngrams.
        Returns:
            kgrams_counts [list->dict] List of dict from kgram to counts
                where k is in descending order from highest_order to 0.
        """
        kgrams_counts = [highest_order_counts]
        for i in range(1, self.highest_order):
            last_order = kgrams_counts[-1]
            new_order = defaultdict(int)
            for ngram in last_order.keys():
                suffix = ngram[1:]
                new_order[suffix] += 1
            kgrams_counts.append(new_order)
        return kgrams_counts

    def _calc_probs(self, orders):
        """
        Calculates interpolated probabilities of kgrams for all orders.
        """
        backoffs = []
        for order in orders[:-1]:
            backoff = self._calc_order_backoff_probs(order)
            backoffs.append(backoff)
        orders[-1] = self._calc_unigram_probs(orders[-1])
        backoffs.append(defaultdict(int))
        self._interpolate(orders, backoffs)
        return orders

    def _calc_unigram_probs(self, unigrams):
        sum_vals = sum(v for v in unigrams.values())
        unigrams = dict((k, math.log(v/sum_vals)) for k, v in unigrams.items())
        return unigrams

    def _calc_order_backoff_probs(self, order):
        num_kgrams_with_count = Counter(
            value for value in order.values() if value <= 4)
        discounts = self._calc_discounts(num_kgrams_with_count)
        prefix_sums = defaultdict(int)
        backoffs = defaultdict(int)
        for key in order.keys():
            prefix = key[:-1]
            count = order[key]
            prefix_sums[prefix] += count
            discount = self._get_discount(discounts, count)
            order[key] -= discount
            backoffs[prefix] += discount
        for key in order.keys():
            prefix = key[:-1]
            order[key] = math.log(order[key]/prefix_sums[prefix])
        for prefix in backoffs.keys():
            backoffs[prefix] = math.log(backoffs[prefix]/prefix_sums[prefix])
        return backoffs

    def _get_discount(self, discounts, count):
        if count > 3:
            return discounts[3]
        return discounts[count]

    def _calc_discounts(self, num_with_count):
        """
        Calculate the optimal discount values for kgrams with counts 1, 2, & 3+.
        """
        common = num_with_count[1]/(num_with_count[1] + 2 * num_with_count[2])
        # Init discounts[0] to 0 so that discounts[i] is for counts of i
        discounts = [0]
        for i in range(1, 4):
            if num_with_count[i] == 0:
                discount = 0
            else:
                discount = (i - (i + 1) * common
                        * num_with_count[i + 1] / num_with_count[i])
            discounts.append(discount)
        if any(d for d in discounts[1:] if d <= 0):
            raise Exception(
                '***Warning*** Non-positive discounts detected. '
                'Your dataset is probably too small.')
        return discounts

    def _interpolate(self, orders, backoffs):
        """
        """
        for last_order, order, backoff in zip(
                reversed(orders), reversed(orders[:-1]), reversed(backoffs[:-1])):
            for kgram in order.keys():
                prefix, suffix = kgram[:-1], kgram[1:]
                order[kgram] += last_order[suffix] + backoff[prefix]

    def logprob(self, ngram):
        for i, order in enumerate(self.lm):
            if ngram[i:] in order:
                return order[ngram[i:]]
        return None

    def score_sent(self, sent):
        """
        Return log prob of the sentence.
        Params:
            sent [tuple->string] The words in the unpadded sentence.
        """
        padded = (
            (self.start_pad_symbol,) * (self.highest_order - 1) + sent +
            (self.end_pad_symbol,))
        sent_logprob = 0
        for i in range(len(sent) - self.highest_order + 1):
            ngram = sent[i:i+self.highest_order]
            sent_logprob += self.logprob(ngram)
        return sent_logprob

    def generate_sentence(self, min_length=4):
        """
        Generate a sentence using the probabilities in the language model.
        Params:
            min_length [int] The mimimum number of words in the sentence.
        """
        sent = []
        probs = self.highest_order_probs()
        while len(sent) < min_length + self.highest_order:
            sent = [self.start_pad_symbol] * (self.highest_order - 1)
            # Append first to avoid case where start & end symbal are same
            sent.append(self._generate_next_word(sent, probs))
            while sent[-1] != self.end_pad_symbol:
                sent.append(self._generate_next_word(sent, probs))
        sent = ' '.join(sent[(self.highest_order - 1):-1])
        return sent

    def _get_context(self, sentence):
        """
        Extract context to predict next word from sentence.
        Params:
            sentence [tuple->string] The words currently in sentence.
        """
        return sentence[(len(sentence) - self.highest_order + 1):]

    def _generate_next_word(self, sent, probs):
        context = tuple(self._get_context(sent))
        pos_ngrams = list(
            (ngram, logprob) for ngram, logprob in probs.items()
            if ngram[:-1] == context)
        # Normalize to get conditional probability.
        # Subtract max logprob from all logprobs to avoid underflow.
        _, max_logprob = max(pos_ngrams, key=lambda x: x[1])
        pos_ngrams = list(
            (ngram, math.exp(prob - max_logprob)) for ngram, prob in pos_ngrams)
        total_prob = sum(prob for ngram, prob in pos_ngrams)
        pos_ngrams = list(
            (ngram, prob/total_prob) for ngram, prob in pos_ngrams)
        rand = random.random()
        for ngram, prob in pos_ngrams:
            rand -= prob
            if rand < 0:
                return ngram[-1]
        return ngram[-1]

In [31]:
gut_ngrams = (
    ngram for sent in gutenberg.sents() for ngram in ngrams(sent, 3,
    pad_left=True, pad_right=True, left_pad_symbol="<s>", right_pad_symbol="</s>"))
Counter(gut_ngrams)

Counter({('<s>', '<s>', '['): 81,
         ('<s>', '[', 'Emma'): 1,
         ('[', 'Emma', 'by'): 1,
         ('Emma', 'by', 'Jane'): 1,
         ('by', 'Jane', 'Austen'): 3,
         ('Jane', 'Austen', '1816'): 1,
         ('Austen', '1816', ']'): 1,
         ('1816', ']', '</s>'): 1,
         (']', '</s>', '</s>'): 78,
         ('<s>', '<s>', 'VOLUME'): 3,
         ('<s>', 'VOLUME', 'I'): 1,
         ('VOLUME', 'I', '</s>'): 1,
         ('I', '</s>', '</s>'): 13,
         ('<s>', '<s>', 'CHAPTER'): 276,
         ('<s>', 'CHAPTER', 'I'): 8,
         ('CHAPTER', 'I', '</s>'): 4,
         ('<s>', '<s>', 'Emma'): 212,
         ('<s>', 'Emma', 'Woodhouse'): 1,
         ('Emma', 'Woodhouse', ','): 4,
         ('Woodhouse', ',', 'handsome'): 1,
         (',', 'handsome', ','): 6,
         ('handsome', ',', 'clever'): 1,
         (',', 'clever', ','): 1,
         ('clever', ',', 'and'): 4,
         (',', 'and', 'rich'): 4,
         ('and', 'rich', ','): 1,
         ('rich', ',', 'with'): 1,


In [25]:
lm.highest_order

3

In [26]:
lm.lm

[Counter({('1816', ']', '</s>'): -9.58744743822726,
          (']', '</s>', '</s>'): -16.989049875687268,
          ('<s>', '<s>', 'VOLUME'): -37.39646113320956,
          ('<s>', 'VOLUME', 'I'): -11.695919703497736,
          ('VOLUME', 'I', '</s>'): -16.789597499327503,
          ('I', '</s>', '</s>'): -15.290032852708608,
          ('<s>', '<s>', 'CHAPTER'): -30.89081331634736,
          ('<s>', 'CHAPTER', 'I'): -15.723677090821855,
          ('CHAPTER', 'I', '</s>'): -17.51082318178849,
          ('<s>', '<s>', 'Emma'): -27.65961381078666,
          ('<s>', 'Emma', 'Woodhouse'): -25.7678103823588,
          ('Emma', 'Woodhouse', ','): -8.73027774289052,
          ('Woodhouse', ',', 'handsome'): -28.172273395510345,
          (',', 'handsome', ','): -7.476058939160024,
          ('handsome', ',', 'clever'): -28.60914249746205,
          (',', 'clever', ','): -8.77951473694336,
          ('clever', ',', 'and'): -11.201016372200446,
          (',', 'and', 'rich'): -30.220856031286594,