In [54]:
import torch
import torchvision

In [55]:
print(f'torch: {torch.__version__}')
print(f'torchvision: {torchvision.__version__}')

torch: 1.9.0+cu111
torchvision: 0.10.0+cu111


In [56]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [57]:
np.random.seed(1234)
torch.manual_seed(33)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [58]:
print(f'Device:{device}')

Device:cuda


In [59]:
root = os.path.join('data', 'mnist')

In [60]:
transform = transforms.Compose([
                               transforms.ToTensor(),
                               lambda x: x.view(-1)
])

In [61]:
mnist_train = torchvision.datasets.MNIST(root=root,
                                         download=True,
                                         train=True,
                                         transform=transform)

mnist_test = torchvision.datasets.MNIST(root=root,
                                        download=True,
                                        train=False,
                                        transform=transform
                                        )

In [62]:
mnist_train

Dataset MNIST
    Number of datapoints: 60000
    Root location: data/mnist
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               <function <lambda> at 0x7f56a02d34d0>
           )

In [63]:
mnist_test

Dataset MNIST
    Number of datapoints: 10000
    Root location: data/mnist
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               <function <lambda> at 0x7f56a02d34d0>
           )

In [64]:
train_dataloader = DataLoader(mnist_train,
                              batch_size=100,
                              shuffle=True)
test_dataloader = DataLoader(mnist_test,
                             batch_size=100,
                             shuffle=False)

In [65]:
# for x in train_dataloader:
#     print(x)
#     break

VAE Class

In [66]:
class VAE(nn.Module):
    """
    Variation AutoEncoder
    """
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.encoder = Encoder(device)
        self.decoder = Decoder(device)

    def forward(self, x):

        mean, var = self.encoder(x)
        # Creating latent variables
        z = self.reparameterize(mean, var)

        y = self.decoder(z)
        # Generated images and latent variables
        return y, z

    def reparameterize(self, mean, var):
        """
        Creating latent variables
        """
        # Creating a standard normal distribution
        eps = torch.randn(mean.size()).to(self.device)
        # Reparameterization trick
        z = mean + torch.sqrt(var) * eps

    def lower_bound(self, x):
        # Variance and mean vectors
        mean, var = self.encoder(x)

        # Creating latent variables from means and variances
        z = self.reparameterize(mean, var)
        # Create a generated image from a latent variable
        y = self.decoder(z)
        # Reconstruction error
        reconst = - torch.mean(torch.sum(x * torch.log(y) + (1-x) * torch.log(1-y),dim=1))
        # Regularization
        kl = - 1/2 * torch.mean(torch.sum(1
                                          + torch.log(var)
                                          - mean**2
                                          - var, dim=1))

        L = reconst + kl

        return L

In [67]:
class Encoder(nn.Module):

    def __init__(self, device):
        super().__init__()
        self.device = device
        self.l1 = nn.Linear(784, 200)
        self.l_mean = nn.Linear(200, 10)
        self.l_var = nn.Linear(200, 10)

    def forward(self, x):

        h = self.l1(x)
        h = torch.relu(h)

        mean = self.l_mean(h)
        var = self.l_var(h)

        var = F.softplus(var)

        return mean, var

In [68]:
class Decoder(nn.Module):

    def __init__(self, device):
        super().__init__()
        self.device = device
        self.l1 = nn.Linear(10, 200)
        self.out = nn.Linear(200, 784)

    def forward(self, x):

        h = self.l1(x)
        h = torch.relu(h)
        h = self.out(h)
        y = torch.sigmoid(h)

        return y

In [70]:
model = VAE(device=device).to(device)

In [71]:
criterion = model.lower_bound

In [72]:
optimizer = optimizers.Adam(model.parameters())

In [73]:
print(model)

VAE(
  (encoder): Encoder(
    (l1): Linear(in_features=784, out_features=200, bias=True)
    (l_mean): Linear(in_features=200, out_features=10, bias=True)
    (l_var): Linear(in_features=200, out_features=10, bias=True)
  )
  (decoder): Decoder(
    (l1): Linear(in_features=10, out_features=200, bias=True)
    (out): Linear(in_features=200, out_features=784, bias=True)
  )
)


In [74]:
!pip install torchinfo



In [75]:
from torchinfo import summary

In [76]:
# summary(model, input_size=(100,784))

In [77]:
EPOCHS = 20

In [78]:
for epoch in range(EPOCHS):
    
    train_loss = 0

    for (x, _) in train_dataloader:
        x = x.to(device)

        print(f'x \n {x.shape}')
        print(f'training...')

        model.train()
        loss = criterion(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    train_loss /= len(train+dataloader)

    print(f'epoch:{epoch}, loss{train_loss}')

RuntimeError: ignored