In [1]:
import pickle
from math import log10
from pathlib import Path
from tqdm import tqdm

## Save Train Data as List of Sentences for Preprocessing

In [2]:
train = Path('../data/train')
train_data = []

In [3]:
with open(train, 'r') as f:
    for line in f:
        train_data.append(line[:-1])

In [4]:
with open('../data/train_data.pkl', 'wb') as f:
    pickle.dump(train_data, f)

## Train Data

In [5]:
with open('../data/train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)

In [6]:
data = train_data

In [7]:
data

['goran ivanišević is a retired professional tennis player from croatia .',
 'labials or labial consonants are consonants made with the lips .',
 'there are %NUMBER% pupils attending the school .',
 'navarrenx is a commune of the pyrénées - atlantiques département in the southwestern part of france .',
 'this is a list of users whose changes may be ignored while searching for vandalism .',
 'with a population of approximately %NUMBER% it is the third-largest city in the country .',
 'vladimír šmicer is a football player . he plays for slavia prague .',
 'she has a shrine on the island of enoshima in sagami bay , about %NUMBER% kilometers south of tokyo ; and a five-headed dragon and her are the main people in the enoshima engi , a history of the shrines on enoshima written by the japanese buddhist monk kokei .',
 'since the television series , there have been live tours of it .',
 "scuderia ferrari is the name for the gestione sportiva , the division of the ferrari automobile company c

## Dummy Data for Sanity Checking

In [8]:
# data = ['a a b b',
#        'a c a b',
#        'b a b a']

## Preprocess Data

In [9]:
vocab = set(['<s>', '</s>', '<UNK>'])
seen_words = set()
new_data = []
for line in data:
    words = line.split()
    sentence = ""
    for word in words:
        if word not in seen_words:
            seen_words.add(word)
            word = "<UNK>"
        else:
            if word not in vocab:
                vocab.add(word)
        sentence += word + " "
    new_data.append("<s> " + sentence + "</s>")
new_data, vocab

(['<s> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> </s>',
  '<s> <UNK> <UNK> <UNK> <UNK> <UNK> consonants <UNK> <UNK> <UNK> <UNK> . </s>',
  '<s> <UNK> are <UNK> <UNK> <UNK> the <UNK> . </s>',
  '<s> <UNK> is a <UNK> <UNK> the <UNK> <UNK> <UNK> <UNK> <UNK> the <UNK> <UNK> of <UNK> . </s>',
  '<s> <UNK> is a <UNK> of <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> . </s>',
  '<s> with a <UNK> of <UNK> %NUMBER% <UNK> is the <UNK> <UNK> in the <UNK> . </s>',
  '<s> <UNK> <UNK> is a <UNK> player . <UNK> <UNK> for <UNK> <UNK> . </s>',
  '<s> <UNK> <UNK> a <UNK> <UNK> the <UNK> of <UNK> in <UNK> <UNK> <UNK> <UNK> %NUMBER% <UNK> <UNK> of <UNK> <UNK> <UNK> a <UNK> <UNK> and <UNK> are the <UNK> <UNK> in the enoshima <UNK> , a <UNK> of the <UNK> on enoshima <UNK> <UNK> the <UNK> <UNK> <UNK> <UNK> . </s>',
  '<s> <UNK> the <UNK> <UNK> , there <UNK> <UNK> <UNK> <UNK> of it . </s>',
  '<s> <UNK> <UNK> is the <UNK> for the <UNK> <UNK> , the <UNK> of the ferrari <UNK

## Calculate Probabilities

In [10]:
def get_ngram_tables(new_data):
    num_words = 0
    unigrams = {}
    bigrams = {}
    trigrams = {}
    unigram_counts = 0
    bigram_counts = {}
    trigram_counts = {}
    
    for line in new_data:
        words = line.split()
        num_words += len(words)
        
        unigrams[words[0]] = unigrams.get(words[0], 0) + 1
        unigrams[words[1]] = unigrams.get(words[1], 0) + 1
        
        unigram_counts += 1
        unigram_counts += 1
        
        for i in range(2, len(words)):
            w0 = words[i-2]
            w1 = words[i-1]
            w2 = words[i]
            
            if w0 not in bigrams.keys():
                bigrams[w0] = {}
            if w0 not in trigrams.keys():
                trigrams[w0] = {}
                trigram_counts[w0] = {}
            if w1 not in trigrams[w0].keys():
                trigrams[w0][w1] = {}
            
            unigrams[w2] = unigrams.get(w2, 0) + 1
            bigrams[w0][w1] = bigrams[w0].get(w1, 0) + 1
            trigrams[w0][w1][w2] = trigrams[w0][w1].get(w2, 0) + 1
            
            unigram_counts += 1
            bigram_counts[w0] = bigram_counts.get(w0, 0) + 1
            trigram_counts[w0][w1] = trigram_counts[w0].get(w1, 0) + 1
            
        if w1 not in bigrams.keys():
            bigrams[w1] = {}
        bigrams[w1][w2] = bigrams[w1].get(w2, 0) + 1
        bigram_counts[w1] = bigram_counts.get(w1, 0) + 1
        
    return unigrams, bigrams, trigrams, unigram_counts, bigram_counts, trigram_counts
            

In [11]:
def unigram_probs(unigrams, unigram_counts):
    probs = {}
    for word in unigrams.keys():
        probs[word] = unigrams[word]/unigram_counts
    return probs

In [12]:
def bigram_probs(bigrams, bigram_counts):
    probs = {}
    for word1 in bigrams.keys():
        word2_dict = bigrams[word1].keys()
        for word2 in word2_dict:
            probs[word1] = probs.get(word1, {})
            probs[word1][word2] = bigrams[word1][word2] / bigram_counts[word1]
    return probs

In [13]:
def trigram_probs(trigrams, trigram_counts):
    probs = {}
    for word1 in trigrams.keys():
        word2_dict = trigrams[word1].keys()
        for word2 in word2_dict:
            word3_dict = trigrams[word1][word2].keys()
            for word3 in word3_dict:
                probs[word1] = probs.get(word1, {})
                probs[word1][word2] = probs[word1].get(word2, {})
                probs[word1][word2][word3] = trigrams[word1][word2].get(word3, 0) / trigram_counts[word1][word2]
    return probs

In [14]:
unigrams, bigrams, trigrams, unigram_counts, bigram_counts, trigram_counts = get_ngram_tables(new_data)
uni_probs = unigram_probs(unigrams, unigram_counts)
bi_probs = bigram_probs(bigrams, bigram_counts)
tri_probs = trigram_probs(trigrams, trigram_counts)

In [15]:
# Example
# trigram_counts['der']['sloot']
# trigrams['der']['sloot']

## Calculate $P_{abs}$

For all $xyz$ in vocab, calculate $P_{abs}(z | xy)$ as

$$P_{abs} = \begin{cases}\dfrac{C(xyz) - D}{C(xy)} & C(xyz)>0\\\alpha(xy)P_{abs}(z|y) & \text{otherwise}\end{cases}$$

where


$$\alpha(xy) = \dfrac{reserved\_mass(xy)}{1 - \sum\limits_{W: C(xyW) > 0}p(W | y)}$$

and

$$reserved\_mass(xy) = \dfrac{\text{# of types starting with $xy$} * D}{C(xyW)}$$

In [16]:
# Save calculated alphas and probabilities for reuse
alphas = {}
calculated_probs = {}

In [17]:
def get_reserved_mass_uni(word1, bigrams, D):
    if (word1 in bigrams.keys()):
        total_count = 0
        word2_lst = bigrams[word1]
        for word2 in word2_lst:
            total_count += bigrams[word1][word2]
        return (len(word2_lst) * D) / total_count

In [18]:
def get_reserved_mass_bi(word1, word2, trigrams, D):
    if (word1 in trigrams.keys()) and (word2 in trigrams[word1].keys()):
        total_count = 0
        word3_lst = trigrams[word1][word2]
        for word3 in word3_lst:
            total_count += trigrams[word1][word2][word3]
        return (len(word3_lst) * D) / total_count
    else:
        return 0

In [19]:
def get_alpha_uni(word1, vocab, uni_probs, bigrams, D):
    res_mass = get_reserved_mass_uni(word1, bigrams, D)
    if res_mass == 0:
        return 0
    
    denom = 1
    for word in vocab:
        if (word1 in bigrams.keys()) and (word in bigrams[word1].keys()):
            denom -= uni_probs[word]
    if denom <= 0:  # python rounding causes probs to be slightly less than 0 sometimes
        denom = 1e-9
    return res_mass / denom

In [20]:
def get_alpha_bi(word1, word2, vocab, bi_probs, trigrams, D):
    res_mass = get_reserved_mass_bi(word1, word2, trigrams, D)
    if res_mass == 0:
        return 0
    
    denom = 1
    for word in vocab:
        if (word1 in trigrams.keys()) and (word2 in trigrams[word1].keys()) and (word in trigrams[word1][word2].keys()):
            denom -= bi_probs[word2][word]
    if denom <= 0:  # python rounding causes probs to be slightly less than 0 sometimes
        denom = 1e-9
    return res_mass / denom

In [21]:
def get_p_abs_bi(word1, word2, unigrams, bigrams, uni_probs, D):
    if (word1 in bigrams.keys()) and (word2 in bigrams[word1].keys()):
        total = 0
        for word in bigrams[word1].keys():
            total += bigrams[word1][word]
        return max((bigrams[word1][word2] - D) / total, 0)
    else:
        if (word1, word2) in alphas.keys():
            alpha = alphas[(word1, word2)]
        else:
            alpha = get_alpha_uni(word1, vocab, uni_probs, bigrams, D)
            alphas[(word1, word2)] = alpha
        return uni_probs[word2] if alpha==0 else alpha * uni_probs[word2]

In [22]:
def get_p_abs_tri(word1, word2, word3, unigrams, bigrams, trigrams, uni_probs, bi_probs, tri_probs, vocab, D):
    if (word1 in trigrams.keys()) and (word2 in trigrams[word1].keys()) and (word3 in trigrams[word1][word2].keys()):
        total = 0
        for word in trigrams[word1][word2]:
            total += trigrams[word1][word2][word]
        return max((trigrams[word1][word2][word3] - D) / total, 0)
    else:
        if (word1, word2, word3) in alphas.keys():
            alpha = alphas[(word1, word2, word3)]
        else:
            alpha = get_alpha_bi(word1, word2, vocab, bi_probs, trigrams, D)
            alphas[(word1, word2, word3)] = alpha
            
        if (word2, word3) in calculated_probs:
            p_abs_2 = calculated_probs[(word2, word3)]
        else:
            p_abs_2 = get_p_abs_bi(word2, word3, unigrams, bigrams, uni_probs, D)
            calculated_probs[(word2, word3)] = p_abs_2
        return p_abs_2 if alpha==0 else alpha * p_abs_2

Below is an example we hand-calculated the probability for. The output matches our calculation.

In [23]:
# Example
get_p_abs_tri('<s>', 'to', 'this', unigrams, bigrams, trigrams, uni_probs, bi_probs, tri_probs, vocab, D=0.5)

0.015625

## Perplexity

In [24]:
def perplexity(test_sentences, unigrams, bigrams, trigrams, uni_probs, bi_probs, tri_probs, vocab, D=0.5):
    total_log_prob = 0
    total_trigrams = 0

    with open(test_sentences) as f:
        for line in tqdm(f):
            words = ['<s>'] + line.split() + ['</s>']

            for i in range(2, len(words)):
                a = words[i - 2] if words[i - 2] in vocab else '<UNK>'
                b = words[i - 1] if words[i - 1] in vocab else '<UNK>'
                c = words[i] if words[i] in vocab else '<UNK>'

                trigram_prob = get_p_abs_tri(a,b,c, unigrams, bigrams, trigrams, uni_probs, bi_probs, tri_probs, vocab, D)

                if trigram_prob == 0:
                    trigram_prob = 1e-9
                
                try:
                    total_log_prob += log10(trigram_prob)
                except:
                    print(a,b,c)
                    print(trigram_prob)
                    print(log10(trigram_prob))

            total_trigrams += len(words)

    return 10 ** ((-1) * (total_log_prob / total_trigrams))

In [25]:
test_sentences = '../data/dev'

In [26]:
perplexity(test_sentences, unigrams, bigrams, trigrams, uni_probs, bi_probs, tri_probs, vocab, D=0.5)

10000it [16:23, 10.17it/s]


74.23728334863344