In [None]:
import numpy as np

import json
import math
import os
import re
from collections import Counter
from typing import Any, List, Optional, Tuple, Dict
from datetime import datetime
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 10

In [None]:
def normalize_and_split(raw_text: str) -> List[str]:
    cleaned = raw_text.lower()
    cleaned = re.sub(r"([.,!?;:()\"'])", r" \1 ", cleaned)
    cleaned = re.sub(r"\s+", " ", cleaned).strip()
    return cleaned.split()


def read_poetry_file() -> str:
    with open("poems.txt", "r", encoding="utf-8") as file_handle:
        return file_handle.read()


def create_vocabulary(token_stream: List[str],
                      min_occurrence: int = 1
                      ) -> Tuple[Dict[str, int], Dict[int, str]]:

    frequency_map = Counter(token_stream)

    base_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
    vocab_list = list(base_tokens)

    for token, freq in frequency_map.items():
        if freq >= min_occurrence and token not in vocab_list:
            vocab_list.append(token)

    token_to_index = {
        token: idx for idx, token in enumerate(vocab_list)
    }

    index_to_token = {
        idx: token for token, idx in token_to_index.items()
    }

    return token_to_index, index_to_token


def encode_tokens(token_stream: List[str],
                  token_to_index: Dict[str, int]
                  ) -> List[int]:

    unknown_id = token_to_index["<unk>"]

    encoded = [
        token_to_index.get(tok, unknown_id)
        for tok in token_stream
    ]

    return encoded


def build_sequence_pairs(id_stream: List[int],
                         window: int
                         ) -> Tuple[List[List[int]], List[List[int]]]:

    input_sequences = []
    target_sequences = []

    upper_bound = len(id_stream) - window

    for start in range(upper_bound):
        src = id_stream[start:start + window]
        tgt = id_stream[start + 1:start + window + 1]

        input_sequences.append(src)
        target_sequences.append(tgt)

    return input_sequences, target_sequences


def analyze_text_statistics(sample_text: str) -> Dict[str, Any]:

    tokens = normalize_and_split(sample_text)

    if not tokens:
        return {
            "token_count": 0,
            "unique_tokens": 0,
            "unique_ratio": 0.0,
            "repeat_2gram_ratio": 0.0,
            "repeat_3gram_ratio": 0.0,
            "top_tokens": [],
        }

    token_freq = Counter(tokens)

    total_tokens = len(tokens)
    unique_count = len(token_freq)
    diversity_ratio = unique_count / max(total_tokens, 1)

    def repeated_ngram_fraction(n: int) -> float:
        if len(tokens) < n:
            return 0.0

        ngrams = list(zip(*[tokens[i:] for i in range(n)]))
        gram_freq = Counter(ngrams)

        repeats = sum(v - 1 for v in gram_freq.values() if v > 1)
        return repeats / max(len(ngrams), 1)

    return {
        "token_count": int(total_tokens),
        "unique_tokens": int(unique_count),
        "unique_ratio": float(diversity_ratio),
        "repeat_2gram_ratio": float(repeated_ngram_fraction(2)),
        "repeat_3gram_ratio": float(repeated_ngram_fraction(3)),
        "top_tokens": [
            (word, int(count))
            for word, count in token_freq.most_common(10)
        ],
    }


In [None]:
def stable_softmax(vec):
    shifted = vec - np.max(vec)
    exp_vals = np.exp(shifted)
    return exp_vals / np.sum(exp_vals)


def vector_one_hot(index, vocab_dim):
    basis = np.zeros((vocab_dim, 1))
    basis[index] = 1.0
    return basis


