In [9]:
import torch
from conformer import ConformerConvModule, ConformerBlock, Conformer

In [7]:
layer = ConformerConvModule(
    dim = 512,
    causal = False,             # auto-regressive or not - 1d conv will be made causal with padding if so
    expansion_factor = 2,       # what multiple of the dimension to expand for the depthwise convolution
    kernel_size = 31,           # kernel size, 17 - 31 was said to be optimal
    dropout = 0.                # dropout at the very end
)

x = torch.randn(1, 1024, 512)
print(x)
x = layer(x) + x
print(x)

tensor([[[ 0.5265, -0.4810,  0.1760,  ...,  0.3470, -0.9452,  0.7958],
         [ 0.8491, -1.6638, -0.5008,  ..., -2.3336,  0.4957,  0.7496],
         [-0.2855,  0.2087, -1.1082,  ...,  1.2685,  0.2433,  0.3270],
         ...,
         [-0.9187,  1.2772, -0.7080,  ..., -0.4587,  0.1281,  1.1617],
         [-1.8390, -0.2245,  0.9514,  ...,  0.4004, -0.1579, -1.0373],
         [ 1.1652, -1.9307,  0.8463,  ...,  0.6792, -0.1495,  0.6091]]])
tensor([[[ 0.5922, -0.5241, -0.1524,  ...,  0.0465, -0.9541,  0.7753],
         [ 0.9469, -1.8314, -0.5312,  ..., -2.6137,  0.4106,  0.6094],
         [-0.2634,  0.1208, -1.2688,  ...,  1.1890,  0.1873,  0.5199],
         ...,
         [-0.7835,  1.5197, -1.1714,  ..., -0.3793,  0.0170,  1.3951],
         [-1.7959, -0.4181,  0.5922,  ...,  0.8836, -0.5154, -0.7655],
         [ 1.5140, -1.9135,  0.6860,  ...,  0.4897, -0.2789,  0.2855]]],
       grad_fn=<AddBackward0>)


In [8]:
block = ConformerBlock(
    dim = 512,
    dim_head = 64,
    heads = 8,
    ff_mult = 4,
    conv_expansion_factor = 2,
    conv_kernel_size = 31,
    attn_dropout = 0.,
    ff_dropout = 0.,
    conv_dropout = 0.
)

x = torch.randn(1, 1024, 512)
print(x)

block(x) # (1, 1024, 512)
print(x)


tensor([[[ 2.0100,  1.2666, -0.4990,  ..., -0.3092,  0.1367, -0.0962],
         [-0.0834,  1.3541, -0.1584,  ...,  0.1919, -0.1561,  0.3935],
         [ 0.7074, -0.3509, -1.4604,  ..., -1.6113, -0.9696,  0.3001],
         ...,
         [ 1.5183,  0.5957, -1.7596,  ...,  0.3979, -0.4933, -0.5757],
         [ 2.2552,  0.3020,  0.4947,  ..., -0.2974, -0.7883,  0.6733],
         [ 0.6021, -0.3604,  0.8213,  ...,  1.4099, -0.0175, -0.1067]]])
tensor([[[ 2.0100,  1.2666, -0.4990,  ..., -0.3092,  0.1367, -0.0962],
         [-0.0834,  1.3541, -0.1584,  ...,  0.1919, -0.1561,  0.3935],
         [ 0.7074, -0.3509, -1.4604,  ..., -1.6113, -0.9696,  0.3001],
         ...,
         [ 1.5183,  0.5957, -1.7596,  ...,  0.3979, -0.4933, -0.5757],
         [ 2.2552,  0.3020,  0.4947,  ..., -0.2974, -0.7883,  0.6733],
         [ 0.6021, -0.3604,  0.8213,  ...,  1.4099, -0.0175, -0.1067]]])


In [10]:
conformer = Conformer(
    dim = 512,
    depth = 12,          # 12 blocks
    dim_head = 64,
    heads = 8,
    ff_mult = 4,
    conv_expansion_factor = 2,
    conv_kernel_size = 31,
    attn_dropout = 0.,
    ff_dropout = 0.,
    conv_dropout = 0.
)

x = torch.randn(1, 1024, 512)

conformer(x) # (1, 1024, 512)


tensor([[[-0.2011,  1.0815, -1.2923,  ...,  0.3705,  0.5901,  1.1830],
         [ 1.4419, -0.9692, -0.0616,  ...,  0.8974,  0.4045, -0.2396],
         [ 0.8609, -0.7805, -0.4860,  ..., -0.2618,  1.1178, -0.0750],
         ...,
         [ 1.0500, -0.2999, -0.0979,  ...,  1.3331, -2.2303, -0.4900],
         [ 1.2021,  2.0945,  0.2122,  ...,  0.0690, -1.1321, -1.1737],
         [ 1.6938,  1.6155, -0.2591,  ..., -0.2987, -2.6603,  0.0206]]],
       grad_fn=<NativeLayerNormBackward0>)