## Text Generation using RNN

This project aims to go beyond basic sequence generation by implementing and comparing advanced Recurrent Neural Network (RNN) architectures, such as LSTMs and GRUs. To experiment with the different architectures and text generation methods, including temperature-controlled sampling, a dataset of Shakespeare's writings will be used. More specifically, the Tiny Shakespeare dataset, curated by Andrej Karpathy is used (karpathy/tiny_shakespeare).

#### Load libraries

In [None]:
#! pip install datasets
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import re
import os
import random
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Union
from torch.nn.utils import clip_grad_norm_
from copy import deepcopy
! pip install tqdm
from tqdm import tqdm

#### Load the dataset

In [None]:
# Load and save dataset
dataset = load_dataset('karpathy/tiny_shakespeare', trust_remote_code=True)
with open('tiny_shakespeare.txt', 'w') as f:
    f.write(dataset['train'][0]['text'])  # Save full text to file

### Data Processing
To process the data, 2 different tokenization methods are used - BPE and WordPiece. This decision is driven by the need to balance between capturing meaningful subword units and ensuring the model can handle out-of-vocabulary words effectively.

#### BPE Tokenization
BPE is a subword tokenization technique that iteratively merges the most frequent pair of characters or subwords in the training corpus. This method is particularly effective in handling rare and unseen words by breaking them down into smaller, more frequent subword units. Main benefits include capturing subword structure, reducing vocabulary size and handling rare or compound words.

In [None]:
class ShakespeareBPE(Dataset):
    def __init__(self, file_path, seq_length=100, split='train'):
        self.seq_length = seq_length
        self.tokenizer = self._init_tokenizer(file_path)

        with open(file_path, 'r', encoding='utf-8') as f:
            self.raw_text = f.read()

        # Split data (80-10-10)
        train_end = int(0.8 * len(self.raw_text))
        val_end = int(0.9 * len(self.raw_text))
        self.text = {
            'train': self.raw_text[:train_end],
            'val': self.raw_text[train_end:val_end],
            'test': self.raw_text[val_end:]
        }[split]

        self.data = self.tokenizer.encode(self.text).ids

    def _init_tokenizer(self, file_path):
        tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))

        # Enhanced pre-tokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
             pre_tokenizers.Split(r'(\n\n)', behavior='isolated'),
            pre_tokenizers.Split(r'(?<=\n)|(?=\n)', behavior='isolated'),  # Isolate newlines
            pre_tokenizers.WhitespaceSplit(),
            pre_tokenizers.Split(r"('(?:[sstdm]|ll|re|ve|st)\b)", behavior='isolated')
        ])

        # Comprehensive special tokens
        special_tokens = [
            "[SOS]", "[EOS]", "[UNK]", "[PAD]",  # Control tokens
            "\n", "\n\n", "\t",  # Whitespace
            *[f"'{x}" for x in ['s', 't', 'll', 'd', 'm', 're', 've', 'st']]  # Contractions
        ]

        trainer = trainers.BpeTrainer(
            special_tokens=special_tokens,
            min_frequency=2,
            vocab_size=12000,
            show_progress=True
        )

        tokenizer_path = 'shakespeare-bpe-enhanced.json'
        if not os.path.exists(tokenizer_path):
            tokenizer.train([file_path], trainer=trainer)
            tokenizer.save(tokenizer_path)
        else:
            tokenizer = Tokenizer.from_file(tokenizer_path)

        # Post-processor for sentence structure
        tokenizer.post_processor = processors.TemplateProcessing(
            single="[SOS] $A [EOS]",
            pair="[SOS] $A [EOS] $B [EOS]",
            special_tokens=[
                ("[SOS]", tokenizer.token_to_id("[SOS]")),
                ("[EOS]", tokenizer.token_to_id("[EOS]")),
            ]
        )

        return tokenizer

    def decode(self, token_ids):
        text = self.tokenizer.decode(token_ids)
        # Post-process to restore Shakespearean formatting
        text = text.replace(" '", "'")  # Fix contractions
        text = re.sub(r' ([,.!?;:])', r'\1', text)  # Remove spaces before punctuation
        return text

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        inputs = self.data[idx:idx + self.seq_length]
        targets = self.data[idx + 1:idx + self.seq_length + 1]
        return {
            'input': torch.tensor(inputs, dtype=torch.long),
            'target': torch.tensor(targets, dtype=torch.long)
        }

    def test_tokenization(self, sample_text):
        """Test with full formatting"""
        encoded = self.tokenizer.encode(sample_text)
        decoded = self.decode(encoded.ids)

        print("Original:")
        print(repr(sample_text))
        print("\nTokens:")
        for token, id in zip(encoded.tokens, encoded.ids):
            print(f"{repr(token):<10} (ID: {id})")
        print("\nDecoded:")
        print(repr(decoded))
        print("\nExact Match:", sample_text == decoded)
        return sample_text == decoded

#### Test BPE Tokenization

In [None]:
dataset = ShakespeareBPE("shakespeare.txt")

# Contraction test
sample1 = "I'll die,\nAnd if thou wilt, thou may'st be merciful."
print("=== Contraction Test ===")
dataset.test_tokenization(sample1)

# Whitespace test
sample2 = "  ACT I\n\tScene 1:\n  Enter ROMEO\n\nROMEO:\n  'Tis morning!"
print("\n=== Whitespace Test ===")
dataset.test_tokenization(sample2)

# Complex test
sample3 = "' are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you'"
print("\n=== Complex Test ===")
dataset.test_tokenization(sample3)

=== Contraction Test ===
Original:
"I'll die,\nAnd if thou wilt, thou may'st be merciful."

Tokens:
'[SOS]'    (ID: 0)
'I'        (ID: 34)
"'ll"      (ID: 9)
'die,'     (ID: 2487)
'\n'       (ID: 4)
'And'      (ID: 131)
'if'       (ID: 229)
'thou'     (ID: 142)
'wilt,'    (ID: 8250)
'thou'     (ID: 142)
'may'      (ID: 323)
"'st"      (ID: 14)
'be'       (ID: 106)
'merciful' (ID: 8783)
'.'        (ID: 21)
'[EOS]'    (ID: 1)

Decoded:
'I die, And if thou wilt, thou may be merciful.'

Exact Match: False

=== Whitespace Test ===
Original:
"  ACT I\n\tScene 1:\n  Enter ROMEO\n\nROMEO:\n  'Tis morning!"

Tokens:
'[SOS]'    (ID: 0)
'A'        (ID: 26)
'C'        (ID: 28)
'T'        (ID: 45)
'I'        (ID: 34)
'\n'       (ID: 4)
'\t'       (ID: 6)
'S'        (ID: 44)
'c'        (ID: 54)
'ene'      (ID: 930)
'[UNK]'    (ID: 2)
':'        (ID: 23)
'\n'       (ID: 4)
'En'       (ID: 2827)
'ter'      (ID: 198)
'ROME'     (ID: 634)
'O'        (ID: 40)
'\n\n'     (ID: 5)
'ROMEO:'   (ID: 636)
'\n' 

False

#### Observation:

Despite multiple attempts at refining the tokenisation process, there is still a difference between the original and decoded output. This discrepancy can be attributed to the following reasons
1. Subword tokenisation artifacts: BPE intentionally splits rare words/contractions ("may'st" transformed into "may" + "'st") to improve vocabulary coverage. As a result, this causes reconstructed text to sometimes merge or separate components differently than the original.
2. Whitespace normalisation: During tokenisation, newlines are often converted to single spaces, spaces are added/removed around puncutation and multiple spaces are usually collapsed. This loses original formatting fidelity but helps models learn patterns.
3. Contraction handling: Modern tokenizers treat apostrophes as word separators by default. In Shakespearean style, contractions ("'tis", "'ll") get split despite being single semantic units unlike in Early Modern English.
4. Vocabulary limitations: The predefined vocabulary is likely to lack certain archaic word forms thus forced to decompose them into suboptimal subword pieces.

Conventionally, BPE was designed for efficient machine processing rather than text preservation. Thus, the algorithm makes 3 intentional tradeoffs in the above process: (i) Prioritises consistent tokenization over original formatting (ii) Optimizes for coverage of modern language patterns and (iii) Sacrifices exact reconstruction to enable better generalization.

Thus, the differences observed are not necessarily considered 'errors' as they are inherent characteristics of how BPE fundamentally operates when applied to historical text with non-standard orthography.

#### Wordpiece tokenization
WordPiece, a tokenization method similar to BPE, is typically used in models like BERT. It also breaks down words into subwords, but it differs from BPE in how it merges subwords. The main goal of WordPiece is to minimize the number of unknown tokens by maintaining a vocabulary of subword units that are most likely to appear in the text. The reasons for WordPiece Tokenization method include the ability to handle complex words, improve generalization and more robust handling of unseen words.

In [None]:
class ShakespeareWordPiece(Dataset):
    def __init__(self, file_path, seq_length=100, split='train'):
        self.seq_length = seq_length
        self.tokenizer = self._init_tokenizer(file_path)

        with open(file_path, 'r', encoding='utf-8') as f:
            self.raw_text = f.read()

        # Split data (80-10-10)
        train_end = int(0.8 * len(self.raw_text))
        val_end = int(0.9 * len(self.raw_text))
        self.text = {
            'train': self.raw_text[:train_end],
            'val': self.raw_text[train_end:val_end],
            'test': self.raw_text[val_end:]
        }[split]

        self.data = self.tokenizer.encode(self.text).ids

    def _init_tokenizer(self, file_path):
        tokenizer_path = 'shakespeare-wordpiece-uncapped.json'  # Renamed to reflect uncapped vocab

        # Define special tokens (treated as mandatory, but not counted against a limit)
        special_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "[PAD]",
            *[f"'{x}" for x in ['s', 't', 'll', 'd', 'm', 're', 've', 'st', 'Tis']],
            "\n", "\t", "  ", "   ", ":", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0"
        ]

        if os.path.exists(tokenizer_path):
            tokenizer = Tokenizer.from_file(tokenizer_path)
        else:
            tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))

            tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
                pre_tokenizers.Split(r'(?<=\n)|(?=\n)', behavior='isolated'),
                pre_tokenizers.Split(r'(?<=\t)|(?=\t)', behavior='isolated'),
                pre_tokenizers.Split(r' +', behavior='isolated'),
                pre_tokenizers.Split(r"('(?:[sstdm]|ll|re|ve|st|Tis)\b)", behavior='isolated')
            ])

            # Trainer with NO vocab_size limit (learn all subwords)
            trainer = trainers.WordPieceTrainer(
                continuing_subword_prefix="##",
                special_tokens=special_tokens,  # Special tokens are added but don't cap the total
                show_progress=True,
                min_frequency=2  # Optional: Skip very rare tokens
            )

            tokenizer.train([file_path], trainer=trainer)
            tokenizer.save(tokenizer_path)

        # Report actual vocabulary size (now uncapped)
        vocab_size = len(tokenizer.get_vocab())
        print(f'Vocabulary size: {vocab_size} (Uncapped, learned from data)')

        # Post-processing (unchanged)
        cls_id = tokenizer.token_to_id("[CLS]")
        sep_id = tokenizer.token_to_id("[SEP]")
        if cls_id is not None and sep_id is not None:
            tokenizer.post_processor = processors.TemplateProcessing(
                single="[CLS] $A [SEP]",
                pair="[CLS] $A [SEP] $B [SEP]",
                special_tokens=[
                    ("[CLS]", cls_id),
                    ("[SEP]", sep_id),
                ]
            )

        return tokenizer

    def decode(self, token_ids):
        tokens = [self.tokenizer.id_to_token(id) for id in token_ids]
        text = ""
        for i, token in enumerate(tokens):
            if token in ["[CLS]", "[SEP]", "[PAD]"]:
                continue
            elif token.startswith("##"):
                text += token[2:]
            elif token in ["\n", "\t", "  ", "   "]:
                text += token
            elif (
                len(text) > 0
                and not text.endswith((" ", "\n", "\t"))
                and not token.startswith("'")
            ):
                text += " " + token
            else:
                text += token

        # Post-processing for contractions and punctuation spacing
        #text = text.replace(" '", "'")
        #text = text.replace(" :", ":").replace(" ,", ",").replace(" .", ".")
        #text = text.replace(" ?", "?").replace(" !", "!")
        #text = re.sub(r"\b([A-Za-z]+) '([a-z]{1,4})\b", r"\1'\2", text)
            # Post-processing fixes
        text = re.sub(r" ([.,;:!?])", r"\1", text)           # space before punctuation
        #text = re.sub(r"\b' ([a-zA-Z]+)", r"'\1", text)       # fix ' Tis → 'Tis
        text = re.sub(r"'\s?([a-zA-Z]+)", r"'\1", text)
        text = text.replace("[UNK]", "<??>")                 # optional: make unknowns visible
        text = re.sub(r" +", " ", text)                      # remove excessive spaces
        return text.strip()


    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        inputs = self.data[idx:idx + self.seq_length]
        targets = self.data[idx + 1:idx + self.seq_length + 1]
        return {
            'input': torch.tensor(inputs, dtype=torch.long),
            'target': torch.tensor(targets, dtype=torch.long)
        }

    def test_tokenization(self, sample_text):
        encoded = self.tokenizer.encode(sample_text)
        decoded = self.decode(encoded.ids)

        print("Original:")
        print(repr(sample_text))
        print("\nTokens:")
        for token, id in zip(encoded.tokens, encoded.ids):
            print(f"{repr(token):<15} (ID: {id})")
        print("\nDecoded:")
        print(repr(decoded))
        print("\nMatch:", sample_text == decoded)
        return sample_text == decoded


