In [1]:
import torch
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1
        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                   kernel_size=3, padding=padding, padding_mode='circular', bias=False)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x

In [19]:
batch_size = 32
seq_len = 8
c_in = 10
d_model = 128
kernel_size = 3
padding = 1

tokenConv = nn.Conv1d(
    in_channels=c_in,
    out_channels=d_model,
    kernel_size=kernel_size,
    padding=padding,
    padding_mode='circular',
    bias=False
)

- In the traditonal transformer, from an integer token to d_model embedding

- TokenConvolution
    Instead of token, we have float channels (features)
    Using #(c_in * d_model) kernels, we extend the dimension of features
    Using the convolutional layer, locally adjascent tokens generate embeddings togeter

In [21]:
x = torch.rand(batch_size, seq_len, c_in)
x_permuted = x.permute(0, 2, 1)
x_conved = tokenConv(x_permuted)
x_out = x_conved.transpose(1,2)

print(f"x shape: {x.shape}")
print(f"x_permuted shape: {x_permuted.shape}")
print(f"x_conved shape: {x_conved.shape}")
print(f"x_out shape: {x_out.shape}")

x shape: torch.Size([32, 8, 10])
x_permuted shape: torch.Size([32, 10, 8])
x_conved shape: torch.Size([32, 128, 8])
x_out shape: torch.Size([32, 8, 128])
