<a href="https://colab.research.google.com/github/brainmentorspvtltd/IGDTU_PyTorchTraining/blob/main/IG_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

In [9]:
class VAE(nn.Module):
  def __init__(self, input_dim, z_dim, hidden_dim):
    super().__init__()
    # Encoder
    self.h1 = nn.Linear(input_dim, hidden_dim) # input
    self.h2 = nn.Linear(hidden_dim, z_dim) # mean
    self.h3 = nn.Linear(hidden_dim, z_dim) # standard deviation

    # Decoder
    self.z_to_hidden = nn.Linear(z_dim, hidden_dim)
    self.hidden_to_img = nn.Linear(hidden_dim, input_dim)

  # x - input image
  def encode(self, x):
    x = F.relu(self.h1(x))
    mean_layer = self.h2(x)
    std_layer = self.h3(x)
    return mean_layer, std_layer

  # z - output of encoder
  def decode(self, z):
    x = F.relu(self.z_to_hidden(z))
    x = torch.sigmoid(self.hidden_to_img(x))
    return x

  def forward(self, x):
    mean, std = self.encode(x)
    epsilon = torch.randn_like(std)
    z_param = mean + epsilon * std
    x = self.decode(z_param)
    return x, mean, std

In [4]:
dataset = datasets.MNIST(root="data/", train=True,
                         transform=transforms.ToTensor(),
                         download=True)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 109443175.40it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 98564437.61it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 26888530.58it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 15338589.99it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 10

input_dim = 784 # 28 x 28
z_dim = 32
hidden_dim = 256
model = VAE(input_dim, z_dim, hidden_dim).to(device)

optimization = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.BCELoss(reduction="sum")

for epoch in range(epochs):
  var = tqdm(enumerate(train_loader))
  for batch, (x,y) in var:
    x = x.to(device).view(-1, input_dim)
    output_img, mean, std = model(x)
    loss = loss_fn(output_img, x)
    kl_divergence = -torch.sum(1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2))

    total_loss = loss + kl_divergence
    optimization.zero_grad()
    loss.backward()
    optimization.step()
    var.set_postfix(total_loss = total_loss.item())

1875it [00:19, 97.84it/s, total_loss=1.85e+4]
1875it [00:19, 95.66it/s, total_loss=1.92e+4] 
1875it [00:20, 93.25it/s, total_loss=1.98e+4]
1875it [00:19, 95.19it/s, total_loss=1.86e+4] 
1875it [00:19, 95.03it/s, total_loss=2.08e+4] 
1875it [00:19, 98.49it/s, total_loss=1.97e+4] 
1875it [00:19, 95.83it/s, total_loss=1.98e+4]
1875it [00:19, 97.79it/s, total_loss=2.13e+4] 
1875it [00:19, 95.72it/s, total_loss=1.99e+4] 
1875it [00:19, 95.65it/s, total_loss=2.1e+4]
