In [68]:
import torch
from torch import nn
from torch import functional as F

In [3]:
class VAE(nn.Module):
   def __init__(self, latent_dim=20):
       super(VAE, self).__init__()
       self.encoder = Encoder(latent_dim)
       self.decoder = Decoder(latent_dim)

   def forward(self,inputs):
     z_mean, z_log_var = self.encoder(inputs)
     z = self.reparameterize(z_mean, z_log_var)
     reconstructed = self.decoder(z)
     return reconstructed, z_mean, z_log_var

   def reparameterize(self, mu, log_var):
     std = torch.exp(0.5 * log_var)
     eps = torch.randn_like(std)
     return mu + (eps * std)

In [4]:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, src)

In [102]:

#MLP block
class MLP(nn.Module):

    def __init__(self,
                 in_dim,
                 block_sizes,
                 last_linear=True,
                 **kwargs) -> None:

        super(MLP, self).__init__(**kwargs)

        self.in_dim = in_dim
        self.last_linear = last_linear

        current_dim = in_dim

        blocks = []
        for i in range(len(block_sizes)):
            linear = nn.Linear(current_dim, block_sizes[i])
            blocks.append(linear)
            if i < len(block_sizes)-1 or not last_linear:
                act = nn.ReLU()
                blocks.append(act)
            current_dim = block_sizes[i]
        self.mlp = nn.Sequential(*blocks)

    def forward(self, input, transpose=False):
        if transpose:
            return self._transpose_call(input)
        else:
            return self.mlp(input)
    
    def _transpose_call(self, input):
        out = input
        for layer in reversed(self.mlp):
            if isinstance(layer, nn.Linear):
                out = torch.matmul(out, layer.weight)
            else:
                out = layer(out)
        return out


#Residual attention block 
class TransformerLayer(nn.Module):
    def __init__(
        self,
        embed_dim=128,
        num_heads=8,
        **kwargs
        ):
        super(TransformerLayer, self).__init__(**kwargs)
        
        #Dropout
        self.do = nn.Dropout(p=0.2)

        #Multi-head attention
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mha = nn.MultiheadAttention(
            embed_dim=embed_dim, 
            num_heads=num_heads
            )
        
        #MLP
        self.mlp = MLP(embed_dim, [4*embed_dim, embed_dim])
        self.norm2 = nn.LayerNorm(embed_dim)
    
    def forward(self, x, mask=None, attn_mask=None):
        mha,_ = self.mha(x, x, x, mask, attn_mask=attn_mask, need_weights=False)
        mha = self.do(mha)
        x = self.norm1(x + mha)
        x = self.norm2(x + self.do(self.mlp(x)))
        return x

In [103]:
transformer_model = TransformerLayer(embed_dim=512)
src = torch.rand((10, 512))
out = transformer_model(src)

In [104]:
out.shape

torch.Size([10, 512])

In [135]:
# class TransMLPBlock(nn.Module):
    
#     def __init__(self, input_dim, hidden_dim, num_heads, **kwargs):
#         super(TransMLPBlock, self).__init__(**kwargs)
        
#         self.trans_layer = TransformerLayer(embed_dim=input_dim,
#                                             num_heads=num_heads)
#         self.dropout = nn.Dropout(p=0.2)
#         self.mlp = MLP(input_dim, [input_dim//2, hidden_dim])
        
#     def forward(self, x):
#         out = self.trans_layer(x)
#         out = self.dropout(out)
#         out = self.mlp(out)
#         return out
    

class TransMLPBlock(nn.Module):
    
    def __init__(self, mlp, input_dim, num_heads, transpose_mlp=False, **kwargs):
        super(TransMLPBlock, self).__init__(**kwargs)
        
        self.trans_layer = TransformerLayer(embed_dim=input_dim,
                                            num_heads=num_heads)
        self.dropout = nn.Dropout(p=0.2)
        self.mlp = mlp
        self.transpose_mlp = transpose_mlp
        
    def forward(self, x):
        out = self.trans_layer(x)
        out = self.dropout(out)
        out = self.mlp(out, transpose=self.transpose_mlp)
        return out

In [106]:
block = TransMLPBlock(512, 14)
out = block(src)

