In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time


def g(x: torch.Tensor) -> torch.Tensor:
    return torch.where(x >= 0, x + 0.5, torch.sigmoid(x))

def log_g(x: torch.Tensor) -> torch.Tensor:
    return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))


def parallel_scan_log(log_coeffs: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
    """
    Parallel scan in log-space for h_t = a_t * h_{t-1} + b_t.
    log_coeffs: (batch, seq_len, hidden)
    log_values: (batch, seq_len+1, hidden), with log_values[:,0] = log_h0.
    Returns h: (batch, seq_len+1, hidden)
    """
    print(f"[parallel_scan_log] log_coeffs shape: {log_coeffs.shape}, log_values shape: {log_values.shape}")
    # a_star = cumulative sum of log_coeffs, padded at start
    a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0))  # (batch, seq_len+1, hidden)
    print(f"[parallel_scan_log] a_star shape: {a_star.shape}")
    # log_cumsum_exp of (log_values - a_star)
    log_h0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=1)
    print(f"[parallel_scan_log] log_h0_plus_b_star shape: {log_h0_plus_b_star.shape}")
    # recover log_h and exponentiate
    log_h = a_star + log_h0_plus_b_star
    h = torch.exp(log_h)
    print(f"[parallel_scan_log] output h shape: {h.shape}")
    return h


class ParallelLogMinGRULayer(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.linear_z = nn.Linear(input_size, hidden_size)
        self.linear_h = nn.Linear(input_size, hidden_size)

    def forward(self, x: torch.Tensor, h0: torch.Tensor) -> torch.Tensor:
        print(f"[GRU Layer] input x shape: {x.shape}, h0 shape: {h0.shape}")
        k = self.linear_z(x)
        log_z = -F.softplus(-k)
        log_coeffs = -F.softplus(k)
        print(f"[GRU Layer] mean log_z: {log_z.mean().item():.4f}, mean log_coeffs: {log_coeffs.mean().item():.4f}")
        log_h0 = log_g(h0.squeeze(1)).unsqueeze(1)
        log_tilde_h = log_g(self.linear_h(x))
        log_values = torch.cat([log_h0, log_z + log_tilde_h], dim=1)
        print(f"[GRU Layer] log_values shape: {log_values.shape}")
        h_full = parallel_scan_log(log_coeffs, log_values)
        out = h_full[:, 1:, :]
        print(f"[GRU Layer] output shape: {out.shape}")
        return out


class ParallelLogMinLSTMLayer(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.linear_f = nn.Linear(input_size, hidden_size)
        self.linear_i = nn.Linear(input_size, hidden_size)
        self.linear_h = nn.Linear(input_size, hidden_size)

    def forward(self, x: torch.Tensor, h0: torch.Tensor) -> torch.Tensor:
        print(f"[LSTM Layer] input x shape: {x.shape}, h0 shape: {h0.shape}")
        p = self.linear_f(x)
        k = self.linear_i(x)
        diff = F.softplus(-p) - F.softplus(-k)
        log_f = -F.softplus(diff)
        log_i = -F.softplus(-diff)
        print(f"[LSTM Layer] mean log_f: {log_f.mean().item():.4f}, mean log_i: {log_i.mean().item():.4f}")
        log_h0 = log_g(h0.squeeze(1)).unsqueeze(1)
        log_tilde_h = log_g(self.linear_h(x))
        log_values = torch.cat([log_h0, log_i + log_tilde_h], dim=1)
        print(f"[LSTM Layer] log_values shape: {log_values.shape}")
        h_full = parallel_scan_log(log_f, log_values)
        out = h_full[:, 1:, :]
        print(f"[LSTM Layer] output shape: {out.shape}")
        return out


class MultiLayerParallelLogRNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int, rnn_type: str = "gru"):
        super().__init__()
        assert rnn_type in {"gru", "lstm"}
        self.layers = nn.ModuleList()
        for layer in range(num_layers):
            in_size = input_size if layer == 0 else hidden_size
            if rnn_type == "gru":
                self.layers.append(ParallelLogMinGRULayer(in_size, hidden_size))
            else:
                self.layers.append(ParallelLogMinLSTMLayer(in_size, hidden_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch = x.size(0)
        out = x
        for idx, layer in enumerate(self.layers):
            print(f"[MultiLayer] Layer {idx} start, input shape: {out.shape}")
            h0_size = layer.linear_z.out_features if hasattr(layer, 'linear_z') else layer.linear_f.out_features
            h0 = torch.zeros(batch, 1, h0_size, device=x.device)
            out = layer(out, h0)
            print(f"[MultiLayer] Layer {idx} end, output shape: {out.shape}")
        return out


# Example usage
if __name__ == '__main__':
    input_size, hidden_size, seq_len = 16, 32, 50
    batch, classes = 64, 4
    X = torch.randn(batch, seq_len, input_size)
    h0 = torch.zeros(batch, 1, hidden_size)
    model = MultiLayerParallelLogRNN(input_size, hidden_size, num_layers=2, rnn_type='gru')
    out = model(X)
    print(f"Final output shape: {out.shape}")



[MultiLayer] Layer 0 start, input shape: torch.Size([64, 50, 16])
[GRU Layer] input x shape: torch.Size([64, 50, 16]), h0 shape: torch.Size([64, 1, 32])
[GRU Layer] mean log_z: -0.7332, mean log_coeffs: -0.7444
[GRU Layer] log_values shape: torch.Size([64, 51, 32])
[parallel_scan_log] log_coeffs shape: torch.Size([64, 50, 32]), log_values shape: torch.Size([64, 51, 32])
[parallel_scan_log] a_star shape: torch.Size([64, 51, 32])
[parallel_scan_log] log_h0_plus_b_star shape: torch.Size([64, 51, 32])
[parallel_scan_log] output h shape: torch.Size([64, 51, 32])
[GRU Layer] output shape: torch.Size([64, 50, 32])
[MultiLayer] Layer 0 end, output shape: torch.Size([64, 50, 32])
[MultiLayer] Layer 1 start, input shape: torch.Size([64, 50, 32])
[GRU Layer] input x shape: torch.Size([64, 50, 32]), h0 shape: torch.Size([64, 1, 32])
[GRU Layer] mean log_z: -0.6971, mean log_coeffs: -0.7421
[GRU Layer] log_values shape: torch.Size([64, 51, 32])
[parallel_scan_log] log_coeffs shape: torch.Size([64, 