# Imports

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import dill as pickle
import os
from random import shuffle
from random import randint
from random import uniform
from math import floor
import shutil
import codecs
import math
from collections import Counter
from collections import defaultdict

import numpy as np

import nltk
from nltk import word_tokenize
from nltk.util import ngrams
#from nltk.corpus import words

from tqdm import tqdm

In [2]:
#nltk.download('punkt')
#nltk.download('words')

In [3]:
print(words.readme())

Wordlists

en: English, http://en.wikipedia.org/wiki/Words_(Unix)
en-basic: 850 English words: C.K. Ogden in The ABC of Basic English (1932)



# Number of words in each dataset

In [4]:
PATH_brown ='../data/brown/brown_all/'
PATH_gutenberg ='../data/gutenberg/gutenberg_all/'

## Number of words in Brown

In [5]:
!find {PATH_brown} -name '*.*' | xargs cat | wc -w

1005088


## Number of words in Gutenberg

In [6]:
!find {PATH_gutenberg} -name '*.*' | xargs cat | wc -w

2102546


# Split data into test and train

In [7]:
class DatasetSplitter:

    def __init__(self):
        pass
    
    def __get_file_list_from_dir(self,datadir):
        all_files = os.listdir(os.path.abspath(datadir))
        data_files = list(filter(lambda file: file.endswith('.txt'), all_files))
        return data_files
    
    def __copy_files(self,src_path,target_path,file_list):
        for file_name in file_list:
            full_file_name = os.path.join(src_path, file_name)
            if (os.path.isfile(full_file_name)):
                shutil.copy(full_file_name, target_path)
        
    def __make_dir(self,directory):
        try:
            os.makedirs(directory)
            return True
        except OSError as e:
            if e.errno != errno.EEXIST:
                return False
        
    def get_train_and_test_sets(self,datadir,percentage_train=0.9):
        file_list = self.__get_file_list_from_dir(datadir)
        shuffle(file_list)
        split = percentage_train
        split_index = floor(len(file_list) * split)
        training = file_list[:split_index]
        testing = file_list[split_index:]
        return training, testing 
    
    def get_train_validate_test_sets(self,datadir,percentage_train=0.9,percentage_validate=0.1):
        train,test = self.get_train_and_test_sets(datadir,percentage_train=percentage_train)
        split = 1-percentage_validate
        split_index = floor(len(train) * split)
        validate = train[split_index:]
        train = train[:split_index]
        return train,validate,test

    def copy_to_dirs(self,src_path,target_path,train,validate,test):
        if self.__make_dir(target_path+"train"):
            self.__copy_files(src_path,target_path+"train/",train)
        if self.__make_dir(target_path+"validate"):
            self.__copy_files(src_path,target_path+"validate/",validate)
        if self.__make_dir(target_path+"test"):
            self.__copy_files(src_path,target_path+"test/",test)
            

#data_splitter = DatasetSplitter()
#PATH = PATH_brown
#test,validate,train = data_splitter.get_train_validate_test_sets(PATH)
#data_splitter.copy_to_dirs(PATH,PATH+"../",test,validate,train)

# Train on Brown corpus

In [8]:
TRAIN_PATH = '../data/brown/train/train.txt'
TEST_PATH = '../data/brown/validate/validate.txt'

#TRAIN_PATH = '../data/brown/sample/good_sent.txt'
#TEST_PATH = '../data/brown/sample/good_sent.txt'

f = codecs.open(TRAIN_PATH, encoding='utf-8')
tokens = nltk.word_tokenize(f.read())
tokens_original = tokens

type(tokens)

for token in tokens:
    # print(token)
    pass

bigrams = nltk.ngrams(['a','b'],2)
type(bigrams)

# replace "few" rare words with < unk >  (no spaces inside the unk)

unigram_counter = Counter(ngrams(tokens,1))
tuple(["it",'is'])
#unigram_counter[('it',)]

convert_to_unk = []
for item in unigram_counter.keys():
    if unigram_counter[item] < 2:
        convert_to_unk.append(item[0])
num_unks_needed = max(floor(0.01 * unigram_counter.most_common(1)[0][1]),2)
convert_to_unk = convert_to_unk[:num_unks_needed]
for i in range(len(tokens_original)):
    if(tokens_original[i] in convert_to_unk):
        tokens[i] = '<unk>'

unigram_counter = Counter(ngrams(tokens,1))
print(unigram_counter[('<unk>',)])


In [37]:
class KNLM:
    def __init__(self,order=2):
        self.order = 2
    
    def set_training_data(self,train_file):
        '''
        reads file, tokenizes, replaces "some" rare words with <unk>. stores these in the instance 
        variable tokens
        '''
        self.train_file = train_file
        f = codecs.open(self.train_file, encoding='utf-8')
        tokens = nltk.word_tokenize(f.read())
        f.close()
        tokens_original = tokens
        unigram_counter = Counter(ngrams(tokens,1))
#        convert_to_unk = []
#        for item in unigram_counter.keys():
#            if unigram_counter[item] < 2:
#                convert_to_unk.append(item[0])               
#        num_unks_needed = max(floor(1 * unigram_counter.most_common(1)[0][1]),2)
#        convert_to_unk = convert_to_unk[:num_unks_needed]
#        for i in range(len(tokens_original)):
#            if(tokens_original[i] in convert_to_unk):
#                tokens[i] = '<unk>'
#        print("Num unk={}".format(num_unks_needed))
        
        num_unk = 0
        for i in range(len(tokens_original)):
            if(  unigram_counter[tuple([tokens_original[i]])] < 3):
                tokens[i] = '<unk>'
                num_unk = num_unk + 1
        print("Num unk={}".format(num_unk))
        
        
        self.tokens = ["<pad>"] + tokens
