<a href="https://colab.research.google.com/github/ek0212/comparing-rnn-lstm-transformer/blob/main/comparing_rnn_lstm_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- Required Libraries ---
import numpy as np
import torch
import torch.nn.functional as F

print("----------------------------------------------------------------------")
print("--- Welcome to the Neural Network Sequence Processor Explorer! ---")
print("----------------------------------------------------------------------")
print("\nWe're going to see how different types of neural networks 'read' and 'remember' a simple sequence: 'a b c'.")
print("We'll explore:")
print("1. RNN (Recurrent Neural Network): Reads step-by-step, basic memory.")
print("2. LSTM (Long Short-Term Memory): Smarter step-by-step reading, better memory with 'gates'.")
print("3. Transformer (Self-Attention): Looks at the whole sequence at once to see how parts relate.")

print("\n--- Setting up our 'Ingredients' ---")

# Set random seed for reproducibility
torch.manual_seed(42)
print("\n[SETUP] Setting a 'random seed' (42). This ensures that if we run this code again,")
print("        any 'random' numbers generated (like initial weights) will be the exact same.")
print("        It's like starting a card game with the deck shuffled the same way every time for an experiment.")

# Vocabulary: mapping characters to integers
vocab = {'a': 0, 'b': 1, 'c': 2}
vocab_inv = {v: k for k, v in vocab.items()}
print("\n[SETUP] Creating a 'vocabulary'. Computers like numbers, not letters directly.")
print(f"        Our vocab: {vocab} (maps letters to numbers)")
print(f"        And an inverse vocab: {vocab_inv} (maps numbers back to letters)")

input_sequence = torch.tensor([0, 1, 2])  # This represents the sequence "a b c"
print("\n[SETUP] Our input sequence is 'a b c', represented by numbers:", input_sequence.tolist())
print("        So, 'a' is 0, 'b' is 1, 'c' is 2.")

vocab_size = len(vocab)
embedding_size = 2
hidden_size = 2
print(f"\n[SETUP] Vocab size: {vocab_size} (we have 3 unique characters: a, b, c)")
print(f"[SETUP] Embedding size: {embedding_size}. Each character will be represented by a list (vector) of {embedding_size} numbers.")
print("        Think of these as 'features' or 'traits' for each character.")
print(f"[SETUP] Hidden size: {hidden_size}. This is like the 'memory capacity' or 'thinking space' for our RNN/LSTM.")

# Original prints for input sequence details
print("\n===== Toy Input Sequence (Recap) =====")
print("Input indices:", input_sequence.tolist())
print("Input tokens:", [vocab_inv[i.item()] for i in input_sequence])

# Create random embeddings for each token
embedding_matrix = torch.randn(vocab_size, embedding_size)
print("\n===== Embedding Matrix (Randomly Initialized) =====")
print("[SETUP] Creating an 'Embedding Matrix'. This is a lookup table.")
print("        Row 0 = vector for 'a', Row 1 = vector for 'b', Row 2 = vector for 'c'.")
print("        In real life, these vectors are *learned*. Here, they are just random numbers to start.")
print(embedding_matrix)
print(f"        For example, the random vector for 'a' (index 0) is: {embedding_matrix[0]}")

### --- RNN (Recurrent Neural Network) --- ###
print("\n\n===============================================")
print("===== RNN (Recurrent Neural Network) =====")
print("===============================================")
print("\n[RNN INFO] RNNs process sequences one item at a time, like reading a sentence word by word.")
print("           They have a 'hidden state' (memory) that tries to capture what's been seen so far.")
print("           Let's see how it processes 'a b c'.")

# Define weights and bias for RNN
input_to_hidden_weights = torch.randn(embedding_size, hidden_size)
hidden_to_hidden_weights = torch.randn(hidden_size, hidden_size)
hidden_bias = torch.zeros(hidden_size)
print("\n[RNN SETUP] Initializing RNN 'weights' and 'bias'.")
print("            These are numbers the RNN would learn if we were training it.")
print("            - 'input_to_hidden_weights': Transform the current character's vector.")
print("            - 'hidden_to_hidden_weights': Transform the previous memory.")
print("            - 'hidden_bias': A small adjustment.")

# Initialize hidden state (memory)
hidden_state = torch.zeros(hidden_size)
print("\n[RNN SETUP] Initializing the RNN's 'hidden state' (memory) to all zeros.")
print("            Before seeing any input, the memory is blank:", hidden_state)

