In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers.weight_init import lecun_normal_, trunc_normal_


# from timm and modified for Conv1d implementation
def init_weights_vit_jax(module: nn.Module, name: str = ''):
    """ ViT weight initialization, matching JAX (Flax) impl """
    if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Conv2d):
        lecun_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()

class MultiHeadSelfAttention(nn.Module):
    def __init__(self,  embed_dim, output_dim, num_heads):
        """
        dim of each head = embed_dim//num_heads (i.e. 64 as default)
        """
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.dim_each_heads = embed_dim // num_heads

        self.to_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=1, bias=True)
        self.to_k = nn.Conv1d(embed_dim, embed_dim, kernel_size=1, bias=True)
        self.to_v = nn.Conv1d(embed_dim, embed_dim, kernel_size=1, bias=True)

        self.dk = float( torch.rsqrt(torch.FloatTensor([self.dim_each_heads])).item())   # dk: dim of q and k

        self.out_proj = nn.Conv1d(embed_dim, output_dim, kernel_size=1, bias=True)
        
    def forward(self, x):
        """
        x: batch size, embedding dimension, sequence length   (B, N, L)
        """
        assert x.ndim == 3
        
        B = x.size(0)       # batch size
        Lq = x.shape[-1]    # q sequence length
        
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)
        
        q_multihead = q.view(B, self.num_heads, self.dim_each_heads, -1)
        k_multihead = k.view(B, self.num_heads, self.dim_each_heads, -1)
        v_multihead = v.view(B, self.num_heads, self.dim_each_heads, -1)
        
        scaled_qk = torch.einsum('bhnl, bhnm->bhlm', q_multihead, k_multihead) * self.dk # (B, Lq, Lk)
        attention_weight = torch.softmax(scaled_qk, dim=2)                               # (B, Lq, Lk)
        
        transformed_cocatenated_heads = torch.einsum('bhlm, bhkm->bhkl', attention_weight, v_multihead).reshape(B, -1, Lq) # concatenation
        
        return self.out_proj( transformed_cocatenated_heads )       
    
    
class TransformerEncoder(nn.Module):
    def __init__(self,  embed_dim,  num_heads, ffn_dim):
        super(TransformerEncoder, self).__init__()
        self.mhsa = MultiHeadSelfAttention(embed_dim, embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6)
        
        self.ffn = nn.Sequential( *[nn.Conv1d(embed_dim, ffn_dim, kernel_size=1, bias=True),
                                    nn.ReLU(inplace=True),
                                    nn.Conv1d(ffn_dim, embed_dim, kernel_size=1, bias=True)])
        
    def forward(self, x):
        """
        x: batch size, embedding dimension, sequence length   (B, N, L)
        """
        assert x.ndim == 3
        skip = x
        x = self.mhsa(x)
        
        x = skip + x
        x = self.norm1( x.permute(0, 2, 1) ).permute(0, 2, 1)
        
        skip = x
        x = self.ffn( x )
        
        x = skip + x
        x = self.norm2( x.permute(0, 2, 1) ).permute(0, 2, 1)
        
        return x
    
class VisionTransformerEncoder(nn.Module):
    def __init__(self,  embed_dim,  num_heads, mlp_dim):
        super(VisionTransformerEncoder, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6)
        self.mhsa = MultiHeadSelfAttention(embed_dim, embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6)
        
        self.mlp = nn.Sequential(*[nn.Conv1d(embed_dim, mlp_dim, kernel_size=1, bias=True),
                                 nn.ReLU(inplace=True),
                                 nn.Conv1d(mlp_dim, embed_dim, kernel_size=1, bias=True)])
        
    def forward(self, x):
        """
        x: batch size, embedding dimension, sequence length   (B, N, L)
        """
        assert x.ndim == 3
        skip_1 = x
        x = self.norm1( x.permute(0, 2, 1) ).permute(0, 2, 1)
        x = self.mhsa(x)
        x = skip_1 + x

        skip_2 = x
        x = self.norm2( x.permute(0, 2, 1) ).permute(0, 2, 1)
        x = self.mlp(x)
        x = skip_2 + x        
        
        return x
    
class ImageEmbedding(nn.Module):
    def __init__(self,  image_dim=3, embed_dim=512,  patch_size=(16, 16)):
        super(ImageEmbedding, self).__init__()
        self.proj = nn.Conv2d(image_dim, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0, bias=True)
        self.embed_dim= embed_dim       
        
    def forward(self, x):
        """
        x: batch size, image_dim, height, width   (B, C, H, W)
        """
        B = x.size(0)
        
        return self.proj( x ).view(B, self.embed_dim, -1)   
    
