In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

In [2]:
# actually it represent a LSTM layer
class LSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super(LSTMCell, self).__init__()
        self.input_dim, self.hidden_dim = input_dim, hidden_dim

        # Forget gate parameters
        self.W_xf, self.W_hf, self.b_f = self.create_gate_parameters()

        # Input gate parameters
        self.W_xi, self.W_hi, self.b_i = self.create_gate_parameters()

        # Output gate parameters
        self.W_xo, self.W_ho, self.b_o = self.create_gate_parameters()

        # Candidate cell parameters
        self.W_xg, self.W_hg, self.b_g = self.create_gate_parameters()

    def create_gate_parameters(self):
        W_x = nn.Parameter(torch.zeros(self.input_dim, self.hidden_dim))   # W_x?
        W_h = nn.Parameter(torch.zeros(self.hidden_dim, self.hidden_dim))  # W_h?
        nn.init.xavier_uniform_(W_x)
        nn.init.xavier_uniform_(W_h)
        b = nn.Parameter(torch.zeros(self.hidden_dim))
        return W_x, W_h, b

    def forward(self, x, h_t_1, c_t_1):
        # x: [batch_size, seq_len, input_dim]
        output_h, output_c = [], []

        for i in range(x.shape[1]):
            x_t = x[:, i]  # extracts the i-th token for every sequence in the batch (1 token: i-th rows all column)
            # or x[:, i, :]

            # Forget gate
            f_t = torch.sigmoid((x_t @ self.W_xf) + (h_t_1 @ self.W_hf) + self.b_f)

            # Input gate
            i_t = torch.sigmoid((x_t @ self.W_xi) + (h_t_1 @ self.W_hi) + self.b_i)

            # Candidate cell update (g_t or c~_t)
            g_t = torch.tanh((x_t @ self.W_xg) + (h_t_1 @ self.W_hg) + self.b_g)

            # Cell state update
            c_t = (f_t * c_t_1) + (i_t * g_t)

            # Output gate
            o_t = torch.sigmoid((x_t @ self.W_xo) + (h_t_1 @ self.W_ho) + self.b_o)

            # Hidden state update
            h_t = torch.tanh(c_t) * o_t

            # Store for sequence output
            output_h.append(h_t.unsqueeze(1))
            output_c.append(c_t.unsqueeze(1))

            # Update for next timestep
            h_t_1 = h_t
            c_t_1 = c_t

        return torch.concat(output_h, dim=1), torch.concat(output_c, dim=1) # convert the list into tensor alonge the second dimension (B, 1)