#        self.text = nltk.Text(self.tokens)
#        self.token_searcher = nltk.TokenSearcher(self.tokens)
        self.unigram_counter = Counter(ngrams(self.tokens,1))
        self.bigram_counter = Counter(ngrams(self.tokens,2))
#        self.vocabulary = self.text.vocab()
        self.vocabulary = set(self.tokens)
        self.vocabulary_list = list(self.vocabulary)
        #print(unigram_counter[('<unk>',)])
        
    def train(self):
        print("Training : Getting counts from training data")
        self.unique_continuations = defaultdict(set)
        self.unique_contexts = defaultdict(set)
        for bigram in tqdm(self.bigram_counter.keys(),total=len(self.bigram_counter.keys())):
            self.unique_continuations[bigram[0]].add(bigram[1])
            self.unique_contexts[bigram[1]].add(bigram[0])
    
    def get_log_prob_kn(self,ngram):
        abs_discount = 0.85
        count_input_ngram = max(self.bigram_counter[ngram] - abs_discount , 0)
        count_context = self.unigram_counter[tuple([ngram[0]])]
        num_unique_continuations = len(self.unique_continuations[ngram[0]])
        num_unique_contexts = len(self.unique_continuations[ngram[1]])
#        for bigram in self.bigram_counter.keys():
#            if(bigram[0] == ngram[0]):
#                unique_continuations.add(bigram[-1])    
#            if(bigram[-1] == ngram[-1]):
#                unique_contexts.add(bigram[0])
#        print("ngram:{}, count;{}".format(tuple([ngram[0]]),count_context))
        interpolation_weight = (abs_discount / count_context) * num_unique_continuations
        continuation_probability = num_unique_contexts / len(self.bigram_counter)
        p_kn = (count_input_ngram / count_context) + (interpolation_weight * continuation_probability)
        #return math.log(p_kn)
        return p_kn
    
    def __preproc_test_input(self,tokens):
        '''converts words not in vocab to <unk>'''
        print("preprocessing training data")
        unk_count = 0
        for i in range(len(tokens)):
            if(tokens[i] not in  self.vocabulary):
                tokens[i] = '<unk>'
                unk_count = unk_count + 1
        tokens = ['<pad>'] + tokens
        print("data length={}, number of <unk>={}".format(len(tokens),unk_count))
        return tokens
    
    def get_perplexity(self,test_file):
        f = codecs.open(test_file, encoding='utf-8')
        tokens = nltk.word_tokenize(f.read())
        tokens = self.__preproc_test_input(tokens)
        f.close()
        bigrams = nltk.ngrams(tokens,2)
        sum_log_prob = 0
        print("Calculating perplexity")
        iter_count = 0
        for bigram in tqdm(bigrams,total=len(tokens)-1): 
            prob = math.log(self.get_log_prob_kn(bigram))
            sum_log_prob = sum_log_prob + prob
            iter_count += 1
        print("Num iterations = {}".format(iter_count))
        return math.exp( -(1.0/len(tokens)) * sum_log_prob )
    
    def get_random_word(self):
        index = randint(0,len(self.vocabulary_list)-1)
        return kn_lm.vocabulary_list[index]

    
    def generate_text(self,num_tokens=10):
        sentence = ['<pad>']
        for i in range(num_tokens+1):
            random_word = self.get_random_word()
            bigram = tuple(sentence[-1:] + [random_word])
            #prob_bigram = math.exp(self.get_log_prob_kn(bigram))
            prob_bigram = self.get_log_prob_kn(bigram)
            #print("bigram={}, prob={}".format(bigram,prob_bigram))
            random_number = uniform(0,1)
            if(random_number < prob_bigram):
                sentence = sentence + [random_word]
        return ' '.join(sentence[1:])
        
    

In [38]:
%%time
kn_lm = KNLM()

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 10.3 µs


In [39]:
%%time
kn_lm.set_training_data(TRAIN_PATH)

Num unk=29786
CPU times: user 6.19 s, sys: 32 ms, total: 6.22 s
Wall time: 6.22 s


In [40]:
%%time
kn_lm.train()

 21%|██▏       | 70413/331239 [00:00<00:00, 703971.22it/s]

Training : Getting counts from training data


100%|██████████| 331239/331239 [00:00<00:00, 461111.38it/s]

CPU times: user 708 ms, sys: 16 ms, total: 724 ms
Wall time: 720 ms





In [41]:
%%time
kn_lm.get_perplexity(TEST_PATH)

 29%|██▉       | 26649/90646 [00:00<00:00, 266443.88it/s]

preprocessing training data
data length=90647, number of <unk>=5977
Calculating perplexity


100%|██████████| 90646/90646 [00:00<00:00, 267576.20it/s]

Num iterations = 90646
CPU times: user 780 ms, sys: 4 ms, total: 784 ms
Wall time: 780 ms





416.2034752846155

In [42]:
kn_lm.generate_text()

''

In [43]:
sentence = ['<pad>']
while(len(sentence) < 20):
    random_word = kn_lm.get_random_word()
    if(random_word == '<unk>'):
        continue
    bigram = tuple(sentence[-1:] + [random_word])
    prob_bigram = kn_lm.get_log_prob_kn(bigram)
    #print("bigram={}, prob={}".format(bigram,prob_bigram))
    random_number = uniform(0,1)
    if(random_number < prob_bigram):
        sentence = sentence + [random_word]
' '.join(sentence[1:])

'he could be no you and destroy the membership established meticulously blue place redoute lot of the only the'

In [44]:
a = ['a','b','c']
b = ['d']
a[1:]


['b', 'c']

In [35]:
kn_lm.get_log_prob_kn(('who','is'))

0.042908161111036014

In [36]:
kn_lm.get_log_prob_kn(('what','is'))

0.09084592740529628