# Kneser-Ney Smoothed Trigram
### Fetching and Organizing Data

In [28]:
from datasets import load_dataset
from tqdm import tqdm
import regex as re
import math

dataset = load_dataset("coastalcph/tydi_xor_rc")

languages = ['ar', 'ko', 'te']
train_dataset = dataset["train"].filter(lambda example: example['lang'] in languages)
val_dataset = dataset["validation"].filter(lambda example: example['lang'] in languages)

In [29]:
ko_questions = list(train_dataset.filter(lambda x: x["lang"] == "ko")["question"])
ar_questions =  list(train_dataset.filter(lambda x: x["lang"] == "ar")["question"])
te_questions = list(train_dataset.filter(lambda x: x["lang"] == "te")["question"])
en_context =  list(train_dataset["context"])

ko_questions_val = list(val_dataset.filter(lambda x: x["lang"] == "ko")["question"])
ar_questions_val =  list(val_dataset.filter(lambda x: x["lang"] == "ar")["question"])
te_questions_val = list(val_dataset.filter(lambda x: x["lang"] == "te")["question"])
en_context_val =  list(val_dataset["context"])

def UnfoldSentences(l):
    return [re.findall(r'\w+', sentence) for sentence in l]
    

ko_questions = UnfoldSentences(ko_questions)
ar_questions = UnfoldSentences(ar_questions)
te_questions = UnfoldSentences(te_questions)
en_context= UnfoldSentences(en_context)

ko_questions_val = UnfoldSentences(ko_questions_val)
ar_questions_val = UnfoldSentences(ar_questions_val)
te_questions_val = UnfoldSentences(te_questions_val)
en_context_val = UnfoldSentences(en_context_val)


### Model Training and Validation

In [None]:

  
class KneserNey:
    trigram = {}
    bigram = {}
    unigram = {}
    V = set()

    def __init__(self, name):
        self.model = {}
        self.name = name

    def Train(self, texts):
        self.V = set()
        #Unigram
        for sentence in tqdm(texts, desc="Training Unigram"):
            for w in range(len(sentence)):
                self.V.add(w)
                if sentence[w] in self.unigram.keys():
                    self.unigram[sentence[w]] += 1
                else:
                    self.unigram[sentence[w]] = 1

        #Bigram
        for sentence in tqdm(texts, desc="Training Bigram"):
            for w in range(len(sentence) - 1):
                contextKey = (sentence[w])
                followUpKey = sentence[w+1]

                if contextKey in self.bigram.keys():
                    if followUpKey in self.bigram[contextKey].keys():
                        self.bigram[contextKey][followUpKey] += 1
                    else:
                        self.bigram[contextKey][followUpKey] = 1
                else:
                    self.bigram[sentence[w]] = {followUpKey : 1}
        
        #Trigram
        for sentence in tqdm(texts, desc="Training Trigram"):
            for w in range(len(sentence) - 2):
                contextKey = (sentence[w], sentence[w+1])
                followUpKey = sentence[w+2]

                if contextKey in self.trigram.keys():
                    if followUpKey in self.bigram[contextKey].keys():
                        self.trigram[contextKey][followUpKey] += 1
                    else:
                        self.trigram[contextKey][followUpKey] = 1
                else:
                    self.trigram[sentence[w]] = {followUpKey : 1}

    def P(self, context, followup, type="Trigram"):
        #Discount
        D = 0.75

        if type=="Trigram":
            #Trigram with Discount
            if context in self.trigram.keys() and followup in self.trigram[context].keys():
                lmb = D * len(self.trigram[context]) /  sum(self.trigram[context].values())
                return max(0 , self.trigram[context][followup] - D) / sum(self.trigram[context].values()) + lmb * self.P((context[1]), followup, "Bigram")
            
            #Bigram BackOff
            else:
                return self.P((context[1]), followup, "Bigram")
        elif type=="Bigram":
            #Bigram with Discount
            if context in self.bigram.keys() and followup in self.bigram[context].keys():
                lmb = D * len(self.bigram[context]) /  sum(self.bigram[context].values())

                return self.bigram[context][followup] / sum(self.bigram[context].values()) + lmb * self.P(followup, followup, "Unigram")
            #Unigram BackOff
            else:
                return self.P(followup, followup, "Unigram")
            
        elif type=="Unigram":
            #Unigram
            if followup in self.unigram.keys():
                return self.unigram[followup] / sum(self.unigram.values())
            #Out of Vocabulary
            else:
                return 1/len(self.V)
            
                
    def Perplexity(self, wordset):
        Sum = 0
        for i in range(len(wordset) - 2):
            inside = self.P((wordset[i], wordset[i+1]), wordset[i+2])
            
            Sum += math.log(inside) if inside > 0 else float("-inf")

        return math.exp((-1/len(wordset)) * Sum)

    def AvgPerplexity(self, sentences):
        i = 0
        p = 0
        for sentence in tqdm(sentences):
            p += self.Perplexity(sentence)
            i+=1
        return p/i
    

    

In [30]:
ko_kneser = KneserNey("Korean Kneser")
ar_kneser = KneserNey("Arabic Kneser")
te_kneser = KneserNey("Telugu Kneser")
en_kneser = KneserNey("English Kneser")

ko_kneser.Train(ko_questions)
ar_kneser.Train(ar_questions)
te_kneser.Train(te_questions)
en_kneser.Train(en_context)     


Training Unigram: 100%|██████████| 2422/2422 [00:00<00:00, 397177.32it/s]
Training Bigram: 100%|██████████| 2422/2422 [00:00<00:00, 194226.03it/s]
Training Trigram: 100%|██████████| 2422/2422 [00:00<00:00, 344196.12it/s]
Training Unigram: 100%|██████████| 2558/2558 [00:00<00:00, 271930.80it/s]
Training Bigram: 100%|██████████| 2558/2558 [00:00<00:00, 319012.54it/s]
Training Trigram: 100%|██████████| 2558/2558 [00:00<00:00, 289817.12it/s]
Training Unigram: 100%|██████████| 1355/1355 [00:00<00:00, 266283.18it/s]
Training Bigram: 100%|██████████| 1355/1355 [00:00<00:00, 331618.74it/s]
Training Trigram: 100%|██████████| 1355/1355 [00:00<00:00, 330231.37it/s]
Training Unigram: 100%|██████████| 6335/6335 [00:00<00:00, 29742.34it/s]
Training Bigram: 100%|██████████| 6335/6335 [00:00<00:00, 16950.50it/s]
Training Trigram: 100%|██████████| 6335/6335 [00:00<00:00, 17530.46it/s]


In [31]:
def PrintPerplexity(model, val):
    print("-"*30)
    print(model.name)
    print(f"Validation - {model.AvgPerplexity(val):0.02f}")

PrintPerplexity(ko_kneser, ko_questions_val)
PrintPerplexity(ar_kneser, ar_questions_val)
PrintPerplexity(te_kneser, te_questions_val)
PrintPerplexity(en_kneser, en_context_val)
print("-------------------------------------------------------------------")


------------------------------
Korean Kneser


100%|██████████| 356/356 [00:00<00:00, 993.08it/s] 


Validation - 66.41
------------------------------
Arabic Kneser


100%|██████████| 415/415 [00:00<00:00, 751.49it/s]


Validation - 319.76
------------------------------
Telugu Kneser


100%|██████████| 384/384 [00:00<00:00, 639.20it/s]


Validation - 76.48
------------------------------
English Kneser


100%|██████████| 1155/1155 [00:54<00:00, 21.24it/s]

Validation - 366.84
-------------------------------------------------------------------



