In [None]:
import math
import random
from collections import Counter
from nltk.corpus import europarl_raw
from nltk import ngrams
from sklearn.cross_validation import train_test_split
from nltk.corpus import words

alpha = 0.01 #alpha value is chosen using the validation set
corpus = europarl_raw.english
delta = 0.75
P_cont_denom_bigrams = 0

#split(sentences) into train-->60% validation-->20% and test-->20%
train_init, test = train_test_split(corpus.sents(), train_size = 0.8,random_state=4542)
train, validation = train_test_split(train_init, train_size = 0.75,random_state=4572)

def getCount(countdict,key):   
#Helper function that returns # of uni/bi/tri -grams(key) occurencies in corpus.
#If there are not such occurencies in corpus returns 0
    try:
        return countdict[key]
    except KeyError:
        return 0  
    
cnt = Counter()
for word in [item for sublist in train for item in sublist]:
    cnt[word] += 1

#Remove rare words 
unigrams = { k:v for k, v in cnt.items() if v >= 10}
unigrams['<s>'] = 1
unigrams_size = sum(v for v in unigrams.values())
V = len(unigrams)


bigrams =  (ngram for sent in train for ngram in ngrams(sent, 2,
    pad_left=True, pad_right=True, pad_symbol='<s>'))

cnt2 = Counter()
for bigram in bigrams:
    cnt2[bigram] += 1


    
#remove bigrams containing rare words
bigrams_final = { k:v for k, v in cnt2.items() if ((k[0] in unigrams.keys()) and(k[1] in unigrams.keys())) }
bigrams_size = sum(v for v in bigrams_final.values())

#Auxiliary data structure for bigrams that  used in order to improve Knesser-Ney performance
bigrams_KN ={}
for k, v in bigrams_final.items():
    bigrams_KN[k[0]] = {}
for k, v in bigrams_final.items():
    bigrams_KN[k[0]][k[1]] = v

#Auxiliary data structure for bigrams that used in order to improve Knesser-Ney performance  
bigrams_KN_reverse ={}
for k, v in bigrams_final.items():
    bigrams_KN_reverse[k[1]] = {}
for k, v in bigrams_final.items():
    bigrams_KN_reverse[k[1]][k[0]] = v

#compute P(continuation) denominator for bigrams
wordsNum = []
for k,v in bigrams_final.items():
    wordsNum.append(k[0])
P_cont_denom_bigrams = len(set(wordsNum))
trigrams = (ngram for sent in train for ngram in ngrams(sent, 3,
            pad_left=True, pad_right=True, pad_symbol='<s>'))


cnt3 = Counter()
for trigram in trigrams:
    cnt3[trigram] += 1

#remove trigrams containing rare words
trigrams_final = { k:v for k, v in cnt3.items() if ((k[0] in unigrams.keys())
    and(k[1] in unigrams.keys()) and(k[2] in unigrams.keys())) }
trigrams_size = sum(v for v in trigrams_final.values())




def unigram_logprob_ls(unigram):
# laplace smoothing
    global alpha
    return math.log2(((getCount(unigrams,unigram))+alpha) / (unigrams_size+(alpha*V)) )

def bigram_logprob_ls(bigram):
# laplace smoothing
    global alpha
    return math.log2(((getCount(bigrams_final,bigram))+alpha ) / ((getCount(unigrams,(bigram[0])))+(alpha*V))) 
    
def bigram_logprob_Mod_KN(bigram):
# Modified Knesser ney 
    global delta
    try:
        return math.log2(((max((getCount(bigrams_final,bigram) -delta),0)) / getCount(unigrams,(bigram[0]))) 
                     +((bigram_Mod_KN_lamda(bigram[0])) * bigram_Mod_KN_Pcont(bigram[1])))
    
    except (ZeroDivisionError,ValueError):
        # return 0 if both bigram does not exist and P(continuation) is  zero
        return 0
    

    
def bigram_Mod_KN_lamda(word):
#Compute Modified Knesser ney  interpolation term(lamda)
    global delta
    # Compute The number of word types that can follow word
    try:
        times = len(bigrams_KN[word])
    except KeyError:
        return 0
    lamda = (delta/getCount(unigrams,word)) * times
    return lamda

def bigram_Mod_KN_Pcont(word):
#Compute P continuation
    global P_cont_denom_bigrams
    #Count distinct vocabulary words seen to proceede word
    try:
        count = len(bigrams_KN_reverse[word])
    except KeyError:
        return 0
    return count/P_cont_denom_bigrams
    

def trigram_logprob_ls(trigram):
# laplace smoothing
    global alpha
    return math.log2(((getCount(trigrams_final,trigram)) +alpha ) /
                     ((getCount(bigrams_final,(trigram[0],trigram[1])))+(alpha*V)))   

def logprob_sentence_bigram(sentence):
    sumprob = 0
    for i in range(len(sentence)-1):
        sumprob += bigram_logprob_ls((sentence[i],sentence[i+1]))
    return sumprob

def logprob_sentence_bigram_Mod_KN(sentence):
    sumprob = 0
    for i in range(len(sentence)-1):
        sumprob += bigram_logprob_Mod_KN((sentence[i],sentence[i+1]))
    return sumprob       
    
def logprob_sentence_trigram(sentence):
    sumprob = 0
    for i in range(len(sentence)-2):
        sumprob += trigram_logprob_ls((sentence[i],sentence[i+1],sentence[i+2]))
    return sumprob


    