#### Test WordPiecee tokenisation

In [None]:
dataset = ShakespeareWordPiece("shakespeare.txt")

# Contraction test
sample1 = "I'll die,\nAnd if thou wilt, thou may'st be merciful."
print("=== Contraction Test ===")
dataset.test_tokenization(sample1)

# Whitespace test
sample2 = "  ACT I\n\tScene 1:\n  Enter ROMEO\n\nROMEO:\n  'Tis morning!"
print("\n=== Whitespace Test ===")
dataset.test_tokenization(sample2)

# Complex test
sample3 = "' are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you'"
print("\n=== Complex Test ===")
dataset.test_tokenization(sample3)

Vocabulary size: 30000 (Uncapped, learned from data)
=== Contraction Test ===
Original:
"I'll die,\nAnd if thou wilt, thou may'st be merciful."

Tokens:
'[CLS]'         (ID: 1)
'I'             (ID: 46)
"'ll"           (ID: 6)
' '             (ID: 28)
'##di'          (ID: 336)
'##e'           (ID: 95)
'##,'           (ID: 100)
'\n'            (ID: 13)
'And if thou'   (ID: 21535)
'## w'          (ID: 185)
'##ilt, '       (ID: 12594)
'##thou may'    (ID: 22168)
"'st"           (ID: 11)
' '             (ID: 28)
'##be mer'      (ID: 22301)
'##ci'          (ID: 477)
'##ful'         (ID: 616)
'##.'           (ID: 116)
'[SEP]'         (ID: 2)

Decoded:
"I'll die,\nAnd if thou wilt, thou may'st be merciful."

Match: True

=== Whitespace Test ===
Original:
"  ACT I\n\tScene 1:\n  Enter ROMEO\n\nROMEO:\n  'Tis morning!"

Tokens:
'[CLS]'         (ID: 1)
'  '            (ID: 15)
'A'             (ID: 38)
'##C'           (ID: 130)
'##T'           (ID: 90)
'## I'          (ID: 231)
'\n'            (ID

False

#### Observation:

Similar to BPE, there still remains discrepancy between the original and decoded output despite multiple attempts at refining the tokenisation process. However, this process gave a match in contraction test, which was a really pleasant surprise! This difference in performance between BPE and WordPiece can be attributed to the following factors:
1. Subword Segmentation: WordPiece is designed to handle subword segmentation, which means it can break down rare words or contractions into smaller meaningful parts. For example, "shakespearean" might get broken down into something like ['shakes', '##peare', '##an'], which is particularly effective for languages with many inflected forms or compound words.
    
    BPE also does subword segmentation but with a more general approach that may not capture the intricacies of specific languages or contractions as effectively. BPE tries to merge frequently occurring character pairs, but it may not be as good at handling rare tokens or handling contractions and punctuation efficiently.
2. Handling Rare Words and Vocabulary: WordPiece often works better when there is a high variance in the vocabulary (e.g., rare or novel words, contractions), which is common in datasets like Shakespeare’s works. WordPiece builds a more structured vocabulary where the most common subwords are preserved, while less frequent subwords (or characters) are split further. This can help in dealing with rare words and contractions, which is why the contraction test worked well.

    BPE on the other hand, may not always split words as meaningfully, especially if it is trained on a dataset that does not include the specific vocabulary or contractions it is testing against. BPE simply merges frequent character pairs into tokens, which might not capture contractions and punctuation as cleanly.
3. Dealing with Contractions: WordPiece can be more accurate in handling contractions like 'Tis or 's, since it learns them explicitly during training. If these contractions are treated as single tokens or split into smaller meaningful units, it can preserve the integrity of these words better than BPE, which might not handle them as well during subword merging.

    BPE may also split contractions incorrectly if the tokenization rules are not specific or precise enough (especially for edge cases), which could lead to issues in the decoded text.

Considering the better performance of WordPiece Tokenisation method, it will be the tokenisation choice for the subsequent segments.

#### Custom vocabulary reduction enhancement to WordPiece

**Rationale for implementation:**
1. Increase Efficiency: Reducing the vocabulary size helps to minimize memory usage, model size, and computation during training and inference. This can make training faster and more efficient.
2. Remove Unnecessary Tokens: Some tokens may have very low frequency in the training corpus and may not contribute much to the model’s learning. By removing these rare tokens, the model can focus on more frequent, meaningful tokens.
3. Improve Generalization: By limiting the vocabulary, overfitting to rare or outlier tokens can be reduced, which could lead to better generalization across different datasets or real-world inputs.
4. Prevent OOV (Out-Of-Vocabulary) Issues: When dealing with very large vocabularies, a model may frequently encounter tokens that are rare and outside the training distribution. Pruning the vocabulary to remove very rare tokens can reduce this problem by focusing on a more robust set of tokens.

In [None]:
class ShakespeareWordPieceVocab(Dataset):
    def __init__(self, file_path: str, seq_length: int = 100, split: str = 'train', vocab_size = None):
        """
        Args:
            file_path: Path to Shakespeare text file
            seq_length: Length of sequences to generate
            split: One of ['train', 'val', 'test']
            vocab_size: Desired vocabulary size (including special tokens)
        """
        self.seq_length = seq_length
        self.vocab_size = vocab_size

        # First read the file to get text statistics
        with open(file_path, 'r', encoding='utf-8') as f:
            self.raw_text = f.read()

        # Initialize tokenizer (with proper text statistics)
        self.tokenizer = self._init_tokenizer(file_path, self.raw_text)

        # Split data (80-10-10)
        train_end = int(0.8 * len(self.raw_text))
        val_end = int(0.9 * len(self.raw_text))
        self.text = {
            'train': self.raw_text[:train_end],
            'val': self.raw_text[train_end:val_end],
            'test': self.raw_text[val_end:],
        }[split]

        # Tokenize the selected split
        self.data = self.tokenizer.encode(self.text).ids

        @property
        def vocab_size(self):
            return len(self.tokenizer.get_vocab())

    def _init_tokenizer(self, file_path: str, text: str) -> Tokenizer:
        """Initialize or load a WordPiece tokenizer with controlled vocabulary size."""
        tokenizer_path = f'shakespeare-wp-{self.vocab_size}.json'

        # Define special tokens (must include at least [UNK], [CLS], [SEP], [PAD])
        special_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]",
            *[f"'{x}" for x in ['s', 't', 'll', 'd', 'm', 're', 've', 'st', 'Tis']],
            "\n", "\t", "  ", "   ", ":", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0"
        ]

        if os.path.exists(tokenizer_path):
            tokenizer = Tokenizer.from_file(tokenizer_path)
        else:
            # Initialize new tokenizer
            tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))

            # Custom pre-tokenization for Shakespeare text
            tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
                pre_tokenizers.WhitespaceSplit(),
                pre_tokenizers.Split(r"('(?:[sstdm]|ll|re|ve|st|Tis)\b)", behavior='isolated'),
                pre_tokenizers.Split(r'([\n\t])', behavior='isolated')
            ])

            # Calculate how many slots we have for regular tokens
            num_regular_tokens = self.vocab_size - len(special_tokens)

            # Train with frequency-aware sampling
            trainer = trainers.WordPieceTrainer(
                vocab_size=num_regular_tokens,
                continuing_subword_prefix="##",
                show_progress=True,
                special_tokens=special_tokens,
                min_frequency=2  # Ignore very rare tokens
            )

            # Train on the text
            tokenizer.train_from_iterator(
                [text],
                trainer=trainer,
                length=len(text)
            )

            # Verify vocabulary size
            actual_vocab_size = len(tokenizer.get_vocab())
            if actual_vocab_size > self.vocab_size:
                print(f"Warning: Actual vocab size {actual_vocab_size} exceeds target {self.vocab_size}")

            tokenizer.save(tokenizer_path)

        print(f"Vocabulary size: {len(tokenizer.get_vocab())} (Target: {self.vocab_size})")

        # Configure post-processing
        tokenizer.post_processor = processors.TemplateProcessing(
            single="[CLS] $A [SEP]",
            pair="[CLS] $A [SEP] $B [SEP]",
            special_tokens=[
                ("[CLS]", tokenizer.token_to_id("[CLS]")),
                ("[SEP]", tokenizer.token_to_id("[SEP]")),
            ]
        )

        return tokenizer

    def decode(self, token_ids: Union[List[int], torch.Tensor]) -> str:
        """Convert token IDs back to text with proper formatting."""
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()

        tokens = []
        for id in token_ids:
            token = self.tokenizer.id_to_token(id)
            if token is None:
                token = "[UNK]"
            tokens.append(token)

        # Reconstruct text with proper spacing
        text = ""
        for i, token in enumerate(tokens):
            if token in ["[CLS]", "[SEP]", "[PAD]"]:
                continue

            # Handle subword tokens
            if token.startswith("##"):
                text += token[2:]
            # Handle special whitespace tokens
            elif token in ["\n", "\t"]:
                text += token
            # Handle contractions and punctuation
            elif token.startswith("'") or token in {',', '.', '!', '?', ';', ':'}:
                text += token
            # Normal case - add space if needed
            elif i > 0 and not tokens[i-1] in ["\n", "\t", " "]:
                text += " " + token
            else:
                text += token

        # Post-processing cleanup
        text = re.sub(r'\s+([.,!?;:])', r'\1', text)  # Fix punctuation spacing
        text = re.sub(r' +', ' ', text).strip()       # Collapse multiple spaces
        return text

    def __len__(self) -> int:
        return len(self.data) - self.seq_length

    def __getitem__(self, idx: Union[int, slice]) -> Dict[str, torch.Tensor]:
        """Fixed to handle both integer and slice indexing"""
        if isinstance(idx, slice):
            # Handle slice case
            start = idx.start if idx.start is not None else 0
            stop = idx.stop if idx.stop is not None else len(self)
            step = idx.step if idx.step is not None else 1

            inputs = []
            targets = []
            for i in range(start, stop, step):
                seq_input = self.data[i:i + self.seq_length]
                seq_target = self.data[i + 1:i + self.seq_length + 1]
                inputs.append(seq_input)
                targets.append(seq_target)

            return {
                'input': torch.tensor(inputs, dtype=torch.long),
                'target': torch.tensor(targets, dtype=torch.long)
            }
        else:
            # Handle integer index case
            inputs = self.data[idx:idx + self.seq_length]
            targets = self.data[idx + 1:idx + self.seq_length + 1]
            return {
                'input': torch.tensor(inputs, dtype=torch.long),
                'target': torch.tensor(targets, dtype=torch.long)
            }

    def test_tokenization(self, sample_text: str) -> bool:
        """Test round-trip tokenization for debugging."""
        encoded = self.tokenizer.encode(sample_text)
        decoded = self.decode(encoded.ids)

        print("[Original]:", repr(sample_text))
        print("\n[Tokens]:")
        for token, id in zip(encoded.tokens, encoded.ids):
            print(f"{repr(token):<20} (ID: {id})")
        print("\n[Decoded]:", repr(decoded))
        print("\n[Match]:", sample_text == decoded)
        return sample_text == decoded

# Initialize with custom vocabulary size
dataset = ShakespeareWordPieceVocab(
    file_path='/content/tiny_shakespeare.txt',
    vocab_size=15000,
    seq_length=128
)

# Contraction test
sample1 = "I'll die,\nAnd if thou wilt, thou may'st be merciful."
print("=== Contraction Test ===")
dataset.test_tokenization(sample1)

# Whitespace test
sample2 = "  ACT I\n\tScene 1:\n  Enter ROMEO\n\nROMEO:\n  'Tis morning!"
print("\n=== Whitespace Test ===")
dataset.test_tokenization(sample2)

# Complex test
sample3 = "' are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you'"
print("\n=== Complex Test ===")
dataset.test_tokenization(sample3)

# Get a batch
batch = dataset[0:5]  # Gets first 5 sequences
print("Input shape:", batch['input'].shape)
print("Target shape:", batch['target'].shape)

Vocabulary size: 14971 (Target: 15000)
=== Contraction Test ===
[Original]: "I'll die,\nAnd if thou wilt, thou may'st be merciful."

[Tokens]:
'[CLS]'              (ID: 1)
'I'                  (ID: 46)
"'ll"                (ID: 7)
'die,'               (ID: 2627)
'\n'                 (ID: 14)
'And'                (ID: 207)
'if'                 (ID: 354)
'thou'               (ID: 220)
'wilt,'              (ID: 9165)
'thou'               (ID: 220)
'may'                (ID: 414)
"'st"                (ID: 12)
'be'                 (ID: 179)
'merci'              (ID: 8508)
'##ful.'             (ID: 12639)
'[SEP]'              (ID: 2)

