In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
from fast_transformers.causal_product import causal_dot_product

import laxtnn.utils.causal_product as causal_product


n = 512
r = 16
b = 8
h = 7

inp = np.random.randn(b,h,n,1)
U = np.random.randn(1,h,n,r)
V = np.random.randn(1,h,n,r)


inp, U, V = [torch.Tensor(x) for x in [inp, U, V]]
U = U.expand((b, -1, -1, -1))
V = V.expand((b,-1,-1,-1))
y = (U @ V.transpose(-1,-2)).tril() @ inp
print(y.shape)

y2 = causal_dot_product(U,V,inp)
print(torch.sqrt(((y-y2)**2).sum()/n))

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([8, 7, 512, 1])
tensor(0.0002)


In [2]:
from einops import rearrange
Uc = U.cuda()
Vc = V.cuda()
inpc = inp.cuda()

Uc = rearrange(Uc, 'b h n r -> (b h) n r').contiguous()
Vc = rearrange(Vc, 'b h n r -> (b h) n r').contiguous()
inpc = rearrange(inpc, 'b h n 1 -> (b h) n 1').contiguous()

# requires double to be accurate!
yt = rearrange((Uc.double() @ Vc.transpose(-1,-2).double()).tril() @ inpc.double(), '(b h) n 1 -> b h n 1', b=b).float()

y3 = rearrange(causal_product.causal_product_naive_cumsum(Uc, Vc, inpc), '(b h) n 1 -> b h n 1', b=b) 


Sc = torch.eye(r, device='cuda')[None].expand((b*h, -1, -1))
print(Uc.shape, Vc.shape, Sc.shape)
y4 = rearrange(causal_product.causal_product_trio(Uc, Sc, Vc, inpc[...,0]), '(b h) n -> b h n 1', b=b)

print(torch.sqrt(((y-yt.cpu())**2).sum()))
print(torch.sqrt(((y-y3.cpu())**2).sum()))
print(torch.sqrt(((y-y4.cpu())**2).sum()))

torch.Size([56, 512, 16]) torch.Size([56, 512, 16]) torch.Size([56, 16, 16])
tensor(0.0031)
tensor(0.0049)
tensor(0.0049)


In [7]:
from einops import rearrange
from laxtnn.utils.toep_mat import ToepMat

Uc = U.cuda()
inpc = inp.cuda()

Uc = rearrange(Uc, 'b h n r -> (b h) n r').contiguous()
inpc = rearrange(inpc, 'b h n 1 -> (b h) n 1').contiguous()

yU = (U @ U.transpose(-1,-2)).tril() @ inp

ac = torch.ones((1, 1), device='cuda').expand((b*h, -1))
Sc = ToepMat(ac, r)
print(torch.sqrt(((Sc@Uc.transpose(-1,-2)-Uc.transpose(-1,-2))**2).sum()))

print(Uc.shape, ac.shape)
y5 = rearrange(causal_product.causal_product_trio_toep(Uc, ac, inpc[...,0]), '(b h) n -> b h n 1', b=b)

print(torch.sqrt(((yU-y5.cpu())**2).sum()))

tensor(0.0001, device='cuda:0')
torch.Size([56, 512, 16]) torch.Size([56, 1])
tensor(0.0053)


## time forward only

In [3]:
%%timeit
torch.cuda.synchronize()
((Uc.double() @ Sc.double() @ Vc.transpose(-1,-2).double()).tril() @ inpc.double()).float()
torch.cuda.synchronize()

2.1 ms ± 310 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
%%timeit
torch.cuda.synchronize()
causal_dot_product(Uc[None],(Vc@Sc.transpose(-1,-2))[None],inpc[None])[0]
torch.cuda.synchronize()

194 µs ± 220 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [5]:
%%timeit
torch.cuda.synchronize()
causal_product.causal_product_trio(Uc, Sc, Vc, inpc[...,0])
torch.cuda.synchronize()

187 µs ± 83.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## backward too