class VisionTransformer(nn.Module):
    def __init__(self,  
                 image_size=(256, 256), 
                 patch_size=(16, 16), 
                 image_dim=3, 
                 embed_dim=384, 
                 mlp_dim=1536, 
                 num_heads=8,
                 num_layers=8,
                 img_embed=ImageEmbedding,
                 global_pool='avg'):
        """
        No dropout implementation for simplicity
        """
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.image_dim  = image_dim
        self.embed_dim  = embed_dim
        self.mlp_dim    = mlp_dim
        self.num_heads  = num_heads
        self.num_layers = num_layers
        embed_len = (image_size[0]//patch_size[0]) * (image_size[1]//patch_size[1])
        self.pos_embed = nn.Parameter(torch.randn(1,  embed_dim, embed_len+1) * .02)
        self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, 1))
        
        trunc_normal_(self.pos_embed, std=.02)
        nn.init.normal_(self.cls_token, std=1e-6)
        
        if img_embed is not None:
            self.img_embed = img_embed(image_dim, embed_dim, patch_size)
        else:
            self.img_embed = nn.Identity()
        
        self.transformer = nn.ModuleList([VisionTransformerEncoder(embed_dim, num_heads, mlp_dim) for _ in range(num_layers) ])   
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=True, eps=1e-6)
        
        self.global_pool = global_pool
        
        # weight init
        for name, module in self.named_modules():
            init_weights_vit_jax(module, name)
        
    def forward(self, x):
        x = self.img_embed(x)
        x = torch.cat( [self.cls_token.expand(x.size(0), -1, -1), x], dim=2)  # concatenate cls token
        x = x + self.pos_embed                                                # add positional embedding
        for layer in self.transformer: 
            x = layer(x)
        x = self.norm( x.permute(0, 2, 1) ).permute(0, 2, 1)
        if self.global_pool == 'avg':
            cls = x.mean(dim=2)
        else:
            cls = x[:, :, 0]   # token  in timm
        
        return  cls
        

In [23]:
ViT_S_config = {'num_layers':12, 'embed_dim': 384, 'mlp_dim': 1536, 'num_heads': 6} 

tfe = TransformerEncoder(512, 8, 3072)
vtfe = VisionTransformerEncoder(512, 8, 3072)
ie = ImageEmbedding(image_dim=3, embed_dim=512,  patch_size=(16, 16))
vit = VisionTransformer(**ViT_S_config, global_pool='token')
x = torch.rand(4, 3, 256, 256)

with torch.no_grad():
    print( tfe( ie(x) ).size() )
    print( vtfe( ie(x) ).size() )
    print( vit( x ).size() )

torch.Size([4, 512, 256])
torch.Size([4, 512, 256])
torch.Size([4, 384])


In [24]:
# parameter counts
ViT_S_config = {'num_layers':12, 'embed_dim': 384, 'mlp_dim': 1536, 'num_heads': 6} 
ViT_B_config = {'num_layers':12, 'embed_dim': 768, 'mlp_dim': 3072, 'num_heads':12} 
HIPT256_config = {'num_layers':8, 'embed_dim': 384, 'mlp_dim': 1536,'num_heads':6} 
HIPT4096_config = {'num_layers':4, 'embed_dim': 192, 'mlp_dim': 768, 'num_heads':6, 'img_embed': None} 

num_params = 0
for params in VisionTransformer(**ViT_S_config, global_pool='token').parameters():
    num_params += torch.prod( torch.tensor(params.shape) )
print("Number of parameters: {:0.2f}M".format((num_params / 10**6).item()))

num_params = 0
for params in VisionTransformer(**ViT_B_config, global_pool='token').parameters():
    num_params += torch.prod( torch.tensor(params.shape) )
print("Number of parameters: {:0.2f}M".format((num_params / 10**6).item()))

num_params = 0
for params in VisionTransformer(**HIPT256_config, global_pool='token').parameters():
    num_params += torch.prod( torch.tensor(params.shape) )
print("Number of parameters: {:0.2f}M".format((num_params / 10**6).item()))

num_params = 0
for params in VisionTransformer(**HIPT4096_config, global_pool='token').parameters():
    num_params += torch.prod( torch.tensor(params.shape) )
print("Number of parameters: {:0.2f}M".format((num_params / 10**6).item()))


Number of parameters: 21.69M
Number of parameters: 85.84M
Number of parameters: 14.59M
Number of parameters: 1.83M


In [6]:
# Quick DINO implementation
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from copy import deepcopy
import timm

HIPT256_config = {'num_layers':8, 'embed_dim': 384, 'mlp_dim': 1536,'num_heads':6} 
HIPT256_config_timm = {'depth':8, 'embed_dim': 384, 'mlp_ratio': 4.0,'num_heads':6, 'img_size': 256, 'weight_init': 'jax'} 