[Decoded]: "I'll die,\nAnd if thou wilt, thou may'st be merciful."

[Match]: True

=== Whitespace Test ===
[Original]: "  ACT I\n\tScene 1:\n  Enter ROMEO\n\nROMEO:\n  'Tis morning!"

[Tokens]:
'[CLS]'              (ID: 1)
'  '                 (ID: 16)
'A'                  (ID: 38)
'##C'                (ID: 137)
'##T'                (ID: 131)
'I'                

#### **Observation:**

When vocabulary reduction was implemented, the complex test resulted in a match, unexpectedly, an improvement from before! Logically speaking, by not restricting the vocabulary size, it should enable the tokenizer to better adapt to the specific characters and word constructions found in the data (like contractions and specific punctuation). This means that rare or unusual contractions should be treated more accurately as tokens and are not split further, which should better retain the proper spacing and punctuation.

Instead, the reverse was observed (at least based on the above examples), which can be contributed by the following factors:
1. Generalization vs. Overfitting:
In the uncapped version, the large vocabulary size allowed the tokenizer to split words into many small pieces (subwords), sometimes even for common words or contractions. This is typically more flexible, as it can handle previously unseen words. However, when this flexibility is pushed too far, the tokenizer ends up creating subwords that might not reflect the natural linguistic structure of the data (like splitting contractions into pieces such as don##' t or I##' m).

    On the other hand, limiting the vocabulary forces the tokenizer to generalize better. With a smaller set of tokens, the model avoids overfitting on specific subword patterns and instead learns to combine subwords in ways that are more natural and coherent in a wider context. This leads to more consistent tokenization of common words, contractions, and punctuation marks, improving the quality of the output during decoding.

2. Handling of Contractions and Punctuation:
Contractions like don't, I'm, you're, etc., are a special challenge for tokenizers because they contain apostrophes or other punctuation marks. In the uncapped vocabulary version, the tokenizer might treat the apostrophe (') as a separate entity, leading to splits like don ##' t or I ##' m. This can break the natural structure of the word and cause the decoding process to misalign the apostrophe or create unwanted spaces, making the output look unnatural.

    With vocab reduction, the tokenizer likely learned to treat the contractions more holistically, recognizing don't as one or two units (don and 't or as a whole depending on frequency), keeping the punctuation intact. This avoids issues where the apostrophe gets split into separate tokens or wrongly attached to surrounding characters.

3. Impact of Special Tokens and Frequent Patterns:
In the uncapped version, with a larger vocabulary, the tokenizer might have been overly focused on learning rare tokens or creating subword fragments for every character combination it encountered. As a result, common special tokens like [UNK], [PAD], and punctuation marks (such as ,, ., !, etc.) might not have been learned as effectively or might have been given less importance.

    When the vocabulary is reduced, the tokenizer is forced to focus on the most frequent patterns and words. As a result, it is likely to learn the most important subwords that make up common contractions, words, and punctuation. This ensures that it retains the structure of the text more effectively, preserving meaningful relationships between tokens (like apostrophes) rather than splitting them in unnatural ways.

4. Balance Between Flexibility and Robustness:
By reducing the vocabulary size, the tokenizer strikes a balance between flexibility (the ability to learn unseen words) and robustness (the ability to process known patterns correctly). It avoids the over-flexibility that leads to excessive token splitting, which is common in an uncapped vocabulary.

    A smaller vocabulary pushes the tokenizer to learn more abstract, larger subwords, which can represent entire words or parts of words (like contractions) without excessive fragmentation. This leads to better reassembly of the text during decoding, improving the final output quality.

Thus, instances from the ShakespeareWordPieceVocab class will be used for subsequent analyses, specifically with a vocabulary reduction size of 15000.

#### Sequence length variation
Experimenting with sequence length variation is meaningful in the context of training models, especially for tasks like text generation or sequence prediction, because it directly influences how the model processes and learns from the data. Sequence length determines how much context the model can capture at each time step and affects the model's ability to learn temporal or sequential dependencies is highly dependent on the sequence length. Thus, in this section, the aim is to get an idea of how the segmented words may look like, and decide what sequence length may be ideal for the project.

In [None]:
class ShakespeareWordPieceSegment(Dataset):
    def __init__(self, file_path, seq_length=100, split='train'):
        self.seq_length = seq_length
        self.tokenizer = self._init_tokenizer(file_path)

        with open(file_path, 'r', encoding='utf-8') as f:
            self.raw_text = f.read()

        # Split data (80-10-10)
        train_end = int(0.8 * len(self.raw_text))
        val_end = int(0.9 * len(self.raw_text))
        self.text = {
            'train': self.raw_text[:train_end],
            'val': self.raw_text[train_end:val_end],
            'test': self.raw_text[val_end:],
        }[split]

        self.data = self.tokenizer.encode(self.text).ids

    def _init_tokenizer(self, file_path):
        tokenizer_path = 'shakespeare-wordpiece-uncapped.json'  # Renamed to reflect uncapped vocab

        # Define special tokens (treated as mandatory, but not counted against a limit)
        special_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "[PAD]",
            *[f"'{x}" for x in ['s', 't', 'll', 'd', 'm', 're', 've', 'st', 'Tis']],
            "\n", "\t", "  ", "   ", ":", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0"
        ]

        if os.path.exists(tokenizer_path):
            tokenizer = Tokenizer.from_file(tokenizer_path)
        else:
            tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))
            tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
                pre_tokenizers.Split(r'(?<=\n)|(?=\n)', behavior='isolated'),
                pre_tokenizers.Split(r'(?<=\t)|(?=\t)', behavior='isolated'),
                pre_tokenizers.Split(r' +', behavior='isolated'),
                pre_tokenizers.Split(r"('(?:[sstdm]|ll|re|ve|st|Tis)\b)", behavior='isolated')
            ])
            trainer = trainers.WordPieceTrainer(
                continuing_subword_prefix="##",
                special_tokens=special_tokens,
                show_progress=True,
                min_frequency=2
            )
            tokenizer.train([file_path], trainer=trainer)
            tokenizer.save(tokenizer_path)

        # Report actual vocabulary size (now uncapped)
        vocab_size = len(tokenizer.get_vocab())
        print(f'Vocabulary size: {vocab_size} (Uncapped, learned from data)')

        # Post-processing (unchanged)
        cls_id = tokenizer.token_to_id("[CLS]")
        sep_id = tokenizer.token_to_id("[SEP]")
        if cls_id is not None and sep_id is not None:
            tokenizer.post_processor = processors.TemplateProcessing(
                single="[CLS] $A [SEP]",
                pair="[CLS] $A [SEP] $B [SEP]",
                special_tokens=[("[CLS]", cls_id), ("[SEP]", sep_id)]
            )

        return tokenizer

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        # Use seq_length for variable sequences
        inputs = self.data[idx:idx + self.seq_length]
        targets = self.data[idx + 1:idx + self.seq_length + 1]
        return {
            'input': torch.tensor(inputs, dtype=torch.long),
            'target': torch.tensor(targets, dtype=torch.long)
        }

    def print_varied_length_sequences(self):
        """
        Print sequences at three distinct length levels to showcase segmentation
        - Short (50-70 tokens): Shows basic word-level segmentation
        - Medium (100-120 tokens): Shows phrase-level patterns
        - Long (150-200 tokens): Shows paragraph-level structure
        """
        length_levels = [
            ("SHORT", random.randint(50, 70)),
            ("MEDIUM", random.randint(100, 120)),
            ("LONG", random.randint(150, 200))
        ]

        for level_name, seq_len in length_levels:
            max_start = len(self.data) - seq_len - 1
            if max_start <= 0:
                print(f"Skip {level_name} - dataset too small")
                continue

            start_idx = random.randint(0, max_start)
            input_seq = self.data[start_idx:start_idx + seq_len]
            decoded_seq = self.decode(input_seq)

            print(f"\n[{level_name}] Sequence Length: {seq_len}")
            print("=" * 80)
            print(decoded_seq)
            print("=" * 80)
            print("\nToken Breakdown:")
            print("-" * 40)
            for i, token_id in enumerate(input_seq):
                token = self.tokenizer.id_to_token(token_id)
                print(f"{i:2d}: {token_id:4d} -> {token}")
            print(f"{seq_len} tokens\n")

    def decode(self, token_ids):
        tokens = [self.tokenizer.id_to_token(id) for id in token_ids]
        text = ""
        for i, token in enumerate(tokens):
            if token in ["[CLS]", "[SEP]", "[PAD]"]:
                continue
            elif token.startswith("##"):
                text += token[2:]
            elif token in ["\n", "\t", "  ", "   "]:
                text += token
            elif (
                len(text) > 0
                and not text.endswith((" ", "\n", "\t"))
                and not token.startswith("'")
            ):
                text += " " + token
            else:
                text += token
        text = re.sub(r" ([.,;:!?])", r"\1", text)  # Remove space before punctuation
        text = re.sub(r" +", " ", text)  # Remove excessive spaces
        return text.strip()

dataset = ShakespeareWordPieceSegment(file_path='shakespeare.txt', split='train')
dataset.print_varied_length_sequences()

Vocabulary size: 30000 (Uncapped, learned from data)

[SHORT] Sequence Length: 58
wolves.
Had I been there, which am a silly woman,
The soldiers should have toss'd me on their pikes
Before I would have granted to that act.
But thou preferr'st thy life before thine honour:
And seeing thou dost, I here divorce myself
Both from thy

Token Breakdown:
----------------------------------------
 0: 7287 -> ##wol
 1: 5114 -> ##ves
 2:  116 -> ##.
 3:   13 -> 

 4: 9815 -> Had I
 5: 9507 -> ## been 
 6: 6851 -> ##there, 
 7: 3517 -> ##which 
 8: 4508 -> ##am 
 9: 1290 -> ##a s
10: 9244 -> ##illy 
11: 2123 -> ##woman
12:  100 -> ##,
13:   13 -> 

14: 15832 -> The s
15: 1573 -> ##oldi
16: 1034 -> ##ers 
17: 4790 -> ##should have 
18: 16174 -> ##toss
19:    7 -> 'd
20:   28 ->  
21:  273 -> ##me 
22: 7365 -> ##on their
23: 18705 -> ## pik
24:  177 -> ##es
25:   13 -> 

26: 2749 -> Before 
27: 3068 -> ##I would 
28:  259 -> ##have 
29: 9965 -> ##granted 
30: 15194 -> ##to that 
31: 6973 -> ##act
32:

#### Insight

The choice of sequence length is a critical factor in model performance. For Shakespeare's text, which contains rich poetic and narrative structures, it is important to strike a balance between short and long sequences. Short sequences (50-70 tokens) may help the model learn more granular relationships at the word level, but they might not capture enough context for understanding longer sentence structures or dialogues. On the other hand, longer sequences (150-200 tokens) allow the model to capture more comprehensive context, which is crucial for understanding the flow of sentences and paragraphs. However, longer sequences also increase memory usage and training time, which may pose challenges for large datasets and computational constraints.

Based on the nature of the text, a sequence length of 128 will be used to strike a balance between capturing sufficient context and keeping the model's complexity manageable.

### **Architecture**
The ShakespeareRNNBase architecture is designed for sequence generation tasks, such as text generation, and incorporates several key components to handle the complexities of modeling language.

#### Embedding Layer:
The architecture starts with an embedding layer that converts each token in the vocabulary into a dense, continuous vector. These embeddings are learned during training and allow the model to represent words in a lower-dimensional space. This helps the model capture semantic relationships between words, where similar words are represented by vectors that are close together in this space. The embedding layer is crucial because it transforms sparse, high-dimensional representations (like one-hot encodings) into more efficient and meaningful representations.

#### Recurrent Neural Network (RNN):
At the core of the architecture is a Recurrent Neural Network (RNN), which processes the sequence of word embeddings. RNNs are specifically designed to handle sequential data, making them suitable for tasks like text generation where the model needs to consider previous words (or tokens) to predict the next. The RNN layer can be an LSTM (Long Short-Term Memory) or a GRU (Gated Recurrent Unit), which are both advanced types of RNNs designed to address the issue of vanishing gradients, making them more effective at capturing long-range dependencies in sequences.

The RNN is designed to capture temporal dependencies in the data, processing each token in sequence while maintaining a hidden state that summarizes the context of the sequence so far. This hidden state is updated at each time step as the model processes new tokens, allowing it to "remember" previous context and make more informed predictions. The architecture can also include multiple layers of RNNs, which helps the model capture more complex patterns and representations at different levels of abstraction.

#### Bidirectionality:
A significant aspect of the architecture is the option to make the RNN bidirectional. In a bidirectional RNN, the sequence is processed in both the forward and backward directions. This means that the model can access information from both past and future context at each time step, providing a richer understanding of the input sequence. However, while bidirectional RNNs are powerful for certain tasks like text classification, they are not always suitable for autoregressive tasks, where each token should only depend on previous tokens in the sequence. In such cases, bidirectional RNNs can introduce information leakage from future tokens, which can interfere with the autoregressive nature of the model. This is why, in some configurations, bidirectional processing may be deactivated.

