In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
import PIL
from PIL import Image
from scipy.misc import toimage
from torch.utils import data as D
import random

In [2]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [4]:
img_size = (28, 28, 1) 
hidden_size = 256

num_epochs = 100
batch_size = 64

dataset_dir = "./data/FMNIST"
sample_dir = "./result_denoising_ae_FMNIST/" 

if not os.path.exists(dataset_dir):
    os.makedirs(dataset_dir)

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [5]:
transform = transforms.Compose([
            transforms.Resize(img_size[0]),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))])
    
trainset = torchvision.datasets.FashionMNIST(root=dataset_dir, train=True ,transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

In [15]:
class Denoising_autoencoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Denoising_autoencoder, self).__init__()
        
        self.network = nn.Sequential(
            # (-1, 1, 28, 28) --> (-1, 256, 14, 14)
            nn.Conv2d(input_size[2], hidden_size, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            
            # (-1, 256, 14, 14) --> (-1, 1024, 7, 7)
            nn.Conv2d(hidden_size, hidden_size*4, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(hidden_size*4),
            nn.ReLU(True),
            
            # (-1, 1024, 7, 7) --> (-1, 2048, 7, 7)
            nn.Conv2d(hidden_size*4, hidden_size*8, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(hidden_size*8),
            nn.ReLU(True),
            
            # (-1, 2048, 7, 7) --> (-1, 1024, 7, 7)
            nn.Conv2d(hidden_size*8, hidden_size*4, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(hidden_size*4),
            nn.ReLU(True),
            
            # (-1, 1024, 7, 7) --> (-1, 256, 14, 14)
            nn.ConvTranspose2d(hidden_size*4, hidden_size, kernel_size=3, padding=1, stride=2, output_padding=1, bias=False),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            
            # (-1, 256, 14, 14) --> (-1, 1, 28, 28)
            nn.ConvTranspose2d(hidden_size, input_size[2], kernel_size=3, padding=1, stride=2, output_padding=1, bias=False),
            nn.BatchNorm2d(input_size[2]),
            nn.Tanh()
        )
        
    def forward(self, x):
        output = self.network(x)
        return output

In [16]:
AE_model = Denoising_autoencoder(img_size, hidden_size) 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

AE_model.to(device)

cuda:0


Denoising_autoencoder(
  (network): Sequential(
    (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(256, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d(1024, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
    (13): Bat

In [32]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(AE_model.parameters(), lr=0.01)
#optimizer = torch.optim.SGD(AE_model.parameters(), lr=0.1, momentum=0.9)
init_lr = 0.1

mean = 0
std = 0.2

In [33]:
# Start training
total_step = len(trainloader)
for epoch in range(num_epochs):
    if epoch % 20 == 0:
        optimizer = torch.optim.Adam(AE_model.parameters(), lr=init_lr)
        #optimizer = torch.optim.SGD(AE_model.parameters(), lr=init_lr, momentum=0.9)

        init_lr /= 10
    
    for i, (images, _) in enumerate(trainloader):
        # add noise
       
        noise = torch.tensor(images.data.new(images.size()).normal_(mean, std))
        images_noise = images + noise
        images = images.to(device)
        images_noise = images_noise.to(device)

        outputs = AE_model(images_noise)
        
        mse_loss = criterion(outputs, images)
        optimizer.zero_grad()
        
        mse_loss.backward()
        optimizer.step()
        
        if (i+1) % 800 == 0:
            print('Epoch [{}/{}], Step [{}/{}], mse loss: {:.4f}, lr: {:.5f}' 
                  .format(epoch, num_epochs, i+1, total_step, mse_loss.item(), init_lr))
    
    # Save real images and output_images
    images_noise = images_noise.reshape(images_noise.size()[0], 1, 28, 28)
    output_images = outputs.reshape(outputs.size()[0], 1, 28, 28)
        
    results = torch.cat([images_noise, output_images], dim=2)
    
    save_image(denorm(results), os.path.join(sample_dir, 'ae_model_result-{}.png'.format(epoch+1)))

Epoch [0/100], Step [800/938], mse loss: 0.0082, lr: 0.01000
Epoch [1/100], Step [800/938], mse loss: 0.0067, lr: 0.01000
Epoch [2/100], Step [800/938], mse loss: 0.0076, lr: 0.01000
Epoch [3/100], Step [800/938], mse loss: 0.0067, lr: 0.01000
Epoch [4/100], Step [800/938], mse loss: 0.0075, lr: 0.01000
Epoch [5/100], Step [800/938], mse loss: 0.0071, lr: 0.01000
Epoch [6/100], Step [800/938], mse loss: 0.0067, lr: 0.01000
Epoch [7/100], Step [800/938], mse loss: 0.0067, lr: 0.01000
Epoch [8/100], Step [800/938], mse loss: 0.0069, lr: 0.01000
Epoch [9/100], Step [800/938], mse loss: 0.0072, lr: 0.01000
Epoch [10/100], Step [800/938], mse loss: 0.0067, lr: 0.01000
Epoch [11/100], Step [800/938], mse loss: 0.0060, lr: 0.01000
Epoch [12/100], Step [800/938], mse loss: 0.0068, lr: 0.01000
Epoch [13/100], Step [800/938], mse loss: 0.0063, lr: 0.01000
Epoch [14/100], Step [800/938], mse loss: 0.0090, lr: 0.01000
Epoch [15/100], Step [800/938], mse loss: 0.0060, lr: 0.01000
Epoch [16/100], St

KeyboardInterrupt: 