In [1]:
%load_ext autoreload
%autoreload 2

In [19]:
from wavenet import WaveNetBlock, WaveNet

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

## dilated causal convolutions

In [4]:
class DilatedCausalConv1d(nn.Conv1d):
    def __init__(self, in_channels, out_channels, dilation):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=1,
            padding=dilation,
            dilation=dilation,
            bias=False,
            padding_mode="zeros"
        )
        self.dilation_ = dilation

    def forward(self, x):
        ret = super().forward(x)
        return ret[..., :-self.dilation_]

In [5]:
dccs = {2**i:DilatedCausalConv1d(2, 2, 2**i) for i in range(4)}

In [6]:
dccs

{1: DilatedCausalConv1d(2, 2, kernel_size=(2,), stride=(1,), padding=(1,), bias=False),
 2: DilatedCausalConv1d(2, 2, kernel_size=(2,), stride=(1,), padding=(2,), dilation=(2,), bias=False),
 4: DilatedCausalConv1d(2, 2, kernel_size=(2,), stride=(1,), padding=(4,), dilation=(4,), bias=False),
 8: DilatedCausalConv1d(2, 2, kernel_size=(2,), stride=(1,), padding=(8,), dilation=(8,), bias=False)}

In [7]:
for dcc in dccs.values():
    dcc.weight.data.fill_(1)

In [8]:
seq = torch.arange(8, dtype=torch.long).reshape(1, -1)
print(seq)

embed = nn.Embedding(255, 2)

seq = embed(seq)
seq = rearrange(seq, "batch seq channels -> batch channels seq") # PyTorch Conv1d required order 
                           
print(seq)

for dil in (1, 2, 4, 8):
    dcc = dccs[dil]
    seq = dcc(seq)
    print(seq)

tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
tensor([[[-0.3989, -0.0918,  0.7880, -1.1292,  0.5464, -0.4332, -0.9581,
          -2.1084],
         [ 0.2470,  2.2259, -0.6204, -1.3825, -1.3547, -1.1894, -1.1706,
          -1.4953]]], grad_fn=<PermuteBackward0>)
tensor([[[-0.1519,  1.9822,  2.3017, -2.3442, -3.3201, -2.4310, -3.7513,
          -5.7323],
         [-0.1519,  1.9822,  2.3017, -2.3442, -3.3201, -2.4310, -3.7513,
          -5.7323]]], grad_fn=<SliceBackward0>)
tensor([[[ -0.3038,   3.9645,   4.2995,  -0.7240,  -2.0368,  -9.5505, -14.1428,
          -16.3266],
         [ -0.3038,   3.9645,   4.2995,  -0.7240,  -2.0368,  -9.5505, -14.1428,
          -16.3266]]], grad_fn=<SliceBackward0>)
tensor([[[ -0.6076,   7.9289,   8.5991,  -1.4480,  -4.6813, -11.1720, -19.6866,
          -34.1013],
         [ -0.6076,   7.9289,   8.5991,  -1.4480,  -4.6813, -11.1720, -19.6866,
          -34.1013]]], grad_fn=<SliceBackward0>)
tensor([[[ -1.2153,  15.8578,  17.1982,  -2.8960,  -9.3627, -22.3441, -39.3

```
our perceptual system feeds raw tokens (photons, waves, etc.) into a Perceiver model that outputs embeddings, which are then processed by other models using a universal set of embeddings

the current formulation of "embeddings" is holding ML back because it's arbitrary for every new model architecture...
```

## wavenet block

In [15]:
wnb = WaveNetBlock(1, 6, 4, 2)

In [16]:
x = torch.randn(1, 8, 4)
print(x)

tensor([[[ 0.6415,  2.0079,  1.2987, -1.6521],
         [-0.4333,  2.0356,  0.5587, -0.2746],
         [ 0.0465, -1.1569, -0.2227, -0.9221],
         [-0.3786,  0.3795,  0.8372,  0.6274],
         [ 2.1452, -0.3153,  1.5244, -0.4513],
         [ 0.5395,  0.7476,  0.0573,  0.9140],
         [ 1.0305, -0.6988,  0.3695, -0.6360],
         [ 0.1316, -0.8818, -0.1656,  0.4459]]])


In [17]:
wnb(x)

(tensor([[[-0.2766, -0.2638, -0.0616, -0.0880, -0.0358, -0.1975, -0.0095,
           -0.1001],
          [ 0.2595,  0.4514,  0.4089,  0.3159,  0.3149,  0.4715,  0.3673,
            0.3817],
          [-0.3067, -0.1060, -0.3404, -0.4422, -0.4741, -0.0669, -0.4493,
           -0.2762],
          [-0.4703, -0.5385, -0.4511, -0.4565, -0.4142, -0.4178, -0.4365,
           -0.3758]]], grad_fn=<ConvolutionBackward0>),
 tensor([[[ 0.5628,  0.2582,  0.5419,  0.6263,  0.7145,  0.4346,  0.6398,
            0.6511],
          [-0.4450, -0.1115, -0.3777, -0.5143, -0.5695, -0.1573, -0.4977,
           -0.4053]]], grad_fn=<ConvolutionBackward0>))

## wavenet

In [40]:
wn = WaveNet(5)

In [41]:
x2 = torch.randint(0, 5, (1, 8), dtype=torch.long)
print(x2)

tensor([[3, 2, 1, 1, 0, 2, 1, 3]])


In [42]:
wn(x2)

tensor([[[-0.9521,  0.9509,  0.4908],
         [ 2.2902, -0.5103,  0.3629],
         [ 0.8098,  0.8580, -0.1801],
         [ 0.8098,  0.8580, -0.1801],
         [ 2.4660, -0.2746, -0.0791],
         [ 2.2902, -0.5103,  0.3629],
         [ 0.8098,  0.8580, -0.1801],
         [-0.9521,  0.9509,  0.4908]]], grad_fn=<EmbeddingBackward0>)
[tensor([[[0.0453, 0.1430, 0.1095, 0.0920, 0.1358, 0.1542, 0.1095, 0.0638],
         [0.1175, 0.4894, 0.3909, 0.3295, 0.5229, 0.5372, 0.3909, 0.1265]]],
       grad_fn=<ConvolutionBackward0>), tensor([[[-0.0691, -0.1369,  0.0361, -0.0846, -0.1191, -0.1052, -0.0986,
          -0.0434],
         [ 0.3147,  0.2504, -0.0290,  0.3456,  0.0863,  0.1387,  0.3233,
           0.3268]]], grad_fn=<ConvolutionBackward0>), tensor([[[ 0.0227,  0.3507,  0.1165,  0.1611,  0.3120,  0.3692,  0.1612,
           0.1179],
         [-0.4224, -0.2954, -0.4045, -0.3934, -0.3220, -0.3053, -0.3876,
          -0.3880]]], grad_fn=<ConvolutionBackward0>)]
tensor([[[-0.0011,  0.3568,  

tensor([[[ 0.0162, -0.2296,  0.2837,  0.1779,  0.0148],
         [-0.0729, -0.1715,  0.3773,  0.2691,  0.0466],
         [-0.0080, -0.1913,  0.3153,  0.2200,  0.0756],
         [-0.0280, -0.1932,  0.3291,  0.2300,  0.0291],
         [-0.0507, -0.1780,  0.3563,  0.2525,  0.0579],
         [-0.0727, -0.1675,  0.3798,  0.2715,  0.0656],
         [-0.0337, -0.1917,  0.3343,  0.2342,  0.0251],
         [ 0.0014, -0.2013,  0.3015,  0.2083,  0.0461]]],
       grad_fn=<PermuteBackward0>)