In [5]:
from sunyata.pytorch.tiny_imagenet import TinyImageNet

In [59]:
from torch.utils.data import DataLoader
from torchvision import transforms

In [49]:
train_transforms = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

In [54]:
train_data = TinyImageNet(split='train', transform=train_transforms)

tiny-imagenet-200.zipalready downloaded and verified.


In [58]:
image, target = train_data[0]
image.shape

torch.Size([3, 224, 224])

In [60]:
batch_size = 2
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

In [61]:
len(train_data), len(train_loader)

(100000, 50000)

In [66]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [67]:
# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [68]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn =fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [69]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [82]:
image_size = 224
patch_size = 16
num_classes = 200

dim = 1024
depth = 6
heads = 16
mlp_dim = 2048

pool = 'cls'
channels = 3
dim_head = 64

dropout = 0.1
emb_dropout = 0.1

In [75]:
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

In [76]:
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

In [78]:
to_patch_embedding = nn.Sequential(
    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
    nn.Linear(patch_dim, dim),
)

In [116]:
pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
cls_token = nn.Parameter(torch.randn(1, 1, dim))
dropout = nn.Dropout(emb_dropout)

In [80]:
from sunyata.pytorch.transformer2 import TransformerLayer

In [87]:
transformer = nn.Sequential(*[TransformerLayer(dim, mlp_dim, heads, dropout=0.1) for _ in range(depth)])

In [88]:
mlp_head = nn.Sequential(
    nn.LayerNorm(dim),
    nn.Linear(dim, num_classes)
)

In [99]:
image, target = next(iter(train_loader))
image.shape, target.shape

(torch.Size([2, 3, 224, 224]), torch.Size([2]))

In [101]:
x = to_patch_embedding(image)
x.shape

torch.Size([2, 196, 1024])

In [104]:
patch_dim, patch_height, patch_width, num_patches

(768, 16, 16, 196)

In [105]:
b, n, _ = x.shape

In [106]:
cls_tokens = repeat(cls_token, '1 n d -> b n d', b = b)
cls_tokens.shape

torch.Size([2, 1, 1024])

In [107]:
x = torch.cat((cls_tokens, x), dim=1)

In [111]:
x.shape, pos_embedding.shape

(torch.Size([2, 197, 1024]), torch.Size([1, 197, 1024]))

In [114]:
x += pos_embedding


In [117]:
x = dropout(x)

In [118]:
x = transformer(x)

In [119]:
x.shape

torch.Size([2, 197, 1024])

In [120]:
x = x.mean(dim = 1) if pool == 'mean' else x[:, 0]

In [121]:
x.shape

torch.Size([2, 1024])

In [122]:
output = mlp_head(x)