#### Output Layer:
The output of the RNN is then passed through a fully connected (linear) layer, which maps the hidden states to a vocabulary-sized output. This layer generates a probability distribution over the vocabulary for each token in the sequence, allowing the model to predict the next word or token. The final output is a sequence of predicted token probabilities, which can be used to generate text or perform other sequence-based tasks.

#### Dropout:
To prevent overfitting, the model includes dropout layers. Dropout is a regularization technique where a random subset of the neurons is "dropped" (set to zero) during training. This forces the model to not rely too heavily on specific neurons, improving its generalization ability. Dropout is applied both in the embedding layer and between RNN layers, ensuring that the model learns more robust features.

#### Hidden State Initialization:
The model also defines an initialization function for the RNN’s hidden state. Before processing each sequence, the hidden state is initialized to zeros (or to zeros and ones for LSTM). This ensures that the model starts with no prior information about the sequence and is ready to process new data. For sequences with multiple time steps, maintaining and updating this hidden state is crucial for the RNN’s ability to remember past context.

#### Autoregressive Forward Pass:
For text generation tasks, the model implements an autoregressive forward pass, which generates one token at a time based on previous predictions. Initially, the first token in the sequence is given as input, and subsequent tokens are generated based on the model's own predictions. The model can optionally use teacher forcing, where the true previous token is fed as input during training rather than the model's prediction. Teacher forcing can help the model learn faster and avoid compounding errors, though it is typically turned off during inference to allow the model to generate text on its own. In the case of bidirectional RNNs, teacher forcing must be disabled because the model uses both past and future context, violating the autoregressive principle.

