In [1]:
import torch
import torch.nn as nn
import numpy as np


In [2]:
class ImageToPatchEmbeddings(nn.Module):
    def __init__(self, latent_dim, patch_size):
        super().__init__()
        self.patch_size = patch_size
        self.latent_dim = latent_dim

        self.lin_projection = nn.Linear(3*self.patch_size*self.patch_size, self.latent_dim)

        self.class_embedding = nn.Linear(self.latent_dim, self.latent_dim)

        self.learnable_positional_enbedding = nn.Linear(self.latent_dim, self.latent_dim)

    def forward(self, x):
        x = x.unfold(-2, self.patch_size, self.patch_size)
        x = x.unfold(-2, self.patch_size, self.patch_size)
        x = x.movedim(1,-3)
        x = x.flatten(1,2)
        x = x.flatten(-3,-1)

        x = self.lin_projection(x)

        

        pos = self.positions(x.shape[1], x.shape[2])

        pos_embedding = self.learnable_positional_enbedding(pos)

        x = x+pos_embedding

        ones = torch.ones(x.shape[0], 1, self.latent_dim)
        cls_embedding = self.class_embedding(ones)

        embeddings = torch.cat((cls_embedding, x), 1)

        return embeddings
    
    def positions(self, num_patch, latent_dim):
        x = torch.ones(num_patch, latent_dim)
        for i in range(num_patch):
            x[i,:]*=i+1
        
        return x


In [7]:
class CreateQKV(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.WQ = nn.Linear(d_model, d_model, bias= False)
        self.WK = nn.Linear(d_model, d_model, bias= False)
        self.WV = nn.Linear(d_model, d_model, bias= False)

    def forward(self, x):
        return self.WQ(x), self.WK(x), self.WV(x)

In [9]:
    def Attention(query, key, values):
        dk = query.size(1)
        scores = nn.functional.softmax((torch.matmul(query, key.T)/np.sqrt(dk)), dim = 1)

        return torch.matmul(scores, values)

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.d_model = d_model
        self.heads = heads

        self.WQ = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WK = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WV = nn.ModuleList([nn.Linear(self.d_model,(self.d_model//self.heads), bias= False) for _ in range(self.heads)])
        self.WO = nn.Linear(self.d_model, self.d_model, bias = False)
    
    def forward(self, query, key, values):
        attn = []
        for i in range(self.heads):
            q = self.WQ[i](query)
            k = self.WK[i](key)
            v = self.WV[i](values)
            
            attn.append(Attention(q, k, v))
        
        cat_attn = torch.cat(attn, dim = 1)

        return self.WO(cat_attn)

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_heads, latent_dim):
        super().__init__()
        self.num_heads = num_heads
        self.latent_dim = latent_dim

        self.layer_norm1 = nn.LayerNorm(self.latent_dim)
        self.layer_norm2 = nn.LayerNorm(self.latent_dim)
        self.qkv = CreateQKV(self.latent_dim)
        self.MSA = MultiHeadAttention(self.latent_dim, self.num_heads)


    def forward(self, x):
        x_norm1 = self.layer_norm1(x)
        q, k, v = self.qkv(x_norm)
        attention = self.MSA(q, k, v)
        add1 = x+attention
        
        
        
        pass


In [3]:
pre = ImageToPatchEmbeddings(1024,16)
o = pre(torch.rand(1,3,2048,2048))

In [5]:
o.shape

torch.Size([1, 16385, 1024])

In [101]:
x = np.ones((5,10))
x

array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [102]:
for i in range(x.shape[0]):
    x[i,:]*=i+1 
x

array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
       [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]])

In [105]:
xx = np.random.rand(3, 5, 10)
xx

array([[[0.97939543, 0.52654997, 0.68040251, 0.14814911, 0.66951159,
         0.09927551, 0.56171187, 0.18610922, 0.37494385, 0.32788911],
        [0.72699297, 0.96470595, 0.95027024, 0.17224075, 0.44855508,
         0.51678971, 0.12351106, 0.32853366, 0.29950902, 0.63464192],
        [0.21312251, 0.87399493, 0.9251385 , 0.00955824, 0.65094259,
         0.88571189, 0.08502573, 0.93209916, 0.6730934 , 0.71931374],
        [0.7492376 , 0.11851665, 0.63399055, 0.47007862, 0.30583137,
         0.38315912, 0.82187785, 0.9527907 , 0.7735972 , 0.68588855],
        [0.06134051, 0.86871411, 0.82104252, 0.04978584, 0.36772339,
         0.97140903, 0.17883904, 0.11841364, 0.24477979, 0.72758525]],

       [[0.73468369, 0.11631524, 0.97561519, 0.24343114, 0.79762096,
         0.68914959, 0.84889841, 0.61958279, 0.3018582 , 0.66298499],
        [0.27929298, 0.16951288, 0.53438094, 0.86954213, 0.46116236,
         0.88433585, 0.59452318, 0.97853236, 0.19320634, 0.15282897],
        [0.55038326, 0.57

In [106]:
xx+x

array([[[1.97939543, 1.52654997, 1.68040251, 1.14814911, 1.66951159,
         1.09927551, 1.56171187, 1.18610922, 1.37494385, 1.32788911],
        [2.72699297, 2.96470595, 2.95027024, 2.17224075, 2.44855508,
         2.51678971, 2.12351106, 2.32853366, 2.29950902, 2.63464192],
        [3.21312251, 3.87399493, 3.9251385 , 3.00955824, 3.65094259,
         3.88571189, 3.08502573, 3.93209916, 3.6730934 , 3.71931374],
        [4.7492376 , 4.11851665, 4.63399055, 4.47007862, 4.30583137,
         4.38315912, 4.82187785, 4.9527907 , 4.7735972 , 4.68588855],
        [5.06134051, 5.86871411, 5.82104252, 5.04978584, 5.36772339,
         5.97140903, 5.17883904, 5.11841364, 5.24477979, 5.72758525]],

       [[1.73468369, 1.11631524, 1.97561519, 1.24343114, 1.79762096,
         1.68914959, 1.84889841, 1.61958279, 1.3018582 , 1.66298499],
        [2.27929298, 2.16951288, 2.53438094, 2.86954213, 2.46116236,
         2.88433585, 2.59452318, 2.97853236, 2.19320634, 2.15282897],
        [3.55038326, 3.57

In [24]:
X_unfold.unfold(-2,2,2)

tensor([[[[[ 1.,  2.],
           [ 5.,  6.]],

          [[ 3.,  4.],
           [ 7.,  8.]]],


         [[[ 9., 10.],
           [13., 14.]],

          [[11., 12.],
           [15., 16.]]]]])

In [25]:
X

tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.],
         [13., 14., 15., 16.]]])

In [28]:
XX=X.unfold(-2, 2, 2)
XX =XX.unfold(-2, 2, 2)

In [31]:
XX.flatten(1,2)

tensor([[[[ 1.,  2.],
          [ 5.,  6.]],

         [[ 3.,  4.],
          [ 7.,  8.]],

         [[ 9., 10.],
          [13., 14.]],

         [[11., 12.],
          [15., 16.]]]])