In [1]:
import torch
from torch import nn
import math
from google.colab import files
import torch.nn.functional as F

In [3]:
file_path = "/content/pride_and_prejudice.txt"

try:
  with open(file_path, "r", encoding='utf-8') as f:
    text = f.read()

    print(f"Successfully read file from: {file_path}")
    print(f"Dataset length: {len(text)} Characters")
except FileNotFoundError:
  print(f"Error: File not found at {file_path}. Did you upload it or mount Google Drive correctly?")

Successfully read file from: /content/pride_and_prejudice.txt
Dataset length: 748135 Characters


In [4]:
@torch.no_grad()
def generate_text(model, output_layer, stoi, itos, start_char="H",
                  length=300, temperature=1.0, device="cpu"):
    model.eval()

    idx = torch.tensor([stoi[start_char]], device=device)
    x = F.one_hot(idx, num_classes=len(stoi)).float().unsqueeze(0).unsqueeze(1)

    H_C = None
    generated = [start_char]

    for _ in range(length):
        outputs, H_C = model(x, H_C)

        logits = output_layer(outputs[-1])      # (1, V)
        logits = logits.squeeze(0)              # (V,)
        logits /= temperature

        probs = F.softmax(logits, dim=-1)       # (V,)

        idx = torch.multinomial(probs.view(-1), 1)  # ‚úÖ FIX
        char = itos[idx.item()]
        generated.append(char)

        x = F.one_hot(idx, num_classes=len(stoi)).float().unsqueeze(0).unsqueeze(1)

    return "".join(generated)





# @torch.no_grad()
# def generate_text(model, output_layer, stoi, itos, start_char="H",
#                   length=300, temperature=1.0, device="cpu"):
#     model.eval()

#     idx = torch.tensor([stoi[start_char]], device=device)
#     x = F.one_hot(idx, num_classes=len(stoi)).float().unsqueeze(0).unsqueeze(1)

#     H_C = None
#     generated = [start_char]

#     for _ in range(length):
#         outputs, H_C = model(x, H_C)
#         logits = output_layer(outputs[-1]).squeeze(0)
#         logits /= temperature
#         probs = F.softmax(logits, dim=-1)

#         idx = torch.multinomial(probs, 1)
#         char = itos[idx.item()]
#         generated.append(char)

#         x = F.one_hot(idx, num_classes=len(stoi)).float().unsqueeze(0).unsqueeze(1)

#     return "".join(generated)

In [6]:
class LSTMScratch(nn.Module):
    """
    A minimal, from-scratch implementation of the Long Short-Term Memory (LSTM) unit.
    The implementation handles the forward pass for an entire sequence.
    """
    def __init__(self, num_inputs: int, num_hiddens: int, sigma: float = 0.01):
        """
        Initializes all weight matrices and bias vectors for the four gates/nodes:
        Input (I), Forget (F), Output (O), and Candidate Cell (C_tilde).
        """
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_inputs = num_inputs

        # Utility function to initialize weights (X -> Hidden, H -> Hidden) and bias
        def init_weights_and_bias():
            # W_x (Input to Hidden)
            W_x = nn.Parameter(torch.randn(num_inputs, num_hiddens) * sigma)
            # W_h (Hidden to Hidden, Recurrent)
            W_h = nn.Parameter(torch.randn(num_hiddens, num_hiddens) * sigma)
            # Bias (b)
            b = nn.Parameter(torch.zeros(num_hiddens))
            return W_x, W_h, b

        # --- Gates and Nodes Initialization ---

        # 1. Input Gate (I): Controls how much the candidate value updates the cell state
        self.W_xi, self.W_hi, self.b_i = init_weights_and_bias()

        # 2. Forget Gate (F): Controls how much of the old cell state (C) is retained
        self.W_xf, self.W_hf, self.b_f = init_weights_and_bias()

        # 3. Output Gate (O): Controls how much of the cell state (C) is exposed to the hidden state (H)
        self.W_xo, self.W_ho, self.b_o = init_weights_and_bias()

        # 4. Candidate Cell (C_tilde): The new information proposed for the cell state
        self.W_xc, self.W_hc, self.b_c = init_weights_and_bias()


    def forward(self, inputs: torch.Tensor, H_C: tuple[torch.Tensor, torch.Tensor] = None):
        """
        Performs the forward pass over the sequence.

        Args:
            inputs: Tensor of shape (num_steps, batch_size, num_inputs).
            H_C: Optional tuple (H, C) of initial hidden state and cell state.

        Returns:
            A tuple (outputs, (H, C)) containing:
            - outputs: List of hidden states H for each time step.
            - (H, C): The final hidden and cell states of the sequence.
        """
        # Determine initial states (H and C)
        if H_C is None:
            batch_size = inputs.shape[1]
            device = inputs.device
            H = torch.zeros((batch_size, self.num_hiddens), device=device)
            C = torch.zeros((batch_size, self.num_hiddens), device=device)
        else:
            H, C = H_C

        outputs = []

        # Iterate over the sequence (inputs[0] is X_1, inputs[1] is X_2, etc.)
        for X in inputs:
            # 1. Input Gate (I_t): Uses sigmoid activation
            I = torch.sigmoid(torch.matmul(X, self.W_xi) +
                             torch.matmul(H, self.W_hi) + self.b_i)

            # 2. Forget Gate (F_t): Uses sigmoid activation
            F = torch.sigmoid(torch.matmul(X, self.W_xf) +
                             torch.matmul(H, self.W_hf) + self.b_f)

            # 3. Output Gate (O_t): Uses sigmoid activation
            O = torch.sigmoid(torch.matmul(X, self.W_xo) +
                             torch.matmul(H, self.W_ho) + self.b_o)

            # 4. Candidate Cell (~C_t): Uses tanh activation
            C_tilde = torch.tanh(torch.matmul(X, self.W_xc) +
                                 torch.matmul(H, self.W_hc) + self.b_c)

            # 5. Cell State Update (C_t) - The Core LSTM Operation
            C = F * C + I * C_tilde

            # 6. Hidden State Update (H_t)
            H = O * torch.tanh(C)

            outputs.append(H)

        return outputs, (H, C)