In [5]:
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class ShakespeareRNNBase(nn.Module):
    def __init__(self, wp_vocab_size, embed_dim=128, hidden_dim=256, num_layers=2,
                 rnn_type='lstm', bidirectional=False, dropout=0.2):
        super().__init__()
        self.wp_vocab_size = wp_vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.rnn_type = rnn_type.lower()

        # Embedding layer
        self.wp_embed = nn.Embedding(wp_vocab_size, embed_dim, padding_idx=0)

        # RNN layer
        self.rnn = getattr(nn, rnn_type.upper())(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        # Output layer
        self.fc = nn.Linear(hidden_dim * (2 if bidirectional else 1), wp_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def init_hidden(self, batch_size):
        """Initialize hidden state for the RNN"""
        device = next(self.parameters()).device
        num_directions = 2 if self.bidirectional else 1
        shape = (self.num_layers * num_directions, batch_size, self.hidden_dim)

        if self.rnn_type == 'lstm':
            return (torch.zeros(shape, device=device),
                    torch.zeros(shape, device=device))
        return torch.zeros(shape, device=device)

    def forward(self, x, hidden=None):
        """
        Standard forward pass (batch processing).
        Args:
            x: input tensor [batch_size, seq_len]
            hidden: previous hidden state
        """
        embedded = self.dropout(self.wp_embed(x))  # [batch_size, seq_len, embed_dim]
        rnn_out, hidden = self.rnn(embedded, hidden)
        outputs = self.fc(self.dropout(rnn_out))  # [batch_size, seq_len, vocab_size]
        return outputs, hidden

    def forward_autoregressive(self, x, hidden=None, targets=None, tf_ratio=0.5):
        """
        Autoregressive forward pass with optional teacher forcing.
        Args:
            x: input tensor [batch_size, seq_len]
            hidden: previous hidden state
            targets: target tensor [batch_size, seq_len] (for teacher forcing)
            tf_ratio: teacher forcing ratio (0.0-1.0)
        """
        # Fall back to standard forward pass if bidirectional
        if self.bidirectional:
            return self.forward(x, hidden)

        if targets is not None and x.size(1) != targets.size(1):
            raise ValueError("Input and targets must have the same sequence length")

        batch_size = x.size(0)
        seq_len = x.size(1)
        outputs = []

        if hidden is None:
            hidden = self.init_hidden(batch_size)

        # Process first token
        input_step = x[:, 0:1]  # [batch_size, 1]

        for t in range(seq_len):
            embedded = self.dropout(self.wp_embed(input_step))  # [batch_size, 1, embed_dim]
            rnn_out, hidden = self.rnn(embedded, hidden)
            output = self.fc(self.dropout(rnn_out))  # [batch_size, 1, vocab_size]
            outputs.append(output)

            # Decide next input (if not last step)
            if t < seq_len - 1:
                if targets is not None and random.random() < tf_ratio:
                    input_step = targets[:, t+1:t+2]  # Teacher forcing
                else:
                    input_step = output.argmax(-1)  # Model prediction

        outputs = torch.cat(outputs, dim=1)  # [batch_size, seq_len, vocab_size]
        return outputs, hidden

class ShakespeareRNN(ShakespeareRNNBase):
    def __init__(self, wp_vocab_size, embed_dim=128, hidden_dim=256, num_layers=2,
                 bidirectional=False, dropout=0.2):
        super().__init__(wp_vocab_size, embed_dim, hidden_dim, num_layers,
                        rnn_type='rnn', bidirectional=bidirectional, dropout=dropout)

class ShakespeareLSTM(ShakespeareRNNBase):
    def __init__(self, wp_vocab_size, embed_dim=128, hidden_dim=256, num_layers=2,
                 bidirectional=False, dropout=0.2):
        super().__init__(wp_vocab_size, embed_dim, hidden_dim, num_layers,
                        rnn_type='lstm', bidirectional=bidirectional, dropout=dropout)

class ShakespeareGRU(ShakespeareRNNBase):
    def __init__(self, wp_vocab_size, embed_dim=128, hidden_dim=256, num_layers=2,
                 bidirectional=False, dropout=0.2):
        super().__init__(wp_vocab_size, embed_dim, hidden_dim, num_layers,
                        rnn_type='gru', bidirectional=bidirectional, dropout=dropout)

#### Training loop
The training loop implements sequence prediction with teacher forcing, which is particularly useful for autoregressive tasks like text generation. At each training step, the model is set to training mode, and the loss and token counters are initialized to track performance across the epoch.

For each batch, input and target sequences are moved to the appropriate device (e.g., GPU). The loop uses teacher forcing, where the model sometimes receives the ground truth token as its next input rather than relying solely on its own predictions. This technique helps stabilize training and improve convergence by reducing error accumulation over time. As mentioned above, in the event bidirectional=True, teacher forcing will need to be deactivated.

Loss is computed token by token within the sequence, which allows for fine-grained control and the opportunity to vary inputs dynamically with or without teacher forcing at each timestep. The accumulated loss is then normalized by sequence length to ensure consistency across batches of different sizes.

Gradient accumulation is used to effectively simulate a larger batch size without increasing memory usage, by dividing the loss and performing an optimizer step only after a set number of batches. To prevent exploding gradients, gradient clipping is applied before the optimizer update.

Throughout training, loss is accumulated and normalized by the total number of tokens to give a meaningful measure of performance over the entire epoch.

In [None]:
def calculate_perplexity(loss: torch.Tensor):
    """Calculate perplexity from the loss value."""
    return math.exp(loss.item()) if loss.item() < 100 else float('inf')

def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    teacher_forcing_ratio: float = 0.5,
    max_grad_norm: float = 5.0,
    accumulation_steps: int = 1
):
    """Training epoch with teacher forcing for sequence prediction."""
    model.train()
    total_loss = 0
    total_tokens = 0
    optimizer.zero_grad()

    # Wrap the DataLoader with tqdm to show a progress bar
    for i, batch in enumerate(tqdm(loader, desc="Training Batch", leave=False), 1):
        inputs = batch['input'].to(device)  # [batch_size, seq_len]
        targets = batch['target'].to(device)  # [batch_size, seq_len]

        # Initialize with start token (assuming first token is start)
        decoder_input = inputs[:, 0:1]  # [batch_size, 1]
        hidden = None
        batch_loss = 0

        # Autoregressive generation with teacher forcing
        for t in range(targets.size(1)):
            # Forward pass
            outputs, hidden = model(decoder_input, hidden)

            # Calculate loss for current timestep
            loss = criterion(outputs.view(-1, model.wp_vocab_size),
                           targets[:, t])
            batch_loss += loss

            # Teacher forcing decision
            use_teacher_forcing = random.random() < teacher_forcing_ratio
            if use_teacher_forcing and t < targets.size(1) - 1:
                decoder_input = targets[:, t:t+1]  # Teacher forcing
            else:
                decoder_input = outputs.argmax(-1).detach()  # Model prediction

        # Normalise loss by sequence length
        batch_loss = batch_loss / targets.size(1)

        # Backward pass
        (batch_loss / accumulation_steps).backward()

        if i % accumulation_steps == 0:
            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

        total_loss += batch_loss.item() * targets.size(1)
        total_tokens += inputs.size(0) * targets.size(1)

    return total_loss / total_tokens

def evaluate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: torch.device
):
    """Enhanced evaluation with metrics tracking."""
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", leave=False):
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)

            outputs, _ = model(inputs)
            loss = criterion(outputs.view(-1, model.wp_vocab_size),
                           targets.view(-1))
            total_loss += loss.item() * targets.numel()
            total_tokens += targets.numel()

    avg_loss = total_loss / total_tokens
    perplexity = calculate_perplexity(torch.tensor(avg_loss))
    return avg_loss, perplexity

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 10,
    lr: float = 0.001,
    patience: int = 3,
    min_delta: float = 0.01,
    model_name: str = 'model',
    teacher_forcing_ratio: float = 0.5,
    max_grad_norm: float = 5.0,
    accumulation_steps: int = 1
):
    """Enhanced training with early stopping and learning rate scheduling."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Initialize optimizers
    optimizers = {
        'Adam': optim.Adam(model.parameters(), lr=lr),
        'RMSprop': optim.RMSprop(model.parameters(), lr=lr)
    }

    # Initialize schedulers
    schedulers = {
        name: optim.lr_scheduler.ReduceLROnPlateau(
            opt, mode='min', factor=0.5, patience=2
        )
        for name, opt in optimizers.items()
    }

    criterion = nn.CrossEntropyLoss(ignore_index=0)

    best_model = None
    best_perplexity = float('inf')
    best_optimizer = None
    no_improve = 0
    history = {
        'train_loss': [],
        'val_loss': [],
        'perplexity': [],
        'lr': [],
        'optimizer': []
    }

    # Wrap the epoch loop with tqdm for progress tracking
    for epoch in tqdm(range(epochs), desc="Epochs", leave=True):
        epoch_results = {}

        # Train and evaluate with each optimizer
        for opt_name, optimizer in optimizers.items():
            # Training
            print(f'Training with {opt_name} optimizer...')
            train_loss = train_epoch(model, train_loader, optimizer, criterion, device, teacher_forcing_ratio=teacher_forcing_ratio, max_grad_norm=max_grad_norm, accumulation_steps=accumulation_steps)

            # Validation
            val_loss, perplexity = evaluate(model, val_loader, criterion, device)

            # Update learning rate
            schedulers[opt_name].step(val_loss)

            # Store results
            epoch_results[opt_name] = {
                'train_loss': train_loss,
                'val_loss': val_loss,
                'perplexity': perplexity,
                'lr': optimizer.param_groups[0]['lr']
            }

        # Find best optimizer for this epoch
        best_opt_epoch = min(
            epoch_results.keys(),
            key=lambda x: epoch_results[x]['perplexity']
        )

        # Update history
        history['train_loss'].append(epoch_results[best_opt_epoch]['train_loss'])
        history['val_loss'].append(epoch_results[best_opt_epoch]['val_loss'])
        history['perplexity'].append(epoch_results[best_opt_epoch]['perplexity'])
        history['lr'].append(epoch_results[best_opt_epoch]['lr'])
        history['optimizer'].append(best_opt_epoch)

        # Print losses for each optimizer in this epoch
        print(f"Epoch {epoch+1}/{epochs}:")
        for opt_name, results in epoch_results.items():
            print(f"  {opt_name} - "
                  f"Train Loss: {results['train_loss']:.4f} | "
                  f"Val Loss: {results['val_loss']:.4f} | "
                  f"Perplexity: {results['perplexity']:.2f} | "
                  f"LR: {results['lr']:.2e}")

        # Check for improvement
        current_perplexity = epoch_results[best_opt_epoch]['perplexity']
        if current_perplexity < (best_perplexity - min_delta):
            best_perplexity = current_perplexity
            best_model = deepcopy(model.state_dict())
            best_optimizer = best_opt_epoch
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"No improvement for {patience} epochs, stopping early")
                break

    # Load best model weights
    if best_model:
        model.load_state_dict(best_model)
        torch.save(best_model, f"best_{model_name}.pth")
        print(f'Saved best {model_name} weights to best_{model_name}.pth')

    return {
        'model': model,
        'history': history,
        'best_perplexity': best_perplexity,
        'best_optimizer': best_optimizer
    }


# Initialize dataset to creates/loads the tokenizer
print('Train set:')
train_dataset = ShakespeareWordPieceVocab(
    file_path='/content/tiny_shakespeare.txt',
    vocab_size=15000,
    split='train'
)

# Get vocab size from the trained tokenizer
wp_vocab_size = train_dataset.vocab_size

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
print('\nValidation set:')
val_loader = DataLoader(
    ShakespeareWordPieceVocab('/content/tiny_shakespeare.txt', vocab_size=15000, split='val'),
    batch_size=128
)

# --- UNIDIRECTIONAL MODEL TRAINING (GRU) ---
print("\n=== TRAINING GRU MODEL ===")

# Initialize GRU model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gru_model = ShakespeareGRU(wp_vocab_size).to(device)

# Train GRU
gru_results_ft = train_model(
    gru_model,
    train_loader,
    val_loader,
    epochs=2,
    lr=0.001,
    patience=3,
    model_name= 'gru_tf'
)

# Show GRU results
print(f"\nGRU Results with Teacher Forcing:")
print(f"Best Perplexity: {gru_results_ft['best_perplexity']:.2f}")
print(f"Best Optimizer: {gru_results_ft['best_optimizer']}")

Train set:
Vocabulary size: 14971 (Target: 15000)

Validation set:
Vocabulary size: 14971 (Target: 15000)

=== TRAINING GRU MODEL ===


Epochs:   0%|          | 0/2 [00:00<?, ?it/s]

Training with Adam optimizer...



Training Batch:   0%|          | 0/1609 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1609 [00:00<06:53,  3.89it/s][A
Training Batch:   0%|          | 2/1609 [00:00<06:11,  4.32it/s][A
Training Batch:   0%|          | 3/1609 [00:00<05:54,  4.54it/s][A
Training Batch:   0%|          | 4/1609 [00:00<05:45,  4.65it/s][A
Training Batch:   0%|          | 5/1609 [00:01<05:37,  4.75it/s][A
Training Batch:   0%|          | 6/1609 [00:01<05:36,  4.76it/s][A
Training Batch:   0%|          | 7/1609 [00:01<05:33,  4.80it/s][A
Training Batch:   0%|          | 8/1609 [00:01<05:34,  4.79it/s][A
Training Batch:   1%|          | 9/1609 [00:01<05:30,  4.84it/s][A
Training Batch:   1%|          | 10/1609 [00:02<05:30,  4.84it/s][A
Training Batch:   1%|          | 11/1609 [00:02<05:29,  4.85it/s][A
Training Batch:   1%|          | 12/1609 [00:02<05:28,  4.87it/s][A
Training Batch:   1%|          | 13/1609 [00:02<05:33,  4.79it/s][A
Training Batch:   1%|          | 14/1609 [00:03<05:

Training with RMSprop optimizer...



Training Batch:   0%|          | 0/1609 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1609 [00:00<06:51,  3.91it/s][A
Training Batch:   0%|          | 2/1609 [00:00<06:07,  4.38it/s][A
Training Batch:   0%|          | 3/1609 [00:00<05:59,  4.47it/s][A
Training Batch:   0%|          | 4/1609 [00:00<05:49,  4.59it/s][A
Training Batch:   0%|          | 5/1609 [00:01<05:45,  4.64it/s][A
Training Batch:   0%|          | 6/1609 [00:01<05:45,  4.64it/s][A
Training Batch:   0%|          | 7/1609 [00:01<05:42,  4.68it/s][A
Training Batch:   0%|          | 8/1609 [00:01<05:39,  4.71it/s][A
Training Batch:   1%|          | 9/1609 [00:01<05:38,  4.72it/s][A
Training Batch:   1%|          | 10/1609 [00:02<05:37,  4.74it/s][A
Training Batch:   1%|          | 11/1609 [00:02<05:36,  4.75it/s][A
Training Batch:   1%|          | 12/1609 [00:02<05:36,  4.74it/s][A
Training Batch:   1%|          | 13/1609 [00:02<05:37,  4.73it/s][A
Training Batch:   1%|          | 14/1609 [00:03<05:

Epoch 1/2:
  Adam - Train Loss: 0.0495 | Val Loss: 5.9245 | Perplexity: 374.08 | LR: 1.00e-03
  RMSprop - Train Loss: 0.0411 | Val Loss: 6.1165 | Perplexity: 453.28 | LR: 1.00e-03
Training with Adam optimizer...



Training Batch:   0%|          | 0/1609 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1609 [00:00<06:09,  4.35it/s][A
Training Batch:   0%|          | 2/1609 [00:00<05:51,  4.58it/s][A
Training Batch:   0%|          | 3/1609 [00:00<05:48,  4.60it/s][A
Training Batch:   0%|          | 4/1609 [00:00<05:45,  4.65it/s][A
Training Batch:   0%|          | 5/1609 [00:01<05:47,  4.61it/s][A
Training Batch:   0%|          | 6/1609 [00:01<05:44,  4.65it/s][A
Training Batch:   0%|          | 7/1609 [00:01<05:42,  4.67it/s][A
Training Batch:   0%|          | 8/1609 [00:01<05:43,  4.66it/s][A
Training Batch:   1%|          | 9/1609 [00:01<05:41,  4.68it/s][A
Training Batch:   1%|          | 10/1609 [00:02<05:41,  4.68it/s][A
Training Batch:   1%|          | 11/1609 [00:02<05:40,  4.70it/s][A
Training Batch:   1%|          | 12/1609 [00:02<05:39,  4.70it/s][A
Training Batch:   1%|          | 13/1609 [00:02<05:38,  4.71it/s][A
Training Batch:   1%|          | 14/1609 [00:02<05:

Training with RMSprop optimizer...



Training Batch:   0%|          | 0/1609 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1609 [00:00<06:06,  4.38it/s][A
Training Batch:   0%|          | 2/1609 [00:00<05:53,  4.55it/s][A
Training Batch:   0%|          | 3/1609 [00:00<05:52,  4.56it/s][A
Training Batch:   0%|          | 4/1609 [00:00<05:49,  4.59it/s][A
Training Batch:   0%|          | 5/1609 [00:01<05:48,  4.60it/s][A
Training Batch:   0%|          | 6/1609 [00:01<05:44,  4.66it/s][A
Training Batch:   0%|          | 7/1609 [00:01<05:43,  4.67it/s][A
Training Batch:   0%|          | 8/1609 [00:01<05:42,  4.68it/s][A
Training Batch:   1%|          | 9/1609 [00:01<05:41,  4.68it/s][A
Training Batch:   1%|          | 10/1609 [00:02<05:40,  4.70it/s][A
Training Batch:   1%|          | 11/1609 [00:02<06:11,  4.30it/s][A
Training Batch:   1%|          | 12/1609 [00:02<06:26,  4.14it/s][A
Training Batch:   1%|          | 13/1609 [00:02<06:35,  4.04it/s][A
Training Batch:   1%|          | 14/1609 [00:03<06:

Epoch 2/2:
  Adam - Train Loss: 0.0372 | Val Loss: 6.2139 | Perplexity: 499.66 | LR: 1.00e-03
  RMSprop - Train Loss: 0.0351 | Val Loss: 6.3447 | Perplexity: 569.44 | LR: 1.00e-03
Saved best gru_tf weights to best_gru_tf.pth

GRU Results with Teacher Forcing:
Best Perplexity: 374.08
Best Optimizer: Adam


#### Observation & Insights

Based on the GRU model implemented with teacher forcing for Shakespeare text generation, a comparison between the Adam and RMSprop optimisers reveals that Adam consistently yields lower validation loss and perplexity across both epochs. This suggests that Adam may be more effective in navigating the loss landscape for this particular task. Adam combines the advantages of both momentum and adaptive learning rate mechanisms, which may give it an edge over RMSprop especially in scenarios with sparse gradients or varying curvature. This makes it particularly suitable for text generation tasks with complex loss surfaces, as seen here. Consequently, to streamline subsequent model comparisons, Adam will be retained as the primary optimiser.

**Epoch Comparison Insights:**

Interestingly, from epoch 1 to epoch 2, both validation loss and perplexity increased for both optimisers. This counterintuitive behavior may be attributed to several factors:
1. Limited Training Epochs: With only two epochs, early fluctuations in validation metrics are expected. The model is still undergoing rapid weight adjustments and hasn't yet stabilized. It is possible that the optimiser initially found a local minimum that seemed promising (as seen in epoch 1), but subsequent updates caused it to overshoot or explore a less optimal region before converging.


2. Instability During Early Training: GRUs, while capable of capturing sequential patterns, might not fully stabilize within the first few epochs—especially in character-level generation tasks where the sequence dependencies are subtle and span across long ranges. Early instability could therefore manifest as inconsistent validation performance.


3. Ineffective Scheduler Impact: If a learning rate scheduler was implemented, its effects would be negligible within such a short training window. Typically, schedulers require several epochs to adjust the learning rate in response to plateaus or loss trends. Here, with only two epochs, its presence may not yet yield tangible benefits.

Since both validation loss and perplexity with Adam optimiser are lower than with RMSprop, to streamline subsequent training processes, only Adam optimiser will be used for comparison of subsequent models.

*Disclaimer:*
  - Epochs = 2 was used as the above code chunks were rather computationally heavy on the system and resulted in several crashes even before 2 epochs could be completed. Thus, higher epoch numbers were omitted in all the experiments.
  - There is also a notable difference in magnitude between training andvalidation loss. This is likely the result of training and validation processes using fundamentally different processed. Thus, the relative trends across epochs will be used as a basis of comparison rather than the absolute numbers.

#### GRU Model with no teacher forcing

In [22]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = 10,
    lr: float = 0.001,
    patience: int = 3,
    min_delta: float = 0.01,
    model_name: str = 'model',
    teacher_forcing_ratio: float = 0.5,
    max_grad_norm: float = 5.0,
    accumulation_steps: int = 1,
    optimizer_name: str = 'Adam'  # choose between 'Adam' or 'RMSprop' depending on the better performance found above
):
    """
    Streamlined training loop with chosen optimiser based on results above for direct comparison with teacher forcing
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Select single optimizer
    if optimizer_name == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif optimizer_name == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=lr)
    else:
        raise ValueError("Unsupported optimizer. Choose 'Adam' or 'RMSprop'")

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )

    criterion = nn.CrossEntropyLoss(ignore_index=0)

    best_model = None
    best_perplexity = float('inf')
    no_improve = 0
    history = {
        'train_loss': [],
        'val_loss': [],
        'perplexity': [],
        'lr': [],
    }

    for epoch in tqdm(range(epochs), desc="Epochs", leave=True):
        print(f'Training with {optimizer_name} optimizer...')
        train_loss = train_epoch(
            model, train_loader, optimizer, criterion, device,
            teacher_forcing_ratio=teacher_forcing_ratio,
            max_grad_norm=max_grad_norm,
            accumulation_steps=accumulation_steps
        )

        val_loss, perplexity = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_loss)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['perplexity'].append(perplexity)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"Perplexity: {perplexity:.2f} | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        if perplexity < (best_perplexity - min_delta):
            best_perplexity = perplexity
            best_model = deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"No improvement for {patience} epochs, stopping early")
                break

    if best_model:
        model.load_state_dict(best_model)
        torch.save(best_model, f"best_{model_name}.pth")
        print(f'Saved best {model_name} weights to best_{model_name}.pth')

    return {
        'model': model,
        'history': history,
        'best_perplexity': best_perplexity,
        'optimizer': optimizer_name
    }

# GRU without teacher forcing
gru_results = train_model(
    gru_model,
    train_loader,
    val_loader,
    epochs=2,
    teacher_forcing_ratio=0,
    optimizer_name='Adam',
    model_name='gru_norm'
)

# Show GRU results
print(f"\nGRU Results with no Teacher Forcing:")
print(f"Best Perplexity: {gru_results['best_perplexity']:.2f}")

Epochs:   0%|          | 0/2 [00:00<?, ?it/s]

Training with Adam optimizer...



