# Long Short-Term Memory RNNs (LSTMs)

## Definition
Long Short-Term Memory networks (LSTMs) are specialized recurrent neural networks designed to model long-term dependencies in sequential data. Introduced by Hochreiter and Schmidhuber in 1997, LSTMs overcome the limitations of traditional RNNs through a sophisticated gating mechanism that regulates information flow through the network.

## Mathematical Formulation

The LSTM cell operates through the following equations at time step $t$:

**Input Gate:**
$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$

**Forget Gate:**
$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$$

**Cell Update:**
$$\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)$$

**Cell State Update:**
$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$$

**Output Gate:**
$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$$

**Hidden State Update:**
$$h_t = o_t \odot \tanh(C_t)$$

Where:
- $x_t$ represents the input vector at time $t$
- $h_{t-1}$ is the previous hidden state
- $C_{t-1}$ is the previous cell state
- $W$ matrices are trainable weights
- $b$ vectors are bias terms
- $\sigma$ denotes the sigmoid function
- $\odot$ represents element-wise multiplication

## Core Principles

### Cell Architecture
- **Memory Cell (Cell State)**: Central component that acts as an information highway running through the sequence
- **Three Gates**: Input, forget, and output gates that regulate information flow
- **Controlled Information Flow**: Selective reading, writing, and forgetting operations

### Gating Mechanisms
- **Forget Gate**: Determines what information to discard from the cell state
- **Input Gate**: Controls what new information to incorporate into the cell state
- **Output Gate**: Filters what parts of the cell state to output as the hidden state

## How LSTM Solves Vanishing Gradients

### The Problem
In standard RNNs, gradients flowing backward through time diminish exponentially due to repeated multiplication with the recurrent weight matrix, causing:
- Loss of long-range dependencies
- Stalled learning for early time steps
- Unstable training dynamics

### LSTM's Solutions

#### 1. Constant Error Carousel (CEC)
- The cell state provides an uninterrupted gradient pathway through time
- The key equation revealing this mechanism:
$$\frac{\partial C_t}{\partial C_{t-1}} = f_t$$

When $f_t \approx 1$, gradients can flow backward with minimal decay

#### 2. Additive Update Structure
- Unlike standard RNNs which use multiplicative updates, LSTMs use additive updates:
$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$$
- This prevents gradient decay through repeated multiplication

#### 3. Gated Information Flow
- Gates control which information persists through time
- Forget gates typically initialize with biases toward 1, creating near-identity mappings early in training
- This establishes gradient highways that prevent vanishing

#### 4. Protected Memory Cell
- Cell state is only partially exposed to non-linearities
- Limits compounding effects of activation functions that contribute to gradient vanishing

## Is Vanishing/Exploding Gradient Just an RNN Problem?

No, these issues affect various neural architectures, though they are most pronounced in recurrent networks.

### Other Affected Architectures

#### Deep Feedforward Networks
- Very deep networks suffer from gradient decay/explosion across layers
- Each additional layer compounds the gradient transformation problem

#### Convolutional Neural Networks
- Pre-ResNet, deep CNNs faced significant gradient issues
- Networks beyond ~20 layers became increasingly difficult to train

#### Transformers
- Despite self-attention, very deep transformers can experience gradient problems
- Mitigated through normalization techniques and residual connections

### Universal Solutions
- **Architectural shortcuts**: Residual connections, highway networks, dense connections
- **Normalization techniques**: Batch normalization, layer normalization
- **Initialization strategies**: Careful weight initialization (Xavier/Glorot, He)
- **Gradient stabilization**: Gradient clipping, gradient scaling
- **Activation functions**: ReLU variants reduce vanishing compared to sigmoid/tanh

## Pros and Cons of LSTMs

### Advantages
- Effective modeling of long-term dependencies
- Robust gradient flow through time
- Selective memory retention through gating
- Superior performance on sequential tasks compared to vanilla RNNs
- Interpretable internal states through gate activations

### Disadvantages
- Computationally expensive (3-4× parameters compared to vanilla RNNs)
- Sequential computation limits parallelization
- Still struggles with very long sequences (thousands of steps)
- Complex architecture increases training difficulty
- Outperformed by Transformers in many modern NLP tasks
- High memory requirements for backpropagation through time

## Recent Advancements

### Architectural Innovations
- **Peephole Connections**: Allow gates to access cell state directly
- **ConvLSTM**: Incorporate convolutional operations for spatial-temporal data
- **Bidirectional LSTMs**: Process sequences in both forward and backward directions
- **Attention-augmented LSTMs**: Combine recurrent processing with attention mechanisms

### Training Improvements
- **Layer normalization**: Stabilizes LSTM training
- **Chrono initialization**: Specialized initialization for forget gates
- **Zoneout regularization**: Alternative to dropout for recurrent connections

### Hardware Optimizations
- **Quantized LSTMs**: Reduced precision for efficiency
- **Sparse LSTM variants**: Reduced computational overhead through pruning
- **Optimized CUDA implementations**: Hardware-specific optimizations

### Hybrid Architectures
- **LSTM-Transformer hybrids**: Leveraging strengths of both architectures
- **Hierarchical LSTMs**: Capturing information at multiple time scales
- **Neural Ordinary Differential Equations (NODE)**: Continuous-time variants of LSTMs



