In [None]:
!pip install torch
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [None]:
import numpy as np
import torch
from torch import nn, einsum
from einops import rearrange, repeat

In [None]:
import torch.nn.functional as f

In [None]:
class PatchMerging_Conv(nn.Module):
  def __init__(self, in_channels, out_channels, downscaling_factor):
    super().__init__()
    self.patch_merge_conv = nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=downscaling_factor,
        stride=downscaling_factor,
        padding=0
    )
  def forward(self,x): # x - (B, C(last stage), H, W)
      x = self.patch_merge_conv(x) # x - (B, C, H, W)
      return x.permute(0,2,3,1) # x - (B, H, W, C)

In [None]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x) # x - (1,48,3136) | (1,384,784) | (1,768,196) | (1,1536,49) : (b,df*df*oldC, L=new_h*new_w)
        x = x.view(b, -1, new_h, new_w) # x - (b,df*df*oldC,new_h,new_w)
        x = x.permute(0, 2, 3, 1) # x - (b,new_h,new_w,df*df*oldC)
        x = self.linear(x) # x - (b,new_h,new_w,C)
        return x

In [None]:
dmx = torch.randn(1,3,224,224)
pm = PatchMerging(3, 96, 4)
dmx = pm(dmx)
dmx2 = torch.randn(1,96,56,56)
pm = PatchMerging(96, 192, 2)
dmx2 = pm(dmx2)
dmx.shape, dmx2.shape


(torch.Size([1, 56, 56, 96]), torch.Size([1, 28, 28, 192]))

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

In [None]:
class LN(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim) # dim = C
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs) # Pre Norm for V1
        # return self.norm(self.fn(x), **kwargs) # Post Norm for V2

In [None]:
class MLP(nn.Module):
  # dim = C = 96 | 192 | 384 | 768
  # mlp_dim = C*4
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# try ln
torch.manual_seed(10)
B,H,W,C = 1, 2, 2, 3
input = torch.randn(B,H,W,C) * 10
input[0,0,0,0] = 1
input[0,0,0,1] = 2
input[0,0,0,2] = 3

print(input)
print(nn.LayerNorm(C)(input))

tensor([[[[  1.0000,   2.0000,   3.0000],
          [-12.2769,   9.1983,  -3.4847]],

         [[ -8.6918,  -9.5817, -11.9205],
          [ 19.0500,  -9.3733,  -8.4647]]]])
tensor([[[[-1.2247,  0.0000,  1.2247],
          [-1.1445,  1.2917, -0.1471]],

         [[ 1.0083,  0.3547, -1.3629],
          [ 1.4137, -0.7413, -0.6724]]]], grad_fn=<NativeLayerNormBackward0>)


In [None]:
def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2)

    # masking : exp(-inf) gives 0
    if upper_lower: # last row mask
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf') # down left section mask
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf') # up right section mask

    if left_right: # last column mask
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask

In [None]:
create_mask(3, 1, False, True)

tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.]])

In [None]:
create_mask(3, 1, True, False)

tensor([[0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.]])

In [None]:
x = torch.linspace(1,81,81).view(9,9)
print(x)
y = torch.roll(x,shifts=(-1,-1), dims=(0,1))
print(y)

# shifts = (-1,-1) : shift the window down and right
# shifts = (1,1) : shift the window up and left (back)

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24., 25., 26., 27.],
        [28., 29., 30., 31., 32., 33., 34., 35., 36.],
        [37., 38., 39., 40., 41., 42., 43., 44., 45.],
        [46., 47., 48., 49., 50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59., 60., 61., 62., 63.],
        [64., 65., 66., 67., 68., 69., 70., 71., 72.],
        [73., 74., 75., 76., 77., 78., 79., 80., 81.]])
tensor([[11., 12., 13., 14., 15., 16., 17., 18., 10.],
        [20., 21., 22., 23., 24., 25., 26., 27., 19.],
        [29., 30., 31., 32., 33., 34., 35., 36., 28.],
        [38., 39., 40., 41., 42., 43., 44., 45., 37.],
        [47., 48., 49., 50., 51., 52., 53., 54., 46.],
        [56., 57., 58., 59., 60., 61., 62., 63., 55.],
        [65., 66., 67., 68., 69., 70., 71., 72., 64.],
        [74., 75., 76., 77., 78., 79., 80., 81., 73.],
        [ 2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,  1.]])


In [None]:
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement

    def forward(self, x):
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2)) # to be shifted along H,W dimns


In [None]:
def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    # indices - (49,2)
    distances = indices[None, :, :] - indices[:, None, :]
    return distances
    # distances - relative position of row number and column number along the height and width of window


