In [1]:
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
import torch.utils.data
from torchvision import datasets, transforms
# from torchvision.utils import save_image

from vae_conv_mnist import conv_variational_autoencoder 
import matplotlib.pyplot as plt 
import sys, os 

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)   

In [3]:
channels = 1
# batch_size = cm_train.shape[0]/100
conv_layers = 4
feature_maps = [64,64,64,64]
filter_shapes = [3,3,3,3]
strides = [1,2,1,1]
dense_layers = 2
dense_neurons = [128, 64]
dense_dropouts = [0.0, 0.0]
latent_dim = 20

image_size = train_loader.dataset.train_data.shape[1]

In [None]:
autoencoder = conv_variational_autoencoder(image_size,channels,conv_layers,feature_maps,
                                           filter_shapes,strides,dense_layers,dense_neurons,dense_dropouts,latent_dim)

In [None]:
epochs = 50
log_interval = 200 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for epoch in range(1, epochs + 1): 
    autoencoder.train(train_loader, epoch) 
    autoencoder.test(test_loader, epoch) 
    
#     with torch.no_grad(): 
#         train_all = train_loader.dataset.train_data.reshape(-1, 1, 28, 28).float()[:2000].to(device)
# #         print(train_all.shape)
#         decoded = autoencoder.predict(train_all) 
#         embeded = autoencoder.return_embeddings(train_all) 
#         encoded = autoencoder.encode(train_all) 
    
#     plt.figure(figsize=(10,8))
#     plt.scatter(embeded.cpu()[:2000,0], embeded.cpu()[:2000,1], c=train_loader.dataset.train_labels.cpu()[:2000], cmap='tab10') 
#     plt.title('Epoch: %d' % epoch)
#     plt.colorbar() 
#     plt.savefig('./results/3_result_epoch%d.pdf' % epoch) 

====> Epoch: 1 Average loss: 150.9879 ====> Test set loss: 119.1473
====> Epoch: 2 Average loss: 110.0895 ====> Test set loss: 105.3541
====> Epoch: 3 Average loss: 103.0524 ====> Test set loss: 101.1282
====> Epoch: 4 Average loss: 100.3063 ====> Test set loss: 99.0949


In [None]:
n = 30
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates corresponding to the 2D plot
# of digit classes in the latent space
grid_x = np.linspace(-200, 200, n)
grid_y = np.linspace(-200, 200, n)[::-1]

for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        z_sample = torch.tensor([[xi, yi]]).to(device) 
        with torch.no_grad(): 
            x_decoded = autoencoder.decode(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit.cpu()

plt.figure(figsize=(10, 10))
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap='Greys_r')
plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad(): 
    train_all = train_loader.dataset.train_data.reshape(-1, 1, 28, 28).float()[:2000].to(device)
    print(train_all.shape)
    decoded = autoencoder.predict(train_all) 
    embeded = autoencoder.return_embeddings(train_all) 
    encoded = autoencoder.encode(train_all) 

In [None]:
import matplotlib.pyplot as plt 
fig, ax = plt.subplots(nrows=2, ncols=4) 

for i in range(4): 
    ax[0,i].imshow(np.squeeze(train_all.cpu())[i]) 
    ax[0,i].set_title('original')
    ax[1,i].imshow(np.squeeze(decoded.cpu())[i]) 
    ax[1,i].set_title('generated')
    ax[0,i].axis('off') 
    ax[1,i].axis('off')
plt.show()

In [None]:
plt.figure(figsize=(10,8))
plt.scatter(embeded.cpu()[:2000,0], embeded.cpu()[:2000,1], c=train_loader.dataset.train_labels.cpu()[:2000], cmap='tab10')
plt.colorbar() 
plt.xlim((-500, 500))
plt.ylim((-500, 500))