# Reverse MNIST with deconvolutional neural networks and PyTorch

A simple deconvolutional network to re-generate the convolutional network to re-generate MNIST digits using PyTorch.

In [None]:
import random
import matplotlib.pyplot as plt
import torch
import torch.utils.data
import torchvision

## Setup GPU

Specific versions of PyTorch are needed depending on your system configuration.

* For GPU training on devices supporting using the ROCM backend, use `requirements/rocm.txt`
* For GPU training on devices supporting using the MPS backend (ie. Apple Metal), use `requirements/cpu.txt`
* For CPU training, use `requirements/cpu.txt`

In [None]:
if torch.cuda.is_available():
    backend = "cuda"
elif torch.backends.mps.is_available():
    backend = "mps"
else:
    backend = "cpu"

device = torch.device(backend)
device

## CVAE
https://github.com/lyeoni/pytorch-mnist-CVAE/blob/master/pytorch-mnist-CVAE.ipynb
https://www.tensorflow.org/tutorials/generative/cvae

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc1 = torch.nn.Linear(64 * 7 * 7 + 10, 256)
        self.fc_mu = torch.nn.Linear(256, latent_dim)
        self.fc_logvar = torch.nn.Linear(256, latent_dim)

    def forward(self, x, labels):
        x = torch.nn.functional.relu(self.conv1(x))
        x = torch.nn.functional.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = torch.cat((x, labels), dim=1)
        h = torch.nn.functional.relu(self.fc1(x))  # Latent vector
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar


class Decoder(torch.nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc1 = torch.nn.Linear(latent_dim, 256)
        self.fc2 = torch.nn.Linear(256, 64 * 7 * 7)
        self.deconv1 = torch.nn.ConvTranspose2D(
            64, 32, kernel_size=3, stride=2, padding=1
        )
        self.deconv2 = torch.nn.ConvTranspose2D(
            32, 1, kernel_size=3, stride=2, padding=1
        )

    def forward(self, z, labels):
        z = torch.cat((z, labels), dim=1)
        h = torch.nn.functional.relu(self.fc1(z))
        h = torch.nn.functional.relu(self.fc2(h))
        h = h.view(h.size(0), 64, 7, 7)
        h = torch.nn.functional.relu(self.deconv1(h))
        x = torch.sigmoid(self.decov2(h))
        return x


class CVAE(torch.nn.Module):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, labels):
        mu, logvar = self.encoder(x, labels)
        z = self.reparameterize(mu, logvar)
        x_pred = self.decoder(z, labels)
        return x_pred, mu, logvar