Training Batch:   0%|          | 0/1609 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1609 [00:00<07:47,  3.44it/s][A
Training Batch:   0%|          | 2/1609 [00:00<06:30,  4.12it/s][A
Training Batch:   0%|          | 3/1609 [00:00<06:52,  3.90it/s][A
Training Batch:   0%|          | 4/1609 [00:01<06:53,  3.88it/s][A
Training Batch:   0%|          | 5/1609 [00:01<06:53,  3.88it/s][A
Training Batch:   0%|          | 6/1609 [00:01<06:55,  3.86it/s][A
Training Batch:   0%|          | 7/1609 [00:01<06:52,  3.89it/s][A
Training Batch:   0%|          | 8/1609 [00:02<06:57,  3.83it/s][A
Training Batch:   1%|          | 9/1609 [00:02<07:01,  3.79it/s][A
Training Batch:   1%|          | 10/1609 [00:02<07:06,  3.75it/s][A
Training Batch:   1%|          | 11/1609 [00:02<06:43,  3.96it/s][A
Training Batch:   1%|          | 12/1609 [00:03<06:20,  4.20it/s][A
Training Batch:   1%|          | 13/1609 [00:03<06:05,  4.37it/s][A
Training Batch:   1%|          | 14/1609 [00:03<05:

Epoch 1/2 - Train Loss: 0.0521 | Val Loss: 6.3911 | Perplexity: 596.50 | LR: 1.00e-03
Training with Adam optimizer...



Training Batch:   0%|          | 0/1609 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1609 [00:00<06:16,  4.28it/s][A
Training Batch:   0%|          | 2/1609 [00:00<08:37,  3.11it/s][A
Training Batch:   0%|          | 3/1609 [00:01<13:21,  2.00it/s][A
Training Batch:   0%|          | 4/1609 [00:02<16:19,  1.64it/s][A
Training Batch:   0%|          | 5/1609 [00:02<17:43,  1.51it/s][A
Training Batch:   0%|          | 6/1609 [00:03<14:44,  1.81it/s][A
Training Batch:   0%|          | 7/1609 [00:03<12:18,  2.17it/s][A
Training Batch:   0%|          | 8/1609 [00:03<10:11,  2.62it/s][A
Training Batch:   1%|          | 9/1609 [00:03<08:47,  3.03it/s][A
Training Batch:   1%|          | 10/1609 [00:04<07:51,  3.39it/s][A
Training Batch:   1%|          | 11/1609 [00:04<07:13,  3.68it/s][A
Training Batch:   1%|          | 12/1609 [00:04<06:44,  3.95it/s][A
Training Batch:   1%|          | 13/1609 [00:04<06:25,  4.14it/s][A
Training Batch:   1%|          | 14/1609 [00:04<06:

Epoch 2/2 - Train Loss: 0.0519 | Val Loss: 6.5007 | Perplexity: 665.60 | LR: 1.00e-03
Saved best gru_norm weights to best_gru_norm.pth

GRU Results with no Teacher Forcing:
Best Perplexity: 596.50


KeyError: 'best_optimizer'

#### Observations & Insights

2 GRU models were trained with identical hyperparameters but differed in their teacher forcing ratio: one with teacher forcing enabled (ratio > 0), and one with it disabled (ratio = 0). Across both epochs, the model trained without teacher forcing exhibited higher training loss, validation loss, and perplexity. This outcome is expected and reflects how teacher forcing influences both model performance and the training process, particularly in autoregressive sequence models.

When teacher forcing is used, the model receives the ground truth tokens at each time step during training. This stabilizes the input context, helping the model produce more accurate predictions and faster convergence, as seen by the more pronounced drop in training loss (from 0.0495 to 0.0372). The model benefits from a cleaner optimization path early on, which helps it learn useful sequence-level patterns more efficiently.

In contrast, the model without teacher forcing must use its own predictions as inputs, making it more susceptible to error propagation—a single incorrect token can misguide the entire sequence. This leads to a slower reduction in training loss (only marginally decreasing from 0.0521 to 0.0519) and higher perplexity due to accumulated errors over time steps.

**Important Consideration:**

Although the model with teacher forcing shows lower validation loss and perplexity, the observed difference across only two training epochs may not fully capture the long-term impact of teacher forcing. In early epochs, especially for sequence models, loss values can fluctuate significantly as the model adjusts its weights. It is also possible that the model with teacher forcing simply had a more favorable initialization or learning trajectory in these early steps.

Given more epochs, the model without teacher forcing may gradually improve as it learns to handle noisy inputs, and the performance gap may either widen or narrow depending on how well it generalizes. Therefore, the increase in validation loss and perplexity in the no-teacher-forcing model should be interpreted with caution, as two epochs is insufficient to observe stable convergence behavior.

**Learning points future training strategy:**
1. Teacher forcing provides a strong initial learning signal, helping the model stabilize early in training.


2. Slower convergence without teacher forcing suggests a harder optimization problem, but not necessarily a worse final outcome if trained longer.


3. A scheduled teacher forcing ratio to gradually reduce the reliance on ground truth can help balance early training stability with long-term generalization, allowing the model to transition from guided learning to realistic inference conditions.

Overall, while teacher forcing clearly benefits early training in terms of stability and loss reduction, its full impact, especially relative to models trained without it requires longer training durations to assess fairly.

### Train and Evaluate all 3 architctures - LSTM, RNN, GRU

In [None]:
# --- MULTI-MODEL TRAINING ---
print("\n=== TRAINING ALL MODELS ===")

models = {
    'RNN': ShakespeareRNN(wp_vocab_size, bidirectional=True).to(device),
    'LSTM': ShakespeareLSTM(wp_vocab_size, bidirectional=True).to(device),
    'GRU': ShakespeareGRU(wp_vocab_size, bidirectional=True).to(device)
}

results = {}
for name, model in models.items():
    print(f"\nTraining {name} model...")

    # Train
    model_results = train_model(  # Adam optimizer used by default
        model,
        train_loader,
        val_loader,
        epochs=2,
        lr=0.001,
        patience=3,
        model_name=name,
        teacher_forcing_ratio=0  # Teacher forcing becomes tricky with bidirectional, thus deactivated
    )

    # Store results
    results[name] = {
        'model': model_results['model'],
        'perplexity': model_results['best_perplexity'],
        'optimizer': model_results['optimizer'],
        'history': model_results['history']
    }

    print(f"{name} complete. Best perplexity: {model_results['best_perplexity']:.2f}")


# Compare results
print("\n=== FINAL RESULTS ===")
for model_name, res in results.items():
    print(f"{model_name}:")
    print(f"  Best Perplexity: {res['perplexity']:.2f}")


=== TRAINING ALL MODELS ===

Training RNN model...


Epochs:   0%|          | 0/2 [00:00<?, ?it/s]

Training with Adam optimizer...



Training Batch:   0%|          | 0/1611 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1611 [00:00<15:52,  1.69it/s][A
Training Batch:   0%|          | 2/1611 [00:01<13:27,  1.99it/s][A
Training Batch:   0%|          | 3/1611 [00:01<12:40,  2.12it/s][A
Training Batch:   0%|          | 4/1611 [00:01<12:19,  2.17it/s][A
Training Batch:   0%|          | 5/1611 [00:02<12:05,  2.21it/s][A
Training Batch:   0%|          | 6/1611 [00:02<11:56,  2.24it/s][A
Training Batch:   0%|          | 7/1611 [00:03<11:52,  2.25it/s][A
Training Batch:   0%|          | 8/1611 [00:03<11:49,  2.26it/s][A
Training Batch:   1%|          | 9/1611 [00:04<11:47,  2.26it/s][A
Training Batch:   1%|          | 10/1611 [00:04<11:47,  2.26it/s][A
Training Batch:   1%|          | 11/1611 [00:04<11:46,  2.27it/s][A
Training Batch:   1%|          | 12/1611 [00:05<11:46,  2.26it/s][A
Training Batch:   1%|          | 13/1611 [00:05<11:47,  2.26it/s][A
Training Batch:   1%|          | 14/1611 [00:06<11:

Epoch 1/2 - Train Loss: 0.0526 | Val Loss: 6.6366 | Perplexity: 762.46 | LR: 1.00e-03
Training with Adam optimizer...



Training Batch:   0%|          | 0/1611 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1611 [00:00<12:23,  2.17it/s][A
Training Batch:   0%|          | 2/1611 [00:00<12:18,  2.18it/s][A
Training Batch:   0%|          | 3/1611 [00:01<12:21,  2.17it/s][A
Training Batch:   0%|          | 4/1611 [00:01<12:22,  2.16it/s][A
Training Batch:   0%|          | 5/1611 [00:02<12:22,  2.16it/s][A
Training Batch:   0%|          | 6/1611 [00:02<12:23,  2.16it/s][A
Training Batch:   0%|          | 7/1611 [00:03<12:21,  2.16it/s][A
Training Batch:   0%|          | 8/1611 [00:03<12:19,  2.17it/s][A
Training Batch:   1%|          | 9/1611 [00:04<12:18,  2.17it/s][A
Training Batch:   1%|          | 10/1611 [00:04<12:20,  2.16it/s][A
Training Batch:   1%|          | 11/1611 [00:05<12:21,  2.16it/s][A
Training Batch:   1%|          | 12/1611 [00:05<12:21,  2.16it/s][A
Training Batch:   1%|          | 13/1611 [00:06<12:20,  2.16it/s][A
Training Batch:   1%|          | 14/1611 [00:06<12:

Epoch 2/2 - Train Loss: 0.0523 | Val Loss: 6.3796 | Perplexity: 589.71 | LR: 1.00e-03
Saved best RNN weights to best_RNN.pth
RNN complete. Best perplexity: 589.71

Training LSTM model...


Epochs:   0%|          | 0/2 [00:00<?, ?it/s]

Training with Adam optimizer...



Training Batch:   0%|          | 0/1611 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1611 [00:00<10:50,  2.48it/s][A
Training Batch:   0%|          | 2/1611 [00:00<10:15,  2.61it/s][A
Training Batch:   0%|          | 3/1611 [00:01<10:23,  2.58it/s][A
Training Batch:   0%|          | 4/1611 [00:01<10:13,  2.62it/s][A
Training Batch:   0%|          | 5/1611 [00:01<10:05,  2.65it/s][A
Training Batch:   0%|          | 6/1611 [00:02<10:04,  2.65it/s][A
Training Batch:   0%|          | 7/1611 [00:02<10:02,  2.66it/s][A
Training Batch:   0%|          | 8/1611 [00:03<09:59,  2.67it/s][A
Training Batch:   1%|          | 9/1611 [00:03<10:00,  2.67it/s][A
Training Batch:   1%|          | 10/1611 [00:03<09:56,  2.68it/s][A
Training Batch:   1%|          | 11/1611 [00:04<09:54,  2.69it/s][A
Training Batch:   1%|          | 12/1611 [00:04<09:54,  2.69it/s][A
Training Batch:   1%|          | 13/1611 [00:04<09:53,  2.69it/s][A
Training Batch:   1%|          | 14/1611 [00:05<09:

Epoch 1/2 - Train Loss: 0.0525 | Val Loss: 6.8209 | Perplexity: 916.79 | LR: 1.00e-03
Training with Adam optimizer...



Training Batch:   0%|          | 0/1611 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1611 [00:00<10:04,  2.66it/s][A
Training Batch:   0%|          | 2/1611 [00:00<10:04,  2.66it/s][A
Training Batch:   0%|          | 3/1611 [00:01<10:09,  2.64it/s][A
Training Batch:   0%|          | 4/1611 [00:01<10:05,  2.65it/s][A
Training Batch:   0%|          | 5/1611 [00:01<10:01,  2.67it/s][A
Training Batch:   0%|          | 6/1611 [00:02<10:01,  2.67it/s][A
Training Batch:   0%|          | 7/1611 [00:02<09:59,  2.68it/s][A
Training Batch:   0%|          | 8/1611 [00:02<09:58,  2.68it/s][A
Training Batch:   1%|          | 9/1611 [00:03<09:59,  2.67it/s][A
Training Batch:   1%|          | 10/1611 [00:03<09:57,  2.68it/s][A
Training Batch:   1%|          | 11/1611 [00:04<09:58,  2.67it/s][A
Training Batch:   1%|          | 12/1611 [00:04<09:57,  2.68it/s][A
Training Batch:   1%|          | 13/1611 [00:04<09:57,  2.68it/s][A
Training Batch:   1%|          | 14/1611 [00:05<09:

Epoch 2/2 - Train Loss: 0.0522 | Val Loss: 6.8430 | Perplexity: 937.29 | LR: 1.00e-03
Saved best LSTM weights to best_LSTM.pth
LSTM complete. Best perplexity: 916.79

Training GRU model...


Epochs:   0%|          | 0/2 [00:00<?, ?it/s]

Training with Adam optimizer...



Training Batch:   0%|          | 0/1611 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1611 [00:00<09:34,  2.80it/s][A
Training Batch:   0%|          | 2/1611 [00:00<09:27,  2.83it/s][A
Training Batch:   0%|          | 3/1611 [00:01<09:43,  2.76it/s][A
Training Batch:   0%|          | 4/1611 [00:01<09:37,  2.78it/s][A
Training Batch:   0%|          | 5/1611 [00:01<09:34,  2.79it/s][A
Training Batch:   0%|          | 6/1611 [00:02<09:36,  2.78it/s][A
Training Batch:   0%|          | 7/1611 [00:02<09:33,  2.79it/s][A
Training Batch:   0%|          | 8/1611 [00:02<09:30,  2.81it/s][A
Training Batch:   1%|          | 9/1611 [00:03<09:30,  2.81it/s][A
Training Batch:   1%|          | 10/1611 [00:03<09:30,  2.81it/s][A
Training Batch:   1%|          | 11/1611 [00:03<09:28,  2.81it/s][A
Training Batch:   1%|          | 12/1611 [00:04<09:29,  2.81it/s][A
Training Batch:   1%|          | 13/1611 [00:04<09:28,  2.81it/s][A
Training Batch:   1%|          | 14/1611 [00:04<09:

Epoch 1/2 - Train Loss: 0.0526 | Val Loss: 6.8830 | Perplexity: 975.57 | LR: 1.00e-03
Training with Adam optimizer...



Training Batch:   0%|          | 0/1611 [00:00<?, ?it/s][A
Training Batch:   0%|          | 1/1611 [00:00<09:37,  2.79it/s][A
Training Batch:   0%|          | 2/1611 [00:00<09:35,  2.80it/s][A
Training Batch:   0%|          | 3/1611 [00:01<09:39,  2.77it/s][A
Training Batch:   0%|          | 4/1611 [00:01<09:35,  2.79it/s][A
Training Batch:   0%|          | 5/1611 [00:01<09:34,  2.79it/s][A
Training Batch:   0%|          | 6/1611 [00:02<09:35,  2.79it/s][A
Training Batch:   0%|          | 7/1611 [00:02<09:34,  2.79it/s][A
Training Batch:   0%|          | 8/1611 [00:02<09:32,  2.80it/s][A
Training Batch:   1%|          | 9/1611 [00:03<09:30,  2.81it/s][A
Training Batch:   1%|          | 10/1611 [00:03<09:31,  2.80it/s][A
Training Batch:   1%|          | 11/1611 [00:03<09:30,  2.81it/s][A
Training Batch:   1%|          | 12/1611 [00:04<09:29,  2.81it/s][A
Training Batch:   1%|          | 13/1611 [00:04<09:28,  2.81it/s][A
Training Batch:   1%|          | 14/1611 [00:04<09:

Epoch 2/2 - Train Loss: 0.0522 | Val Loss: 6.8051 | Perplexity: 902.48 | LR: 1.00e-03
Saved best GRU weights to best_GRU.pth
GRU complete. Best perplexity: 902.48

=== FINAL RESULTS ===
RNN:
  Best Perplexity: 589.71
LSTM:
  Best Perplexity: 916.79
GRU:
  Best Perplexity: 902.48





#### Observation & Insight

Though teacher forcing led to lower validation loss and perplexity in the earlier unidirectional GRU models, it had to be removed once bidirectional RNNs were introduced in the subsequent models. This was necessary due to the fundamental characteristics of bidirectional RNNs, which conflicted with the autoregressive nature required for text generation.

A bidirectional RNN processes the input sequence in both forward and backward directions, meaning that, at each time step, the hidden state of the model is influenced by both the past and the future of the input sequence. In other words, while processing a token, a bidirectional RNN has access to information from both directions, allowing it to understand the full context of the sequence. This is highly beneficial for tasks that require contextual understanding of the entire sequence at once, such as:
- Text classification, where the full sequence context is useful for determining the label.
- Named Entity Recognition (NER) or Part-of-Speech (POS) tagging, where understanding both the preceding and succeeding words improves accuracy.

However, this bidirectional context can be problematic in autoregressive text generation.

**Rationale for removing teacher forcing in bidirectional RNN:**

In an autoregressive model, the goal is to generate one token at a time, conditioning each prediction solely on the previous tokens, to mimic real-time generation. The ShakespeareRNNBase class was designed to generate text step by step, using only the preceding tokens as context, in line with how text generation occurs in production settings.
When a bidirectional RNN is introduced, the model receives both future and past context at each time step. This allows it to "cheat" by having access to information it wouldn’t have in a real-world scenario (where only previous tokens are available). For example, when generating the token "the", the model would know the words after it (like "quick") and may make predictions that are not consistent with how the model would behave during actual generation. This breaks the causal flow of information, which is why it’s critical to remove bidirectional processing when generating text autoregressively.

Thus, teacher forcing had to be deactivated, as it would further distort the training process by providing ground-truth tokens during training, which could exacerbate the issue of data leakage caused by bidirectionality.

#### Analysis of performance of the 3 bidirectional architectures
All three models — GRU, LSTM, and vanilla RNN — were trained under the same conditions. While their training losses are nearly identical, the validation losses and perplexities reveal differences in how well each generalizes.

Surprisingly, the vanilla RNN achieved the lowest validation loss (6.38) and perplexity (589.71), suggesting it performed best in modeling the structure of the Shakespeare text on unseen data, despite being a simpler architecture. This could be attributed to its reduced complexity, which may have helped avoid overfitting early in training.

In contrast, GRU and LSTM, though more advanced with gating mechanisms designed to capture long-range dependencies, exhibited higher validation losses (6.80+), with LSTM slightly underperforming compared to GRU. This suggests that in this context — with limited training (possibly few epochs) — the added complexity of GRU and LSTM might not have had enough time or data to fully leverage their strengths.

It may also imply that RNNs can still perform competitively in character-level generation tasks when training depth or data is constrained. However, with longer training or more data, GRUs and LSTMs may show improved performance due to their better handling of long-term dependencies.

**Comparison with Unidirectional GRU**

Although the vanilla bidirectional RNN achieved relatively strong performance without teacher forcing, its validation loss (6.3796) and perplexity (589.71) were ultimately outperformed by the unidirectional GRU with teacher forcing, which achieved a validation loss of 5.924 and perplexity of 374 in the very first epoch. This result suggests that teacher forcing plays a substantial role in stabilizing and guiding the learning process, especially in early training stages.

**Discussion on perplexity**
Perplexity, as a metric, is a direct measure of how well a language model predicts the next token in a sequence. Lower perplexity indicates the model is more confident and accurate in its predictions, which translates to better performance in sequence generation tasks like text completion or language modeling.

In this context, the unidirectional GRU with teacher forcing achieved the lowest perplexity (374), indicating it had the most precise token-level predictions early on. This aligns with the role of teacher forcing in stabilizing training by supplying the true previous token, thereby allowing the model to focus on learning the output distribution without being penalized for compounding past errors.

By contrast, the bidirectional GRU, although designed to capture both past and future context, showed higher perplexity (≈902). This can be partly attributed to the inability to use teacher forcing in an autoregressive setting — a mismatch that makes it harder for the model to predict the next token based solely on previous ones. It might also reflect some over-reliance on future context, which is unavailable at inference time, reducing generalization. Even more telling is the highest perplexity (~665) seen in the unidirectional GRU without teacher forcing, confirming how the absence of guidance during training can cause error accumulation over the sequence.

So, overall, perplexity reveals not just predictive performance, but also the training dynamics and model compatibility with the task. It's especially useful for evaluating how different modeling decisions (directionality, teacher forcing, etc.) affect generative quality.

### Text generation with temperature experimentation

Rationale:

In the context of text generation, temperature sampling is a crucial mechanism for controlling the randomness and creativity of the model’s outputs. By adjusting the temperature parameter, we scale the predicted logits before applying the softmax function, which directly influences the probability distribution over the vocabulary at each step.

Experimenting with different temperature values is meaningful because it allows us to balance coherence and diversity in generated text. A low temperature (e.g., 0.5 or lower) sharpens the probability distribution, making the model more confident and deterministic — leading to safer, more predictable, and often repetitive outputs. This is beneficial when fluency and grammatical correctness are prioritized.

On the other hand, a higher temperature (e.g., 1.0 or above) flattens the distribution, encouraging the model to sample less likely tokens. This introduces novelty, variety, and creativity, which can be desirable for generating expressive or stylistically rich text like Shakespearean dialogue — but it may also increase the risk of incoherence or grammatical errors.

In [16]:
def generate_shakespeare(
    model,
    vocab,
    start_text: str,
    max_length: int = 100,
    temperature: float = 0.8,
    device: str = "cpu",
    top_k: int = 50,
    repetition_penalty: float = 1.2,
    min_length: int = 20
):
    """Improved Shakespeare generator with better controls"""
    model.eval()
    model.to(device)

    # Encode initial text
    input_ids = vocab.tokenizer.encode(start_text).ids
    input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)

    # Initialize hidden state
    hidden = model.init_hidden(1)
    if isinstance(hidden, tuple):
        hidden = tuple(h.to(device) for h in hidden)
    else:
        hidden = hidden.to(device)

    generated = input_ids.copy()

    with torch.no_grad():
        while len(generated) < max_length:
            outputs, hidden = model.forward_autoregressive(input_tensor, hidden)
            logits = outputs[0, -1, :]

            # Apply repetition penalty
            for token in set(generated[-10:]):  # Look back 10 tokens
                logits[token] /= repetition_penalty

            # Apply temperature and top-k
            scaled_logits = logits / temperature
            if top_k > 0:
                values, indices = torch.topk(scaled_logits, top_k)
                scaled_logits = torch.full_like(scaled_logits, -float('Inf'))
                scaled_logits.scatter_(-1, indices, values)

            probs = torch.nn.functional.softmax(scaled_logits, dim=-1)

            # Sample with retries
            for _ in range(5):
                next_token = torch.multinomial(probs, 1).item()
                token_str = vocab.tokenizer.id_to_token(next_token)
                if token_str not in ["[UNK]", "[PAD]"] and not token_str.startswith("##"):
                    break

            generated.append(next_token)
            input_tensor = torch.tensor([[next_token]], device=device)

            # Early stopping if we have reasonable output
            current_text = vocab.decode(generated)
            if len(generated) > min_length and any(punct in current_text for punct in [".", "!", "?"]):
                break

    # Advanced post-processing
    text = vocab.decode(generated)

    # Fix spacing around punctuation
    text = re.sub(r'\s([?.!,](?:\s|$))', r'\1', text)
    text = re.sub(r" '([ts])", r"'\1", text)  # Fix 't and 's
    text = re.sub(r'\s{2,}', ' ', text).strip()

    # Capitalize first letter and after punctuation
    sentences = re.split(r'([.!?] )', text)
    text = ''.join([sentences[i].capitalize() if i == 0 or i % 2 == 1 else
                   sentences[i] for i in range(len(sentences))])

    return text

# Load LSTM model to generate text
# Define model architecture (must match the saved model)
model_lstm = ShakespeareLSTM(
    wp_vocab_size=15000,  # Must match original vocab size
    bidirectional=True
)
# Load saved weights
model_lstm.load_state_dict(torch.load("/content/best_LSTM.pth"))  # Path to your saved model
model_lstm.eval() #Set to evaluation mode
device = "cuda" if torch.cuda.is_available() else "cpu"
model_lstm.to(device)  # Move to device

# Repeat for RNN
model_rnn = ShakespeareRNN(
    wp_vocab_size=15000,  # Must match original vocab size
    bidirectional=True
)

# Load saved weights
model_rnn.load_state_dict(torch.load("/content/best_RNN.pth"))  # Path to your saved model
model_rnn.eval() #Set to evaluation mode
device = "cuda" if torch.cuda.is_available() else "cpu"
model_rnn.to(device)  # Move to device

# Repeat for GRU bidirectional
model_grubi = ShakespeareGRU(
    wp_vocab_size=15000,  # Must match original vocab size
    bidirectional=True
)

# Load saved weights
model_grubi.load_state_dict(torch.load("/content/best_GRU.pth"))  # Path to your saved model
model_grubi.eval() #Set to evaluation mode
device = "cuda" if torch.cuda.is_available() else "cpu"
model_grubi.to(device)  # Move to device

# Load GRU with teacher forcing unidirectional
model_gruni = ShakespeareGRU(
    wp_vocab_size=15000,  # Must match original vocab size
    bidirectional=False
)

# Load saved weights
model_gruni.load_state_dict(torch.load("/content/best_gru_tf.pth"))  # Path to your saved model
model_gruni.eval() #Set to evaluation mode
device = "cuda" if torch.cuda.is_available() else "cpu"
model_gruni.to(device)  # Move to device

# For coherent, conservative output
print('--------------------------------------------- Temperature = 0.8 -------------------------------------------------------')
print('LSTM bidirectional:\n', generate_shakespeare(model_lstm, train_dataset, "To be or not to be",
                         temperature=0.8, top_k=50, device=device), '\n')
print('RNN bidirectional:\n', generate_shakespeare(model_rnn, train_dataset, "To be or not to be",
                         temperature=0.8, top_k=50, device=device), '\n')
print('GRU bidirectional:\n', generate_shakespeare(model_grubi, train_dataset, "To be or not to be",
                         temperature=0.8, top_k=50, device=device), '\n')
print('GRU unidirectional:\n', generate_shakespeare(model_gruni, train_dataset, "To be or not to be",
                         temperature=0.8, top_k=50, device=device), '\n')

# For more creative output
print('--------------------------------------------- Temperature = 1.5 -------------------------------------------------------')
print('LSTM bidirectional:\n', generate_shakespeare(model_lstm, train_dataset, "To be or not to be",
                         temperature=1.5, top_k=100,device=device), '\n')
print('RNN bidirectional:\n', generate_shakespeare(model_rnn, train_dataset, "To be or not to be",
                         temperature=1.5, top_k=50, device=device), '\n')
print('GRU bidirectional:\n', generate_shakespeare(model_grubi, train_dataset, "To be or not to be",
                         temperature=1.5, top_k=50, device=device), '\n')
print('GRU unidirectional:\n', generate_shakespeare(model_gruni, train_dataset, "To be or not to be",
                         temperature=1.5, top_k=50, device=device), '\n')

--------------------------------------------- Temperature = 0.8 -------------------------------------------------------
LSTM bidirectional:
 To be or not to be of
his's: i no and your: of the's so all
that
of our but thou your the to a with my
you i it what my
and my to this him your:: but and the in
his
not his be::
to: it i
as my your
of me'd and you my
i the:: this
we what be 

RNN bidirectional:
 To be or not to be
's king and the'd will in a man to me and
it to hear your blood
i may not a to be the and take's the that you his 'll i would to this the thou art with'd
we'll this thy my
at a father of the to the
'd, our is my heart and to this's
a blood
and we was my of the
a more that that's this noble all 

GRU bidirectional:
 To be or not to be and you my:: not to all and your
the'd i a i so i is
our he
of all and
i he'd this the and: it
not to so shall i'd he
the's thou with what you: for in a to and thy for from my the
he: in to and of a be and my in thy your
will: to is a him 



#### Observation:
These outputs clearly highlight the impact of temperature on the creativity and coherence of generated text. At a lower temperature (0.8), the models tend to generate more structured, grammatically plausible, and thematically Shakespearean phrases — though still somewhat repetitive and safe. For example, repeated use of common words like “the,” “to,” and “my” reflects the model’s preference for high-probability tokens, resulting in more fluent but less diverse output.

At a higher temperature (1.5), the text becomes more unpredictable and expressive, with increased variety in word choice and phrasing — yet at the cost of syntactic and semantic coherence. The GRU and LSTM models in particular start producing disjointed or nonsensical sequences, such as “are’ll to no the i is” or “would not so been be this all your.”

This contrast suggests that temperature tuning is a valuable tool for exploring the trade-off between linguistic accuracy and creative exploration in generative models, which is especially important when the goal is to mimic artistic styles like Shakespearean drama, where both fluency and imagination matter.

#### Beam search
Another method employed in text generation is beam search, which balances exploration and optimality. Unlike greedy decoding, which selects the highest-probability word at each step, beam search keeps multiple candidate sequences (beams) and expands them in parallel, allowing the model to consider alternative, potentially better sequences. This often leads to more coherent and contextually appropriate output, as it avoids early commitment to suboptimal word choices. It is especially valuable in structured or stylistic generation tasks like Shakespearean text, where maintaining thematic consistency and fluency is key.

In [None]:
def robust_generation(
    model,
    vocab,
    prompt: str,
    max_length: int = 50,
    strategy: str = "beam",  # "beam" or "greedy"
    beam_width: int = 3,
    temperature: float = 0.8,
    device: str = "cpu"
):
    """Bulletproof text generation with sanity checks"""
    model.eval()
    model.to(device)

    # 1. Verify tokenizer
    input_ids = vocab.tokenizer.encode(prompt).ids
    if not input_ids:
        raise ValueError("Prompt failed to tokenize")

    # 2. Initialize generation
    input_tensor = torch.tensor([input_ids], device=device)
    hidden = model.init_hidden(1)
    if isinstance(hidden, tuple):
        hidden = tuple(h.to(device) for h in hidden)
    else:
        hidden = hidden.to(device)

    generated = input_ids.copy()
    last_tokens = []

    # 3. Generation loop with protections
    for _ in range(max_length):
        with torch.no_grad():
            output, hidden = model(input_tensor, hidden)
            logits = output[0, -1, :] / max(temperature, 0.1)

            # Penalize repeating tokens
            for token in set(last_tokens[-4:]):
                logits[token] -= 2.0

            if strategy == "greedy":
                next_token = torch.argmax(logits).item()
            else:  # beam
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1).item()

            # Prevent infinite loops
            last_tokens.append(next_token)
            if len(set(last_tokens[-4:])) == 1:  # Repeating pattern
                logits[next_token] = -float('Inf')
                next_token = torch.argmax(logits).item()

            generated.append(next_token)
            input_tensor = torch.tensor([[next_token]], device=device)

    # 4. Post-processing with strict checks
    text = vocab.tokenizer.decode(generated)

    # If output is degenerate, fall back to prompt
    if (len(set(text.split())) < 3 or text.count("I") > len(text)/2):
        return prompt

    # Basic cleaning
    text = text.replace(" ##", "").replace("  ", " ")
    return text if text.strip() else prompt

print('---------------------------------------- Beam search using beam strategy: --------------------------------------------')
print('LSTM Bidirectional:\n', robust_generation(model_lstm, train_dataset, prompt='To be or not to be', strategy='beam', device=device))
print('\nRNN Bidirectional:\n', robust_generation(model_rnn, train_dataset, prompt='To be or not to be', strategy='beam', device=device))
print('\nGRU Bidirectional:\n', robust_generation(model_grubi, train_dataset, prompt='To be or not to be', strategy='beam', device=device))
print('\nGRU Unidirectional:\n', robust_generation(model_gruni, train_dataset, prompt='To be or not to be', strategy='beam', device=device))

print('---------------------------------------- Beam search using greedy strategy: --------------------------------------------')
print('\nLSTM Bidirectional:\n', robust_generation(model_lstm, train_dataset, prompt='To be or not to be', strategy='greedy', device=device))
print('\nRNN Bidirectional:\n', robust_generation(model_rnn, train_dataset, prompt='To be or not to be', strategy='greedy', device=device))
print('\nGRU Bidirectional:\n', robust_generation(model_grubi, train_dataset, prompt='To be or not to be', strategy='greedy', device=device))
print('\nGRU Uniidirectional:\n', robust_generation(model_gruni, train_dataset, prompt='To be or not to be', strategy='greedy', device=device))

---------------------------------------- Beam search using beam strategy: --------------------------------------------
LSTM Bidirectional:
 To be or not to be the let say, alack that the have and Hark! I childish to some willn, what measure and myed thy can theK beUS made with in Dorset to for Henry I When been. to Hereford that I

RNN Bidirectional:
 To be or not to be I a gain you, ease and come you have, in distra from God mighty father friends! I slew and do? as done to thy desc deep think you dream again. wiTER Yea, I the you to be

GRU Bidirectional:
 To be or not to be and I the him byly to not way,ation And. POLIX of what lend was and it should I a- am use the have in soldier a for their is her way to the

GRU Unidirectional:
 To be or not to be you, I well no little: that I would have at repent QUEEN I know, my lord,-- JOHN OF GAUYJ How he I know, and he are gone. DUKE OF YORERLT
---------------------------------------- Beam search using greedy strategy: ----------------------

#### Obervation
The comparison between beam search with different strategies and temperature sampling reveals important insights into the trade-offs between diversity, creativity, and coherence in text generation. Beam search with a beam strategy tends to produce outputs that are more syntactically coherent and semantically richer than greedy decoding, where the model is likely to generate repetitive or nonsensical sequences (e.g., "I not the to the" or "the the the"). This is because beam search explores multiple hypotheses at each step and chooses the most promising ones, allowing for a more structured and fluent sentence construction.

In contrast, greedy decoding, while efficient, often collapses into repetitive sequences as it always chooses the most likely next word without considering other possibilities, leading to lack of diversity. This issue is particularly evident in models like GRU, where outputs like "I the I the I" lack meaningful variation.

**Comparison with temperature sampling**

When comparing this to temperature sampling alone, it is clear that higher temperatures introduce more randomness, encouraging the model to explore less likely but more creative word choices. At a temperature of 1.5, for example, the outputs show greater diversity but may sometimes lack fluency and coherence, as the model is encouraged to sample from a wider range of possibilities.

Beam search strikes a balance between diversity and fluency, producing more structured outputs than temperature sampling at high values, but at the cost of creativity. Temperature sampling provides a way to fine-tune the exploration-exploitation trade-off, where low temperatures prioritize coherence (similar to beam search) and high temperatures foster diversity but with some potential loss in grammatical correctness.

Overall, beam search is beneficial when coherent and grammatically correct outputs are required, whereas temperature sampling and beam search with greedy strategy offer trade-offs between creativity and structure.

#### Discussion of generated text vs original text

**Set 1: Beam Search Strategy**

LSTM Bidirectional:

The output here appears somewhat erratic (e.g., "I childish to some willn, what measure and myed thy can theK beUS"). Beam search tends to generate a broader set of possible sequences, and this sometimes leads to nonsensical or incomplete phrases. While there are some Shakespearean-like phrases ("to Hereford," "for Henry"), the meaning is fragmented, and it does not maintain a coherent narrative. There’s a clear attempt to replicate the formal language, but the result lacks fluency and structure.

RNN Bidirectional:

This output seems disjointed (e.g., “I a gain you, ease and come you have”) and shows evidence of over-generation, as the beam search strategy explores different hypotheses. It feels like the model is trying to generate Shakespearean-esque dialogue but fails to maintain logical progression. While the output starts with a Shakespearean tone (e.g., "To be or not to be"), the sentences quickly become ungrammatical and hard to follow.

GRU Bidirectional:

Similar to the others, this output includes random and sometimes incomplete phrases (e.g., "and I the him byly to not way,ation"). There is an attempt to mimic the lexicon of Shakespeare’s time, but the coherence and fluency are lacking. The choice of words (like “way,ation” and “POLIX”) gives the text a sense of disjointedness, with some phrases trying to mimic the rhythm, but it is ultimately unclear and hard to interpret.

GRU Unidirectional:

This output is somewhat more structured compared to the others, but still suffers from being over-generated and unfocused (e.g., "To be or not to be you, I well no little:"). It uses the character names ("QUEEN," "DUKE") and some words that evoke Shakespearean drama, but the overall structure is awkward. The text tends to repeat itself and lacks the dramatic tension or fluidity that Shakespeare’s original works deliver.

**Set 2: Temperature = 0.8**

LSTM Bidirectional:

The output here is more coherent than the beam search version. While it is still a bit inconsistent, it feels closer to Shakespearean dialogue (e.g., “his’s: i no and your: of the’s so all”). The temperature of 0.8 results in some variance, making the output more diverse, but still within a range that can mimic Shakespeare's syntactic structures. The vocabulary feels a bit more natural, even though the grammar is still quite disjointed.

RNN Bidirectional:

Here, the text starts resembling a Shakespearean tone much more than in the beam search output (e.g., "To be or not to be / 's king and the'd will in a man to me"). There is still a repetitive nature, but the output at least has some semblance of rhythm and attempts to follow the word structure that Shakespeare often used. The temperature setting encourages diversity while avoiding completely random combinations of words, allowing for more coherence compared to the beam search.

GRU Bidirectional:

The temperature of 0.8 here helps the model find some degree of consistency in its word choices, and it still generates phrases that feel reminiscent of Shakespearean dialogue (e.g., "To be or not to be and I the him"). Though the meaning might still be unclear, the syntax and tone are closer to Shakespeare’s style. It is less chaotic than the beam search results, as the temperature introduces some variety without making it completely random.

GRU Unidirectional:

The GRU unidirectional output seems like it is trying to simulate a dramatic monologue, as it incorporates proper nouns like "KING RICHARD II" and other Shakespearean elements. The use of temperature allows the model to explore different word possibilities, creating a more natural progression than beam search. However, it still suffers from repetition and a lack of deep coherence, which makes it feel artificial but more fluid than the beam search output.


#### Challenges & Limitations

During the course of model development and experimentation, several hardware-related challenges arose, which had notable implications for the overall workflow and results.

One of the primary issues stemmed from the use of Google Colab for model training. While Colab offers a convenient environment for running experiments, the limited hardware resources frequently led to system crashes. This disruption forced me to repeatedly re-execute the code, which not only consumed additional time but also introduced an element of unpredictability into the training process. These interruptions made it challenging to maintain a smooth, continuous workflow, and at times, required me to start from scratch after a crash.

Another significant challenge was the duration of each training epoch. Depending on the complexity of the model and the dataset, each epoch took up to an hour to complete. As a result, the total training time for each model became quite extensive, making it impractical to run a high number of epochs. In response to this, I made the decision to limit the training to just two epochs rather than a higher value, which is typically recommended for thorough model evaluation. This reduction in training time inherently limited the potential for the model to fully converge and may have affected the overall performance. However, given the constraints of time and the necessity to analyze and compare the effectiveness of multiple models, this was a necessary compromise.

While these limitations undoubtedly impacted the thoroughness of the model training, they were considered an unavoidable trade-off to allow for the timely analysis and evaluation of the models' comparative effectiveness. Despite these constraints, the experimentation still provided valuable insights into model performance, and the results obtained were sufficient for the scope of this project. 

One important area for future exploration is the use of character-level models instead of word-level models. While word-level models have been effective in generating coherent text, character-level models offer the potential for even more nuanced control over text generation. By predicting one character at a time, character-level models are capable of handling out-of-vocabulary words and generating text with more flexibility and creativity. This shift would not only address some of the limitations observed in the current approach but also open up new possibilities for text generation, especially for languages with complex morphology or unconventional word structures.

The lessons learned from the current study lay the foundation for these future directions. Refining the model architecture and incorporating character-level features could improve the overall text generation process, leading to more natural, contextually rich, and stylistically accurate output. Additionally, further experimentation with optimization techniques and more advanced sampling strategies will likely yield even more promising results, pushing the boundaries of what is achievable in text generation tasks.