In [1]:
%load_ext autoreload
%autoreload 2

import torch

from conv1d import CausalConv1dConfig, CausalConv1d
from einops import rearrange

## Match parallel and recurrent impl of causal 1D convolution

In [2]:
DTYPE = torch.float32 
DEVICE = torch.device("cuda:0")
torch.manual_seed(0)

<torch._C.Generator at 0x7f90963a68d0>

In [3]:
B = 1
S = 12
DH = 5
EPS = 0.0

In [4]:
xs = torch.randn((B, S, DH), dtype=DTYPE, device=DEVICE)
xs, xs.shape

(tensor([[[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209],
          [-0.5797, -0.6229, -0.3284, -1.0745, -0.3631],
          [-1.6711,  2.2655,  0.3117, -0.1842,  1.2866],
          [ 1.1820, -0.1271,  1.2169,  1.4353,  1.0605],
          [-0.4941, -1.4244, -0.7244, -1.2973,  0.0697],
          [-0.0074,  1.8969,  0.6878, -0.0779, -0.8373],
          [ 1.3506, -0.2879, -0.5965, -0.3283, -0.9086],
          [-0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
          [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806],
          [ 1.0252, -1.4622, -0.7554, -0.1836,  0.3824],
          [ 0.3918, -0.0830,  0.8971, -1.1123,  0.1116],
          [ 0.4863, -0.5499, -0.3231, -0.5469,  0.9049]]], device='cuda:0'),
 torch.Size([1, 12, 5]))

In [5]:
conv1d = CausalConv1d(config=CausalConv1dConfig(feature_dim=DH, kernel_size=4, causal_conv_bias=True)).to(device=DEVICE)

In [6]:
y_p = conv1d(xs)
y_p, y_p.shape

(tensor([[[ 0.0333,  0.4533,  1.0648, -0.2179,  0.0256],
          [ 0.3536,  0.3336,  1.6271, -0.2148,  0.1696],
          [ 0.1185, -0.6049,  1.2795, -0.4546, -0.3236],
          [-0.1661,  1.5106,  0.5153, -0.4060, -1.0016],
          [-1.4377,  0.3595, -0.0985, -0.4110, -0.3364],
          [ 0.4205, -1.1767, -0.0652, -0.2971,  0.0118],
          [-0.8406,  1.5220,  0.1182,  0.0782,  0.4352],
          [-0.6601,  0.2417,  0.5137, -0.8620, -0.2299],
          [ 0.7227, -0.3354, -0.1103, -0.1781, -0.5223],
          [-0.8804,  1.1659, -0.0971, -0.0738,  0.2892],
          [-0.8929, -0.2839, -0.1745,  0.0970, -0.6405],
          [-0.1276,  0.6855, -0.1451, -0.2405, -0.3184]]], device='cuda:0',
        grad_fn=<TransposeBackward0>),
 torch.Size([1, 12, 5]))

In [7]:
ys = []
conv_state = None
for x in xs.split(split_size=1, dim=1):
    y, conv_state = conv1d.step(x, conv_state)
    # print(conv_state)
    ys.append(y)
torch.cat(ys)

  from .autonotebook import tqdm as notebook_tqdm


tensor([[ 0.0333,  0.4533,  1.0648, -0.2179,  0.0256],
        [ 0.3536,  0.3336,  1.6271, -0.2148,  0.1696],
        [ 0.1185, -0.6049,  1.2795, -0.4546, -0.3236],
        [-0.1661,  1.5106,  0.5153, -0.4060, -1.0016],
        [-1.4377,  0.3595, -0.0985, -0.4110, -0.3364],
        [ 0.4205, -1.1767, -0.0652, -0.2971,  0.0118],
        [-0.8406,  1.5220,  0.1182,  0.0782,  0.4352],
        [-0.6601,  0.2417,  0.5137, -0.8620, -0.2299],
        [ 0.7227, -0.3354, -0.1103, -0.1781, -0.5223],
        [-0.8804,  1.1659, -0.0971, -0.0738,  0.2892],
        [-0.8929, -0.2839, -0.1745,  0.0970, -0.6405],
        [-0.1276,  0.6855, -0.1451, -0.2405, -0.3184]], device='cuda:0',
       grad_fn=<CatBackward0>)

In [8]:
torch.cat(ys) - y_p

tensor([[[-1.1176e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.1921e-07,  1.4901e-08,  0.0000e+00],
         [-1.4901e-08,  0.0000e+00,  0.0000e+00, -2.9802e-08, -2.9802e-08],
         [ 4.4703e-08, -1.1921e-07,  0.0000e+00,  2.9802e-08,  0.0000e+00],
         [-1.1921e-07,  2.9802e-08,  4.4703e-08,  0.0000e+00, -2.9802e-08],
         [ 8.9407e-08, -1.1921e-07,  3.7253e-08, -2.9802e-08,  2.7008e-08],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  2.9802e-08,  0.0000e+00],
         [-5.9605e-08,  2.9802e-08,  5.9605e-08,  0.0000e+00,  0.0000e+00],
         [-5.9605e-08,  0.0000e+00, -1.4901e-08, -1.4901e-08,  5.9605e-08],
         [-5.9605e-08,  1.1921e-07,  5.2154e-08,  2.9802e-08,  0.0000e+00],
         [ 0.0000e+00,  2.9802e-08, -1.4901e-08,  7.4506e-09, -5.9605e-08],
         [ 0.0000e+00,  0.0000e+00, -2.9802e-08,  0.0000e+00,  0.0000e+00]]],
       device='cuda:0', grad_fn=<SubBackward0>)