class NumpyRNNCore:

    def __init__(self, vocab_dim, hidden_dim=64,
                 learning_rate=1e-2, rng_seed=42):

        rng = np.random.default_rng(rng_seed)

        self.vocab_dim = vocab_dim
        self.hidden_dim = hidden_dim
        self.lr = learning_rate

        self.W_input = rng.normal(0, 0.01,
                                  (hidden_dim, vocab_dim))
        self.W_hidden = rng.normal(0, 0.01,
                                   (hidden_dim, hidden_dim))
        self.W_output = rng.normal(0, 0.01,
                                   (vocab_dim, hidden_dim))

        self.b_hidden = np.zeros((hidden_dim, 1))
        self.b_output = np.zeros((vocab_dim, 1))

    # ---------- forward pass ----------

    def forward_pass(self, token_ids, h_prev):

        x_cache, h_cache, prob_cache = {}, {}, {}
        h_cache[-1] = h_prev

        for step, tok_id in enumerate(token_ids):

            x_cache[step] = vector_one_hot(
                tok_id, self.vocab_dim
            )

            h_cache[step] = np.tanh(
                self.W_input @ x_cache[step]
                + self.W_hidden @ h_cache[step - 1]
                + self.b_hidden
            )

            logits = self.W_output @ h_cache[step] + self.b_output
            probs = stable_softmax(logits.ravel()).reshape(-1, 1)

            prob_cache[step] = probs

        return x_cache, h_cache, prob_cache

    # ---------- loss + gradients ----------

    def backward_pass(self, token_ids, targets, h_prev):

        xs, hs, ps = self.forward_pass(token_ids, h_prev)

        total_loss = 0.0
        for t in range(len(token_ids)):
            total_loss += -np.log(
                ps[t][targets[t], 0] + 1e-12
            )

        g_Wi = np.zeros_like(self.W_input)
        g_Wh = np.zeros_like(self.W_hidden)
        g_Wo = np.zeros_like(self.W_output)
        g_bh = np.zeros_like(self.b_hidden)
        g_bo = np.zeros_like(self.b_output)

        dh_next = np.zeros((self.hidden_dim, 1))

        for t in reversed(range(len(token_ids))):

            grad_out = ps[t].copy()
            grad_out[targets[t]] -= 1.0

            g_Wo += grad_out @ hs[t].T
            g_bo += grad_out

            dh = self.W_output.T @ grad_out + dh_next
            dh_raw = (1 - hs[t] * hs[t]) * dh

            g_bh += dh_raw
            g_Wi += dh_raw @ xs[t].T
            g_Wh += dh_raw @ hs[t - 1].T

            dh_next = self.W_hidden.T @ dh_raw

        for grad in [g_Wi, g_Wh, g_Wo, g_bh, g_bo]:
            np.clip(grad, -5, 5, out=grad)

        last_hidden = hs[len(token_ids) - 1]

        return total_loss, (g_Wi, g_Wh, g_Wo, g_bh, g_bo), last_hidden

    # ---------- parameter update ----------

    def apply_gradients(self, grads):

        g_Wi, g_Wh, g_Wo, g_bh, g_bo = grads

        self.W_input -= self.lr * g_Wi
        self.W_hidden -= self.lr * g_Wh
        self.W_output -= self.lr * g_Wo
        self.b_hidden -= self.lr * g_bh
        self.b_output -= self.lr * g_bo

    # ---------- sampling ----------

    def generate_text(self, start_token, id_to_token,
                      steps=30, temp=1.0):

        hidden = np.zeros((self.hidden_dim, 1))
        current = vector_one_hot(start_token,
                                 self.vocab_dim)

        output_tokens = []

        for _ in range(steps):

            hidden = np.tanh(
                self.W_input @ current
                + self.W_hidden @ hidden
                + self.b_hidden
            )

            logits = self.W_output @ hidden + self.b_output
            probs = stable_softmax(
                logits.ravel() / max(temp, 1e-6)
            )

            sampled = np.random.choice(
                range(self.vocab_dim),
                p=probs
            )

            output_tokens.append(id_to_token[sampled])
            current = vector_one_hot(sampled,
                                     self.vocab_dim)

        return " ".join(output_tokens)


# ----------------------------------------------------------