# Loop through each input token
for step in range(len(input_sequence)):
    token_index = input_sequence[step]
    token_char = vocab_inv[token_index.item()]
    token_vector = embedding_matrix[token_index]  # Get embedding vector

    print(f"\n--- RNN Step {step+1}: Processing Token '{token_char}' ---")
    print(f"Input vector (embedding) for '{token_char}':", token_vector)
    print("Hidden state (memory from PREVIOUS step):", hidden_state)

    print("\n[RNN CALC] Updating hidden state using the formula:")
    print("           new_hidden = tanh( (current_input_vector @ W_input) + (previous_hidden_state @ W_hidden) + bias )")
    # Update hidden state
    input_contribution = token_vector @ input_to_hidden_weights
    previous_memory_contribution = hidden_state @ hidden_to_hidden_weights
    combined_info = input_contribution + previous_memory_contribution + hidden_bias
    hidden_state = torch.tanh(combined_info)
    print(f"Updated RNN hidden state after '{token_char}':", hidden_state)
    print("           The 'tanh' function squashes the result between -1 and 1, keeping numbers manageable.")

print("\n[RNN RESULT] Final RNN hidden state after processing 'a b c':", hidden_state)
print("             This vector is the RNN's attempt to summarize the entire sequence 'a b c'.")
print("             A challenge for RNNs: they can struggle to remember things from far back in long sequences.")

### --- LSTM (Long Short-Term Memory) --- ###
print("\n\n=====================================================")
print("===== LSTM (Long Short-Term Memory) =====")
print("=====================================================")
print("\n[LSTM INFO] LSTMs are a type of RNN, but with a more sophisticated memory system.")
print("            They use 'gates' to control the flow of information, helping them remember things for longer.")
print("            The key components are:")
print("            - A 'cell state' (C_t): The long-term memory conveyor belt.")
print("            - Forget Gate (f_t): Decides what old information to discard from C_t.")
print("            - Input Gate (i_t): Decides what new information to store in C_t.")
print("            - Output Gate (o_t): Decides what part of C_t to output as the current hidden state (h_t).")
print("\n[LSTM INFO] IMPORTANT: In this demo, the LSTM's weights are random, so it doesn't 'know' what to forget yet.")
print("             It 'learns' what's important to forget/remember during a separate 'training' phase, not shown here.")
print("             We're just seeing the *mechanics* of how the gates operate with some random weights.")


# LSTM uses gates and memory cell to track long-term patterns
lstm_weights = torch.randn(embedding_size + hidden_size, 4 * hidden_size) # 4 for: input, forget, cell_candidate, output gates
lstm_bias = torch.zeros(4 * hidden_size)
print("\n[LSTM SETUP] Initializing LSTM 'weights' and 'bias'. LSTMs have more weights because of their gates.")
print("             These weights combine the current input and previous hidden state to control 4 parts.")

# Initialize LSTM hidden state and cell state
hidden_state = torch.zeros(hidden_size) # This is h_{t-1} at the start of each step
cell_state = torch.zeros(hidden_size)   # This is C_{t-1} at the start of each step
print("\n[LSTM SETUP] Initializing LSTM 'hidden state' (h_t, short-term output) and 'cell state' (C_t, long-term memory) to zeros.")
print("            Initial hidden state (h_0):", hidden_state)
print("            Initial cell state (C_0):", cell_state)

