#### Trigram Model

In [1]:
import nltk
from nltk.corpus import reuters
from nltk.util import trigrams
from nltk.probability import FreqDist, ConditionalFreqDist
from collections import defaultdict
import random

# Ensure the Reuters dataset is downloaded
nltk.download("reuters")
nltk.download("punkt")

[nltk_data] Downloading package reuters to /Users/zs74qz/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt to /Users/zs74qz/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

Load the dataset

In [2]:
sentences = reuters.sents()

Preprocess the data

In [3]:
def preprocess(sentences):
    cleaned_text = []
    for sentence in sentences:
        cleaned_sentence = [word.lower() for word in sentence]
        cleaned_text.append(cleaned_sentence)
    return cleaned_text

sentences = preprocess(sentences)

Build trigram frequencies

In [4]:
trigram_model = defaultdict(lambda: defaultdict(lambda: 0))

# Count occurrences of trigrams and bigrams
for sentence in sentences:
    for w1, w2, w3 in trigrams(sentence, pad_left=True, pad_right=True):
        trigram_model[(w1, w2)][w3] += 1

Convert counts to probabilities

In [5]:
trigram_probs = {
    context: {w3: count / sum(context_counts.values()) for w3, count in context_counts.items()}
    for context, context_counts in trigram_model.items()
}

Generate text using the trigram model

In [6]:
# Generate text using the trigram model
def generate_text(trigram_probs, start_words, num_words=20):
    w1, w2 = start_words
    generated_text = [w1, w2]
    
    for _ in range(num_words):
        next_word_candidates = trigram_probs.get((w1, w2), {})
        if not next_word_candidates:
            break
        next_word = random.choices(
            population=list(next_word_candidates.keys()),
            weights=list(next_word_candidates.values()),
            k=1
        )[0]
        generated_text.append(next_word)
        w1, w2 = w2, next_word  # Shift context window
    
    # print(generated_text)
    return ' '.join([str(word) for word in generated_text if word != None])

In [7]:
random.seed(0)

# Example: Generate text
start_words = (None, None)  # Using padding for start
start_words = ("the", "course") # Note that words should b
generated = generate_text(trigram_probs, start_words, num_words=50)
print("Generated text:")
print(generated)


Generated text:
The course
