In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class LRU(nn.Module):
    def __init__(self,in_features,out_features,state_features, rmin=0, rmax=1,max_phase=6.283):
        super().__init__()
        self.out_features=out_features
        self.D=nn.Parameter(torch.randn([out_features,in_features])/math.sqrt(in_features))
        u1=torch.rand(state_features)
        u2=torch.rand(state_features)
        self.nu_log= nn.Parameter(torch.log(-0.5*torch.log(u1*(rmax+rmin)*(rmax-rmin) + rmin**2)))
        self.theta_log= nn.Parameter(torch.log(max_phase*u2))
        Lambda_mod=torch.exp(-torch.exp(self.nu_log))
        self.gamma_log=nn.Parameter(torch.log(torch.sqrt(torch.ones_like(Lambda_mod)-torch.square(Lambda_mod))))
        B_re=torch.randn([state_features,in_features])/math.sqrt(2*in_features)
        B_im=torch.randn([state_features,in_features])/math.sqrt(2*in_features)
        self.B=nn.Parameter(torch.complex(B_re,B_im))
        C_re=torch.randn([out_features,state_features])/math.sqrt(state_features)
        C_im=torch.randn([out_features,state_features])/math.sqrt(state_features)
        self.C=nn.Parameter(torch.complex(C_re,C_im))
        self.state=torch.complex(torch.zeros(state_features),torch.zeros(state_features))

    def forward(self, input,state=None):
        self.state=self.state.to(self.B.device) if state==None else state
        Lambda_mod=torch.exp(-torch.exp(self.nu_log))
        Lambda_re=Lambda_mod*torch.cos(torch.exp(self.theta_log))
        Lambda_im=Lambda_mod*torch.sin(torch.exp(self.theta_log))
        Lambda=torch.complex(Lambda_re,Lambda_im)
        Lambda=Lambda.to(self.state.device)
        gammas=torch.exp(self.gamma_log).unsqueeze(-1).to(self.B.device)
        gammas=gammas.to(self.state.device)
        output=torch.empty([i for i in input.shape[:-1]] +[self.out_features],device=self.B.device)
        #Handle input of (Batches,Seq_length, Input size)
        if input.dim()==3:
            for i,batch in enumerate(input):
                out_seq=torch.empty(input.shape[1],self.out_features)
                for j,step in enumerate(batch):
                    self.state=(Lambda@self.state + gammas* self.B@step.to(dtype= self.B.dtype))
                    out_step= (self.C@self.state).real + self.D@step
                    out_seq[j]=out_step
                self.state=torch.complex(torch.zeros_like(self.state.real),torch.zeros_like(self.state.real))
                output[i]=out_seq
        #Handle input of (Seq_length, Input size)
        if input.dim()==2:
            for i,step in enumerate(input):
                self.state=(Lambda@self.state + gammas* self.B@step.to(dtype= self.B.dtype))
                out_step= (self.C@self.state).real + self.D@step
                output[i]=out_step
            self.state=torch.complex(torch.zeros_like(self.state.real),torch.zeros_like(self.state.real))
        return output

In [2]:
rnn = LRU(in_features=1,
          out_features=32,
          state_features=32)


In [8]:
sequence = torch.randn([32, 100, 1])

In [10]:
%%timeit
output = rnn(sequence)

259 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
compiled_rnn = torch.compile(LRU(in_features=1,
          out_features=32,
          state_features=32))

In [12]:
%%timeit
output = compiled_rnn(sequence)

KeyboardInterrupt: 

In [13]:
import torch
import numpy as np
import time
import triton
import triton.language as tl
from triton.runtime.jit import TensorWrapper, reinterpret

int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
    
def to_triton(x: np.ndarray, device='cuda', dst_type=None):
    t = x.dtype.name
    if t in uint_dtypes:
        signed_type_name = t.lstrip('u')  # e.g. "uint16" -> "int16"
        x_signed = x.astype(getattr(np, signed_type_name))
        return reinterpret(torch.tensor(x_signed, device=device).contiguous(), getattr(tl, t))
    else:
        if dst_type and 'float8' in dst_type:
            return reinterpret(torch.tensor(x, device=device).contiguous(), getattr(tl, dst_type))
        if t == 'float32' and dst_type == 'bfloat16':
            return torch.tensor(x, device=device).contiguous().bfloat16()
        return torch.tensor(x, device=device).contiguous()
    
