In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

def demonstrate_packing():
    print("=== PACKED SEQUENCE INTERNALS ===\n")

    # Create sample sequences with different lengths
    sequences = [
        torch.tensor([[1.0], [2.0], [3.0]]),           # length 3
        torch.tensor([[4.0], [5.0]]),                  # length 2
        torch.tensor([[6.0], [7.0], [8.0], [9.0]])     # length 4
    ]

    # Pad them normally
    from torch.nn.utils.rnn import pad_sequence
    padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    lengths = torch.tensor([3, 2, 4])

    print("1. ORIGINAL PADDED DATA:")
    print("Shape:", padded.shape)  # [3, 4, 1] - batch_size, max_len, features
    print("Data:")
    for i, seq in enumerate(padded):
        print(f"  Batch {i}: {seq.squeeze().tolist()} (length: {lengths[i]})")
    print()

    # Pack the sequence
    packed = pack_padded_sequence(padded, lengths, batch_first=True, enforce_sorted=False)

    print("2. AFTER PACKING:")
    print("Packed data shape:", packed.data.shape)  # [9, 1] - total_elements, features
    print("Packed batch_sizes:", packed.batch_sizes.tolist())
    print("Packed sorted_indices:", packed.sorted_indices.tolist())
    print("Packed unsorted_indices:", packed.unsorted_indices.tolist())
    print()

    print("3. HOW THE DATA IS REORGANIZED:")
    print("Original padded layout:")
    print("  Timestep 0: [1.0, 4.0, 6.0]  <- all 3 sequences")
    print("  Timestep 1: [2.0, 5.0, 7.0]  <- all 3 sequences")
    print("  Timestep 2: [3.0, 0.0, 8.0]  <- 2 sequences + padding")
    print("  Timestep 3: [0.0, 0.0, 9.0]  <- 1 sequence + padding")
    print()

    print("Packed layout (no padding!):")
    data_flat = packed.data.squeeze().tolist()
    print(f"  Packed data: {data_flat}")

    print("\nHow to read packed data:")
    idx = 0
    for t, batch_size in enumerate(packed.batch_sizes):
        timestep_data = data_flat[idx:idx+batch_size]
        print(f"  Timestep {t}: {timestep_data} <- {batch_size} active sequences")
        idx += batch_size
    print()

    return packed, padded, lengths

def demonstrate_lstm_processing():
    print("4. LSTM PROCESSING COMPARISON:\n")

    packed, padded, lengths = demonstrate_packing()

    # Create LSTM
    lstm = nn.LSTM(input_size=1, hidden_size=2, batch_first=True)

    # Method 1: Process padded data (includes padding computation)
    print("Method 1 - Padded processing:")
    with torch.no_grad():
        padded_output, _ = lstm(padded)
    print(f"  Padded output shape: {padded_output.shape}")
    print(f"  Total computations: {padded_output.shape[0] * padded_output.shape[1]} timesteps")
    print("  (Includes wasted computation on padding!)")
    print()

    # Method 2: Process packed data (no padding computation)
    print("Method 2 - Packed processing:")
    with torch.no_grad():
        packed_output, _ = lstm(packed)
    print(f"  Packed output shape: {packed_output.data.shape}")
    print(f"  Total computations: {packed_output.data.shape[0]} timesteps")
    print("  (Only real data, no padding!)")
    print()

    # Unpack to compare
    unpacked_output, _ = pad_packed_sequence(packed_output, batch_first=True)
    print("After unpacking:")
    print(f"  Unpacked shape: {unpacked_output.shape}")
    print()

    print("5. COMPUTATION SAVINGS:")
    padded_ops = padded.shape[0] * padded.shape[1]  # batch * max_len
    packed_ops = packed.data.shape[0]  # total real elements
    savings = (padded_ops - packed_ops) / padded_ops * 100
    print(f"  Padded approach: {padded_ops} operations")
    print(f"  Packed approach: {packed_ops} operations")
    print(f"  Savings: {savings:.1f}%")

# Run the demonstration
demonstrate_lstm_processing()

4. LSTM PROCESSING COMPARISON:

=== PACKED SEQUENCE INTERNALS ===

1. ORIGINAL PADDED DATA:
Shape: torch.Size([3, 4, 1])
Data:
  Batch 0: [1.0, 2.0, 3.0, 0.0] (length: 3)
  Batch 1: [4.0, 5.0, 0.0, 0.0] (length: 2)
  Batch 2: [6.0, 7.0, 8.0, 9.0] (length: 4)

2. AFTER PACKING:
Packed data shape: torch.Size([9, 1])
Packed batch_sizes: [3, 3, 2, 1]
Packed sorted_indices: [2, 0, 1]
Packed unsorted_indices: [1, 2, 0]

3. HOW THE DATA IS REORGANIZED:
Original padded layout:
  Timestep 0: [1.0, 4.0, 6.0]  <- all 3 sequences
  Timestep 1: [2.0, 5.0, 7.0]  <- all 3 sequences
  Timestep 2: [3.0, 0.0, 8.0]  <- 2 sequences + padding
  Timestep 3: [0.0, 0.0, 9.0]  <- 1 sequence + padding

Packed layout (no padding!):
  Packed data: [6.0, 1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 9.0]

How to read packed data:
  Timestep 0: [6.0, 1.0, 4.0] <- 3 active sequences
  Timestep 1: [7.0, 2.0, 5.0] <- 3 active sequences
  Timestep 2: [8.0, 3.0] <- 2 active sequences
  Timestep 3: [9.0] <- 1 active sequences

Meth