In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SimpleChunkScan(nn.Module):
    def __init__(self, chunk_size, nheads, headdim, dstate):
        super(SimpleChunkScan, self).__init__()
        self.chunk_size = chunk_size
        self.nheads = nheads
        self.headdim = headdim
        self.dstate = dstate

    def forward(self, x, dt, A, B, C, D=None, z=None, dt_bias=None, initial_states=None):
        batch, seqlen, nheads, headdim = x.shape
        nchunks = (seqlen + self.chunk_size - 1) // self.chunk_size

        # Initialize states and outputs
        states = torch.zeros(batch, nheads, nchunks, headdim, self.dstate, device=x.device, dtype=x.dtype)
        out = torch.zeros_like(x)
        
        # Perform chunked scan
        for chunk in range(nchunks):
            start_idx = chunk * self.chunk_size
            end_idx = min((chunk + 1) * self.chunk_size, seqlen)
            chunk_len = end_idx - start_idx

            x_chunk = x[:, start_idx:end_idx, :, :]
            dt_chunk = dt[:, start_idx:end_idx, :]
            B_chunk = B[:, start_idx:end_idx, :, :]
            C_chunk = C[:, start_idx:end_idx, :, :]
            
            if D is not None:
                D_chunk = D[:, :]
            else:
                D_chunk = None

            if dt_bias is not None:
                dt_chunk += dt_bias.unsqueeze(0).unsqueeze(0)

            if initial_states is not None and chunk == 0:
                state = initial_states
            else:
                state = states[:, :, chunk - 1, :, :]

            for t in range(chunk_len):
                # Compute new state
                dt_t = torch.exp(dt_chunk[:, t, :])
                A_t = A.unsqueeze(0).unsqueeze(0)
                B_t = B_chunk[:, t, :, :]
                C_t = C_chunk[:, t, :, :]
                
                state = state * torch.exp(-dt_t.unsqueeze(-1)) + B_t

                # Compute output
                out[:, start_idx + t, :, :] = x_chunk[:, t, :, :] * (A_t + C_t)

                if D_chunk is not None:
                    out[:, start_idx + t, :, :] += x_chunk[:, t, :, :] * D_chunk

                states[:, :, chunk, :, :] = state

                if z is not None:
                    out[:, start_idx + t, :, :] += z[:, start_idx + t, :, :]

        return out

In [3]:
# Example usage
batch_size = 2
seq_len = 10
nheads = 4
headdim = 8
dstate = 16
chunk_size = 5

x = torch.randn(batch_size, seq_len, nheads, headdim)
dt = torch.randn(batch_size, seq_len, nheads)
A = torch.randn(nheads)
B = torch.randn(batch_size, seq_len, nheads, dstate)
C = torch.randn(batch_size, seq_len, nheads, dstate)
D = torch.randn(nheads, headdim)
z = torch.randn(batch_size, seq_len, nheads, headdim)
dt_bias = torch.randn(nheads)

model = SimpleChunkScan(chunk_size, nheads, headdim, dstate)
out = model(x, dt, A, B, C, D, z, dt_bias)
print(out)

RuntimeError: The size of tensor a (8) must match the size of tensor b (4) at non-singleton dimension 2