def to_numpy(x):
    if isinstance(x, TensorWrapper):
        return x.base.cpu().numpy().astype(getattr(np, torch_dtypes(x.dtype)))
    elif isinstance(x, torch.Tensor):
        if x.dtype is torch.bfloat16:
            return x.cpu().float().numpy()
        return x.cpu().numpy()
    else:
        raise ValueError(f"Not a triton-compatible tensor: {x}")
    
if __name__ == "__main__":
    use_gpu = True

    if use_gpu:
        device = torch.device('cuda:0')
    else:
        device = None

    triton_times = []
    loop_times = []
    loop_comp_times = []
    vals_to_compare = []

    op = 'cumsum'
    num_warps = 16
    dtype_str = 'float32'
    axis = 0
    shape = (1024, 1)
    n_timings = 10

    x = np.random.rand(*shape).astype(dtype=np.float32)
    inp = torch.tensor(x, device=device, requires_grad=True, dtype=torch.float32)
    init = torch.zeros(shape[1], 1, device=device, requires_grad=True)
    inp_scan = inp

    @triton.jit
    def sum_op(a, b):
        return a + b

    @triton.jit
    def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
        range_m = tl.arange(0, BLOCK_M)
        range_n = tl.arange(0, BLOCK_N)
        x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
        #tl.device_print("z", x)
        z = tl.associative_scan(x, 0, sum_op)
        #tl.device_print("z", z)
        tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)

    z = np.empty_like(x)
    x_tri = to_triton(x, device=device)
    numpy_op = np.cumsum
    z_dtype_str = dtype_str
    z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
    # triton result
    z_tri = to_triton(z, device=device)
    val = kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
    out_triton = to_numpy(z_tri)
    vals_to_compare.append(out_triton)

    for _ in range(n_timings):
        start = time.time_ns()
        kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
        stop = time.time_ns()
        triton_times.append((stop - start) / (10 ** 9))

    def f(carry, x):
        return carry+x, carry+x

    def _fake_scan(f, init, x):
        zs = []
        carry = init
        for xp in x:
            carry, out = f(carry, xp)
            zs.append(out)
        return carry, torch.stack(zs)

    expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)
    out_loop = expected_ys[:, 0, :]
    vals_to_compare.append(out_loop.cpu().detach().numpy())

    for _ in range(n_timings):
        start = time.time_ns()
        expected_carry_out, expected_ys = _fake_scan(f, init, inp_scan)
        stop = time.time_ns()
        loop_times.append((stop - start) / (10 ** 9))

    _fake_scan_comp = torch.compile(_fake_scan, mode='reduce-overhead', fullgraph=True, dynamic=False)

    #Warm-up cycles
    for _ in range(5):
        expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)

    expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)
    out_loop_comp = expected_ys_comp[:, :, 0]
    vals_to_compare.append(out_loop_comp.cpu().detach().numpy())

    for _ in range(n_timings):
        start = time.time_ns()
        expected_carry_out_comp, expected_ys_comp = _fake_scan_comp(f, init, inp_scan)
        stop = time.time_ns()
        loop_comp_times.append((stop - start) / (10 ** 9))

    #Check all results for deviations
    for ind1, res1 in enumerate(vals_to_compare):
        for ind2, res2 in enumerate(vals_to_compare):
            if not np.allclose(res1, res2):
                print((ind1, res1))
                print((ind2, res2))
                raise Exception('Comparison of ' + str(ind1) + ' with ' + str(ind2) + ' failed!')

    print('Times regular loop ' + str(np.array(loop_times).mean()))
    print('Times compiled loop ' + str(np.array(loop_comp_times).mean()))
    print('Times triton ' + str(np.array(triton_times).mean()))
    print('Script ended')

AttributeError: module 'triton.language' has no attribute 'associative_scan'