In [6]:
Uc = U.cuda()
Vc = V.cuda()
Sc = torch.eye(r, device='cuda')[None, None].expand((b, h, -1, -1))
inpc = inp.cuda()
z = torch.zeros_like(inpc)
Uc.requires_grad = False
Sc.requires_grad = True
Vc.requires_grad = False
inpc.requires_grad = True

In [7]:
%%timeit
torch.cuda.synchronize()
y = causal_dot_product(Uc, Vc @ Sc.transpose(-1,-2), inpc)
loss = torch.nn.functional.mse_loss(y, z)
loss.backward()
torch.cuda.synchronize()

755 µs ± 210 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [8]:
%%timeit
torch.cuda.synchronize()
y = causal_dot_product(Uc @ Sc, Vc, inpc)
loss = torch.nn.functional.mse_loss(y, z)
loss.backward()
torch.cuda.synchronize()

750 µs ± 290 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [9]:
Uc = U.cuda()
Vc = V.cuda()
Sc = torch.eye(r, device='cuda')[None].expand((b*h, -1, -1))
inpc = inp.cuda()
Uc = rearrange(Uc, 'b h n r -> (b h) n r').contiguous()
Vc = rearrange(Vc, 'b h n r -> (b h) n r').contiguous()
inpc = rearrange(inpc, 'b h n 1-> (b h) n').contiguous()
z = torch.zeros_like(inpc)
Uc.requires_grad = False
Sc.requires_grad = True
Vc.requires_grad = False
inpc.requires_grad = True

In [10]:
%%timeit
torch.cuda.synchronize()
y = causal_product.causal_product_trio(Uc, Sc, Vc, inpc)
loss = torch.nn.functional.mse_loss(y, z)
loss.backward()
torch.cuda.synchronize()

485 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


##### making sure the gradient values are the same

In [11]:
S = torch.rand((1, h, r, r)).expand((b, -1, -1, -1))

Uc = U.cuda()
Vc = V.cuda()
#Sc = torch.eye(r, device='cuda')[None, None].expand((b, h, -1, -1))
Sc = S.cuda()
inpc = inp.cuda()
z = torch.zeros_like(inpc)
Uc.requires_grad = False
Sc.requires_grad = True
Vc.requires_grad = False
inpc.requires_grad = True
y = causal_dot_product(Uc @ Sc, Vc, inpc)
loss = torch.nn.functional.mse_loss(y, z)
loss.backward()

dX = inpc.grad.detach()
dS = Sc.grad.detach()


Uc = U.cuda()
Vc = V.cuda()
#Sc = torch.eye(r, device='cuda')[None, None].expand((b, h, -1, -1))
Sc = S.cuda()
inpc = inp.cuda()
z = torch.zeros_like(inpc)
Uc.requires_grad = False
Sc.requires_grad = True
Vc.requires_grad = False
inpc.requires_grad = True
y = causal_dot_product(Uc, Vc @ Sc.transpose(-1, -2), inpc)
loss = torch.nn.functional.mse_loss(y, z)
loss.backward()


dX1 = inpc.grad.detach()
dS1 = Sc.grad.detach()



Uc = U.cuda()
Vc = V.cuda()
#Sc = torch.eye(r, device='cuda')[None].expand((b*h, -1, -1))
Sc = rearrange(S.cuda(), 'b h r s -> (b h) r s').contiguous()
inpc = inp.cuda()
Uc = rearrange(Uc, 'b h n r -> (b h) n r').contiguous()
Vc = rearrange(Vc, 'b h n r -> (b h) n r').contiguous()
inpc = rearrange(inpc, 'b h n 1-> (b h) n').contiguous()
z = torch.zeros_like(inpc)
Uc.requires_grad = False
Sc.requires_grad = True
Vc.requires_grad = False
inpc.requires_grad = True
y = causal_product.causal_product_trio(Uc, Sc, Vc, inpc)
loss = torch.nn.functional.mse_loss(y, z)
loss.backward()

dX2 = inpc.grad.detach()
dS2 = Sc.grad.detach()

print(torch.sqrt(((dX1-dX)**2).sum()))
print(torch.sqrt(((dS1-dS)**2).sum()))

