<a href="https://colab.research.google.com/github/mehdihosseinimoghadam/Pytorch-Tutorial/blob/main/CNN_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import torchaudio.transforms as transforms

In [40]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.Conv = nn.Sequential(
        nn.Conv2d(self.in_channels,
                  self.out_channels,
                  4,
                  2,
                  1),
        nn.BatchNorm2d(self.out_channels),
        nn.LeakyReLU()          
    )
  def forward(self, x):
    return self.Conv(x)  



class TransConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(TransConvBlock, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.Conv = nn.Sequential(
        nn.ConvTranspose2d(self.in_channels,
                  self.out_channels,
                  4,
                  2,
                  1),
        nn.BatchNorm2d(self.out_channels),
        nn.LeakyReLU()          
    )
  def forward(self, x):
    return self.Conv(x)  






class Encode(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, filter_num_list=[2, 4, 8]):
    super(Encode, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.latent_dim = latent_dim
    self.filter_num_list = filter_num_list

    self.Conv = nn.Sequential(
        ConvBlock(self.in_channels, self.out_channels),
        ConvBlock(self.out_channels, self.out_channels * self.filter_num_list[0]),
        ConvBlock(self.out_channels * self.filter_num_list[0], self.out_channels * self.filter_num_list[1]),
        ConvBlock(self.out_channels * self.filter_num_list[1], self.out_channels * self.filter_num_list[2])
    )

    self.fc = nn.Linear(80 * 2 * 2, self.latent_dim)
    self.mu = nn.Linear(self.latent_dim, self.latent_dim)
    self.logvar = nn.Linear(self.latent_dim, self.latent_dim)

  def forward(self, x):
    x = self.Conv(x)
    print(x.shape)
    x = x.reshape(-1, 80 * 2 * 2)
    x = self.fc(x)
    mu = self.mu(x)
    logvar = self.logvar(x)
    return mu, logvar


class Decoder(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, filter_num_list=[2,4,8]):
    super(Decoder, self).__init__() 
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.latent_dim = latent_dim
    self.filter_num_list = filter_num_list 

    self.DeConve = nn.Sequential(
        TransConvBlock(self.out_channels* self.filter_num_list[2] , self.out_channels * self.filter_num_list[1]),
        TransConvBlock(self.out_channels * self.filter_num_list[1], self.out_channels * self.filter_num_list[0]),
        TransConvBlock(self.out_channels * self.filter_num_list[0], self.out_channels ),
        TransConvBlock(self.out_channels , self.in_channels)
    )
    self.fc = nn.Linear(latent_dim, 1*80*2*2)


  def reparameterezation(self, mu, logvar):
    eps = torch.randn_like(logvar)
    z = mu + eps*torch.exp(.5*logvar) 
    return z


  def forward(self, mu, logvar): 
    z = self.reparameterezation(mu, logvar)
    print(z.shape)
    z = nn.functional.relu(z)
    print(z.shape)
    x = nn.functional.relu(self.fc(z))
    print(x.shape)
    x = x.reshape(-1, 80,2,2)
    x = self.DeConve(x)
    return x, mu, logvar







class VAE(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, filter_num_list=[2,4,8]):
    super(VAE, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.latent_dim = latent_dim
    self.filter_num_list = filter_num_list

    self.Encoder = Encode(self.in_channels,  self.out_channels, self.latent_dim)
    self.Decoder = Decoder(self.in_channels, self.out_channels, self.latent_dim)


  def forward(self, x):
    mu, logvar = self.Encoder(x)
    x, mu, logvar = self.Decoder(mu, logvar)
    return x, mu, logvar

In [31]:
a = torch.rand(1,32,32).unsqueeze(0)
En = Encode(in_channels = self.in_channels , out_channels = self.out_channels, latent_dim = 100)
En(a)[0].shape

NameError: ignored

In [32]:
b = torch.rand(1,100)
De = Decoder(in_channels = 1, out_channels = 10, latent_dim = 100)
x, mu, logvar = De(b,b)
x.shape

torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 320])


torch.Size([1, 1, 32, 32])

In [41]:
vae = VAE(in_channels = 1, out_channels = 10, latent_dim = 100)
vae(a)

1
torch.Size([1, 80, 2, 2])
2
torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 320])


(tensor([[[[-4.8670e-03,  1.4876e-01,  8.9300e-02,  ..., -2.5375e-03,
            -3.0995e-03, -5.4577e-03],
           [-4.9363e-03, -3.4678e-03, -1.3821e-02,  ..., -2.4478e-03,
             6.6602e-02, -5.3227e-03],
           [-5.1893e-03,  7.1839e-01,  6.1820e-01,  ...,  2.1360e-01,
             1.7043e-01, -6.7877e-03],
           ...,
           [-8.1936e-03, -8.1051e-04, -2.3900e-03,  ..., -4.0086e-03,
            -2.1950e-03, -2.5957e-03],
           [-2.2080e-03,  1.4105e+00,  2.3565e-01,  ...,  8.0272e-02,
            -3.5431e-04, -4.1178e-03],
           [-2.8191e-03, -4.2940e-03, -5.9128e-03,  ..., -5.4708e-03,
            -1.4982e-03, -2.8085e-03]]]], grad_fn=<LeakyReluBackward0>),
 tensor([[ 0.1519, -0.4500,  0.0620, -0.0200,  0.0613,  0.2973,  0.0926,  0.2685,
           0.0106,  0.2134,  0.4791, -0.5707, -0.0932,  0.2887,  0.0886,  0.1570,
          -0.0405, -0.0912, -0.0115,  0.3044, -0.0906, -0.5512,  0.1918,  0.2910,
           0.1034, -0.4667,  0.2199,  0.1940,  0.3

In [24]:
import torchvision.datasets as Datasets
import torchvision.transforms as transform
from torch.utils.data import DataLoader 

In [25]:
trans = transform.Compose([
                           transform.Resize((32,32)),
                           transform.ToTensor()
])

In [26]:
train_data = Datasets.MNIST(root="/", train=True, transform= trans, download=True)
test_data = Datasets.MNIST(root="/", train=False, transform= trans, download=True)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)