In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

class PatchEmbedding(nn.Module):
    def __init__(self, patch_size = 8, in_chans = 3, embed_dim = 768, fix_size = 196, bias = True):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
        self.norm = nn.Identity()
        self.num_patches = fix_size

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

In [7]:
import timm 
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

In [8]:
fake_image_net_img = torch.rand((1, 3, 224, 224))
out = model(fake_image_net_img)
print(out.size())

torch.Size([1, 10])


In [9]:
fake_qdt = torch.rand((1, 3, 8, 8 * 100)) 
model.patch_embed = PatchEmbedding(fix_size = 100)

patch_embed_out = model.patch_embed(fake_qdt)

print(patch_embed_out.size())
print(model.pos_embed.size())

torch.Size([1, 100, 768])
torch.Size([1, 197, 768])


In [10]:
model.pos_embed = nn.Parameter(torch.randn(1, 100 + model.num_prefix_tokens, 768) * .02)

In [11]:
print(model.pos_embed.size())

torch.Size([1, 101, 768])


In [14]:
model(fake_qdt).size()

torch.Size([1, 10])