In [2]:
import xml.etree.ElementTree as ET
import random
import math
import os

In [8]:
def split_corpus(directory, test_ratio=0.2):
    file_list = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(".xml"):
                file_list.append(os.path.join(root, file))

    num_test_files = int(len(file_list) * test_ratio)
    test_files = random.sample(file_list, num_test_files)
    train_files = [file for file in file_list if file not in test_files]
    return train_files, test_files

def parse_xml(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    words = []
    for sentence in root.findall(".//s"):
        words.append("<s>")
        for wtext in sentence.findall(".//w"):
            if wtext.text is not None and wtext.text.strip().lower() != "":
                words.append(wtext.text.strip().lower())
        words.append("</s>")
    return words

def split_list(lst, delimiter):
    result = []
    sublist = []
    for item in lst:
        if item == delimiter:
            if sublist:
                result.append(sublist)
                sublist = []
        else:
            sublist.append(item)
    if sublist:
        result.append(sublist)
    return result

root_dir = os.path.join("British National Corpus, Baby edition", "Texts")

In [4]:
train_files, test_files = split_corpus(root_dir)

print("Number of files in training set:", len(train_files))
print("Number of files in test set:", len(test_files))

train_words = sum([parse_xml(xml_file) for xml_file in train_files], [])
test_words = sum([parse_xml(xml_file) for xml_file in test_files], [])
all_words = train_words + test_words

print("Number of word in training set:" , len(train_words))
print("Number of word in test set:", len(test_words))
print("Number of words in all corpus:", len(train_words) + len(test_words))

print("Number of unique words in training set:", len(set(train_words)))
print("Number of unique words in test set:", len(set(test_words)))
print("Number of unique words in all corpus:", len(set(train_words + test_words)))

Number of files in training set: 146
Number of files in test set: 36
Number of word in training set: 3989656
Number of word in test set: 678057
Number of words in all corpus: 4667713
Number of unique words in training set: 76562
Number of unique words in test set: 33934
Number of unique words in all corpus: 85768


In [None]:
class VanillaLanguageModel:
    def __init__(self, xml_files: list = None, words: list[str] = None):
        self.xml_files = xml_files
        if words is not None:
            self.words = words
        elif xml_files is not None:
            self.words = self._parse_xml_files()
        self.gram = {
            1: {},
            2: {},
            3: {},
        }

    def _parse_xml_files(self):
        words = sum([parse_xml(xml_file) for xml_file in self.xml_files], [])
        return words

    def _parse_xml(self, xml_file):
        return parse_xml(xml_file)

    def _count_ngrams(self, words: str, ngram: int):
        grams = self._tokenize(words, ngram)
        gram_counts: dict = {}
        for gram in grams:
            if gram in gram_counts:
                gram_counts[gram] += 1
            else:
                gram_counts[gram] = 1
        return gram_counts

    def _tokenize(self, words: str, ngram: int):
        ngrams = [tuple(words[i:i+ngram]) for i in range(len(words)-ngram+1)]
        return ngrams

    def _calculate_probability(self, ngram: int, word: str, context: tuple):
        numerator = self.gram[ngram].get(context + (word,), 0)
        if ngram - 1 >= 1:
            denominator = self.gram[ngram - 1].get(context, 0)
        else:
            denominator = len(self.gram[ngram])
        if numerator == 0:
            return 0
        return math.log(numerator / denominator)

    def _calculate_sentence_probability(self, tokens: list, ngram: int):
        n = len(tokens)
        probability: float = 0.0

        for i in range(0, n):
            if ngram == 1:
                probability += self._calculate_probability(1, tokens[i], tuple())
            else:
                if i < ngram - 1:
                    _prob = self._calculate_probability(i+1, tokens[i], tuple(tokens[:i]))
                    probability += _prob
                else:
                    _prob = self._calculate_probability(ngram, tokens[i], tuple(tokens[i-ngram+1:i]))
                    probability += _prob
        return probability

    def _check_ngram(self, ngram: int):
        if ngram in self.gram:
            return True
        else:
            raise ValueError(f"Invalid ngram. Choose from 1 to {len(self.gram)}.")

    def train(self, ngram: int = 3):
        for i in range(1, ngram+1):
            self.gram[i] = self._count_ngrams(self.words, i)

    def get_top_ngrams(self, ngram, top_n=10):
        self._check_ngram(ngram)

        ngram_counts: dict = self.gram[ngram]

        return sorted(ngram_counts.items(), key=lambda x: x[1], reverse=True)[:top_n]

    def get_ngam_probability(self, ngram: int):
        total_count = sum(self.gram[ngram].values())
        ngram_probabilities = {k: v / total_count for k, v in self.gram[ngram].items()}
        return ngram_probabilities

    def get_top_ngram_probability(self, ngram: int, top_n=10):
        ngram_probabilities = self.get_ngam_probability(ngram)
        return sorted(ngram_probabilities.items(), key=lambda x: x[1], reverse=True)[:top_n]

    def calculate_perplexity(self, words: list, ngram: int):
        self._check_ngram(ngram)

        log_sentence_probability = self._calculate_sentence_probability(words, ngram)
        perplexity = math.exp(-log_sentence_probability / len(words))
        return perplexity

In [None]:
vanilla_model = VanillaLanguageModel(words=train_words)
vanilla_model.train()

print("Top 10 unigrams:", vanilla_model.get_top_ngrams(1))
print("Top 10 bigrams:", vanilla_model.get_top_ngrams(2))
print("Top 10 trigrams:", vanilla_model.get_top_ngrams(3))
print("Top 10 unigram probabilities:", vanilla_model.get_top_ngram_probability(1))
print("Top 10 bigram probabilities:", vanilla_model.get_top_ngram_probability(2))
print("Top 10 trigram probabilities:", vanilla_model.get_top_ngram_probability(3))

Top 10 unigrams: [(('<s>',), 258418), (('</s>',), 258418), (('the',), 164216), (('of',), 78712), (('and',), 72920), (('to',), 72602), (('a',), 67912), (('in',), 53041), (('it',), 48151), (('i',), 48014)]
Top 10 bigrams: [(('</s>', '<s>'), 258417), (('of', 'the'), 18274), (('<s>', 'i'), 16951), (('<s>', 'the'), 15522), (('in', 'the'), 14247), (('<s>', 'yeah'), 9449), (('<s>', 'it'), 9295), (('it', "'s"), 8655), (('<s>', 'he'), 8521), (('yeah', '</s>'), 7476)]
Top 10 trigrams: [(('</s>', '<s>', 'i'), 16951), (('</s>', '<s>', 'the'), 15522), (('</s>', '<s>', 'yeah'), 9449), (('</s>', '<s>', 'it'), 9295), (('</s>', '<s>', 'he'), 8521), (('yeah', '</s>', '<s>'), 7476), (('it', '</s>', '<s>'), 7289), (('</s>', '<s>', 'oh'), 7171), (('</s>', '<s>', 'you'), 6482), (('</s>', '<s>', 'no'), 6382)]
Top 10 unigram probabilities: [(('<s>',), 0.07175365775167356), (('</s>',), 0.07175365775167356), (('the',), 0.04559705075245852), (('of',), 0.02185557472370241), (('and',), 0.02024733851067664), (('to'

In [None]:
class LaplaceLanguageModel(VanillaLanguageModel):
    def __init__(self, xml_files: list = None, words: list[str] = None):
        super().__init__(xml_files, words)

    def _calculate_probability(self, ngram: int, word: str, context: tuple):
        numerator = self.gram[ngram].get(context + (word,), 0)
        if ngram - 1 >= 1:
            denominator = self.gram[ngram - 1].get(context, 0)
        else:
            denominator = len(self.gram[ngram])
        return math.log((numerator + 1) / (denominator + len(self.gram[ngram])))

In [None]:
laplace_model = LaplaceLanguageModel(words=train_words)
laplace_model.train()

print("Top 10 Unigrams with Laplace Smoothing:", laplace_model.get_top_ngrams(1))
print("Top 10 Bigrams with Laplace Smoothing:", laplace_model.get_top_ngrams(2))
print("Top 10 Trigrams with Laplace Smoothing:", laplace_model.get_top_ngrams(3))

Top 10 Unigrams with Laplace Smoothing: [(('<s>',), 258418), (('</s>',), 258418), (('the',), 164216), (('of',), 78712), (('and',), 72920), (('to',), 72602), (('a',), 67912), (('in',), 53041), (('it',), 48151), (('i',), 48014)]
Top 10 Bigrams with Laplace Smoothing: [(('</s>', '<s>'), 258417), (('of', 'the'), 18274), (('<s>', 'i'), 16951), (('<s>', 'the'), 15522), (('in', 'the'), 14247), (('<s>', 'yeah'), 9449), (('<s>', 'it'), 9295), (('it', "'s"), 8655), (('<s>', 'he'), 8521), (('yeah', '</s>'), 7476)]
Top 10 Trigrams with Laplace Smoothing: [(('</s>', '<s>', 'i'), 16951), (('</s>', '<s>', 'the'), 15522), (('</s>', '<s>', 'yeah'), 9449), (('</s>', '<s>', 'it'), 9295), (('</s>', '<s>', 'he'), 8521), (('yeah', '</s>', '<s>'), 7476), (('it', '</s>', '<s>'), 7289), (('</s>', '<s>', 'oh'), 7171), (('</s>', '<s>', 'you'), 6482), (('</s>', '<s>', 'no'), 6382)]


In [None]:
class UNKLanguageModel(VanillaLanguageModel):
    def __init__(self, xml_files: list = None, words: list[str] = None):
        super().__init__(xml_files, words)
        self.unk_threshold = 2
        self._replace_rare_words()

    def _replace_rare_words(self):
        word_counts = {}
        for word in self.words:
            if word in word_counts:
                word_counts[word] += 1
            else:
                word_counts[word] = 1
        for i, word in enumerate(self.words):
            if word_counts[word] <= self.unk_threshold:
                self.words[i] = '<UNK>'

In [None]:
unk_model = UNKLanguageModel(train_files)
unk_model.train()

print("Top 10 Unigrams after replacing rare words with <UNK>:", unk_model.get_top_ngrams(1))
print("Top 10 Bigrams after replacing rare words with <UNK>:", unk_model.get_top_ngrams(2))
print("Top 10 Trigrams after replacing rare words with <UNK>:", unk_model.get_top_ngrams(3))

Top 10 Unigrams after replacing rare words with <UNK>: [(('<s>',), 258418), (('</s>',), 258418), (('the',), 164216), (('of',), 78712), (('and',), 72920), (('to',), 72602), (('a',), 67912), (('in',), 53041), (('<UNK>',), 50801), (('it',), 48151)]
Top 10 Bigrams after replacing rare words with <UNK>: [(('</s>', '<s>'), 258417), (('of', 'the'), 18274), (('<s>', 'i'), 16951), (('<s>', 'the'), 15522), (('in', 'the'), 14247), (('<s>', 'yeah'), 9449), (('<s>', 'it'), 9295), (('it', "'s"), 8655), (('<s>', 'he'), 8521), (('yeah', '</s>'), 7476)]
Top 10 Trigrams after replacing rare words with <UNK>: [(('</s>', '<s>', 'i'), 16951), (('</s>', '<s>', 'the'), 15522), (('</s>', '<s>', 'yeah'), 9449), (('</s>', '<s>', 'it'), 9295), (('</s>', '<s>', 'he'), 8521), (('yeah', '</s>', '<s>'), 7476), (('it', '</s>', '<s>'), 7289), (('</s>', '<s>', 'oh'), 7171), (('<UNK>', '</s>', '<s>'), 6636), (('</s>', '<s>', 'you'), 6482)]
Is there any trigram with empty string after replacing rare words with <UNK>: Fal

In [None]:
laplace_unk_model = LaplaceLanguageModel(words=unk_model.words)
laplace_unk_model.train()

print("Top 10 Unigrams with Laplace Smoothing after replacing rare words with <UNK>:", laplace_unk_model.get_top_ngrams(1))
print("Top 10 Bigrams with Laplace Smoothing after replacing rare words with <UNK>:", laplace_unk_model.get_top_ngrams(2))
print("Top 10 Trigrams with Laplace Smoothing after replacing rare words with <UNK>:", laplace_unk_model.get_top_ngrams(3))

Top 10 Unigrams with Laplace Smoothing after replacing rare words with <UNK>: [(('<s>',), 258418), (('</s>',), 258418), (('the',), 164216), (('of',), 78712), (('and',), 72920), (('to',), 72602), (('a',), 67912), (('in',), 53041), (('<UNK>',), 50801), (('it',), 48151)]
Top 10 Bigrams with Laplace Smoothing after replacing rare words with <UNK>: [(('</s>', '<s>'), 258417), (('of', 'the'), 18274), (('<s>', 'i'), 16951), (('<s>', 'the'), 15522), (('in', 'the'), 14247), (('<s>', 'yeah'), 9449), (('<s>', 'it'), 9295), (('it', "'s"), 8655), (('<s>', 'he'), 8521), (('yeah', '</s>'), 7476)]
Top 10 Trigrams with Laplace Smoothing after replacing rare words with <UNK>: [(('</s>', '<s>', 'i'), 16951), (('</s>', '<s>', 'the'), 15522), (('</s>', '<s>', 'yeah'), 9449), (('</s>', '<s>', 'it'), 9295), (('</s>', '<s>', 'he'), 8521), (('yeah', '</s>', '<s>'), 7476), (('it', '</s>', '<s>'), 7289), (('</s>', '<s>', 'oh'), 7171), (('<UNK>', '</s>', '<s>'), 6636), (('</s>', '<s>', 'you'), 6482)]


In [None]:
def linear_interpolation_probability(language_model: VanillaLanguageModel, sentence: list):
    trigram_lambda = 0.6
    bigram_lambda = 0.3
    unigram_lambda = 0.1

    trigram_prob = language_model.calculate_perplexity(sentence, 3)
    bigram_prob = language_model.calculate_perplexity(sentence, 2)
    unigram_prob = language_model.calculate_perplexity(sentence, 1)

    total_prob = (
        trigram_lambda * trigram_prob +
        bigram_lambda * bigram_prob +
        unigram_lambda * unigram_prob
    )

    return total_prob


In [None]:
vanilla_unigram_perplexity = vanilla_model.calculate_perplexity(test_words, ngram=1)
vanilla_bigram_perplexity = vanilla_model.calculate_perplexity(test_words, ngram=2)
vanilla_trigram_perplexity = vanilla_model.calculate_perplexity(test_words, ngram=3)
vanilla_linear_interpolation = linear_interpolation_probability(vanilla_model, test_words)

laplace_unigram_perplexity = laplace_model.calculate_perplexity(test_words, ngram=1)
laplace_bigram_perplexity = laplace_model.calculate_perplexity(test_words, ngram=2)
laplace_trigram_perplexity = laplace_model.calculate_perplexity(test_words, ngram=3)
laplace_linear_interpolation = linear_interpolation_probability(laplace_model, test_words)

unk_unigram_perplexity = unk_model.calculate_perplexity(test_words, ngram=1)
unk_bigram_perplexity = unk_model.calculate_perplexity(test_words, ngram=2)
unk_trigram_perplexity = unk_model.calculate_perplexity(test_words, ngram=3)
unk_linear_interpolation = linear_interpolation_probability(unk_model, test_words)

print("|            | Unigram | Bigram | Trigram | Linear Interpolation |")
print("|------------|---------|--------|---------|----------------------|")
print(f"| Vanilla    | {vanilla_unigram_perplexity:.2f}   | {vanilla_bigram_perplexity:.2f}  | {vanilla_trigram_perplexity:.2f}  | {vanilla_linear_interpolation:.2f} |")
print(f"| Laplace    | {laplace_unigram_perplexity:.2f}   | {laplace_bigram_perplexity:.2f}  | {laplace_trigram_perplexity:.2f}  | {laplace_linear_interpolation:.2f} |")
print(f"| UNK        | {unk_unigram_perplexity:.2f}   | {unk_bigram_perplexity:.2f}  | {unk_trigram_perplexity:.2f}  | {unk_linear_interpolation:.2f} |")

|            | Unigram | Bigram | Trigram | Linear Interpolation |
|------------|---------|--------|---------|----------------------|
| Vanilla    | 15.49   | 24.26  | 3.70  | 11.05 |
| Laplace    | 38.78   | 18920.83  | 398664.74  | 244878.97 |
| UNK        | 6.06   | 23.88  | 3.69  | 9.99 |


In [None]:
def remove_start_end_tokens(sentence: str):
    return sentence.replace("<s>", "\n").replace("</s>", "").strip().capitalize()

def generate_sentence(language_model: VanillaLanguageModel, starting_phrase: str, ngram: int = 3):
    sentence = starting_phrase.lower().split()
    i = 0
    while True:
        i += 1
        context = tuple(sentence[-(ngram - 1):])
        next_word = generate_next_word(language_model, context, ngram)
        print("next_word: ", next_word)
        if next_word[0] == "</s>":
            if i < 4 and next_word[1]:
                sentence.append(next_word[1])
            else:
                sentence.append(next_word[0])
                break
        sentence.append(next_word[0])

    return remove_start_end_tokens(' '.join(sentence))


def generate_next_word(language_model: VanillaLanguageModel, context, ngram: int):
    candidates = language_model.gram[ngram]

    if ngram == 1:
        candidates = {k: v for k, v in candidates.items() if k != ''}
    else:
        candidates = {k[-1]: v for k, v in candidates.items() if k[:-1] == context}

    if not candidates:
        return "</s>"

    total_count = sum(candidates.values())

    probabilities = {word: count / total_count for word, count in candidates.items()}
    sorted_data = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)

    top_3_elements = sorted_data[:3]

    top_3_text = [item[0] for item in top_3_elements]
    return top_3_text

In [None]:
selected_model = input("Select a language model (Vanilla, Laplace, UNK): ").lower()

if selected_model == "vanilla":
    model = vanilla_model
elif selected_model == "laplace":
    model = laplace_model
elif selected_model == "unk":
    model = unk_model
else:
    raise ValueError("Invalid model selection.")

starting_phrase = input("Enter a starting phrase: ")

generated_sentence = generate_sentence(model, f"<s> {starting_phrase}", ngram=3)
print("Generated sentence: ", generated_sentence)

ValueError: Invalid model selection.