In [34]:
from torch import nn

class Encoder(nn.Module):
    def __init__(self, hidden_dim = 128, input_dim = 1):
        super().__init__()
        # input_dim is the size of each element’s feature vector—what the encoder LSTM sees at every time-step.
        # In the little “sort a list of real numbers” demo we built, each token is just a single scalar (e.g., 0.42).
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
        # With batch_first=True: The input and output tensors are expected to be of shape (batch, seq_len, input_size). 
        # With bidirectional=True: The LSTM processes the input sequence in both forward and backward directions. 
        #   This means for each time step, the output contains information from both past and future contexts.
        # The output of the bidirectional LSTM is twice the hidden dimension.

    def forward(self, x):
        x = x.unsqueeze(-1) # add feature dimension -> (B, T, 1)
        # print(x)
        # x is the input to the LSTM, which is a tensor of shape (batch_size, seq_len, input_size).
        # LSTM expects a 3D input tensor of shape (batch_size, seq_len, input_size).
        h, _ = self.lstm(x)
        # h is the output of the LSTM, which is a tensor of shape (batch_size, seq_len, hidden_dim * 2).
        # The final layer of the encoder is a linear layer that maps the output of the LSTM to a vector of size hidden_dim.
        return h

class Decoder(nn.Module):
    def __init__(self, hidden_dim = 128):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.lstm_cell = nn.LSTMCell(input_size=hidden_dim * 2, hidden_size=hidden_dim)
        self.W1 = nn.Linear(2 * hidden_dim, hidden_dim, bias=False) # input is bidirectional
        self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, enc_out, targets=None):
        """
        Parameters
        ----------
        enc_out : (B, T, 2H)
        targets  : (B, T) or None  (teacher forcing indices)

        B = batch size
        T = sequence length
        H = hidden dimension

        Returns
        -------
        logits : (B, T, T)  --- unnormalised pointer scores
        """

        B, T, _ = enc_out.size()
        h_t = torch.zeros(B, self.hidden_dim, device=enc_out.device)
        c_t = torch.zeros_like(h_t)

        logits = []
        mask = torch.zeros(B, T, device=enc_out.device)

        enc_proj = self.W1(enc_out) # (B, T, H)
        enc_out_idx = torch.arange(B, device=enc_out.device) # (B) This creates a tensor [0, 1, ..., B-1] to index each batch.

        for i in range(T):
            if i == 0:
                # first time step, use the mean of the encoder outputs
                ctx = enc_out.mean(dim=1) # (B, 2H)
            else:
                # subsequent time steps, use the previously chosen embedding as context
                idx = targets[:, i-1] if targets is not None else prev_idx
                ctx = enc_out[enc_out_idx, idx] # (B, H) For each item in the batch, select the encoder output at the position given by idx.
                                                #        This gives you a context vector for each batch item, based on the previously selected position.
            
            # LSTMCell
            # Inputs: input, (h_0, c_0)
            # input of shape (batch, input_size) or (input_size): tensor containing input features
            # h_0 of shape (batch, hidden_size) or (hidden_size): tensor containing the initial hidden state
            # c_0 of shape (batch, hidden_size) or (hidden_size): tensor containing the initial cell state

            # Outputs: (h_1, c_1)
            # h_1 of shape (batch, hidden_size) or (hidden_size): tensor containing the next hidden state
            # c_1 of shape (batch, hidden_size) or (hidden_size): tensor containing the next cell state

            # current decoder hidden- and cell-state BEFORE we look at position i
            h_t, c_t = self.lstm_cell(ctx, (h_t, c_t))   # (B, H)
            W2_proj = self.W2(h_t).unsqueeze(1) # (B, 1, H)
            tanh = torch.tanh(enc_proj + W2_proj) # (B, T, H)
            u_i = self.v(tanh).squeeze(-1) # (B, T)
            u_i = u_i - 1e9 * mask
            # print ("u_i", u_i)
            logits.append(u_i)

            prev_idx = torch.argmax(u_i, dim=-1)
            # print ("prev_idx", prev_idx)
            mask[enc_out_idx, prev_idx] = 1 # mask the selected index for future predictions

        # In this context, torch.stack(logits, dim=1) is used to combine a list of tensors (each representing the logits at a different decoding step) into a single tensor with a new dimension.
        # If you have T decoding steps, and each u_i is shape (B, T), then after stacking, you get a tensor of shape (B, T, T).
        logits = torch.stack(logits, dim=1) # (B, T, T)
        return logits

class PointerNetwork(nn.Module):
    def __init__(self, hidden_dim = 128):
        super().__init__()
        self.encoder = Encoder(hidden_dim)
        self.decoder = Decoder(hidden_dim)

    def forward(self, x, targets=None):
        enc_out = self.encoder(x) # (B, T, 2H)
        # print ("inputs", x)
        # print ("targets", targets)
        logits = self.decoder(enc_out, targets) # (B, T, T)
        return logits    

In [35]:
# --------------------------
# Synthetic dataset helpers
# --------------------------
def gen_batch(batch_sz, seq_len = 5):
    """
    Returns
    -------
    inputs : (B, T)  float32  -- unsorted numbers
    targets: (B, T)  long     -- permutation (indices) that would sort each row ascending
    """
    inputs = torch.rand(batch_sz, seq_len)
    targets = torch.argsort(inputs, dim=1)  # ascending order indices
    return inputs.to(DEVICE), targets.to(DEVICE)



In [37]:
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

import torch.nn.functional as F

if True:
    # --------------------------
    # Model
    # --------------------------
    model = PointerNetwork(hidden_dim=128).to(DEVICE)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)

    steps = 10000
    batch_sz = 32
    seq_len = 5

    for step in range(1, steps + 1):
        # print ("step", step)
        model.train() 
        inputs, targets = gen_batch(batch_sz, seq_len)
        logits = model(inputs, targets)
        logits = logits.view(-1, seq_len) # (B*T, T) samples by logits
        targets = targets.view(-1) # (B*T) by class
        # print ("logits", logits, logits.shape)
        # print ("targets", targets, targets.shape)
        loss = F.cross_entropy(logits, targets)
        # print ("loss", loss)
        optim.zero_grad()
        loss.backward()
        optim.step()

        if step % 100 == 0:
            with torch.no_grad():
                model.eval()
                val_inp, val_tgt = gen_batch(batch_sz, seq_len)
                val_logits = model(val_inp)  # no teacher forcing
                preds = val_logits.argmax(-1)
                accuracy = (preds == val_tgt).float().mean().item()
                print(
                    f"step {step:>4} | loss {loss.item():.3f} | val acc {accuracy*100:5.1f}%"
                )

# inputs, targets, logits

step  100 | loss 412499968.000 | val acc  17.5%
step  200 | loss 400000000.000 | val acc  18.8%
step  300 | loss 431249984.000 | val acc  25.6%
step  400 | loss 406250016.000 | val acc  11.9%
step  500 | loss 399999936.000 | val acc  24.4%
step  600 | loss 350000000.000 | val acc  20.0%
step  700 | loss 449999968.000 | val acc  20.0%
step  800 | loss 418750016.000 | val acc  25.0%
step  900 | loss 387500000.000 | val acc  24.4%


KeyboardInterrupt: 