# Reverse MNIST with deconvolutional neural networks and PyTorch

A simple deconvolutional network to re-generate the convolutional network to re-generate MNIST digits using PyTorch.

In [None]:
import random
import matplotlib.pyplot as plt
import torch
import torch.utils.data
import torchvision

## Setup GPU

Specific versions of PyTorch are needed depending on your system configuration.

* For GPU training on devices supporting using the ROCM backend, use `requirements/rocm.txt`
* For GPU training on devices supporting using the MPS backend (ie. Apple Metal), use `requirements/cpu.txt`
* For CPU training, use `requirements/cpu.txt`

In [None]:
if torch.cuda.is_available():
    backend = "cuda"
elif torch.backends.mps.is_available():
    backend = "mps"
else:
    backend = "cpu"

device = torch.device(backend)
device

## CVAE
https://github.com/lyeoni/pytorch-mnist-CVAE/blob/master/pytorch-mnist-CVAE.ipynb
https://www.tensorflow.org/tutorials/generative/cvae

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc1 = torch.nn.Linear(64*7*7 + 10, 256)
        self.fc_mu = torch.nn.Linear(256, latent_dim)
        self.fc_logvar = torch.nn.Linear(256, latent_dim)

    def forwrad(self, x, labels):
        x = torch.nn.functional.relu(self.conv1(x))
        x = torch.nn.functional.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = torch.cat((x, labels), dim=1)
        x = torch.nn.functional.relu(self.fc1(x))


class AutoencoderCVAE(torch.nn.Module):
    def __init__(self):
        super(AutoencoderCVAE, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 3, stride=2, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, 3, stride=2, padding=1),
            torch.nn.ReLU(),
        )

        self.fc_mu = torch.nn.Linear()
        self.fc_logvar = torch.nn.Linear()
        self.fc_decode = torch.nn.Linear()


    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x)) 
        x = torch.nn.functional.relu(self.fc2(x))
        x = x.view(-1, 64, 7, 7)  # Reshape to an n-length 7x7 grid of 64 channels
        x = torch.nn.functional.relu(self.deconv1(x))
        x = torch.nn.functional.sigmoid(self.deconv2(x))

        return x