In [11]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
        #return self.norm(self.fn(x, **kwargs))

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

        #nn.init.xavier_uniform_(self.net[0].weight, gain=2 ** .5)
        #nn.init.normal_(self.net[0].bias, std=1e-6)
        #nn.init.xavier_uniform_(self.net[3].weight, gain=2 ** .5)
        #nn.init.normal_(self.net[3].bias, std=1e-6)
        #nn.init.kaiming_uniform_(self.net[0].weight, a=5**.5)
        #nn.init.kaiming_uniform_(self.net[0].weight, a=5**.5)
        bound1 = 1 / (dim ** .5)
        bound2 = 1 / (hidden_dim ** .5)
        nn.init.uniform_(self.net[0].weight, -bound1, bound1)
        nn.init.uniform_(self.net[0].bias, -bound1, bound1)
        nn.init.uniform_(self.net[3].weight, -bound2, bound2)
        nn.init.uniform_(self.net[0].bias, -bound2, bound2)

        #nn.init.xavier_normal_(self.net[0].weight, gain=2**-.5)
        #nn.init.normal_(self.net[0].bias, std=.1)
        #nn.init.xavier_normal_(self.net[3].weight, gain=2**-.5)
        #nn.init.normal_(self.net[3].bias, std=.1)

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)
               
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        #nn.init.xavier_uniform_(self.to_qkv.weight)
        bound = 1 / (dim ** .5)
        nn.init.uniform_(self.to_qkv.weight, -bound, bound)
        #nn.init.xavier_normal_(self.to_qkv.weight, gain=2**-.5)
        #nn.init.normal_(self.to_qkv.weight, std=1)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
        #nn.init.xavier_uniform_(self.to_out[0].weight)
        #nn.init.normal_(self.to_out[0].bias, std=1e-6)
        bound = 1 / (inner_dim ** .5)
        nn.init.uniform_(self.to_out[0].weight, -bound, bound)
        nn.init.uniform_(self.to_out[0].bias, -bound, bound)

        #nn.init.zeros_(self.to_out[0].weight)
        #nn.init.zeros_(self.to_out[0].bias)

        #nn.init.xavier_normal_(self.to_out[0].weight, gain=1)
        #nn.init.normal_(self.to_out[0].bias, std=.05)


    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x)
        qkv = qkv.chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3,
                 dim_head = 64, dropout = 0., emb_dropout = 0., use_cls_token=True, 
                 sessions="ignore", subjects="ignore", training_config="ignore", pretrained="ignore", chunk_idx="ignore", chunk_i="ignore"):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )
        
        self.patch_layer_norm = nn.LayerNorm(dim)
        
        bound = 1 / (patch_dim ** .5)
        nn.init.uniform_(self.to_patch_embedding[1].weight, -bound, bound)
        nn.init.uniform_(self.to_patch_embedding[1].bias, -bound, bound)
        #nn.init.xavier_uniform_(self.to_patch_embedding[1].weight)
        #nn.init.normal_(self.to_patch_embedding[1].bias, std=1e-6)
        #nn.init.xavier_normal_(self.to_patch_embedding[1].weight, gain=2**-.5)
        #nn.init.normal_(self.to_patch_embedding[1].bias, std=.1)

        self.use_cls_token = use_cls_token
        if self.use_cls_token:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches + 1, dim))
        else:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches, dim))
        nn.init.normal_(self.pos_embedding, mean=0, std=.02)
        #nn.init.zeros_(self.pos_embedding)

        self.cls_token = nn.Parameter(torch.empty(1, 1, dim))
        nn.init.zeros_(self.cls_token)
        
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
        #nn.init.zeros_(self.mlp_head[1].weight)
        #nn.init.zeros_(self.mlp_head[1].bias)
        bound = 1 / (dim ** .5)
        nn.init.uniform_(self.mlp_head[1].weight, -bound, bound)
        nn.init.uniform_(self.mlp_head[1].bias, -bound, bound)
        #nn.init.xavier_normal_(self.mlp_head[1].weight, gain=1)
        #nn.init.normal_(self.mlp_head[1].weight, std=1)
        #nn.init.normal_(self.mlp_head[1].bias, std=1)

    def forward(self, img):
        
        x = self.to_patch_embedding(img)
        x = self.patch_layer_norm(x)
        b, n, _ = x.shape
        
        if self.use_cls_token:
            cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
            x = torch.cat((cls_tokens, x), dim=1)
            x += self.pos_embedding[:, :(n + 1)]
        else :
            x += self.pos_embedding
        
        x = self.dropout(x)
        
        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        
        x = self.to_latent(x)
        x = self.mlp_head(x)
        return x

In [12]:
vit = ViT(**{
    "image_size": (1, 300),
    "patch_size": (1, 10),
    "channels": 14,
    "num_classes": 8,

    "dim": 64,
    "mlp_dim": 128,
    "dim_head": 32,
    "heads": 8,
    "depth": 1,

    "dropout": .2,
    "emb_dropout": 0,

    "pool": "cls",
    "use_cls_token": True,})

In [13]:
vit(torch.randn(64, 14, 1, 300))

tensor([[-1.3563e-01, -1.1381e+00,  1.0509e-02, -3.2930e-02, -3.7114e-01,
         -6.2313e-01, -1.3273e+00, -7.1051e-01],
        [ 1.0506e-01,  2.9779e-01,  4.1018e-01, -7.8107e-01, -1.6058e-01,
         -8.8037e-01, -7.0013e-01, -3.7741e-01],
        [-8.5659e-01,  1.2082e-01,  1.4838e+00, -5.1904e-01, -1.1494e+00,
         -6.2442e-01, -1.3055e-01, -1.1003e-03],
        [-6.8632e-01,  2.4830e-02,  7.6173e-01, -1.5873e-01, -8.0623e-01,
         -5.0671e-01,  3.1818e-01, -1.8746e-01],
        [-6.8394e-01, -5.6709e-01,  1.2917e+00,  7.0438e-02, -9.2544e-01,
         -7.3106e-01, -6.7923e-01, -5.5209e-01],
        [-4.3025e-01,  1.3769e-01,  8.6306e-01, -5.4819e-01, -9.7765e-01,
         -4.2731e-01,  5.4673e-01, -5.6188e-02],
        [-6.4496e-01, -7.4136e-01, -3.2990e-01,  1.5065e-01,  2.2774e-01,
         -8.6682e-01, -2.4947e-01, -4.9481e-01],
        [ 1.0405e-01,  1.7261e-02, -4.9970e-01,  7.2296e-01, -4.0426e-01,
         -2.3404e-01,  3.1308e-01, -1.0378e+00],
        [-8.9360