for step in range(len(input_sequence)):
    token_index = input_sequence[step]
    token_char = vocab_inv[token_index.item()]
    token_vector = embedding_matrix[token_index]

    print(f"\n--- LSTM Step {step+1}: Processing Token '{token_char}' ---")
    print(f"Input vector (embedding for '{token_char}'):", token_vector)
    print(f"Hidden state from PREVIOUS step (h_{{{step}}}):", hidden_state)
    print(f"Cell state (long-term memory) from PREVIOUS step (C_{{{step}}}):", cell_state)

    print("\n[LSTM CALC] Combining current input vector and previous hidden state (h_t-1).")
    combined_vector = torch.cat([token_vector, hidden_state], dim=0)

    print("[LSTM CALC] Calculating values for all 4 gates/parts using the combined vector and LSTM weights.")
    gate_values = combined_vector @ lstm_weights + lstm_bias
    # The order is often Input, Forget, Cell Candidate (or Gate for cell), Output
    # but here we'll follow the original code's split for consistency in demonstration:
    input_gate_raw, forget_gate_raw, cell_candidate_raw, output_gate_raw = gate_values.chunk(4, dim=0)

    print("\n[LSTM CALC] Applying activation functions to gate values:")
    print("            - 'sigmoid' (outputs 0 to 1) is used for gates. This is key!")
    print("              A value near 0 means 'close the gate / forget / ignore'.")
    print("              A value near 1 means 'open the gate / remember / allow'.")
    print("            - 'tanh' (outputs -1 to 1) is used for new candidate values to add to memory.")

    # --- Forget Gate (f_t) ---
    # Decides what to forget from the old cell state.
    forget_gate = torch.sigmoid(forget_gate_raw)
    print("\nForget Gate (f_t) values:", forget_gate)
    print("           Interpretation: Each number (0 to 1) decides how much of the corresponding")
    print(f"           part of the *old* cell state (C_{{{step}}}) to keep. 0 = forget, 1 = keep.")

    # --- Input Gate (i_t) & Cell Candidate (g_t) ---
    # Input Gate: Decides which new values to update in the cell state.
    input_gate = torch.sigmoid(input_gate_raw)
    # Cell Candidate: Creates a vector of new candidate values that *could* be added.
    cell_candidate = torch.tanh(cell_candidate_raw)
    print("\nInput Gate (i_t) values:", input_gate)
    print("           Interpretation: Decides how much of the 'new candidate info' (below) to add to memory.")
    print("Cell Candidate (g_t) values (new potential info):", cell_candidate)
    print("           Interpretation: These are new values that *could* be added to the cell state.")

    # --- Output Gate (o_t) ---
    # Decides what part of the (updated) cell state to output as the hidden state.
    output_gate = torch.sigmoid(output_gate_raw)
    print("\nOutput Gate (o_t) values:", output_gate)
    print("           Interpretation: Decides how much of the *newly updated* cell state (C_t) to output as the hidden state (h_t).")

    print("\n[LSTM CALC] Updating the cell state (C_t) - the LSTM's long-term memory:")
    print(f"           Formula: C_{{{step+1}}} = (Forget_Gate * C_{{{step}}}) + (Input_Gate * Cell_Candidate)")
    print(f"           Breaking it down:")
    forgotten_part = forget_gate * cell_state
    print(f"             1. `Forget_Gate * C_{{{step}}}`: {forgotten_part} (Old memory filtered by forget gate)")
    new_info_part = input_gate * cell_candidate
    print(f"             2. `Input_Gate * Cell_Candidate`: {new_info_part} (New candidate info filtered by input gate)")
    cell_state = forgotten_part + new_info_part
    print(f"Updated cell state (C_{{{step+1}}}):", cell_state)
    print(f"           This new cell state C_{{{step+1}}} now holds a combination of filtered old memories and filtered new information.")

    print("\n[LSTM CALC] Updating the hidden state (h_t) - the output for this step:")
    print(f"           Formula: h_{{{step+1}}} = Output_Gate * tanh(C_{{{step+1}}})")
    hidden_state = output_gate * torch.tanh(cell_state)
    print(f"Updated hidden state (h_{{{step+1}}}):", hidden_state)
    print(f"           The hidden state h_{{{step+1}}} is a filtered version of the LSTM's internal long-term memory (C_{{{step+1}}}).")


print("\n[LSTM RESULT] Final LSTM hidden state (h_N) after processing 'a b c':", hidden_state)
print("[LSTM RESULT] Final LSTM cell state (C_N) after processing 'a b c':", cell_state)
print("              LSTMs, through their gates (especially the forget gate), can learn what information")
print("              to retain or discard over long sequences. This 'learning' happens during training,")
print("              where weights are adjusted to minimize errors on a specific task.")

### --- Transformer Self-Attention --- ###
print("\n\n===============================================")
print("===== Transformer Self-Attention =====")
print("===============================================")
print("\n[TRANSFORMER INFO] Transformers (specifically, their 'self-attention' mechanism) work differently.")
print("                   Instead of step-by-step, they look at ALL tokens in the sequence AT ONCE.")
print("                   Self-attention helps each token figure out how relevant other tokens are to it.")
print("                   Analogy: When reading, you might glance at other words in a sentence to understand a specific word's context.")

# Self-attention compares every token to every other token
input_vectors = embedding_matrix[input_sequence]  # shape: (3, 2)
print("\n[TRANSFORMER SETUP] First, get the embedding vectors for our entire input sequence 'a b c' at once.")
print("Input vectors (embeddings for 'a', 'b', 'c'):\n", input_vectors)

# Create weights for queries, keys, and values
W_query = torch.randn(embedding_size, embedding_size)
W_key = torch.randn(embedding_size, embedding_size)
W_value = torch.randn(embedding_size, embedding_size)
print("\n[TRANSFORMER SETUP] For each input token, we create three special versions using learned weights:")
print("                   1. Query (Q): What is this token 'looking for' in other tokens?")
print("                   2. Key (K): What kind of information does this token 'represent' or 'offer'?")
print("                   3. Value (V): What actual information does this token 'contribute' if it's deemed relevant?")