class MultiLayerLSTM(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        super(MultiLayerLSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # First LSTM layer (input_dim -> hidden_dim)
        self.layers = nn.ModuleList()
        self.layers.append(LSTMCell(input_dim, hidden_dim))

        # Remaining LSTM layers (hidden_dim -> hidden_dim)
        for _ in range(num_layers - 1):
            self.layers.append(LSTMCell(hidden_dim, hidden_dim))

        # Dropout between layers
        self.dropout = nn.Dropout(dropout)

        # Final linear projection back to input dimension
        self.proj = nn.Linear(hidden_dim, input_dim)
        nn.init.xavier_uniform_(self.proj.weight)
        self.proj.bias.data.fill_(0.0)

    def forward(self, x, state):
        """
        x: [batch_size, seq_len, input_dim]
        state: (h, c)
            # Each layer gets its own independent (h0, c0). each layer has completely separate memory.
                h: [num_layers, batch_size, hidden_dim]
                c: [num_layers, batch_size, hidden_dim]
        """
        h_prev, c_prev = state   # previous hidden & cell states

        # ----- Layer 0 -----
        h_out, c_out = self.layers[0](x, h_prev[0], c_prev[0])

        # Store the final timestep outputs
        h_list = [h_out[:, -1].unsqueeze(0)]
        c_list = [c_out[:, -1].unsqueeze(0)]

        # ----- Remaining layers -----
        for layer_idx in range(1, self.num_layers):
            # Apply dropout between layers
            dropped = self.dropout(h_out)
            h_out, c_out = self.layers[layer_idx](dropped,
                                                  h_prev[layer_idx],
                                                  c_prev[layer_idx])
            h_list.append(h_out[:, -1].unsqueeze(0))
            c_list.append(c_out[:, -1].unsqueeze(0))

        # Output projection
        logits = self.proj(self.dropout(h_out))

        # New hidden states for next forward pass
        h_new = torch.cat(h_list, dim=0)
        c_new = torch.cat(c_list, dim=0)

        return logits, (h_new, c_new)


# With details explanation


### How Inputs Flow Through a Multi-Layer LSTM

The **output sequence of layer 0 becomes the input sequence of layer 1.**  
**Not the original `x`.**

---

### Layer-wise Input Flow

- **Input to layer 0:** `x` (original data)  
- **Input to layer 1:** `h_out` (output of layer 0)  
- **Input to layer 2:** `h_out` from layer 1  
- **Input to layer 3:** output from layer 2  
- ‚Ä¶ and so on.

This is standard in all multi-layer LSTMs.

---

### üìå Shape Always Remains the Same

Each layer outputs:

```

h_out: [batch, seq_len, hidden_dim]

````

So the input to the next layer is always:

```python
next_input = h_out
````

```
```


In [2]:
# actually it represent a LSTM layer
class LSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super(LSTMCell, self).__init__()
        self.input_dim, self.hidden_dim = input_dim, hidden_dim

        # Forget gate parameters
        self.W_xf, self.W_hf, self.b_f = self.create_gate_parameters()

        # Input gate parameters
        self.W_xi, self.W_hi, self.b_i = self.create_gate_parameters()

        # Output gate parameters
        self.W_xo, self.W_ho, self.b_o = self.create_gate_parameters()

        # Candidate cell parameters
        self.W_xg, self.W_hg, self.b_g = self.create_gate_parameters()

    def create_gate_parameters(self):
        W_x = nn.Parameter(torch.zeros(self.input_dim, self.hidden_dim))   # (input_dim, hidden_dim)
        W_h = nn.Parameter(torch.zeros(self.hidden_dim, self.hidden_dim))  # (hidden_dim, hidden_dim)
        nn.init.xavier_uniform_(W_x)
        nn.init.xavier_uniform_(W_h)
        b = nn.Parameter(torch.zeros(self.hidden_dim))                     # (hidden_dim,)
        return W_x, W_h, b

    def forward(self, x, h_t_1, c_t_1):
        # x: [batch_size, seq_len, input_dim]
        # h_t_1: [batch_size, hidden_dim]
        # c_t_1: [batch_size, hidden_dim]

        output_h, output_c = [], []

        for i in range(x.shape[1]):  
            x_t = x[:, i, :]                   
            # x_t: [batch_size, input_dim]

            # ------------------------- Forget gate -------------------------
            # x_t @ W_xf ‚Üí [batch, input_dim] @ [input_dim, hidden_dim]
            xWf = x_t @ self.W_xf
            # xWf: [batch_size, hidden_dim]

            # h_t_1 @ W_hf ‚Üí [batch, hidden_dim] @ [hidden_dim, hidden_dim]
            hWf = h_t_1 @ self.W_hf
            # hWf: [batch_size, hidden_dim]

            f_t = torch.sigmoid(xWf + hWf + self.b_f)
            # f_t: [batch_size, hidden_dim]

            # ------------------------- Input gate --------------------------
            i_t = torch.sigmoid(
                (x_t @ self.W_xi) +                 # [batch, hidden_dim]
                (h_t_1 @ self.W_hi) +               # [batch, hidden_dim]
                self.b_i                            # [hidden_dim]
            )
            # i_t: [batch_size, hidden_dim]

            # ---------------------- Candidate update ------------------------
            g_t = torch.tanh(
                (x_t @ self.W_xg) +                 # [batch, hidden_dim]
                (h_t_1 @ self.W_hg) +               # [batch, hidden_dim]
                self.b_g                            # [hidden_dim]
            )
            # g_t: [batch_size, hidden_dim]

            # ------------------------ Cell state ----------------------------
            c_t = (f_t * c_t_1) + (i_t * g_t)
            # c_t: [batch_size, hidden_dim]

            # ------------------------- Output gate --------------------------
            o_t = torch.sigmoid(
                (x_t @ self.W_xo) +                 # [batch, hidden_dim]
                (h_t_1 @ self.W_ho) +               # [batch, hidden_dim]
                self.b_o                            # [hidden_dim]
            )
            # o_t: [batch_size, hidden_dim]

            # ------------------------ Hidden state --------------------------
            h_t = torch.tanh(c_t) * o_t
            # h_t: [batch_size, hidden_dim]

            # ------------------------- Store outputs ------------------------
            output_h.append(h_t.unsqueeze(1))
            # h_t.unsqueeze(1): [batch_size, 1, hidden_dim]

            output_c.append(c_t.unsqueeze(1))
            # c_t.unsqueeze(1): [batch_size, 1, hidden_dim] --> we add this extra dimension to restore the the dimension of x (i.e for the sentences)

            # Update for next timestep
            h_t_1 = h_t   # [batch_size, hidden_dim]
            c_t_1 = c_t   # [batch_size, hidden_dim]

        # Concatenate along seq_len dimension
        return (
            torch.concat(output_h, dim=1),  # [batch_size, seq_len, hidden_dim] --> we concat them to restore the dimension of x
            torch.concat(output_c, dim=1)   # [batch_size, seq_len, hidden_dim]
        )


In [3]:
class MultiLayerLSTM(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        super(MultiLayerLSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # ---------------------------------------------------------
        # We build the LSTM stack manually.
        # Layer 0 takes input_dim ‚Üí hidden_dim.
        # All higher layers take hidden_dim ‚Üí hidden_dim.
        # ---------------------------------------------------------
        self.layers = nn.ModuleList()

        # First layer (input data enters here)
        self.layers.append(LSTMCell(input_dim, hidden_dim))

        # Remaining layers (receive output from previous layer)
        for _ in range(num_layers - 1):
            self.layers.append(LSTMCell(hidden_dim, hidden_dim))

        # Dropout between layers (NOT inside time steps)
        self.dropout = nn.Dropout(dropout)

        # Final linear projection to bring hidden_dim ‚Üí input_dim
        # (like predicting next token embedding)
        self.proj = nn.Linear(hidden_dim, input_dim)
        nn.init.xavier_uniform_(self.proj.weight)
        self.proj.bias.data.fill_(0.0)


    def forward(self, x, state):
        """
        x:
            Shape: [batch_size, seq_len, input_dim]
            Meaning: entire input sequence for ALL time steps.

        state = (h, c)
            h: [num_layers, batch_size, hidden_dim]
            c: [num_layers, batch_size, hidden_dim]

        Intuition:
        - You have a stack of LSTM layers.
        - Each layer has its OWN separate (h0, c0).
        - Each layer processes the whole sequence.
        - The output of layer L becomes the input to layer L+1.
        """

        # Unpack previous hidden & cell states for ALL layers
        h_prev, c_prev = state


        # =========================================================
        # -------------------- LAYER 0 -----------------------------
        # This layer receives the original input sequence x.
        # We pass x and the initial states for layer 0.
        # h_out, c_out shapes = [batch, seq_len, hidden_dim]
        # =========================================================
        h_out, c_out = self.layers[0](x, h_prev[0], c_prev[0])

        # We collect ONLY the final hidden & cell states (from last timestep)
        # because next forward() call needs these to continue sequence.
        h_list = [h_out[:, -1].unsqueeze(0)]  # shape: [1, batch, hidden_dim]
        c_list = [c_out[:, -1].unsqueeze(0)]  # shape: [1, batch, hidden_dim]


        # =========================================================
        # --------------- REMAINING LAYERS -------------------------
        # Each layer receives:
        #   - The entire output sequence from previous layer
        #   - Its own (h0, c0)
        #
        # Why dropout here?
        #   - Dropout is applied BETWEEN layers,
        #   - NOT between timesteps (would break time-dependency).
        # =========================================================
        for layer_idx in range(1, self.num_layers):

            # Dropout prevents layers from depending too heavily
            # on the exact output of the previous layer.
            dropped = self.dropout(h_out)

            # Feed dropped sequence into next layer with its own states
            h_out, c_out = self.layers[layer_idx](
                dropped,
                h_prev[layer_idx],   # each layer has its OWN memory
                c_prev[layer_idx]
            )

            # Collect last timestep state of this layer
            h_list.append(h_out[:, -1].unsqueeze(0))
            c_list.append(c_out[:, -1].unsqueeze(0))


        # =========================================================
        # -------- FINAL OUTPUT PROJECTION (hidden ‚Üí input) -------
        # Applied only on the final layer output.
        #
        # h_out shape: [batch, seq_len, hidden_dim]
        # After projection:
        # logits shape: [batch, seq_len, input_dim]
        # =========================================================
        logits = self.proj(self.dropout(h_out))


        # =========================================================
        # -------- RETURN UPDATED STATES FOR NEXT CALL -------------
        # h_new, c_new:
        #     [num_layers, batch, hidden_dim]
        #
        # These are the "final" states of each layer‚Äôs last timestep.
        # Perfect for continuing sequence or training.
        # =========================================================
        h_new = torch.cat(h_list, dim=0)
        c_new = torch.cat(c_list, dim=0)

        return logits, (h_new, c_new)
        # It is ‚Äúone-to-one‚Äù calling for logits (output length = input length)


**These `h_new` and `c_new` are *not* per-timestep outputs.**  
They summarize each layer at the **last timestep** of the sequence.


They are returned so that:

- you can continue the sequence later,  
- keep the states between batches, or  
- use them for autoregressive decoding.

But they do \textbf{not} change the fact that your \texttt{logits} are length-\(T\) sequences.


# Understanding `h_new` and `c_new` in a Multi-Layer LSTM

---

## ‚úî What `h_new` and `c_new` Actually Mean  

You have a **multi-layer LSTM**, so each layer has its **own hidden state (h)** and **cell state (c)**.

Each layer produces its final states at the **last timestep**:

- **Layer 0:** `h0_final`, `c0_final`  
- **Layer 1:** `h1_final`, `c1_final`  
- **Layer 2:** `h2_final`, `c2_final`  
- ‚Ä¶  
- **Layer L‚àí1:** `hL‚àí1_final`, `cL‚àí1_final`

These represent the LSTM‚Äôs memory **after finishing the entire sequence**.

## ‚úî Summary (Very Short)

`h_new` and `c_new` =  
**‚ÄúThe final hidden and cell states (last timestep) of EVERY layer, stacked into a single tensor.‚Äù**

These become the starting states for next forward call for next batch of data:

```python
forward(x_next, (h_new, c_new))
````

---

---

## ‚úî Visual Intuition  

Imagine you have **3 layers** (`L=3`):

Final states at the **last timestep**:

```

Layer 0: h0 ‚Üí shape [B, H]
Layer 1: h1 ‚Üí shape [B, H]
Layer 2: h2 ‚Üí shape [B, H]

```

After applying `unsqueeze(0)`:

```

Layer 0: [1, B, H]
Layer 1: [1, B, H]
Layer 2: [1, B, H]

```

After stacking with `torch.cat(..., dim=0)`:

```

h_new =
[
[layer0_state]
[layer1_state]
[layer2_state]
]

Result shape = [3, B, H]

````

So `h_new` and `c_new` are simply:

> All final states of all layers stacked into one tensor.

---

## ‚úî Why Only the Last Timestep?

Because LSTM "memory" is defined as:

- the hidden state at the **final timestep**, and  
- the cell state at the **final timestep**

We only carry:

- **last `h_t`**
- **last `c_t`**

for each layer.

These final states are needed for:

- continuation into next batch chunks  
- autoregressive generation  
- teacher forcing  
- inference continuation  

---


# Understanding ‚ÄúNext Forward Call‚Äù

---

## ‚úî Meaning of ‚Äúnext forward() call‚Äù

The **next time the model runs its forward function**, NOT a new epoch.

There are three common cases:

---

## ‚úî Case 1 ‚Äî Next Batch in the Same Epoch

```python
logits, (h_new, c_new) = model(x_batch, (h_prev, c_prev))
```

Next batch:

```python
logits, (h_new, c_new) = model(x_next_batch, (h0, c0))
# usually reset to zeros
```

This is a **new forward call**, but **not a new epoch**.

---

## ‚úî Case 2 ‚Äî Sequence Continuation (autoregressive or long sequence)

```
chunk 1 ‚Üí forward()
chunk 2 ‚Üí forward()
chunk 3 ‚Üí forward()
```

We pass states forward:

```python
logits, (h_new, c_new) = model(chunk1, (h0, c0))
logits, (h_new, c_new) = model(chunk2, (h_new, c_new))
logits, (h_new, c_new) = model(chunk3, (h_new, c_new))
```

Here:

> "Next forward call" = **next chunk of the SAME sequence**.

Still **not** a new epoch.

---

## ‚úî Case 3 ‚Äî Inference (Text Generation)

```
Step 1 ‚Üí forward()
Step 2 ‚Üí forward()
Step 3 ‚Üí forward()
```

Each generation step is one forward call.

---

## ‚úî Summary (Simple)

**‚ÄúNext forward() call‚Äù means:**
The next time the LSTM processes data through its `forward()` method.

**It does NOT mean ‚Äúnext epoch.‚Äù**

---

# ‚úî When Is ‚ÄúNext Epoch‚Äù?

A new epoch happens only after ALL batches are processed once:

```
Epoch 1:
    forward(batch 1)
    forward(batch 2)
    ...
    forward(batch N)

Epoch 2:
    forward(batch 1)
    forward(batch 2)
    ...
```

Between epochs, models usually **reset hidden states to zero**
(Unless training a stateful LSTM).

---

```
```


In [8]:
tst3 = torch.randint(0, 9, (2,3,4)) # 2 matrix of (3, 4)
tst3

tensor([[[2, 2, 4, 4],
         [7, 6, 3, 8],
         [1, 1, 1, 5]],

        [[8, 8, 4, 8],
         [3, 6, 0, 4],
         [5, 1, 4, 4]]])

In [11]:
tst3[:, 1, :]

tensor([[7, 6, 3, 8],
        [3, 6, 0, 4]])

In [7]:
tst2 = torch.randint(0, 9, (3, 4))
tst2

tensor([[4, 4, 1, 7],
        [4, 3, 0, 7],
        [6, 8, 7, 4]])