In [3]:
# https://gist.github.com/bigsnarfdude/dde651f6e06f266b48bc3750ac730f80
    
import os

import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')


def to_img(x):
    x = x.view(x.size(0), 1, 28, 28)
    return x

num_epochs = 20
batch_size = 128
learning_rate = 1e-3


def add_noise(img):
    noise = torch.randn(img.size()) * 0.2
    noisy_img = img + noise
    return noisy_img


def plot_sample_img(img, name):
    img = img.view(1, 28, 28)
    save_image(img, './sample_{}.png'.format(name))


def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor


def tensor_round(tensor):
    return torch.round(tensor)


img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
    transforms.Lambda(lambda tensor:tensor_round(tensor))
])

dataset = MNIST('./data', transform=img_transform, download=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
            nn.ReLU(True))
        self.decoder = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(True),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = autoencoder().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        noisy_img = add_noise(img)
        noisy_img = Variable(noisy_img).cuda()
        img = Variable(img).cuda()
        # ===================forward=====================
        output = model(noisy_img)
        loss = criterion(output, img)
        MSE_loss = nn.MSELoss()(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item(), MSE_loss.item()))
    if epoch % 10 == 0:
        x = to_img(img.cpu().data)
        x_hat = to_img(output.cpu().data)
        x_noisy = to_img(noisy_img.cpu().data)
        weights = to_img(model.encoder[0].weight.cpu().data)
        save_image(x, './mlp_img/x_{}.png'.format(epoch))
        save_image(x_hat, './mlp_img/x_hat_{}.png'.format(epoch))
        save_image(x_noisy, './mlp_img/x_noisy_{}.png'.format(epoch))
        save_image(weights, './filters/epoch_{}.png'.format(epoch))

#torch.save(model.state_dict(), './sim_autoencoder.pth')

epoch [1/20], loss:0.1134, MSE_loss:0.0348
epoch [2/20], loss:0.0915, MSE_loss:0.0279
epoch [3/20], loss:0.0860, MSE_loss:0.0260
epoch [4/20], loss:0.0691, MSE_loss:0.0206
epoch [5/20], loss:0.0648, MSE_loss:0.0193
epoch [6/20], loss:0.0621, MSE_loss:0.0186
epoch [7/20], loss:0.0597, MSE_loss:0.0178
epoch [8/20], loss:0.0610, MSE_loss:0.0183
epoch [9/20], loss:0.0576, MSE_loss:0.0169
epoch [10/20], loss:0.0586, MSE_loss:0.0175
epoch [11/20], loss:0.0516, MSE_loss:0.0152
epoch [12/20], loss:0.0524, MSE_loss:0.0156
epoch [13/20], loss:0.0545, MSE_loss:0.0163
epoch [14/20], loss:0.0541, MSE_loss:0.0162
epoch [15/20], loss:0.0510, MSE_loss:0.0153
epoch [16/20], loss:0.0484, MSE_loss:0.0145
epoch [17/20], loss:0.0483, MSE_loss:0.0142
epoch [18/20], loss:0.0494, MSE_loss:0.0145
epoch [19/20], loss:0.0502, MSE_loss:0.0149
epoch [20/20], loss:0.0483, MSE_loss:0.0141
