In [None]:
%pip install datasets
%pip install -U nltk
%pip install transformers tiktoken

%nltk.download("punkt")
%nltk.download("punkt_tab")

In [29]:
import nltk
from nltk import sent_tokenize
from itertools import chain
from pprint import pprint
from datasets import load_dataset
from nltk.util import ngrams
import re
from collections import Counter
import random
import math
import numpy as np

In [2]:
# Load the full dataset
raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Convert splits into plain text strings
corpus = {
    "train": " ".join(raw_dataset["train"]["text"]),
    "dev": " ".join(raw_dataset["validation"]["text"]),
    "test":  " ".join(raw_dataset["test"]["text"]),
}

# Display token counts
for split in corpus:
    print(f"{split.capitalize()} set #tokens: {len(corpus[split].split())}")

Train set #tokens: 2051910
Dev set #tokens: 213886
Test set #tokens: 241211


In [3]:
def print_sents(sents: list, n_first:int) -> None:
  for sent in sents[:n_first]:
    print(sent)
    print("___________________")

In [4]:
corpus["train_sentences"] = sent_tokenize(corpus["train"])
corpus["dev_sentences"] = sent_tokenize(corpus["dev"])
corpus["test_sentences"] = sent_tokenize(corpus["test"])

print_sents(corpus["train_sentences"], 5)

  = Valkyria Chronicles III = 
   Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit .
___________________
Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable .
___________________
Released in January 2011 in Japan , it is the third game in the Valkyria series .
___________________
Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " .
___________________
The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II .
___________________


In [5]:
def custom_tokenize(text):
    return re.findall(r"\b\w+\b", text.lower())

In [6]:
corpus["train_filtered"] = [custom_tokenize(text) for text in corpus["train_sentences"]]
corpus["test_filtered"] = [custom_tokenize(text) for text in corpus["test_sentences"]]  
corpus["dev_filtered"] = [custom_tokenize(text) for text in corpus["dev_sentences"]]

print_sents(corpus["train_filtered"], 5)

