In [10]:
import nltk
nltk.download('reuters')
nltk.download('punkt')

[nltk_data] Downloading package reuters to /root/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [2]:
from nltk.corpus import reuters

In [6]:
text = reuters.raw()

In [8]:
text[:500]

"ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RIFT\n  Mounting trade friction between the\n  U.S. And Japan has raised fears among many of Asia's exporting\n  nations that the row could inflict far-reaching economic\n  damage, businessmen and officials said.\n      They told Reuter correspondents in Asian capitals a U.S.\n  Move against Japan might boost protectionist sentiment in the\n  U.S. And lead to curbs on American imports of their products.\n      But some exporters said that while the conflict wo"

In [87]:
from collections import defaultdict

START_TOKEN = '<s>'
END_TOKEN = '<e>'

def preprocess_sentence(sentence):
    return [word.lower() for word in word_tokenize(sentence)]

def preprocess_text(text):
    return [preprocess_sentence(sentence) for sentence in sent_tokenize(text)]

def build_n_grams(sentences, n=2):
    n_grams = defaultdict(int)

    for sentence in sentences:
        sentence = [START_TOKEN] * (n-1) + sentence + [END_TOKEN]

        for i in range(len(sentence)-n+1):
            n_grams[tuple(sentence[i:i+n])] += 1
    
    return n_grams

def estimate_prob_n_gram(n_gram, n_grams, n_minus_1_grams, vocab_size, k=1):
    n_minus_1_gram = tuple(n_gram[:-1])
    return (n_grams.get(tuple(n_gram), 0) + k) / (n_minus_1_grams.get(n_minus_1_gram, 0) + k * vocab_size)

In [60]:
sentences = preprocess_text(text)

In [61]:
unigrams = build_n_grams(sentences, n=1)
bigrams = build_n_grams(sentences, n=2)

In [85]:
vocab = [key[0] for key in unigrams.keys()]

In [86]:
len(vocab)

52708

In [91]:
print(estimate_prob_n_gram(['they', 'told'], bigrams, unigrams, len(vocab)))

0.0001627457008010705


In [90]:
N_GRAMS_MAP = [unigrams, bigrams]

def suggest_word(text, vocab, context_size=1):
    sentence = preprocess_sentence(text)
    max_prob = None
    probable_word = None

    for word in vocab:
        context = sentence[-context_size:]
        n_gram = context + (context_size - len(context)) * [START_TOKEN] + [word]
        n_gram_prob = estimate_prob_n_gram(n_gram,
                                           n_grams=N_GRAMS_MAP[context_size],
                                           n_minus_1_grams=N_GRAMS_MAP[context_size-1],
                                           vocab_size=len(vocab))
        if max_prob is None or max_prob < n_gram_prob:
            max_prob = n_gram_prob
            probable_word = word
    
    return probable_word

In [98]:
print(suggest_word('move against the', vocab))

company


In [104]:
def calculate_perplexity(sentence, n_grams, n_minus_1_grams, vocab):
    n = len(list(n_grams.keys())[0])
    sentence = preprocess_sentence(sentence)
    sentence = [START_TOKEN] * (n-1) + sentence + [END_TOKEN]
    N = len(sentence)
    sentence_prob = 1

    for i in range(len(sentence)-n+1):
        prob = estimate_prob_n_gram(sentence[i:i+n], n_grams, n_minus_1_grams, len(vocab))
        sentence_prob *= prob
    
    return (1 / sentence_prob) ** (1 / N)

In [109]:
print(calculate_perplexity('Move against Japan might boost protectionist sentiment', bigrams, unigrams, vocab))

4409.841652024751


In [108]:
print(calculate_perplexity('Move against Japan might asian be export', bigrams, unigrams, vocab))

7317.823146207935
