<a href="https://colab.research.google.com/github/mehdihosseinimoghadam/Pytorch-Tutorial/blob/main/CNN_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import torchaudio.transforms as transforms

In [12]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.Conv = nn.Sequential(
        nn.Conv2d(self.in_channels,
                  self.out_channels,
                  4,
                  2,
                  1),
        nn.BatchNorm2d(self.out_channels),
        nn.LeakyReLU()          
    )
  def forward(self, x):
    return self.Conv(x)  



class TransConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(TransConvBlock, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.Conv = nn.Sequential(
        nn.ConvTranspose2d(self.in_channels,
                  self.out_channels,
                  4,
                  2,
                  1),
        nn.BatchNorm2d(self.out_channels),
        nn.LeakyReLU()          
    )
  def forward(self, x):
    return self.Conv(x)  






class Encode(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, filter_num_list=[2, 4, 8]):
    super(Encode, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.latent_dim = latent_dim
    self.filter_num_list = filter_num_list

    self.Conv = nn.Sequential(
        ConvBlock(self.in_channels, self.out_channels),
        ConvBlock(self.out_channels, self.out_channels * self.filter_num_list[0]),
        ConvBlock(self.out_channels * self.filter_num_list[0], self.out_channels * self.filter_num_list[1]),
        ConvBlock(self.out_channels * self.filter_num_list[1], self.out_channels * self.filter_num_list[2])
    )

    self.fc = nn.Linear(80 * 2 * 2, self.latent_dim)
    self.mu = nn.Linear(self.latent_dim, self.latent_dim)
    self.logvar = nn.Linear(self.latent_dim, self.latent_dim)

  def forward(self, x):
    x = self.Conv(x)
    print(x.shape)
    x = x.reshape(-1, 80 * 2 * 2)
    x = self.fc(x)
    mu = self.mu(x)
    logvar = self.logvar(x)
    return mu, logvar


class Decoder(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, filter_num_list=[2,4,8]):
    super(Decoder, self).__init__() 
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.latent_dim = latent_dim
    self.filter_num_list = filter_num_list 

    self.DeConve = nn.Sequential(
        TransConvBlock(self.out_channels* self.filter_num_list[2] , self.out_channels * self.filter_num_list[1]),
        TransConvBlock(self.out_channels * self.filter_num_list[1], self.out_channels * self.filter_num_list[0]),
        TransConvBlock(self.out_channels * self.filter_num_list[0], self.out_channels ),
        TransConvBlock(self.out_channels , self.in_channels)
    )
    self.fc = nn.Linear(latent_dim, 1*80*2*2)


  def reparameterezation(self, mu, logvar):
    eps = torch.randn_like(logvar)
    z = mu + eps*torch.exp(.5*logvar) 
    return z


  def forward(self, mu, logvar): 
    z = self.reparameterezation(mu, logvar)
    print(z.shape)
    z = nn.functional.relu(z)
    print(z.shape)
    x = nn.functional.relu(self.fc(z))
    print(x.shape)
    x = x.reshape(-1, 80,2,2)
    x = self.DeConve(x)
    return x, mu, logvar







class VAE(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, filter_num_list=[2,4,8]):
    super(VAE, self).__init()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.latent_dim = latent_dim
    self.filter_num_list = filter_num_list

    self.Encodee = Encode(in_channels = 1, out_channels = 10, latent_dim = self.filter_num_list)
    self.Decoder = Decoder(in_channels = 1, out_channels = 10, latent_dim = self.filter_num_list)




In [13]:
a = torch.rand(1,32,32).unsqueeze(0)
En = Encode(in_channels = self.in_channels , out_channels = self.out_channels, latent_dim = 100)
En(a)[0].shape

torch.Size([1, 80, 2, 2])


torch.Size([1, 100])

In [19]:
b = torch.rand(1,100)
De = Decoder(in_channels = 1, out_channels = 10, latent_dim = 100)
x, mu, logvar = De(b,b)
x.shape

torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 320])


torch.Size([1, 1, 32, 32])

In [24]:
import torchvision.datasets as Datasets
import torchvision.transforms as transform
from torch.utils.data import DataLoader 

In [25]:
trans = transform.Compose([
                           transform.Resize((32,32)),
                           transform.ToTensor()
])

In [26]:
train_data = Datasets.MNIST(root="/", train=True, transform= trans, download=True)
test_data = Datasets.MNIST(root="/", train=False, transform= trans, download=True)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=True)