In [107]:
out.shape

torch.Size([10, 14])

In [108]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, K, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        
        blocks = [TransMLPBlock(input_dim, 64, num_heads=5),
                   TransMLPBlock(64, 32, num_heads=4),
                   TransMLPBlock(32, K, num_heads=4)]
        self.blocks = nn.Sequential(*blocks)
        
    def forward(self, x):
        return self.blocks(x)

In [109]:
encoder = Encoder(100, 7)

In [110]:
src = torch.rand((10, 100))
out = encoder(src)

In [111]:
out.shape

torch.Size([10, 7])

In [157]:
class TransAE(nn.Module):
    
    def __init__(self, input_dim, K, **kwargs):
        super(TransAE, self).__init__(**kwargs)
        
        mlps = [MLP(input_dim, [64, 64]),
                MLP(64, [32, 32]),
                MLP(32, [16, K])]
        
        blocks = [TransMLPBlock(mlps[0], input_dim, num_heads=5),
                  TransMLPBlock(mlps[1], 64, num_heads=4),
                  TransMLPBlock(mlps[2], 32, num_heads=4)]
        
        self.encoder = nn.Sequential(*blocks)
        
        blocks = [TransMLPBlock(mlps[2], K, num_heads=3, transpose_mlp=True),
                  TransMLPBlock(mlps[1], 32, num_heads=4, transpose_mlp=True),
                  TransMLPBlock(mlps[0], 64, num_heads=4, transpose_mlp=True)]
        
        self.decoder = nn.Sequential(*blocks)
        
    def forward(self, inputs):
        code = self.encoder(inputs)
        out = self.decoder(code)
        return out
    
#     def forward(self, inputs):
#         z_mean, z_log_var = self.encoder(inputs)
#         z = self.reparameterize(z_mean, z_log_var)
#         reconstructed = self.decoder(z)
#         return reconstructed, z_mean, z_log_var

#     def reparameterize(self, mu, log_var):
#         std = torch.exp(0.5 * log_var)
#         eps = torch.randn_like(std)
#         return mu + (eps * std)

In [158]:
ae = TransAE(100, 3)

In [159]:
x = torch.rand(200, 100)


In [160]:
code = ae.encoder(x)
code.shape

torch.Size([200, 3])

In [161]:
out = ae.decoder(code)

In [162]:
out.shape

torch.Size([200, 100])

In [118]:
mlp = MLP(10, [8, 4])

In [114]:
inputs = torch.rand(20, 10)
out = mlp(inputs)

In [115]:
out.shape

torch.Size([20, 4])

In [116]:
rec = mlp(out, transpose=True)

In [117]:
rec.shape

torch.Size([20, 10])

In [124]:
mlp.mlp[0].weight.grad

tensor([[2.1223, 2.1223, 2.1223, 2.1223, 2.1223, 2.1223, 2.1223, 2.1223, 2.1223,
         2.1223],
        [0.2600, 0.2600, 0.2600, 0.2600, 0.2600, 0.2600, 0.2600, 0.2600, 0.2600,
         0.2600],
        [1.1222, 1.1222, 1.1222, 1.1222, 1.1222, 1.1222, 1.1222, 1.1222, 1.1222,
         1.1222],
        [0.1715, 0.1715, 0.1715, 0.1715, 0.1715, 0.1715, 0.1715, 0.1715, 0.1715,
         0.1715],
        [3.5680, 3.5680, 3.5680, 3.5680, 3.5680, 3.5680, 3.5680, 3.5680, 3.5680,
         3.5680],
        [0.4781, 0.4781, 0.4781, 0.4781, 0.4781, 0.4781, 0.4781, 0.4781, 0.4781,
         0.4781],
        [0.1711, 0.1711, 0.1711, 0.1711, 0.1711, 0.1711, 0.1711, 0.1711, 0.1711,
         0.1711],
        [1.3130, 1.3130, 1.3130, 1.3130, 1.3130, 1.3130, 1.3130, 1.3130, 1.3130,
         1.3130]])

In [120]:
x = torch.rand(10, 4)
rec = mlp(x, transpose=True)

In [121]:
rec.shape

torch.Size([10, 10])

In [123]:
rec.sum().backward()