In [None]:
import torch
import torch.nn as nn
import math

class LSTMCell(nn.Module):
    """
    Custom implementation of LSTM cell from scratch.
    Attributes:
        input_size (int): Size of input vector
        hidden_size (int): Size of hidden state vector
        weight_ih (Tensor): Input-hidden weights
        weight_hh (Tensor): Hidden-hidden weights
        bias_ih (Tensor): Input-hidden bias
        bias_hh (Tensor): Hidden-hidden bias
    """

    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Initialize weights using Xavier/Glorot initialization
        std = math.sqrt(2.0 / (input_size + hidden_size))

        # Combined weights for input-hidden transformations (4 gates)
        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size) * std)
        # Combined weights for hidden-hidden transformations (4 gates)
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size) * std)
        # Combined biases
        self.bias_ih = nn.Parameter(torch.zeros(4 * hidden_size))
        self.bias_hh = nn.Parameter(torch.zeros(4 * hidden_size))

    def forward(self, x, hidden):
        """
        Forward pass of LSTM cell.
        Args:
            x (Tensor): Input tensor of shape (batch_size, input_size)
            hidden (tuple): Tuple of (h_prev, c_prev) previous hidden and cell states
        Returns:
            tuple: (h_t, c_t) new hidden and cell states
        """
        h_prev, c_prev = hidden

        # Linear transformations
        gates = (torch.mm(x, self.weight_ih.t()) + self.bias_ih +
                torch.mm(h_prev, self.weight_hh.t()) + self.bias_hh)

        # Split into four gates (input, forget, cell, output)
        chunk_size = self.hidden_size
        i_t = gates[:, :chunk_size]  # Input gate
        f_t = gates[:, chunk_size:2*chunk_size]  # Forget gate
        g_t = gates[:, 2*chunk_size:3*chunk_size]  # Cell gate
        o_t = gates[:, 3*chunk_size:]  # Output gate

        # Apply activation functions
        i_t = torch.sigmoid(i_t)
        f_t = torch.sigmoid(f_t)
        g_t = torch.tanh(g_t)
        o_t = torch.sigmoid(o_t)

        # Update cell state
        c_t = f_t * c_prev + i_t * g_t

        # Update hidden state
        h_t = o_t * torch.tanh(c_t)

        return h_t, c_t

    def init_hidden(self, batch_size, device):
        """
        Initialize hidden and cell states.
        Args:
            batch_size (int): Batch size
            device (torch.device): Device to create tensors on
        Returns:
            tuple: (h_0, c_0) initialized hidden and cell states
        """
        h_0 = torch.zeros(batch_size, self.hidden_size, device=device)
        c_0 = torch.zeros(batch_size, self.hidden_size, device=device)
        return h_0, c_0


class LSTMNetwork(nn.Module):
    """
    LSTM network wrapper for multiple time steps.
    Attributes:
        lstm_cell (LSTMCell): LSTM cell implementation
        num_layers (int): Number of LSTM layers
        hidden_size (int): Size of hidden state
    """

    def __init__(self, input_size, hidden_size, num_layers=1):
        super(LSTMNetwork, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Create LSTM cells for each layer
        self.lstm_cells = nn.ModuleList([
            LSTMCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])

    def forward(self, x, hidden=None):
        """
        Forward pass of LSTM network.
        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_length, input_size)
            hidden (list): List of initial hidden states for each layer
        Returns:
            tuple: (output, hidden) final output and hidden states
        """
        batch_size, seq_length, _ = x.size()
        device = x.device

        # Initialize hidden states if not provided
        if hidden is None:
            hidden = [
                self.lstm_cells[i].init_hidden(batch_size, device)
                for i in range(self.num_layers)
            ]

        output = []

        # Process sequence
        for t in range(seq_length):
            x_t = x[:, t, :]  # Current time step input

            # Process through each layer
            for layer_idx in range(self.num_layers):
                hidden[layer_idx] = self.lstm_cells[layer_idx](
                    x_t, hidden[layer_idx]
                )
                x_t = hidden[layer_idx][0]  # Hidden state becomes input to next layer

            output.append(x_t)

        # Stack outputs
        output = torch.stack(output, dim=1)
        return output, hidden


# Example usage and testing
if __name__ == "__main__":
    # # Check if torch is available and get version
    # assert torch.__version__ >= "1.9.0", "PyTorch version 1.9.0 or higher required"

    # Parameters
    input_size = 10
    hidden_size = 20
    num_layers = 20
    batch_size = 32
    seq_length = 15

    # Create sample data
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x = torch.randn(batch_size, seq_length, input_size).to(device)

    # Initialize network
    model = LSTMNetwork(input_size, hidden_size, num_layers).to(device)

    # Forward pass
    output, hidden = model(x)

    # Verify output shapes
    assert output.shape == (batch_size, seq_length, hidden_size)
    assert len(hidden) == num_layers
    assert all(h.shape == (batch_size, hidden_size) for h, c in hidden)
    assert all(c.shape == (batch_size, hidden_size) for h, c in hidden)

    print("All tests passed!")
    print(f"Output shape: {output.shape}")
    print(f"Number of layers: {len(hidden)}")