def scratch_rnn_notebook_run():

    raw_text = read_poetry_file()
    token_stream = (
        ["<bos>"]
        + normalize_and_split(raw_text)
        + ["<eos>"]
    )

    token_to_id, id_to_token = create_vocabulary(
        token_stream, min_occurrence=1
    )

    encoded_stream = encode_tokens(
        token_stream, token_to_id
    )

    model = NumpyRNNCore(
        vocab_dim=len(token_to_id),
        hidden_dim=128,
        learning_rate=0.05
    )

    window = 25
    hidden_state = np.zeros((model.hidden_dim, 1))

    timing_log = []
    generated_samples = []

    for epoch_idx in range(n_epochs):

        start_time = time.perf_counter()
        cumulative_loss = 0.0
        batch_count = 0

        for pos in range(
            0,
            len(encoded_stream) - window - 1,
            window
        ):

            inp = encoded_stream[pos:pos + window]
            tgt = encoded_stream[pos + 1:pos + window + 1]

            loss, grads, hidden_state = model.backward_pass(
                inp, tgt, hidden_state
            )

            model.apply_gradients(grads)

            cumulative_loss += loss
            batch_count += 1

        avg_loss = cumulative_loss / max(batch_count, 1)

        sample_text = model.generate_text(
            token_to_id["<bos>"],
            id_to_token,
            steps=30,
            temp=0.9
        )

        generated_samples.append(sample_text)

        end_time = time.perf_counter()
        timing_log.append(float(end_time - start_time))

        print(f"Epoch {epoch_idx+1} | avg loss: {avg_loss:.4f}")
        print("Sample:", sample_text)
        print(f"Epoch time: {timing_log[-1]:.3f}s")

    return timing_log, generated_samples[-1]


scratch_epoch_times, scratch_last_output = scratch_rnn_notebook_run()

Epoch 1 | avg loss: 301.3461
Sample: masts : in our my in our my in forgetting this in emma this in emma this in sea this may emma ' in when ' seen our this in
Epoch time: 85.784s
Epoch 2 | avg loss: 323.4304
Sample: pains tis our emma and wild our emma are wild our emma are wild our emma are wild our emma are wild our emma are wild our emma are wild
Epoch time: 87.461s
Epoch 3 | avg loss: 367.3105
Sample: days death ' wild our green and wild ; green and wild our green this wild our green this wild our green this wild our green and wild our emma
Epoch time: 94.385s
Epoch 4 | avg loss: 365.7155
Sample: tea about our , and name of , and call of , seen myself of , and name of , and told are , and call to , seen name
Epoch time: 88.021s
Epoch 5 | avg loss: 366.0527
Sample: high side , we they course had thee it have , we it have , we sky have , we it have , thee they , had name it have
Epoch time: 96.653s
Epoch 6 | avg loss: 364.6420
Sample: lance name swift , seen wild they , seen wild th

In [None]:
class OneHotSequenceDataset(Dataset):

    def __init__(self, input_sequences, target_sequences, vocab_dim):
        self.inputs = torch.tensor(input_sequences, dtype=torch.long)
        self.targets = torch.tensor(target_sequences, dtype=torch.long)
        self.vocab_dim = vocab_dim

    def __len__(self):
        return self.inputs.size(0)

    def __getitem__(self, index):

        token_ids = self.inputs[index]
        label_ids = self.targets[index]

        one_hot_tensor = torch.zeros(
            token_ids.size(0),
            self.vocab_dim,
            dtype=torch.float32
        )

        one_hot_tensor.scatter_(
            1,
            token_ids.unsqueeze(1),
            1.0
        )

        return one_hot_tensor, label_ids


# ----------------------------------------------------------

class OneHotLanguageRNN(nn.Module):

    def __init__(self, vocab_dim, hidden_dim=256):
        super().__init__()

        self.recurrent = nn.RNN(
            input_size=vocab_dim,
            hidden_size=hidden_dim,
            batch_first=True
        )

        self.projection = nn.Linear(hidden_dim, vocab_dim)

    def forward(self, x_encoded, h_state=None):

        rnn_out, next_state = self.recurrent(x_encoded, h_state)
        logits = self.projection(rnn_out)

        return logits, next_state


# ----------------------------------------------------------

