# Imports

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import dill as pickle
import os
from random import shuffle
from random import randint
from random import uniform
from math import floor
import shutil
import codecs
import math
from collections import Counter
from collections import defaultdict

import numpy as np

import nltk
from nltk import word_tokenize
from nltk.util import ngrams
from nltk.corpus import words

from tqdm import tqdm

In [2]:
#nltk.download('punkt')
#nltk.download('words')

# Train on Brown corpus

In [4]:
#brown train and test
#TRAIN_PATH = '../data/brown/train/train_and_validate.txt'
#TEST_PATH = '../data/brown/test/test.txt'

#gutenberg train and test
#TRAIN_PATH = '../data/gutenberg/train/train_and_validate.txt'
#TEST_PATH = '../data/gutenberg/test/test.txt'

#both train, brown test
#TRAIN_PATH = '../data/both_train/both_train.txt'
#TEST_PATH = '../data/brown/test/test.txt'

#both train, gutenberg test
TRAIN_PATH = '../data/both_train/both_train.txt'
TEST_PATH = '../data/gutenberg/test/test.txt'



# Modified KN

In [5]:
class Modified_kn:
    
    def __init__(self,order):
        self.order = order
        self.ngram_counters = {}
        #self.unique_continuations_ge_1 = defaultdict(int) 
        #self.unique_continuations_e_1 = defaultdict(int) 
        #self.unique_continuations_e_2 = defaultdict(int) 
        #self.unique_continuations_ge_3 = defaultdict(int) 
        #self.unique_contexts_ge_1 = defaultdict(int) 
        
        #self.unique_continuations = {'ge1':{}, 'e1':{}, 'e2':{}, 'ge3':{}}
        #self.unique_contexts = {'ge1':{}}
        
        self.unique_continuations = {}
        self.unique_contexts = {}
        self.D = {}
        for i in range(1,order):
            self.unique_continuations[i] = {'ge1':{}, 'e1':{}, 'e2':{}, 'ge3':{}}
            self.unique_contexts[i] = {'ge1':{}}
        for i in range(order):    
            self.D[i+1] = {1:0.0,2:0.0,3:0.0}
        self.n1plus_dot_ngram_dot = {}
    
    def set_training_data(self,train_file):
        print("Setting training data")
        self.train_file = train_file
        f = codecs.open(self.train_file, encoding='utf-8')
        tokens = nltk.word_tokenize(f.read())
        f.close()
        tokens_original = tokens
        unigram_counter = Counter(ngrams(tokens,1))
        num_unk = 0
        for i in range(len(tokens_original)):
            if(  unigram_counter[tuple([tokens_original[i]])] < 2):  # <1 is bad
                tokens[i] = '<unk>'
                num_unk = num_unk + 1
        print("Num unk={}".format(num_unk))
        self.tokens = ["<pad>"]*max((self.order)-1,0) + tokens
        self.vocabulary = set(self.tokens)
        self.vocabulary_list = list(self.vocabulary)
        print("Tokens + padding ={}".format(len(self.tokens)))
    
    def compute_ngram_counts(self):
        print("Computing ngram counts")
        for i in range(self.order):
            self.ngram_counters[i+1] = Counter(ngrams(self.tokens,i+1))
    
    def compute_D(self):
        print("Computing discounts")
        n = {}
        for i in range(self.order):
            n[i+1] = {1:0.0,2:0.0,3:0.0,4:0.0}
        for i in range(self.order):
            for j in range(4):
                n[i+1][j+1] = Counter(self.ngram_counters[i+1].values())[j+1]
                
        print("n :{}".format(n))
        for i in range(1,self.order+1):
            y = (n[i][1] / (n[i][1] + 2.0*n[i][2]))# if n[i][2]>0 else 0
            #self.D[i][1] = 1 - ((2.0*y*(n[i][2]/n[i][1])))# if n[i][1] > 0 else 0)
            self.D[i][1] = 1 - (2.0*(n[i][2]/(n[i][1]+ (2.0*n[i][2]))))
            self.D[i][2] = 2 - ((3.0*y*(n[i][3]/n[i][2])))# if n[i][2] > 0 else 0)
            self.D[i][3] = 3 - ((4.0*y*(n[i][4]/n[i][3])))# if n[i][3] > 0 else 0)
        
        
    def compute_single_word_continuations_and_contexts(self):
        print("Computing context and continuation counts")
        for n in range(1,self.order):
            ngram_and_its_continuation_words = defaultdict(list)
            ngram_and_its_context_words = defaultdict(list)
            
            for ngram in self.ngram_counters[n+1].keys():
                ngram_and_its_continuation_words[ngram[:len(ngram)-1]].append(ngram[-1:])
                ngram_and_its_context_words[ngram[1:]].append(ngram[0:1])
            
            for ngram in ngram_and_its_continuation_words.keys():
                continuation_counts_of_each_word = Counter(ngram_and_its_continuation_words[ngram])
                number_of_each_continuation_count = Counter(continuation_counts_of_each_word.values())
                continuations_ge1 = len(ngram_and_its_continuation_words[ngram])
                continuations_e1 = number_of_each_continuation_count[1]
                continuations_e2 = number_of_each_continuation_count[2]
                continuations_ge3 = len([i for i in number_of_each_continuation_count.values() if i>=3])
                
                self.unique_continuations[n]['ge1'][ngram] = continuations_ge1
                self.unique_continuations[n]['e1'][ngram] = continuations_e1
                self.unique_continuations[n]['e2'][ngram] = continuations_e2
                self.unique_continuations[n]['ge3'][ngram] = continuations_ge3
                
            for ngram in ngram_and_its_context_words.keys():
                context_ge1 = len(set(ngram_and_its_context_words[ngram]))
                self.unique_contexts[n]['ge1'][ngram] = context_ge1
        for i in range(1,self.order):
            token = tuple(['<pad>'] * i)
            self.unique_contexts[i]['ge1'][(token)] = 1
    
    def compute_n1plus_dot_ngram_dot(self):
        print("Computing 'that' count ... i dont know what to call it")
        used_continuations = defaultdict(set)
        dot_ngram_dot_counts = defaultdict(int)
        for n in range(2,self.order+1):
            for ngram in self.ngram_counters[n]:
                mid = ngram[1:-1]
                last = ngram[-1]
                if(last not in used_continuations[mid]):
                    to_count = ngram[1:]
                    count = self.unique_contexts[len(to_count)]['ge1'][to_count]
                    dot_ngram_dot_counts[mid] = dot_ngram_dot_counts[mid] + count
                    used_continuations[mid].add(last)
        self.n1plus_dot_ngram_dot = dot_ngram_dot_counts
                
    def get_D(self,c,ngram):
        if(c == 0):
            return 0
        if(c == 1):
            return self.D[len(ngram)][1]
        if(c == 2):
            return self.D[len(ngram)][2]
        if(c >= 3):
            return self.D[len(ngram)][3]
    
    def train(self):
        print("Training now")
        self.compute_ngram_counts()
        self.compute_D()
        self.compute_single_word_continuations_and_contexts()
        self.compute_n1plus_dot_ngram_dot()
        print("done")
    
    def __preproc_test_input(self,tokens):
        '''converts words not in vocab to <unk>'''
        print("preprocessing test data")
        unk_count = 0
        for i in range(len(tokens)):
            if(tokens[i] not in  self.vocabulary):
                tokens[i] = '<unk>'
                unk_count = unk_count + 1
        tokens = ['<pad>']*max((self.order)-1,0) + tokens
        print("data length={}, number of <unk>={}".format(len(tokens),unk_count))
        return tokens
    
    def gamma(self, ngram):
        t1 = self.get_D(1,ngram) * self.unique_continuations[len(ngram)]['e1'].get(ngram,0)
        t2 = self.get_D(2,ngram) * self.unique_continuations[len(ngram)]['e2'].get(ngram,0)
        t3 = self.get_D(3,ngram) * self.unique_continuations[len(ngram)]['ge3'].get(ngram,0)
        #print("gamma:: {},{},{}".format(t1,t2,t3))
        t4 = 0
        if((len(ngram) == self.order) or (len(ngram) == self.order-1)):
            t4 = self.unique_contexts[len(ngram)]['ge1'].get(ngram,0)
            #if t4 == 0:
            #    print("gamma HIT,t4={},ngram={}".format(t4,ngram)) 
        else:
            t4 = self.n1plus_dot_ngram_dot[ngram]
            #if t4 == 0:
            #    print("gamma hit,t4={},ngram={}".format(t4,ngram)) 
            
        val = 1
        if(t1+t2+t3 > 0 and t4!=0):
            val = (t1+t2+t3)/float(t4)
        return val
    
    def get_prob(self,ngram):
        #print("ngram={}".format(ngram))
        if(len(ngram) == 1):
            val = self.unique_contexts[1]['ge1'][ngram] / self.n1plus_dot_ngram_dot[()]
            #print("get_prob:unigram prob for {} is {}".format(ngram,val))
        elif((len(ngram) == self.order) or (ngram[:self.order-1] == (['<pad>']*(self.order-1) ))):
            t1 = self.ngram_counters[len(ngram)].get(ngram,0)
            t2 = self.get_D(self.ngram_counters[len(ngram)].get(ngram,0) , ngram)
            t3 = max((t1-t2),0)
            #if(t1-t2<=0):
            #    print("t1-t2<=0, ngram={} , t1={}, t2={}".format(ngram,t1,t2))
            t4 = self.ngram_counters[len(ngram[:-1])][ngram[:-1]]
            t5 = (t3/float(t4)) if t4!=0 else 0
            t6 = self.gamma(ngram[:-1])
            #if(t6 <= 0):
            #    print("gamma for {} is <=0".format(ngram[:-1]))
            t7 = self.get_prob(ngram[1:])
            #if(t6 == 0):
            #    print("get_prob HIT")
            #    return t7
            val = t5 + t6*t7
        else:
            t1 = self.unique_contexts[len(ngram)]['ge1'].get(ngram,0)
            t2 = self.get_D(self.ngram_counters[len(ngram)].get(ngram,0) , ngram)
            t3 = max((t1-t2),0)
            #if(t1-t2<=0):
            #    print("t1-t2<=0, ngram={} , t1={}, t2={}".format(ngram,t1,t2))
            t4 = self.n1plus_dot_ngram_dot[ngram[1:-1]]
            #if(t4 == 0):
            #    print(ngram[1:-1])
            t5 = (float(t3)/t4) if t4!=0 else 0
            t6 = self.gamma(ngram[:-1])
            #if(t6 <= 0):
            #    print("gamma for {} is <=0".format(ngram[:-1]))
            t7 = self.get_prob(ngram[1:])
            #if(t6 == 0):
            #    print("get_prob hit")
            #    return t7
            val = t5 + t6*t7
        #if(val ==0 ):
        #    print("ngram={}".format(ngram))
        return val

    
    def get_perplexity(self,test_file):
        f = codecs.open(test_file, encoding='utf-8')
        tokens = nltk.word_tokenize(f.read())
        tokens = self.__preproc_test_input(tokens)
        f.close()
        n_grams = ngrams(tokens,self.order)
        sum_log_prob = 0.0
        print("Calculating perplexity")
        iter_count = 0
        for ngram in n_grams: 
            value = self.get_prob(ngram)
            if(value == 0):
                print("get_perplexity::ngram={}".format(ngram))
                return math.nan
            sum_log_prob = sum_log_prob + math.log(value)
            iter_count += 1
        print("num iterations={}".format(iter_count))
        return math.exp( -(1.0/len(tokens)) * sum_log_prob )
    
    def get_random_word(self):
        index = randint(0,len(self.vocabulary_list)-1)
        return self.vocabulary_list[index]
    
    def generate_text(self,num_tokens=20):
        #sentence = ['<pad>'] * (self.order-1) 
        sentence = ['The']
        while(len(sentence) < num_tokens+1):
            try:
                random_word = self.get_random_word()
                if(random_word == '<unk>'):
                    continue
                ngram = tuple(sentence[-self.order+1:] + [random_word])
                prob_ngram = self.get_prob(ngram)
                #print("bigram={}, prob={}".format(bigram,prob_bigram))
                random_number = uniform(0,1)
                if(random_number < prob_ngram):
                    sentence = sentence + [random_word]
                    print(sentence)
            except:
            #    continue
                raise 
        return ' '.join(sentence)
    
    

