In [80]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
import PIL
from PIL import Image
from scipy.misc import toimage
from torch.utils import data as D
import random

In [81]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [82]:
is_2d = True

img_size = (28, 28, 1) if is_2d else 28*28
hidden_size = 256

num_epochs = 100
batch_size = 64

dataset_dir = "./data/FMNIST"
sample_dir = "./result_vanilla_ae_2d_FMNIST/" if is_2d else "./result_vanilla_ae_1d_FMNIST/"

if not os.path.exists(dataset_dir):
    os.makedirs(dataset_dir)

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [83]:
if is_2d:
    transform = transforms.Compose([
                transforms.Resize(img_size[0]),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))])
else:
    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))])
    
trainset = torchvision.datasets.FashionMNIST(root=dataset_dir, train=True ,transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

In [84]:
class Vanilla_autoencoder_1d(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Vanilla_autoencoder_1d, self).__init__()
        
        self.network = nn.Sequential(
            #784 --> 256
            nn.Linear(input_size, hidden_size, bias=True),
            #nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            
            #256 --> 64
            nn.Linear(hidden_size, hidden_size//4, bias=True),
            #nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            
            #64 --> 256
            nn.Linear(hidden_size//4, hidden_size, bias=True),
            #nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            
            nn.Linear(hidden_size, input_size, bias=True),
            #nn.BatchNorm2d(hidden_size),
            nn.Tanh()
        )
        
    def forward(self, x):
        output = self.network(x)
        return output

In [85]:
class Vanilla_autoencoder_2d(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Vanilla_autoencoder_2d, self).__init__()
        
        self.network = nn.Sequential(
            # (-1, 1, 28, 28) --> (-1, 256, 14, 14)
            nn.Conv2d(input_size[2], hidden_size, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            
            # (-1, 256, 14, 14) --> (-1, 1024, 7, 7)
            nn.Conv2d(hidden_size, hidden_size*4, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(hidden_size*4),
            nn.ReLU(True),
            
            # (-1, 1024, 7, 7) --> (-1, 256, 14, 14)
            nn.ConvTranspose2d(hidden_size*4, hidden_size, kernel_size=3, padding=1, stride=2, output_padding=1, bias=False),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            
            # (-1, 256, 14, 14) --> (-1, 1, 28, 28)
            nn.ConvTranspose2d(hidden_size, input_size[2], kernel_size=3, padding=1, stride=2, output_padding=1, bias=False),
            nn.BatchNorm2d(input_size[2]),
            nn.Tanh()
        )
        
    def forward(self, x):
        output = self.network(x)
        return output

In [86]:
AE_model = Vanilla_autoencoder_2d(img_size, hidden_size) if is_2d else Vanilla_autoencoder_1d(img_size, hidden_size)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

AE_model.to(device)

cuda:0


Vanilla_autoencoder_2d(
  (network): Sequential(
    (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(256, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(1024, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(256, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
    (10): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Tanh()
  )
)

In [87]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(AE_model.parameters(), lr=0.0002)

In [None]:
# Start training
total_step = len(trainloader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(trainloader):
        if not is_2d:
            n_features = np.prod(images.size()[1:])
            images = images.view(-1, n_features)
        images = images.to(device)
                
        outputs = AE_model(images)
        
        mse_loss = criterion(outputs, images)
        optimizer.zero_grad()
        
        mse_loss.backward()
        optimizer.step()
        
        if (i+1) % 800 == 0:
            print('Epoch [{}/{}], Step [{}/{}], mse loss: {:.4f}' 
                  .format(epoch, num_epochs, i+1, total_step, mse_loss.item()))
    
    # Save real images and output_images
    images = images.reshape(images.size()[0], 1, 28, 28)
    output_images = outputs.reshape(outputs.size()[0], 1, 28, 28)
        
    results = torch.cat([images, output_images], dim=2)
    
    save_image(denorm(results), os.path.join(sample_dir, 'ae_model_result-{}.png'.format(epoch+1)))

Epoch [0/100], Step [800/938], mse loss: 0.1742
Epoch [1/100], Step [800/938], mse loss: 0.0813
Epoch [2/100], Step [800/938], mse loss: 0.0359
