# A demonstation of synthetic handwritten image generation using variation autoencoders

In [1]:
import torch
from models.vae import VAE
from models.cvae import CVAE
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Using VAE

In [6]:
# Load trained model weights
path = "outputs/vae epoch_3 lr_0.001 bsize_64/weights.pth"
model = VAE().to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()

VAE(
  (encoder): Encoder(
    (conv1): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (fc_mu): Linear(in_features=3136, out_features=20, bias=True)
    (fc_logvar): Linear(in_features=3136, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (fc): Linear(in_features=20, out_features=6272, bias=True)
    (deconv1): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (deconv2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (deconv3): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

# Using Conditional VAE

In [5]:
# Load trained model weights
path = "outputs/conditional_vae epoch_3 lr_0.001 bsize_64/weights.pth"
model = CVAE(num_classes=10).to(device) # 10 digit classes of the MNIST dataset
model.load_state_dict(torch.load(path, map_location=device))
model.eval()

CVAE(
  (encoder): Encoder(
    (conv1): Conv2d(11, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (fc_mu): Linear(in_features=3136, out_features=20, bias=True)
    (fc_logvar): Linear(in_features=3136, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (fc): Linear(in_features=30, out_features=6272, bias=True)
    (deconv1): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (deconv2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (deconv3): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)