@torch.no_grad()
def sample_sequence(model,
                    token_to_id,
                    id_to_token,
                    seed="<bos>",
                    steps=40,
                    temp=1.0):

    model.eval()

    seed_tokens = seed.split()
    id_stream = [
        token_to_id.get(tok, token_to_id["<unk>"])
        for tok in seed_tokens
    ]

    vocab_dim = len(token_to_id)
    hidden = None

    for _ in range(steps):

        last_token = torch.tensor(
            id_stream[-1:],
            dtype=torch.long,
            device=DEVICE
        )

        encoded = torch.zeros(
            1, 1, vocab_dim,
            device=DEVICE
        )

        encoded.scatter_(2,
                         last_token.view(1, 1, 1),
                         1.0)

        logits, hidden = model(encoded, hidden)

        scaled_logits = logits[0, -1] / max(temp, 1e-6)
        probabilities = torch.softmax(scaled_logits, dim=0)

        sampled_id = torch.multinomial(probabilities, 1).item()
        id_stream.append(sampled_id)

    words = [id_to_token[i] for i in id_stream]
    return " ".join(words)


# ----------------------------------------------------------

def run_onehot_rnn_pipeline():

    corpus_text = read_poetry_file()

    token_stream = (
        ["<bos>"]
        + normalize_and_split(corpus_text)
        + ["<eos>"]
    )

    token_to_id, id_to_token = create_vocabulary(
        token_stream,
        min_occurrence=1
    )

    encoded_stream = encode_tokens(
        token_stream,
        token_to_id
    )

    window = 25
    seq_inputs, seq_targets = build_sequence_pairs(
        encoded_stream,
        window
    )

    dataset = OneHotSequenceDataset(
        seq_inputs,
        seq_targets,
        vocab_dim=len(token_to_id)
    )

    loader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=True,
        drop_last=True
    )

    model = OneHotLanguageRNN(
        vocab_dim=len(token_to_id),
        hidden_dim=256
    ).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1e-3
    )

    criterion = nn.CrossEntropyLoss()

    print("Training One-Hot RNN on", DEVICE)

    timing_log = []
    generated_samples = []

    for epoch_idx in range(n_epochs):

        model.train()
        start_time = time.perf_counter()

        loss_sum = 0.0
        step_count = 0

        for batch_x, batch_y in loader:

            batch_x = batch_x.to(DEVICE)
            batch_y = batch_y.to(DEVICE)

            logits, _ = model(batch_x)

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                batch_y.reshape(-1)
            )

            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(
                model.parameters(), 1.0
            )

            optimizer.step()

            loss_sum += loss.item()
            step_count += 1

        avg_loss = loss_sum / max(step_count, 1)
        print(f"Epoch {epoch_idx+1} | loss: {avg_loss:.4f}")

        sample_text = sample_sequence(
            model,
            token_to_id,
            id_to_token,
            seed="<bos>",
            steps=40,
            temp=0.9
        )

        generated_samples.append(sample_text)

        end_time = time.perf_counter()
        timing_log.append(float(end_time - start_time))

        print("Sample:", sample_text)
        print(f"Epoch time: {timing_log[-1]:.3f}s")

    print(
        f"Total training time (one-hot): "
        f"{sum(timing_log):.2f}s"
    )

    return timing_log, generated_samples[-1]


onehot_epoch_times, onehot_last_output = run_onehot_rnn_pipeline()

Training One-Hot RNN on mps
Epoch 1 | loss: 6.2165
Sample: <bos> lungs . long i rich be watches . borne with the lift of balanced , and or trouble , i who , house antique to ? about the pen and first . ? d that is bride it more dissatisfied
Epoch time: 25.517s
Epoch 2 | loss: 5.0680
Sample: <bos> noon about -- ! so deny , their of circle following or rest the mine chemist when before emanations well , to give our white i ooze for the loves , you trusted i , i ' the little ,
Epoch time: 23.657s
Epoch 3 | loss: 3.9790
Sample: <bos> are your object with her between us , who is not , now we thousand come trouble than others it is limitless to stronger upon ! retreat you die for ! i love , the sign ' d by the
Epoch time: 23.614s
Epoch 4 | loss: 2.9035
Sample: <bos> to nothing , these thirty i know --great a knoll ago , no heart farewell no less part , and what is not ? never might ride . . " " o heart , i am friend powers ? )
Epoch time: 23.741s
Epoch 5 | loss: 1.8531
Sample: <bos> of the e