print(torch.sqrt(((rearrange(dX1, 'b h n 1 -> (b h) n')-dX2)**2).sum()))
print(torch.sqrt(((rearrange(dS1, 'b h n r -> (b h) n r')-dS2)**2).sum()))


tensor(1.1766, device='cuda:0')
tensor(0.8112, device='cuda:0')
tensor(1.1724, device='cuda:0')
tensor(0.8827, device='cuda:0')


#### following is with `triton==2.0.0.post1`. Also fails with `triton==2.1.0` (segfaults):

In [12]:
import math

import torch

import triton
import triton.language as tl

TRITON_MIN_BLOCK_SIZE = 16

@triton.jit
def matvec(X, y):
    return tl.sum(X*y[None, :], 1)

@triton.jit
def test_kernel(
    u_ptr,  # (_, n, r)
    s_ptr,  # (_, r, r)
    v_ptr,  # (_, r, n)  # using transpose to get around a compiler issue
    x_ptr,  # (_, n) input vector
    L_ptr,  # lower triangular mask of size MIN_BLOCK_SIZE
    y_ptr,  # (_, n) output vector
    n, r,  # Size of the tensor dimensions
    BLOCK_SIZE: tl.constexpr,  # >r
    MIN_BLOCK_SIZE: tl.constexpr  # >16, as constrained by triton. Smaller does less redundant computaton though
):
    # each program works on one batch element
    pid = tl.program_id(axis=0)

    # Offset by batch size
    cur_u_pos = pid * n * r
    cur_v_pos = pid * r * n
    cur_s_pos = pid * r * r
    cur_xy_pos = pid * n

    # pointers for one row of U, V, all of S
    s_col_ptrs = tl.arange(0, BLOCK_SIZE)
    s_row_ptrs = s_col_ptrs  # S is square
    s_col_mask = s_col_ptrs < r
    s_row_mask = s_col_mask
        
    # Load all of S as a matrix
    s_block_ptrs = cur_s_pos + s_row_ptrs[:, None] * r + s_col_ptrs[None, :]
    s_mask = s_row_mask[:, None] & s_col_mask[None, :]
    s_offsets = cur_s_pos + s_block_ptrs
    S = tl.load(s_ptr + s_offsets, mask=s_mask)
    
    # Load all of L as a matrix (exactly b x b, so no masking logic needed)
    L_row_ptrs = tl.arange(0, MIN_BLOCK_SIZE)
    L_col_ptrs = L_row_ptrs  #L is square
    L_block_ptrs = L_row_ptrs[:, None] * MIN_BLOCK_SIZE + L_col_ptrs[None, :]
    L_offsets = L_block_ptrs
    L = tl.load(L_ptr + L_offsets)
    
    # pointers for blocks of U, V.T, X, Y
    u_col_ptrs = s_col_ptrs  # dim same size as S
    u_col_mask = s_col_mask
    u_row_ptrs = tl.arange(0, MIN_BLOCK_SIZE)
    u_block_ptrs = u_row_ptrs[:, None] * r + u_col_ptrs[None, :]
    u_mask = (u_row_ptrs[:, None] < n) & (u_col_mask[None, :])
    
    v_row_ptrs = u_col_ptrs  # U, V same shape, but we're using V.T
    v_row_mask = u_col_mask    
    v_col_ptrs = u_row_ptrs
    v_block_ptrs = v_row_ptrs[:, None] * n + v_col_ptrs[None, :]
    v_mask = (v_row_mask[:, None]) & (v_col_ptrs[None, :] < n)
    
    xy_block_ptrs = tl.arange(0, MIN_BLOCK_SIZE)  # produce corresponding blocks for X, Y
    xy_mask = xy_block_ptrs < n
    
    
    # current state [r, ] vector
    # FP32 accumulation
    Ck = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
    
    # Actual math
    loop_len = tl.cdiv(n, MIN_BLOCK_SIZE)
    for i in range(0, loop_len):
        # Offset for a block of rows in U, V
        u_offsets = cur_u_pos + u_block_ptrs
        v_offsets = cur_v_pos + v_block_ptrs
        xy_offsets = cur_xy_pos + xy_block_ptrs
        
        # load a block of X as a vector
        Xk = tl.load(x_ptr + xy_offsets, mask=xy_mask)  #(b,)
        
        # Load a block of rows of U and VT as matrices
        Uk = tl.load(u_ptr + u_offsets, mask=u_mask)  #(b, r)
        VkT = tl.load(v_ptr + v_offsets, mask=v_mask)  #(r, b)
        
        # Intermediate mat (this can be O(r log r) if we have FFT...)
        ## untested because segfault
        #VkpT = tl.dot(S, VkT, allow_tf32=False)  #(r, b)
        #VkpT = tl.sum(S[:, :, None] * VkT[None, :, :], 1)    #(r, b)
        
        VkpT = tl.dot(Uk, VkT, allow_tf32=False)
        
        # Compute output = (U S V^T) X
        
        #Yk = tl.dot(Uk, VkpT, allow_tf32=False)  #(b, b)
        #Yk = matvec(tl.dot(Uk, VkpT, allow_tf32=False), Xk)
        
        
        
        
        
        tl.store(y_ptr + xy_offsets, tl.sum(VkpT, 0), mask=xy_mask)
        
        # prep next iter
        if i < loop_len-1:
            # Move to next block of rows
            cur_u_pos += MIN_BLOCK_SIZE * r
            cur_xy_pos += MIN_BLOCK_SIZE
            
            # Move to next block of columns
            cur_v_pos += MIN_BLOCK_SIZE
        """
        
        
        # Compute output = (L . (U S V^T)) X  [b, b] x [b,] => [b,]
        Yk = matvec(Uk, Ck) + matvec(L*tl.dot(Uk, VkpT, allow_tf32=False), Xk)
        
        # Store the result of this block
        tl.store(y_ptr + xy_offsets, Yk, mask=xy_mask)
        
        # prep next iter
        if i < loop_len-1:
            # Compute next context [M, b] x [b] => [M]
            Ck += matvec(VkpT, Xk)
            # Move to next block of rows
            cur_uv_pos += MIN_BLOCK_SIZE * r
            cur_xy_pos += MIN_BLOCK_SIZE
        """
    
