In [1]:
import torch
import torch.nn as nn

"""
Small code snippet for the equivalence between a linear layer and a convolutional layer with kernel_size=1.

In some code bases I have seen that for the Q, K, V projection layers in the transformer
the authors use a Conv1d layer with kernel_size=1 instead of a Linear layer. This is because the Conv1d layer
adapts better to image data.

So I had to convince myself that the Conv1d layer with kernel_size=1 is equivalent to a Linear layer.
"""

# Shape after unsqueeze: (batch_size=1, in_channels=4)
x = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True).unsqueeze(0)
print(x.shape)
print(x)

embed_dim = 4

# Conv1d layer with in_channels=embed_dim and out_channels=3*embed_dim
qkv_proj_conv = nn.Conv1d(embed_dim, 3 * embed_dim, kernel_size=1, bias=False)
qkv_proj_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=False)

# Copy weights and biases from linear layer to conv layer so that they are the same
with torch.no_grad():
    # Conv1d layer's weights need to be reshaped to (out_channels, in_channels, kernel_size)
    qkv_proj_conv.weight.copy_(qkv_proj_linear.weight.view(3 * embed_dim, embed_dim, 1))

# Apply the linear layer
output_linear = qkv_proj_linear(x)  # Shape: (batch_size=1, out_features=12)

# Apply the convolutional layer
# Reshape input to (batch_size, in_channels, sequence_length) which is (1, 4, 1)
x_conv = x.unsqueeze(-1)  # Shape: (1, 4, 1)
output_conv = qkv_proj_conv(x_conv)  # Shape: (batch_size=1, out_channels=12, sequence_length=1)
output_conv = output_conv.squeeze(-1)  # Shape: (batch_size=1, out_channels=12)

# Check if the outputs are the same
print(torch.allclose(output_linear, output_conv, atol=1e-6))

torch.Size([1, 4])
tensor([[1., 2., 3., 4.]], grad_fn=<UnsqueezeBackward0>)
True