m = 0.9    # centering momentum
b = 0.996    # ema momentum

augment = transforms.Compose([
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.01),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomSolarize(0.2, p=0.1),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

# define device and models
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gs = VisionTransformer(**HIPT256_config, global_pool='token') # student network
#gs = timm.models.vision_transformer.VisionTransformer(**HIPT256_config_timm, global_pool='token')  # comment out if use timm version
#gs.head = nn.Identity()                                                                   # comment out if use timm version
gt = deepcopy(gs)                                            # teacher network
gs.to(device)
gt.to(device)

for param_s in gs.parameters():
    param_s.requires_grad = True

for param_t in gt.parameters():
    param_t.requires_grad = False
    param_t.grad = None

# define loss function
def H(t, s, C, tps=0.04, tpt=0.04, eps=1e-6, literal=False):
    t = t.detach()
    t = torch.softmax((C - t)/tpt, dim=1) 
    if literal:  # unstable and outputs different value than F.log_softmax
        s = torch.softmax(s/tps, dim=1) 
        return - (t*torch.log(s+eps)).sum(dim=1).mean()
    else:
        return - (t*F.log_softmax((s)/tps, dim=1)).sum(dim=1).mean()

# optimizer
optimizer = torch.optim.AdamW(gs.parameters(), lr=0.0005)

# initialize center (C)
C = torch.zeros(gs.embed_dim).to(device)

# Demo: single loop training
while True:
    x = TF.to_pil_image(torch.rand(3, 256, 256).clamp_(0.0, 1.0)) # start from PIL Image
    x1 = augment(x).unsqueeze(0).to(device)
    x2 = augment(x).unsqueeze(0).to(device)

    s1, s2 = gs(x1), gs(x2)
    with torch.no_grad():
        t1, t2 = gt(x1), gt(x2)

    loss = H(t1, s2, C)*0.5 + H(t2, s1, C)*0.5
    loss.backward()
    nn.utils.clip_grad_norm_(gs.parameters(), max_norm=3.0)
    optimizer.step() # update gs

    # ema: exponential moving average from https://github.com/facebookresearch/moco/blob/main/moco/builder.py
    for param_t, param_s in zip(gt.parameters(), gs.parameters()):
        param_t.data = param_t.data * b + param_s.data * (1. - b)
    # update center
    C = m*C + (1-m)*torch.cat([t1, t2], dim=0).mean(dim=0)
    
    print(loss.item())


141.53729248046875
153.80113220214844
62.70068359375
60.295562744140625
58.73565673828125
71.57267761230469
41.55241012573242
67.35658264160156
53.23797607421875
76.87071228027344
67.72838592529297
24.76553726196289
121.81295776367188
113.68099975585938
101.64875793457031
83.54463195800781
56.784637451171875
50.604331970214844
45.24760055541992
71.9202880859375
0.11235539615154266
38.60708999633789
15.8598051071167
23.179563522338867
83.88377380371094
31.170677185058594
53.966033935546875
51.64238357543945
38.58306121826172
68.44325256347656
70.31535339355469
63.14631271362305
53.127044677734375
33.256980895996094
66.13126373291016
51.48120880126953
44.306785583496094
59.341514587402344
44.36029815673828
54.79623794555664
57.2401237487793
35.839439392089844
58.978271484375
51.214149475097656
28.859167098999023
47.431640625
47.35696029663086
53.12445068359375
43.50859069824219
30.73584747314453
46.30270767211914
63.622528076171875
48.988704681396484
43.19705581665039
41.82673263549805
3

KeyboardInterrupt: 

In [30]:
# end-to-end training,  is it possible ?
vit256 = VisionTransformer(**HIPT256_config, global_pool='token')
vit4096 = VisionTransformer(**HIPT4096_config, global_pool='token')

mlp256_4096 = nn.Sequential(*[ nn.Conv1d(HIPT256_config['embed_dim'], HIPT4096_config['embed_dim'], kernel_size=1, bias=True),
                               nn.ReLU(inplace=True),
                               nn.Conv1d(HIPT4096_config['embed_dim'], HIPT4096_config['embed_dim'], kernel_size=1, bias=True)])

x = torch.rand(1, 3, 4096, 4096)  # 4K image
B = x.size(0)

tiled_x = x.unfold(2, 256, 256).unfold(3, 256, 256).permute(0, 2, 3, 1, 4, 5).reshape(-1, 3, 256, 256)

with torch.no_grad():
    print ( vit4096( mlp256_4096( vit256(tiled_x).reshape(B, -1, 256) ) ).size() )


torch.Size([1, 192])