queries = input_vectors @ W_query
keys = input_vectors @ W_key
values = input_vectors @ W_value

print("\n[TRANSFORMER CALC] Generating Query, Key, and Value vectors for each token:")
print("Query vectors (Q) - one row per token ('a','b','c'):\n", queries)
print("Key vectors (K) - one row per token:\n", keys)
print("Value vectors (V) - one row per token:\n", values)

print("\n[TRANSFORMER CALC] Calculating 'Attention Scores':")
print("                   How much should token_i 'pay attention' to token_j?")
print("                   This is done by: (Query_i @ Key_j.transposed) / sqrt(dimension_of_key)")
attention_scores = queries @ keys.T / np.sqrt(embedding_size)
print("Attention scores (raw, after scaling):\n", attention_scores)
print("Interpretation: attention_scores[row_i, col_j] = score of token_i attending to token_j.")
print("e.g., scores[0,1] is 'a' attending to 'b'. Scaling helps stabilize numbers.")

print("\n[TRANSFORMER CALC] Converting scores to 'Attention Weights' (probabilities) using Softmax.")
print("                   Softmax makes sure that for each token, its attention to all other tokens sums up to 1 (like percentages).")
attention_weights = F.softmax(attention_scores, dim=-1)
print("Attention weights (after softmax):\n", attention_weights)
print("Interpretation: weights[row_i, col_j] = how much % token_i focuses on token_j.")

print("\n[TRANSFORMER CALC] Creating new, 'context-aware' representations for each token.")
print("                   Each token's new vector = weighted sum of ALL Value vectors.")
print("                   The weights are the attention_weights we just calculated.")
attention_output = attention_weights @ values
print("\nAttention output vectors (new representations for 'a','b','c' based on context):\n", attention_output)
print("Each row is a new vector for 'a', 'b', or 'c', now enriched with context from other tokens it 'attended' to.")
print("Unlike RNN/LSTM's single final state, Attention gives an updated representation for *each* token.")


### --- Summary --- ###
print("\n\n==========================")
print("===== Quick Summary =====")
print("==========================")

print("\n--- RNN (Recurrent Neural Network) ---")
print("- Reads input one piece at a time (sequential).")
print("- Uses a 'hidden state' as its memory, updated at each step.")
print("- Analogy: Reading a book word-by-word, trying to remember the plot.")
print("- Challenge: Can struggle with long-term memory (forgetting early parts of long sequences).")

print("\n--- LSTM (Long Short-Term Memory) ---")
print("- Also sequential, but with a smarter memory system using 'gates' (input, forget, output).")
print("- Has a 'cell state' for better long-term memory retention. The 'forget gate' specifically learns")
print("  what old information in the cell state is no longer relevant and should be discarded.")
print("- Analogy: Reading with a good note-taking system – deciding what to write down, what to erase (forget gate's job!), and what to refer to.")
print("- Advantage: Better at handling long sequences and remembering important details over time than basic RNNs because it can learn to manage its memory.")

print("\n--- Transformer (Self-Attention) ---")
print("- Looks at ALL input tokens at once (parallel processing).")
print("- 'Self-attention' lets each token weigh the importance of all other tokens (including itself) to understand its context.")
print("- Uses Query, Key, Value mechanism to calculate these attention weights.")
print("- Analogy: Looking at an entire picture at once, and for each object, seeing how it relates to all other objects in the scene.")
print("- Advantages: Often better for long sequences, can capture complex relationships, good for parallel computation (training faster).")

print("\n\n----------------------------------------------------------------------")
print("--- End of Demo! ---")
print("----------------------------------------------------------------------")
print("This was a tiny peek. Real models are much bigger and are 'trained' on lots of data to learn their weights,")
print("including how the LSTM's forget gate should behave for specific tasks!")

----------------------------------------------------------------------
--- Welcome to the Neural Network Sequence Processor Explorer! ---
----------------------------------------------------------------------

We're going to see how different types of neural networks 'read' and 'remember' a simple sequence: 'a b c'.
We'll explore:
1. RNN (Recurrent Neural Network): Reads step-by-step, basic memory.
2. LSTM (Long Short-Term Memory): Smarter step-by-step reading, better memory with 'gates'.
3. Transformer (Self-Attention): Looks at the whole sequence at once to see how parts relate.

--- Setting up our 'Ingredients' ---

[SETUP] Setting a 'random seed' (42). This ensures that if we run this code again,
        any 'random' numbers generated (like initial weights) will be the exact same.
        It's like starting a card game with the deck shuffled the same way every time for an experiment.

[SETUP] Creating a 'vocabulary'. Computers like numbers, not letters directly.
        Our vocab: 