In [None]:
class IndexedSequenceDataset(Dataset):

    def __init__(self, input_ids, target_ids):
        self.inputs = torch.tensor(input_ids, dtype=torch.long)
        self.targets = torch.tensor(target_ids, dtype=torch.long)

    def __len__(self):
        return self.inputs.size(0)

    def __getitem__(self, index):
        return self.inputs[index], self.targets[index]


# ----------------------------------------------------------

class EmbeddingLanguageRNN(nn.Module):

    def __init__(self, vocab_dim,
                 embedding_dim=128,
                 hidden_dim=256):

        super().__init__()

        self.embedding_layer = nn.Embedding(
            vocab_dim, embedding_dim
        )

        self.recurrent = nn.RNN(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True
        )

        self.output_layer = nn.Linear(
            hidden_dim, vocab_dim
        )

    def forward(self, token_ids, hidden_state=None):

        embedded = self.embedding_layer(token_ids)
        rnn_out, next_state = self.recurrent(
            embedded, hidden_state
        )

        logits = self.output_layer(rnn_out)

        return logits, next_state


# ----------------------------------------------------------

@torch.no_grad()
def sample_embedding_model(model,
                           token_to_id,
                           id_to_token,
                           seed="<bos>",
                           steps=40,
                           temp=1.0):

    model.eval()

    seed_tokens = seed.split()
    id_stream = [
        token_to_id.get(tok, token_to_id["<unk>"])
        for tok in seed_tokens
    ]

    hidden = None

    for _ in range(steps):

        last_token = torch.tensor(
            [[id_stream[-1]]],
            dtype=torch.long,
            device=DEVICE
        )

        logits, hidden = model(last_token, hidden)

        scaled = logits[0, -1] / max(temp, 1e-6)
        probs = torch.softmax(scaled, dim=0)

        next_id = torch.multinomial(probs, 1).item()
        id_stream.append(next_id)

    return " ".join(
        id_to_token[i] for i in id_stream
    )


# ----------------------------------------------------------

def run_embedding_rnn_pipeline():

    corpus_text = read_poetry_file()

    token_stream = (
        ["<bos>"]
        + normalize_and_split(corpus_text)
        + ["<eos>"]
    )

    token_to_id, id_to_token = create_vocabulary(
        token_stream,
        min_occurrence=1
    )

    encoded_stream = encode_tokens(
        token_stream,
        token_to_id
    )

    window = 25
    seq_inputs, seq_targets = build_sequence_pairs(
        encoded_stream,
        window
    )

    dataset = IndexedSequenceDataset(
        seq_inputs,
        seq_targets
    )

    loader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=True,
        drop_last=True
    )

    model = EmbeddingLanguageRNN(
        vocab_dim=len(token_to_id),
        embedding_dim=128,
        hidden_dim=256
    ).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=1e-3
    )

    criterion = nn.CrossEntropyLoss()

    print("Training Embedding RNN on", DEVICE)

    timing_log = []
    generated_samples = []

    for epoch_idx in range(n_epochs):

        model.train()
        start_time = time.perf_counter()

        loss_sum = 0.0
        step_count = 0

        for batch_x, batch_y in loader:

            batch_x = batch_x.to(DEVICE)
            batch_y = batch_y.to(DEVICE)

            logits, _ = model(batch_x)

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                batch_y.reshape(-1)
            )

            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(
                model.parameters(), 1.0
            )

            optimizer.step()

            loss_sum += loss.item()
            step_count += 1

        avg_loss = loss_sum / max(step_count, 1)
        print(f"Epoch {epoch_idx+1} | loss: {avg_loss:.4f}")

        sample_text = sample_embedding_model(
            model,
            token_to_id,
            id_to_token,
            seed="<bos>",
            steps=40,
            temp=0.9
        )

        generated_samples.append(sample_text)

        end_time = time.perf_counter()
        timing_log.append(float(end_time - start_time))

        print("Sample:", sample_text)
        print(f"Epoch time: {timing_log[-1]:.3f}s")

    print(
        f"Total training time (embedding): "
        f"{sum(timing_log):.2f}s"
    )

    return timing_log, generated_samples[-1]