In [9]:
conv1d.conv.weight.shape, conv1d.conv.bias.shape

(torch.Size([5, 1, 4]), torch.Size([5]))

In [10]:
## the conv state are the kernel_size last input elements
# take the first 4
conv_state = xs[:, -4:].clone()
conv_state

tensor([[[ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806],
         [ 1.0252, -1.4622, -0.7554, -0.1836,  0.3824],
         [ 0.3918, -0.0830,  0.8971, -1.1123,  0.1116],
         [ 0.4863, -0.5499, -0.3231, -0.5469,  0.9049]]], device='cuda:0')

In [11]:
torch.roll(conv_state, shifts=-1, dims=1)

tensor([[[ 1.0252, -1.4622, -0.7554, -0.1836,  0.3824],
         [ 0.3918, -0.0830,  0.8971, -1.1123,  0.1116],
         [ 0.4863, -0.5499, -0.3231, -0.5469,  0.9049],
         [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806]]], device='cuda:0')

In [12]:
x_new = xs[:, :1, :]
x_new, x_new.shape

(tensor([[[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209]]], device='cuda:0'),
 torch.Size([1, 1, 5]))

In [13]:
conv_state_new = torch.roll(conv_state, shifts=-1, dims=1)
conv_state_new[:, -1, :] = x_new

In [14]:
conv_state_new, conv_state_new.shape

(tensor([[[ 1.0252, -1.4622, -0.7554, -0.1836,  0.3824],
          [ 0.3918, -0.0830,  0.8971, -1.1123,  0.1116],
          [ 0.4863, -0.5499, -0.3231, -0.5469,  0.9049],
          [-0.9247, -0.4253, -2.6438,  0.1452, -0.1209]]], device='cuda:0'),
 torch.Size([1, 4, 5]))

In [15]:
torch.sum(conv_state_new * rearrange(conv1d.conv.weight, 'D 1 KS -> KS D'), dim=1)

tensor([[ 0.2971,  0.1015,  0.7834, -0.4558, -0.4674]], device='cuda:0',
       grad_fn=<SumBackward1>)

In [16]:
def conv1d_step(x: torch.Tensor, conv_state: torch.Tensor, conv1d_weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    B: batch size
    S: sequence length
    D: feature dimension
    KS: kernel size
    Args:
        x (torch.Tensor): (B, S, D)
        conv_state (torch.Tensor): (B, KS, D)
        conv1d_weight (torch.Tensor): (KS, D)
    """
    assert x.shape[0] == conv_state.shape[0], f"x has batch size {x.shape[0]} but conv_state has batch size {conv_state.shape[0]}"
    assert x.shape[2] == conv_state.shape[2], f"x has feature dimension {x.shape[2]} but conv_state has feature dimension {conv_state.shape[2]}"
    assert x.shape[1] == 1, f"x has sequence length {x.shape[1]} but it should be 1"
    conv_state_new = torch.roll(conv_state, shifts=-1, dims=1)
    conv_state_new[:, -1, :] = x
    return torch.sum(conv_state_new * conv1d_weight, dim=1), conv_state_new

In [17]:
torch.zeros_like(xs[:, :4, :]).shape

torch.Size([1, 4, 5])

In [18]:
xs

tensor([[[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209],
         [-0.5797, -0.6229, -0.3284, -1.0745, -0.3631],
         [-1.6711,  2.2655,  0.3117, -0.1842,  1.2866],
         [ 1.1820, -0.1271,  1.2169,  1.4353,  1.0605],
         [-0.4941, -1.4244, -0.7244, -1.2973,  0.0697],
         [-0.0074,  1.8969,  0.6878, -0.0779, -0.8373],
         [ 1.3506, -0.2879, -0.5965, -0.3283, -0.9086],
         [-0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
         [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806],
         [ 1.0252, -1.4622, -0.7554, -0.1836,  0.3824],
         [ 0.3918, -0.0830,  0.8971, -1.1123,  0.1116],
         [ 0.4863, -0.5499, -0.3231, -0.5469,  0.9049]]], device='cuda:0')