['valkyria', 'chronicles', 'iii', 'senjō', 'no', 'valkyria', '3', 'unrecorded', 'chronicles', 'japanese', '戦場のヴァルキュリア3', 'lit']
___________________
['valkyria', 'of', 'the', 'battlefield', '3', 'commonly', 'referred', 'to', 'as', 'valkyria', 'chronicles', 'iii', 'outside', 'japan', 'is', 'a', 'tactical', 'role', 'playing', 'video', 'game', 'developed', 'by', 'sega', 'and', 'media', 'vision', 'for', 'the', 'playstation', 'portable']
___________________
['released', 'in', 'january', '2011', 'in', 'japan', 'it', 'is', 'the', 'third', 'game', 'in', 'the', 'valkyria', 'series']
___________________
['employing', 'the', 'same', 'fusion', 'of', 'tactical', 'and', 'real', 'time', 'gameplay', 'as', 'its', 'predecessors', 'the', 'story', 'runs', 'parallel', 'to', 'the', 'first', 'game', 'and', 'follows', 'the', 'nameless', 'a', 'penal', 'military', 'unit', 'serving', 'the', 'nation', 'of', 'gallia', 'during', 'the', 'second', 'europan', 'war', 'who', 'perform', 'secret', 'black', 'operations', 'a

In [7]:
# Frequency distribution
train_tokens = list(chain.from_iterable(corpus["train_filtered"]))
count = nltk.FreqDist(train_tokens)
print('Top 20 most frequent tokens in train sentences: \n')

pprint(count.most_common(20))

Top 20 most frequent tokens in train sentences: 

[('the', 130771),
 ('of', 57032),
 ('and', 50738),
 ('in', 45019),
 ('to', 39522),
 ('a', 36567),
 ('was', 21008),
 ('on', 15141),
 ('as', 15058),
 ('s', 14982),
 ('that', 14351),
 ('for', 13795),
 ('with', 13012),
 ('by', 12718),
 ('is', 11692),
 ('it', 9277),
 ('from', 9229),
 ('at', 9071),
 ('his', 9020),
 ('he', 8709)]


In [8]:
# Frequency distribution
test_tokens = list(chain.from_iterable(corpus["test_filtered"]))
count = nltk.FreqDist(test_tokens)
print('Top 20 most frequent tokens in train sentences: \n')

pprint(count.most_common(20))

Top 20 most frequent tokens in train sentences: 

[('the', 16083),
 ('of', 6789),
 ('and', 5885),
 ('in', 5079),
 ('to', 4787),
 ('a', 4090),
 ('was', 2575),
 ('on', 1903),
 ('as', 1605),
 ('that', 1522),
 ('s', 1516),
 ('for', 1490),
 ('with', 1485),
 ('by', 1478),
 ('he', 1329),
 ('at', 1240),
 ('his', 1179),
 ('is', 1137),
 ('were', 1085),
 ('from', 1061)]


In [9]:
# Frequency distribution
dev_tokens = list(chain.from_iterable(corpus["dev_filtered"]))
count = nltk.FreqDist(dev_tokens)
print('Top 20 most frequent tokens in train sentences: \n')

pprint(count.most_common(20))

Top 20 most frequent tokens in train sentences: 

[('the', 14717),
 ('of', 5926),
 ('and', 5345),
 ('in', 4755),
 ('to', 4160),
 ('a', 3659),
 ('was', 2324),
 ('on', 1680),
 ('s', 1565),
 ('as', 1476),
 ('that', 1395),
 ('by', 1340),
 ('for', 1273),
 ('with', 1234),
 ('at', 1005),
 ('from', 973),
 ('is', 964),
 ('were', 886),
 ('it', 865),
 ('he', 817)]


In [10]:
""" Function that returns a Unigram Counter of a given corpus."""
def calc_unigrams(tokens):
  unigram_counter = Counter()
  for sent in tokens:
    unigram_counter.update([(gram,) for gram in ["<s>"] + sent])
  
  return unigram_counter

""" Function that returns a Biagram Counter of a given corpus."""
def calc_bigrams(tokens):
  bigram_counter = Counter()
  for sent in tokens:
    bigram_pad_sent = ["<s>"] + sent +  ['<e>']
    bigram_counter.update([(gram1, gram2) for gram1, gram2 in zip(bigram_pad_sent, bigram_pad_sent[1:])])  

  return bigram_counter

""" Function that returns a Trigram Counter of a given corpus."""
def calc_trigrams(tokens):
  trigram_counter = Counter()
  for sent in tokens:
    trigram_pad_sent = ["<s>"] + sent +  ['<e>']
    trigram_counter.update([(gram1, gram2, gram3) for gram1, gram2, gram3 in zip(trigram_pad_sent, trigram_pad_sent[1:], trigram_pad_sent[2:])])  

  return trigram_counter


In [11]:
def replace_oov_words_train(corpus):
    """ Function that calculates and replaces OOV words.
    INPUT: Train corpus (list)
    OUTPUT: 
      OOV_word: dict with key containing OOC words and value the str 'UNK' -> dict
      clean_corpus: the original corpus having the OOV words replaced by 'UNK' -> list
      vocabulary: the words contained in the vocabulary -> set
    """

    unigram_counter = calc_unigrams(corpus)
    OOV_words = {k[0]:"UNK" for k, v in unigram_counter.items() if v < 10}
    clean_corpus = []
    for sentence in corpus:
        clean_corpus.append([OOV_words.get(n,n) for n in sentence])
    vocabulary = [f[0] for f in unigram_counter.keys() if f[0] not in OOV_words]
    vocabulary = set(vocabulary) # set for unique words
    return OOV_words, clean_corpus, vocabulary


In [12]:
oov_words, clean_corpus, vocabulary = replace_oov_words_train(corpus["train_filtered"])
random.sample(list(oov_words.keys()), 10)

['1322',
 'yoganna',
 'maxims',
 'newsweek',
 'frances',
 'understandable',
 'holyrood',
 'anatol',
 'qxd4',
 '2125']

In [None]:
def replace_oov_words_dev_test(corpus, vocabulary, oov_words):
    clean_corpus = []
    for sentence in corpus:
        updated_sentence = ['UNK' if ((word not in vocabulary) or (word in oov_words)) else word for word in sentence]
        clean_corpus.append(updated_sentence)
    return clean_corpus

corpus["dev_filtered"] = replace_oov_words_dev_test(corpus["dev_filtered"], vocabulary, oov_words)
corpus["test_filtered"] = replace_oov_words_dev_test(corpus["test_filtered"], vocabulary, oov_words)


In [None]:
vocabulary_length = len(vocabulary)
unigram_counter = calc_unigrams(corpus["train_filtered"])
bigram_counter = calc_bigrams(corpus["train_filtered"])
trigram_counter = calc_trigrams(corpus["train_filtered"])


In [None]:
print(f'Vocabulary Length: {vocabulary_length}')
print('='*25, '\nUnigram 10 most common:')
pprint(unigram_counter.most_common(10))
print('='*25, '\nBigram 10 most common:')
pprint(bigram_counter.most_common(10))
print('='*25, '\nTrigram 10 most common:')
pprint(trigram_counter.most_common(10))


Vocabulary Length: 13130
Unigram 10 most common:
[(('the',), 130771),
 (('<s>',), 78453),
 (('of',), 57032),
 (('and',), 50738),
 (('in',), 45019),
 (('to',), 39522),
 (('a',), 36567),
 (('was',), 21008),
 (('on',), 15141),
 (('as',), 15058)]
Bigram 10 most common:
[(('of', 'the'), 17353),
 (('<s>', 'the'), 13695),
 (('in', 'the'), 11841),
 (('to', 'the'), 6031),
 (('<s>', 'in'), 4960),
 (('on', 'the'), 4518),
 (('and', 'the'), 4394),
 (('for', 'the'), 3729),
 (('at', 'the'), 3197),
 (('from', 'the'), 3014)]
Trigram 10 most common:
[(('<s>', 'in', 'the'), 979),
 (('one', 'of', 'the'), 869),
 (('<s>', 'it', 'was'), 734),
 (('the', 'united', 'states'), 667),
 (('as', 'well', 'as'), 604),
 (('part', 'of', 'the'), 534),
 (('the', 'end', 'of'), 510),
 (('<s>', 'it', 'is'), 499),
 (('<s>', 'according', 'to'), 450),
 (('<s>', 'at', 'the'), 405)]


In [19]:
def calc_bi_prob(word1: str, word2: str, alpha: float, bigram_counter: Counter, unigram_counter: Counter, vocabulary_length: int) -> float:
    """ Function that calculates the Bigram model's probabilities using Laplace & a-smoothing."""
    #Bigram prob + laplace smoothing
    bigram_prob = (bigram_counter[(word1, word2)] + alpha) / (unigram_counter[(word1,)] + alpha * vocabulary_length)
    return bigram_prob

In [31]:
def bigram_LM(
    corpus: list[list[str]],
    unigram_counter: dict,
    bigram_counter: dict
) -> None:
    vocab_size = len(unigram_counter)
    alpha_list = np.linspace(0.001, 0.1, 100)

    best_entropy = float("inf")
    best_alpha = None
    best_perplexity = None

    for alpha in alpha_list:
        sum_log_prob = 0
        bigram_cnt = 0

        for sent in corpus:
            sent = ['<s>'] + sent + ['<e>']

            for i in range(1, len(sent)):
                w1, w2 = sent[i - 1], sent[i]

                # Skip prediction of start tokens
                if w2 in {'<s>', '<s1>', '<s2>'}:
                    continue

                bigram_count = bigram_counter.get((w1, w2), 0)
                unigram_count = unigram_counter.get(w1, 0)

                prob = (bigram_count + alpha) / (unigram_count + alpha * vocab_size)
                sum_log_prob += math.log2(prob)
                bigram_cnt += 1

        cross_entropy = -sum_log_prob / bigram_cnt
        perplexity = math.pow(2, cross_entropy)

        if cross_entropy < best_entropy:
            best_entropy = cross_entropy
            best_perplexity = perplexity
            best_alpha = alpha

    print(f"Best Laplace-smoothed Bigram Model:")
    print(f"  Cross Entropy: {best_entropy:.3f}")
    print(f"  Perplexity:    {best_perplexity:.3f}")
    print(f"  Best alpha:    {best_alpha:.3f}")

In [None]:
bigram_LM(corpus["dev_filtered"], unigram_counter, bigram_counter)

Best Laplace-smoothed Bigram Model:
  Cross Entropy: 6.125
  Perplexity:    69.796
  Best alpha:    0.001


In [33]:
def calc_tri_prob(word1: str, word2: str, word3:str, alpha: float, trigram_counter:Counter, bigram_counter:Counter, vocabulary_length: int) -> float:
    """ Function that calculates the Trigram model's probabilities using Laplace & a-smoothing."""
    #Bigram prob + laplace smoothing
    trigram_prob = (trigram_counter[(word1, word2, word3)] +alpha) / (bigram_counter[(word1, word2)] + alpha * vocabulary_length)
    return trigram_prob

In [41]:
def trigram_LM(
    corpus: list[list[str]],
    bigram_counter: dict,
    trigram_counter: dict,
    unigram_counter: dict
) -> None:
    vocab_size = len(unigram_counter)
    alpha_list = np.linspace(0.001, 0.1, 100)

    # Initialize best metrics
    best_entropy = float("inf")
    best_perplexity = None
    best_alpha = None

    # Add pseudo bigram (<s>, <s>)
    bigram_counter[("<s>", "<s>")] = len(corpus)

    for alpha in alpha_list:
        sum_log_prob = 0
        trigram_cnt = 0

        for sent in corpus:
            sent = ['<s>', '<s>'] + sent + ['<e>']

            for i in range(2, len(sent)):
                w1, w2, w3 = sent[i-2], sent[i-1], sent[i]

                # Skip if predicting a start token
                if w3 in {'<s>', '<s1>', '<s2>'}:
                    continue

                trigram_count = trigram_counter.get((w1, w2, w3), 0)
                bigram_count = bigram_counter.get((w1, w2), 0)

                prob = (trigram_count + alpha) / (bigram_count + alpha * vocab_size)
                sum_log_prob += math.log2(prob)
                trigram_cnt += 1

        cross_entropy = -sum_log_prob / trigram_cnt
        perplexity = math.pow(2, cross_entropy)

        # Track best alpha
        if cross_entropy < best_entropy:
            best_entropy = cross_entropy
            best_perplexity = perplexity
            best_alpha = alpha

    # Final results
    print("Trigram LM (Laplace-smoothed) — Alpha Tuning:")
    print(f"  Best alpha:      {best_alpha:.3f}")
    print(f"  Cross Entropy:   {best_entropy:.3f}")
    print(f"  Perplexity:      {best_perplexity:.3f}")

In [42]:
trigram_LM(corpus["dev_filtered"], bigram_counter, trigram_counter, unigram_counter)

Trigram LM (Laplace-smoothed) — Alpha Tuning:
  Best alpha:      0.001
  Cross Entropy:   13.866
  Perplexity:      14928.122