def bigram_model_perplexity(corpus):
    #Compute cross entropy and perplexity of our bigram model
    sumprob = 0
    bigram_count = 0
    for sentence in corpus:
        sentence = ['<s>'] + sentence + ['<s>'] 
        bigram_count += (len(sentence) -1)
        sumprob += logprob_sentence_bigram(sentence)
    cross_entropy = -sumprob/bigram_count
    perpl = math.pow(2,cross_entropy)
    return cross_entropy,perpl


def bigram_model_perplexity_Mod_KN(corpus):
    #Compute cross entropy and perplexity of our bigram model with Modified Knesser-Ney Smoothing
    sumprob = 0
    bigram_count = 0
    for sentence in corpus:
        sentence = ['<s>'] + sentence + ['<s>'] 
        bigram_count += (len(sentence) -1)
        sumprob += logprob_sentence_bigram_Mod_KN(sentence)
    cross_entropy = -sumprob/bigram_count
    perpl = math.pow(2,cross_entropy)
    return cross_entropy,perpl

def trigram_model_perplexity(corpus):
    #Compute cross entropy and perplexity of our trigram model
    sumprob = 0
    trigram_count = 0
    for sentence in corpus:
        sentence = ['<s>','<s>'] + sentence + ['<s>','<s>'] 
        trigram_count += (len(sentence) -2)
        sumprob += logprob_sentence_trigram(sentence)
    cross_entropy = -sumprob/trigram_count
    perpl = math.pow(2,cross_entropy)
    return cross_entropy,perpl
    
def score_sentence_vs_random(corpus):
    sent= random.choice(corpus)
    sent = ['<s>','<s>'] + sent + ['<s>','<s>'] 
    score_sent = logprob_sentence_trigram(sent)
    print(sent,"score:",score_sent)
    random_sent = []
    for i in range(len(sent)-4):
        random_sent.append(random.choice(words.words()))
    random_sent =['<s>','<s>']+ random_sent + ['<s>','<s>']
    score_random_sent = logprob_sentence_trigram(random_sent)
    print(random_sent,"score:",score_random_sent)

def predict_next_word(sentence):
    #predict next word based on most frequent relevant trigrams/ bigrams
    suggestions = []
    if (len(sentence) >= 2):
        trigrams = { k:v for k, v in trigrams_final.items() if ((k[0] == sentence[-2])
            and(k[1] == sentence[-1]))  }
        if (len(trigrams)>0):
            sortdict = [(k, trigrams[k]) for k in sorted(trigrams, key=trigrams.get, reverse=True)]
            for k,v in sortdict[:3]:
                suggestions.append(k[2])
            
    else:
        bigrams = { k:v for k, v in bigrams_final.items() if ((k[0] == sentence[-1])) }
        if (len(bigrams)>0):
            sortdict = [(k, bigrams[k]) for k in sorted(bigrams, key=bigrams.get, reverse=True)]
            for k,v in sortdict[:3]:
                suggestions.append(k[1])
    
    
    return suggestions
       
def bigram_trigram_interpolation_prob(trigram,lamda1,lamda2):
    #interpolate bigrams and trigrams
    assert(lamda1+lamda2==1.0),"error interpolation coefs must sum to 1!"
    return ((lamda1*trigram_logprob_ls((trigram[0],trigram[1],trigram[2]))) + 
             (lamda2*bigram_logprob_ls((trigram[1],trigram[2]))))

def interpolated_sentence(sentence,lamda1,lamda2):
    #propability of sentence using the interpolation of the bigram/trigram models
    sumprob = 0
    for i in range(len(sentence)-2):
        sumprob += bigram_trigram_interpolation_prob((sentence[i],sentence[i+1],sentence[i+2]),lamda1,lamda2)
    return sumprob

def interpolated_model_perplexity(corpus,lamda1,lamda2):
    #Compute cross entropy and perplexity of our trigram model
    sumprob = 0
    trigram_count = 0
    for sentence in corpus:
        sentence = ['<s>','<s>'] + sentence + ['<s>','<s>'] 
        trigram_count += (len(sentence) -2)
        sumprob += interpolated_sentence(sentence,lamda1,lamda2)
    cross_entropy = -sumprob/trigram_count
    perpl = math.pow(2,cross_entropy)
    return perpl


# for i in range(0, 101, 1):
#     lamda1 = i/100
#     lamda2 = 1-lamda1
#     print("lamda1: ",lamda1," lamda2: ",lamda2," perplexity--> "
#           ,interpolated_model_perplexity(validation,lamda1,lamda2))
#Best model where lamda1=0 and lamda2 =1
    


#Use validation data to tune alpha parameter
# for i in range(1,101,1):
#     alpha = i/1000
#     print("alpha:", alpha)
#     print(bigram_model_perplexity(validation))

#     print(trigram_model_perplexity(validation))


print(bigram_model_perplexity_Mod_KN(train))
print(bigram_model_perplexity(train))
print(trigram_model_perplexity(train))
print(bigram_model_perplexity_Mod_KN(test))
print(bigram_model_perplexity(test))
print(trigram_model_perplexity(test))

score_sentence_vs_random(test)
print(predict_next_word(["treaty"]))

#Results @ TEST(Cross Entropy, perplexity)
# (3.74168230957505, 13.376996376432324) --> bigram Knesser-Kney 
# (7.3882090675850725, 167.52226781094902) --> bigram laplace smoothing
# (8.203994032606882, 294.8820167238192) --> --> trigram laplace smoothing



