In [1]:
%load_ext autoreload
%autoreload 2

In [268]:
from nltk import ngrams
from utils import get_tokenized_sentences
from math import log, isclose

In [3]:
def log2(x):
    return log(x, 2)

In [4]:
class WordCounter:
    def __init__(self, sentence_generator):
        self.sentence_generator = sentence_generator
        self.sentence_count = 0
        self.token_count = 0
        self.all_ngram_counts = {}
        for ngram_length in range(1, 6):
            self.all_ngram_counts[ngram_length] = {}
        
    def count(self):
        for sentence in self.sentence_generator:
            if sentence:
                self.sentence_count += 1
            for token in sentence:
                self.token_count += 1
            for ngram_length in range(1, 6):
                ngram_counts = self.all_ngram_counts[ngram_length]
                for i, sentence_ngram in enumerate(ngrams(sentence, ngram_length)):
                    ngram_count = ngram_counts.setdefault(sentence_ngram, {'start': 0, 'all': 0})
                    if i == 0:
                        ngram_count['start'] += 1
                    ngram_count['all'] += 1

In [338]:
trainW_generator = get_tokenized_sentences('data/trainW_token_end.txt')
trainW = WordCounter(trainW_generator)
trainW.count()

In [339]:
trainT_generator = get_tokenized_sentences('data/trainT_token_end.txt')
trainT = WordCounter(trainT_generator)
trainT.count()

In [340]:
test1_generator = get_tokenized_sentences('data/test1_token_end.txt')
test1 = WordCounter(test1_generator)
test1.count()

In [341]:
test2_generator = get_tokenized_sentences('data/test2_token_end.txt')
test2 = WordCounter(test2_generator)
test2.count()

In [342]:
train_debug_generator = get_tokenized_sentences('data/train_debug_token_end.txt')
train_debug = WordCounter(train_debug_generator)
train_debug.count()

In [343]:
test_debug_generator = get_tokenized_sentences('data/test_debug_token_end.txt')
test_debug = WordCounter(test_debug_generator)
test_debug.count()

# ngram model

### Unigram model

In [461]:
class UnigramModel():
    def __init__(self, train, k=1):
        self.k = k
        self.train_unigram_counts = train.all_ngram_counts[1].copy()
        self.train_unigrams = set(self.train_unigram_counts.keys())
        self.train_unigram_counts[('<UNK>',)] = {'all': 0, 'start': 0}
        
        self.train_prob_denom = train.token_count + len(self.train_unigram_counts) * self.k
        self.train_prob_noms = {}
        self.train_probs = {}
        
        for unigram, unigram_count in self.train_unigram_counts.items():
            prob_nom = self.train_unigram_counts[unigram]['all'] + self.k
            self.train_prob_noms[unigram] = prob_nom
            self.train_probs[unigram] = prob_nom / self.train_prob_denom
        assert isclose(sum(self.train_probs.values()), 1, rel_tol=1e-5)
        
    def calculate_avg_ll(self, test):
        self.test_ll = 0
        self.test_unigram_counts = test.all_ngram_counts[1].copy()
        self.test_modified_unigram_counts = {}
        self.test_unigram_lls = {}
        
        for unigram, unigram_count in self.test_unigram_counts.items():
            if unigram not in self.train_unigrams:
                unigram = (('<UNK>',))
            unigram_ll = unigram_count['all'] * log2(self.train_probs[unigram])
            self.test_ll += unigram_ll
            self.test_modified_unigram_counts[unigram] = self.test_modified_unigram_counts.get(unigram, 0) + unigram_count['all']
            self.test_unigram_lls[unigram] = self.test_unigram_lls.get(unigram, 0) + unigram_ll
        
        assert isclose(self.test_ll, sum(self.test_unigram_lls.values()), rel_tol=1)
        assert sum(self.test_modified_unigram_counts.values()) == test.token_count
        
        self.avg_test_ll = self.test_ll / test.token_count
        return self.avg_test_ll

In [478]:
unigram_model_debug = UnigramModel(train_debug, k=1)
unigram_model_debug.calculate_avg_ll(test_debug)

-3.051738873949783

In [484]:
unigram_model_trainT = UnigramModel(trainT, k=1)
unigram_model_trainT.calculate_avg_ll(test1)

-9.561905664153294

In [468]:
unigram_model_trainW = UnigramModel(trainW, k=1)

for text in [trainW, trainT, test1, test2]:
    print(unigram_model_trainW.calculate_avg_ll(text))

-18.497573151470494
-18.555918913789192
-18.54761854416354
-18.54868310979827


In [488]:
unigram_model_trainT = UnigramModel(trainT, k=1)

for text in [trainW, trainT, test1, test2]:
    print(unigram_model_trainT.calculate_avg_ll(text))

-18.474849865464517
-9.321327346878743
-9.561905664153294
-10.221738332232865


### Bigram model