In [6]:
lm = Modified_kn(4)

In [7]:
lm.set_training_data(TRAIN_PATH)

Setting training data
Num unk=32770
Tokens + padding =2731302


In [8]:
lm.train()

Training now
Computing ngram counts
Computing discounts
n :{1: {1: 0, 2: 9427, 3: 5236, 4: 3382}, 2: {1: 577745, 2: 99142, 3: 39509, 4: 21465}, 3: {1: 1655201, 2: 134417, 3: 42362, 4: 19663}, 4: {1: 2320380, 2: 86357, 3: 20157, 4: 8205}}
Computing context and continuation counts
Computing 'that' count ... i dont know what to call it
done


In [9]:
lm.get_prob(('who','is'))

0.00026433308359387754

In [10]:
lm.get_perplexity(TEST_PATH)

preprocessing test data
data length=276979, number of <unk>=7984
Calculating perplexity
num iterations=276976


13905.56346081903

In [20]:
lm.generate_text(10)

['The', 'nine']
['The', 'nine', 'isnt']
['The', 'nine', 'isnt', 'good']
['The', 'nine', 'isnt', 'good', 'looks']
['The', 'nine', 'isnt', 'good', 'looks', 'concentration']
['The', 'nine', 'isnt', 'good', 'looks', 'concentration', 'begat']
['The', 'nine', 'isnt', 'good', 'looks', 'concentration', 'begat', 'sweet']
['The', 'nine', 'isnt', 'good', 'looks', 'concentration', 'begat', 'sweet', 'sovran']
['The', 'nine', 'isnt', 'good', 'looks', 'concentration', 'begat', 'sweet', 'sovran', 'curls']
['The', 'nine', 'isnt', 'good', 'looks', 'concentration', 'begat', 'sweet', 'sovran', 'curls', 'continually']


'The nine isnt good looks concentration begat sweet sovran curls continually'