embedding_epoch_times, embedding_last_output = run_embedding_rnn_pipeline()

Training Embedding RNN on mps
Epoch 1 | loss: 5.0160
Sample: <bos> dazzle , theology—but ' things them dim-descried ' s best , but i know , perhaps what is his brother at the wife with him on brotherly and invite , not one is more far hot made ; and am
Epoch time: 14.533s
Epoch 2 | loss: 2.4822
Sample: <bos> sat and straw , over the long-leav ' d slave is lit with the iris sheen of the north , where the voice of enjoyment picks ' d , when thou art gone , my dear , i find him
Epoch time: 14.467s
Epoch 3 | loss: 1.1327
Sample: <bos> sullen , one then stand on a heap ' d tale of nature ' s immortality , a venerable thing ? of hay the best , or through those drain ' d from your eyes , my burst rest in
Epoch time: 14.395s
Epoch 4 | loss: 0.6131
Sample: <bos> text and come , with thee , thou : white peacocks , songs at eve , and antique maps of america . farewell at know that i love thee better . i could not die – here we go to
Epoch time: 14.004s
Epoch 5 | loss: 0.4282
Sample: <bos> text yo

In [None]:
scratch_total_runtime = sum(scratch_epoch_times)
onehot_total_runtime = sum(onehot_epoch_times)
embedding_total_runtime = sum(embedding_epoch_times)

print(
    f"Total training time (Scratch RNN): "
    f"{scratch_total_runtime:.2f}s"
)

print(
    f"Total training time (One-Hot RNN): "
    f"{onehot_total_runtime:.2f}s"
)

print(
    f"Total training time (Embedding RNN): "
    f"{embedding_total_runtime:.2f}s"
)

Total training time (Scratch RNN): 848.27s
Total training time (One-Hot RNN): 245.56s
Total training time (Embedding RNN): 136.47s


In [None]:
scratch_metrics_summary = analyze_text_statistics(
    scratch_last_output
)

onehot_metrics_summary = analyze_text_statistics(
    onehot_last_output
)

embedding_metrics_summary = analyze_text_statistics(
    embedding_last_output
)

print("Scratch RNN Metrics:")
print(scratch_metrics_summary)

print("\nOne-Hot RNN Metrics:")
print(onehot_metrics_summary)

print("\nEmbedding RNN Metrics:")
print(embedding_metrics_summary)

Scratch RNN Metrics:
{'token_count': 30, 'unique_tokens': 9, 'unique_ratio': 0.3, 'repeat_2gram_ratio': 0.5862068965517241, 'repeat_3gram_ratio': 0.4642857142857143, 'top_tokens': [('and', 7), ('our', 6), ('wild', 6), ('green', 5), ('of', 2), ('agony', 1), ('dear', 1), ('least', 1), ('they', 1)]}

One-Hot RNN Metrics:
{'token_count': 41, 'unique_tokens': 34, 'unique_ratio': 0.8292682926829268, 'repeat_2gram_ratio': 0.025, 'repeat_3gram_ratio': 0.0, 'top_tokens': [('the', 2), ('!', 2), ('shall', 2), ("'", 2), ('d', 2), ('yet', 2), ('and', 2), ('<bos>', 1), ('fords', 1), ('park', 1)]}

Embedding RNN Metrics:
{'token_count': 41, 'unique_tokens': 30, 'unique_ratio': 0.7317073170731707, 'repeat_2gram_ratio': 0.0, 'repeat_3gram_ratio': 0.0, 'top_tokens': [('the', 7), (',', 4), (';', 2), ('his', 2), ('<bos>', 1), ('text', 1), ('out', 1), ('from', 1), ('crowd', 1), ('steps', 1)]}