def load_text(path):
    with open(path, "r", encoding='utf-8') as f:
        return f.read()

def build_vocab(text):
    chars = sorted(set(text))
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i:ch for ch, i in stoi.items()}
    return stoi, itos, len(chars)

def one_hot(indices, vocab_size):
    return torch.eye(vocab_size)[indices]

In [7]:
# comments
def train(text, epochs=20, seq_len=40, hidden_size=128, lr=0.003):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    stoi, itos, vocab_size = build_vocab(text)
    data = torch.tensor([stoi[c] for c in text], dtype=torch.long)

    # ---- train / validation split ----
    split = int(0.9 * len(data))
    train_data = data[:split]
    val_data = data[split:]

    model = LSTMScratch(vocab_size, hidden_size).to(device)
    output_layer = nn.Linear(hidden_size, vocab_size).to(device)

    optimizer = torch.optim.Adam(
        list(model.parameters()) + list(output_layer.parameters()), lr=lr
    )

    # üî• CRITICAL FIX
    loss_fn = nn.CrossEntropyLoss(reduction="sum")

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        total_tokens = 0
        H_C = None

        for step, i in enumerate(range(0, len(train_data) - seq_len - 1, seq_len)):
            x_idx = train_data[i:i+seq_len]
            y = train_data[i+1:i+seq_len+1]

            x = one_hot(x_idx, vocab_size).to(device).unsqueeze(1)
            y = y.to(device)

            outputs, H_C = model(x, H_C)
            H_C = (H_C[0].detach(), H_C[1].detach())

            logits = torch.stack([output_layer(h) for h in outputs]).squeeze(1)
            loss = loss_fn(logits, y)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            total_tokens += y.numel()

            if step % 500 == 0:
                print(
                    f"Epoch {epoch:02d} | Step {step:06d} | "
                    f"Batch ppl: {math.exp(loss.item()/y.numel()):.2f}",
                    flush=True
                )

        train_ppl = math.exp(total_loss / total_tokens)

        # ---- VALIDATION ----
        model.eval()
        val_loss = 0.0
        val_tokens = 0
        H_C = None

        with torch.no_grad():
            for i in range(0, len(val_data) - seq_len - 1, seq_len):
                x_idx = val_data[i:i+seq_len]
                y = val_data[i+1:i+seq_len+1]

                x = one_hot(x_idx, vocab_size).to(device).unsqueeze(1)
                y = y.to(device)

                outputs, H_C = model(x, H_C)
                logits = torch.stack([output_layer(h) for h in outputs]).squeeze(1)

                loss = loss_fn(logits, y)
                val_loss += loss.item()
                val_tokens += y.numel()

        val_ppl = math.exp(val_loss / val_tokens)

        print(f"\nEpoch {epoch:02d} DONE")
        print(f"Train Perplexity: {train_ppl:.2f}")
        print(f"Valid Perplexity: {val_ppl:.2f}")

        sample = generate_text(
            model, output_layer, stoi, itos,
            start_char=text[0],
            device=device
        )
        print("\n--- Sample ---")
        print(sample)
        print("-" * 60)

In [None]:
train(text, epochs=5, seq_len=40, hidden_size=128, lr=0.003)

Epoch 00 | Step 000000 | Batch ppl: 97.40
Epoch 00 | Step 000500 | Batch ppl: 11.93
Epoch 00 | Step 001000 | Batch ppl: 11.84
Epoch 00 | Step 001500 | Batch ppl: 9.64
Epoch 00 | Step 002000 | Batch ppl: 9.57
Epoch 00 | Step 002500 | Batch ppl: 8.29
Epoch 00 | Step 003000 | Batch ppl: 5.80
Epoch 00 | Step 003500 | Batch ppl: 8.41
Epoch 00 | Step 004000 | Batch ppl: 6.12
Epoch 00 | Step 004500 | Batch ppl: 5.51
Epoch 00 | Step 005000 | Batch ppl: 4.71
Epoch 00 | Step 005500 | Batch ppl: 4.24
Epoch 00 | Step 006000 | Batch ppl: 5.40
Epoch 00 | Step 006500 | Batch ppl: 4.55
Epoch 00 | Step 007000 | Batch ppl: 3.60
Epoch 00 | Step 007500 | Batch ppl: 5.61
Epoch 00 | Step 008000 | Batch ppl: 4.61
Epoch 00 | Step 008500 | Batch ppl: 3.02
Epoch 00 | Step 009000 | Batch ppl: 6.17
Epoch 00 | Step 009500 | Batch ppl: 4.46
Epoch 00 | Step 010000 | Batch ppl: 4.70
Epoch 00 | Step 010500 | Batch ppl: 6.93
Epoch 00 | Step 011000 | Batch ppl: 5.19
Epoch 00 | Step 011500 | Batch ppl: 4.62
Epoch 00 | St