def test(u, s, v, x, min_block_size=TRITON_MIN_BLOCK_SIZE):
    """
    Accepts 4 tensors U, S, V, X of shape: [B, n, r], [B, r, r], [B, r, n], and [B, n]
    """
    assert all(x.is_cuda and x.is_contiguous for x in (u,s,v,x))
    assert u.size()[0] == v.size()[0]
    assert u.size()[1:] == v.size()[:0:-1]
    assert u.size()[:-1] == x.size(), (u.size(), x.size())
    assert s.size()[-2] == s.size()[-1], s.size()
    assert s.size()[-1] == u.size()[-1], (s.size(), u.size())
    
    # We need to preallocate the output
    y = torch.zeros_like(x)
    
    batch, n, r = u.size()

    def grid(meta): return (batch,)
    block_size = int(2 ** math.ceil(math.log2(r)))
    L = torch.ones((min_block_size, min_block_size), device='cuda').tril()
    test_kernel[grid](
        u, s, v, x,
        L, y, n, r,
        BLOCK_SIZE=block_size,
        MIN_BLOCK_SIZE=min_block_size
    )
    return y

b = 4
n = 64
r = 32


u = torch.rand((b,n,r), device='cuda')
s = torch.rand((b,r,r), device='cuda')
v = torch.rand((b,r,n), device='cuda')
x = torch.rand((b,n), device='cuda')
y_gt = v.sum(-2)
y_gt = (s @ v).sum(-2)

#y_gt = (u @ v).sum(-2)

#y_gt = ((u @ s @ v).tril() @ x[...,None])[...,0]
y = test(u, s, v, x, min_block_size=32)
#print(s)
print(f'error: {torch.sum(torch.abs(y-y_gt)).cpu().numpy()}')

print(torch.stack([y,y_gt], dim=2).cpu().numpy())
#print(torch.stack([y,y_gt], dim=2)[:16,:16].cpu().numpy())


SyntaxError: future feature annotations is not defined (base.py, line 1)