In [None]:
class MHWSA(nn.Module):
   def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        # dim = 96 | 192 | 384 | 768
        # heads = num_heads = 3 | 6 | 12 | 24
        # head_dim = 32
        super().__init__()
        self.heads = heads
        inner_dim = head_dim * heads # inner_dim - 96 | 192 | 384 | 768 (=C)
        self.scale = head_dim ** - 0.5
        self.window_size = window_size # 7
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) # C*3 for Q,K,V via a single Linear layer
        
        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1)) # relative: 13x13 
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2)) #absolute: 49x49

        
        if self.shifted:
            displacement = window_size // 2
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)

        self.to_out = nn.Linear(inner_dim, dim)
  
   def forward(self,x):
     # x - (b,56,56,96) | (b,28,28,192) | (b,14,14,384) | (b,7,7,768) 
     if self.shifted:
            x = self.cyclic_shift(x)
     
     # batch = b, image height = n_h, image width = n_w, heads = h
     b, n_h, n_w, _, h = *x.shape, self.heads
     nw_h = n_h // self.window_size # nw_h = num_windows in image height
     nw_w = n_w // self.window_size # nw_w = num_windows in image width 

     #qkv
     qkv = self.to_qkv(x).chunk(3, dim=-1)

     # d = head_dim, h = heads, w_w = window width, w_h = window height
     q, k, v = map(
            lambda t: rearrange(t, 
                                'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', 
                                h=h, 
                                w_h=self.window_size, 
                                w_w=self.window_size
                    ), qkv
     ) 
     # q,k,v- (Batch, Number of Heads, Number of Windows, Number of Tokens in each window, Embedding Dimension of each head)
     # q,k,v - (b,h,w,t,d), t=49, d=32, w = num_windows, t = num_tokens_in_each_window
     # q,k,v - (b,3,64,49,32) | (b,6,16,49,32) | (b,12,4,49,32) | (b,24,1,49,32)

     #dot product
     dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
     # dots - (b,h,w,t,t), t=i=j=49

     #pos embedding
     
     if self.relative_pos_embedding:
            #relative_indices - (49,49,2), pos_embedding - (13,13) : unique possible combinations out of values of relative_indices
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
     else:
            # absolute pos embedding
            dots += self.pos_embedding  # (b,h,w,t,t) + (t,t) = (b,h,w,t,t) 

     #masking
     if self.shifted:
       # upper_lower_mask, left_right_mask - (t,t)
       dots[:, :, -nw_w:] += self.upper_lower_mask
       dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

     
     #softmax
     attn = dots.softmax(dim=-1)
     
     #dot product with v
     attn = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
     # attn - (b,h,w,t,d)

     # convert to (b,n_h,n_w,C)
     out = rearrange(attn, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
      
     # project to (b,n_h,n_w,C)
     out = self.to_out(out)

     #restore the original feature sequence
     if self.shifted:
            out = self.cyclic_back_shift(out)
     
     return out

            


In [None]:
class SwinBlock(nn.Module):
  def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        # dim = C
        # mlp_dim = C*4
        super().__init__()
        self.attention_block = Residual(LN(dim, MHWSA(dim=dim,
                                                     heads=heads,
                                                     head_dim=head_dim,
                                                     shifted=shifted,
                                                     window_size=window_size,
                                                     relative_pos_embedding=relative_pos_embedding
        )))
        self.mlp_block = Residual(LN(dim, MLP(dim=dim, hidden_dim=mlp_dim))) 
  def forward(self, x): # x - (b,56,56,96) | (b,28,28,192) | (b,14,14,384) | (b,7,7,768)
        x = self.attention_block(x) 
        x = self.mlp_block(x) 
        return x # x - (b,56,56,96) | (b,28,28,192) | (b,14,14,384) | (b,7,7,768)

In [None]:
class Stage(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'

        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        # self.patch_partition = PatchMerging_Conv(in_channels=in_channels, out_channels=hidden_dimension,
        #                                     downscaling_factor=downscaling_factor)
        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    # b = batch size, df = downscaling_factor

    def forward(self, x):
        # x - (b,3,224,224) | (b, 96,56,56) | (b,192,28,28) | (b,384,14,14)
        x = self.patch_partition(x) # Patch Merging of df x df neighbouring patches
        # x - (b,56,56,96) | (b,28,28,192) | (b,14,14,384) | (b,7,7,768)
        for swin_block, shifted_swin_block in self.layers:
            x = swin_block(x) 
            x = shifted_swin_block(x)
        return x.permute(0, 3, 1, 2) # x - (b, 96,56,56) | (b,192,28,28) | (b,384,14,14) | (b,768,7,7)

In [None]:
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()

        # Acc. to Swin T: hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24)
        
        # input img - (3,224,224) (3,n_h,n_w) : each in batch b
        self.stage1 = Stage(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        # x <- b, C(hidden_dim = 96), n_h/4, n_w/4 
        # x <- (b,96,56,56)
        
        self.stage2 = Stage(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        # every subsequent stage does n_h/2 and n_w/2 and c*2 in x : heirarchical structure

        # x <- (b,192,28,28)
        
        self.stage3 = Stage(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        # x <- (b,384,14,14)
        
        self.stage4 = Stage(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        # x <- (b,768,7,7)
        
        # head for classification here
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        x1 = self.stage1(img)
        x2 = self.stage2(x1)
        x3 = self.stage3(x2)
        x4 = self.stage4(x3)
        
        # x = x4.mean(dim=[2, 3])
        # return self.mlp_head(x)
        return x1,x2,x3,x4


In [None]:
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

In [None]:
net = swin_t(
    channels=3,
    num_classes=3,
    head_dim=32,
    window_size=7,
    downscaling_factors=(4,2,2,2),
    relative_pos_embedding=True
)

x = torch.randn(1,3,224,224)
y1,y2,y3,y4 = net(x)
y1.shape,y2.shape,y3.shape,y4.shape

(torch.Size([1, 96, 56, 56]),
 torch.Size([1, 192, 28, 28]),
 torch.Size([1, 384, 14, 14]),
 torch.Size([1, 768, 7, 7]))

In [None]:
class ShallowFeatures_Conv():
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    super().__init__()
    self.conv = nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=0
    )
  def forward(self,x): 
      x = self.conv(x) # x - (B, C, H, W)
      return x

In [None]:
class SpatialInvariant_Conv():
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    super().__init__()
    self.conv = nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=0
    )
  def forward(self,x): 
      x = self.conv(x) # x - (B, C, H, W)
      return x

In [None]:
class AddInductiveBias_Conv():
  def __init__(self, in_channels, out_channels, kernel_size, stride):
    super().__init__()
    self.conv = nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=0
    )
  def forward(self,x): 
      x = self.conv(x) # x - (B, C, H, W)
      return x

In [None]:
class RSTB(nn.Module):
  def __init__(self, in_channels, hidden_dimension, layers, num_heads, head_dim, window_size,
                 relative_pos_embedding, rstb_conv_kernel_size, rstb_conv_stride):
    super().__init__()
    assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'
    self.layers = nn.ModuleList([])
    for _ in range(layers // 2):
        self.layers.append(nn.ModuleList([
            SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                      shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                      shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
        ]))
    
    self.spatial_invariant_conv = SpatialInvariant_Conv(in_channels=in_channels, out_channels=hidden_dimension,kernel_size=rstb_conv_kernel_size,stride=rstb_conv_stride)
    
  def forward(self, x):
      # x - (b,180,56,56) :(B,C,H,W) 
      for swin_block, shifted_swin_block in self.layers:
          x = swin_block(x) 
          x = shifted_swin_block(x)
      x = self.spatial_invariant_conv(x) 
      return x

In [None]:
class SwinIRUpscaler(nn.Module):
  def __init__(self, *, rstb=6, stl=6, window_size=8, heads=6, channels=180, head_dim=32,
                 rstb_conv_kernel_size=3, rstb_conv_stride=1, relative_pos_embedding=True, upscale_factor=3,
               conv_kernel_size=3, conv_stride=1, shallow_conv_kernel_size=3, shallow_conv_stride=1):
    # ISSUE : conv_stride = ?, conv_kernel_size = ? for multiple convolutions
    # ISSUE: tensor shape math
    # ISSUE: convert feature channels to image rbg channels
    super().__init__()

    self.shallow_features = ShallowFeatures_Conv(in_channels=3, out_channels=channels, kernel_size=shallow_conv_kernel_size,stride=shallow_conv_stride)
    self.deep_extract = nn.ModuleList([])
    
    for _ in range(rstb):
      self.deep_extract.append(RSTB(in_channel=channels, hidden_dimension=channels, layers=stl, num_heads=heads, head_dim=head_dim, window_size=window_size, relative_pos_embedding=relative_pos_embedding, rstb_conv_kernel_size=rstb_conv_kernel_size, rstb_conv_stride=rstb_conv_stride))
    
    self.deep_extract.append(AddInductiveBias_Conv(in_channels=channels, out_channels=channels,kernel_size=conv_kernel_size,stride=conv_stride))

    self.subpixelconv = nn.PixelShuffle(upscale_factor) # C = C / (upscale_factor**2), H = H*upscale_factor, W = W*upscale_factor

  def forward(self,x):
    x = self.shallow_features(x)
    y = x
    for stage in self.stages:
      y = stage(y)
    x = x + y # Residual
    x = self.subpixelconv(x)
    return x

