In [2]:
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# Triton imports for parallel scan
import triton
import triton.language as tl

# --- Triton kernel for prefix-scan: h_t = a_t * h_{t-1} + b_t ---
@triton.jit
def parallel_scan_kernel(
    a_ptr, b_ptr, h0_ptr, out_ptr,
    seq_len: tl.constexpr, stride_seq: tl.constexpr, stride_s: tl.constexpr
):
    pid = tl.program_id(0)
    # initial hidden state
    h = tl.load(h0_ptr + pid)
    for t in range(seq_len):
        off = t * stride_seq + pid * stride_s
        a_t = tl.load(a_ptr + off)
        b_t = tl.load(b_ptr + off)
        h = a_t * h + b_t
        tl.store(out_ptr + off, h)

# Python prefix-scan (for backward)
def python_scan(a: torch.Tensor, b: torch.Tensor, h0: torch.Tensor) -> torch.Tensor:
    P = torch.cumprod(a, dim=0)
    P_inv = 1.0 / P
    C = torch.cumsum(b * P_inv, dim=0)
    return P * (h0.unsqueeze(0) + C)

# Autograd-enabled Triton scan
class _TritonScanFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b, h0):
        # Save tensors for backward
        ctx.save_for_backward(a, b, h0)
        T, B, H = a.shape
        S = B * H
        # Flatten tensors for Triton
        a_flat = a.reshape(T, S).contiguous()
        b_flat = b.reshape(T, S).contiguous()
        out_flat = torch.empty_like(a_flat)
        h0_flat = h0.reshape(S).contiguous()
        # Strides for pointer arithmetic
        stride_seq, stride_s = a_flat.stride()
        # Launch Triton kernel
        parallel_scan_kernel[(S,)](
            a_flat, b_flat, h0_flat, out_flat,
            T, stride_seq, stride_s
        )
        # Reshape back to (T, B, H)
        return out_flat.view(T, B, H)

    @staticmethod
    def backward(ctx, grad_output):
        a, b, h0 = ctx.saved_tensors
        # Compute gradients for a and b using Python scan under grad
        with torch.enable_grad():
            a_ = a.detach().requires_grad_(True)
            b_ = b.detach().requires_grad_(True)
            H_py = python_scan(a_, b_, h0)
        # Compute gradients
        grad_a, grad_b = torch.autograd.grad(
            outputs=H_py,
            inputs=(a_, b_),
            grad_outputs=grad_output,
            retain_graph=False,
            allow_unused=True
        )
        return grad_a, grad_b, None

def triton_scan(a: torch.Tensor, b: torch.Tensor, h0: torch.Tensor) -> torch.Tensor:
    """Wrapper for Triton-enabled prefix scan"""
    return _TritonScanFunction.apply(a, b, h0)

# --- Toy dataset generation ---
def make_toy_data(seq_len=10, n_samples=500, input_size=16, n_classes=2):
    X = torch.randn(seq_len, n_samples, input_size)
    y = (X[0].sum(dim=1) > 0).long()
    dataset = TensorDataset(X.permute(1,0,2), y)
    def collate(batch):
        xs, ys = zip(*batch)
        xs = torch.stack(xs).permute(1,0,2)
        ys = torch.tensor(ys)
        return xs, ys
    return DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)

loader = make_toy_data()

# --- Custom MinGRU layer ---
class MinGRULayer(nn.Module):
    def __init__(self, in_dim, hidden_size, dropout=0.1, is_last=False):
        super().__init__()
        self.W_z = nn.Linear(in_dim, hidden_size)
        self.W_h = nn.Linear(in_dim, hidden_size)
        self.b_h = nn.Parameter(torch.zeros(hidden_size))
        self.ln = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout) if not is_last else None

    def forward(self, inp, h_prev):
        z = torch.sigmoid(self.W_z(inp))
        h_tilde = torch.tanh(self.W_h(inp) + self.b_h)
        h = (1 - z) * h_prev + z * h_tilde
        h = self.ln(h)
        if self.dropout:
            h = self.dropout(h)
        return h

# --- Model definitions ---
class StandardMultiLayerGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, dropout=0.1):
        super().__init__()
        self.cells = nn.ModuleList([
            nn.GRUCell(input_size if i==0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
        self.lns = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        T, B, _ = x.size()
        h = [torch.zeros(B, self.cells[0].hidden_size, device=x.device) for _ in self.cells]
        for t in range(T):
            inp = x[t]
            for i, cell in enumerate(self.cells):
                h[i] = cell(inp, h[i])
                h[i] = self.lns[i](h[i])
                if i < len(self.cells) - 1:
                    h[i] = self.dropout(h[i])
                inp = h[i]
        return self.fc(h[-1])

class MultiLayerMinGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, dropout=0.1):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_dim = input_size if i==0 else hidden_size
            layers.append(MinGRULayer(in_dim, hidden_size, dropout, is_last=(i==num_layers-1)))
        self.layers = nn.ModuleList(layers)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        T, B, _ = x.size()
        h = [torch.zeros(B, self.layers[0].W_z.out_features, device=x.device) for _ in self.layers]
        for t in range(T):
            inp = x[t]
            for i, layer in enumerate(self.layers):
                h[i] = layer(inp, h[i])
                inp = h[i]
        return self.fc(h[-1])

class ParallelScanMinGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, dropout=0.1):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_dim = input_size if i==0 else hidden_size
            layers.append(
                nn.ModuleDict({
                    'W_z': nn.Linear(in_dim, hidden_size),
                    'W_h': nn.Linear(in_dim, hidden_size),
                    'ln': nn.LayerNorm(hidden_size)
                })
            )
        self.layers = nn.ModuleList(layers)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        T, B, _ = x.size()
        for i, layer in enumerate(self.layers):
            W_z, W_h, ln = layer['W_z'], layer['W_h'], layer['ln']
            x_flat = x.reshape(T*B, -1).contiguous()
            z = torch.sigmoid(W_z(x_flat)).view(T, B, -1)
            a = 1 - z
            h_tilde = torch.tanh(W_h(x_flat)).view(T, B, -1)
            b = z * h_tilde
            h0 = torch.zeros(B, z.size(2), device=x.device)
            H = triton_scan(a, b, h0)
            H = ln(H)
            if i < len(self.layers) - 1:
                H = self.dropout(H)
            x = H
        return self.fc(x[-1])

# --- Training & evaluation ---
def train_and_eval(model, loader, epochs=20, lr=1e-3, device='cuda'):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss()
    start = time.time()
    for _ in range(epochs):
        model.train()
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = crit(logits, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
    train_time = time.time() - start
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return train_time, correct/total

# --- Benchmarking ---
if __name__ == '__main__':
    INPUT_DIM, HIDDEN, LAYERS, CLASSES = 16, 32, 4, 2
    models = {
        'StandardGRU': StandardMultiLayerGRU(INPUT_DIM, HIDDEN, LAYERS, CLASSES),
        'LoopMinGRU':  MultiLayerMinGRU(INPUT_DIM, HIDDEN, LAYERS, CLASSES),
        'ParScanGRU':  ParallelScanMinGRU(INPUT_DIM, HIDDEN, LAYERS, CLASSES),
    }
    loader = make_toy_data()
    for name, m in models.items():
        t, acc = train_and_eval(m, loader)
        print(f"{name:12s} → time: {t:.2f}s, accuracy: {acc*100:5.2f}%")


StandardGRU  → time: 6.91s, accuracy: 99.80%
LoopMinGRU   → time: 10.42s, accuracy: 100.00%
ParScanGRU   → time: 1.70s, accuracy: 97.00%
