In [25]:
import torch
import torchvision
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.datasets as datasets
import os

# Path
if not os.path.exists('./dc_img'):
    os.mkdir('./dc_img')

# To save images
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

# Parameters
num_epochs = 100
batch_size = 128
learning_rate = 0.001

# To transform to tensor
transforms = transforms.Compose([transforms.ToTensor()])

# Dataset for training, validation and test sets as tensors
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
mnist_trainset, mnist_validationset = torch.utils.data.random_split(mnist_trainset, [50000, 10000])

# Data loader for train, test and validation sets
trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=batch_size, num_workers=2, shuffle=True)
testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=batch_size, num_workers=2, shuffle=True)
validationloader = torch.utils.data.DataLoader(mnist_validationset, batch_size=batch_size, num_workers=2, shuffle=True)


class autoencoder(nn.Module):
    
    # Defining the network
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 5, stride=1, padding=1), nn.MaxPool2d(2, stride=2),
            nn.Conv2d(16, 8, 5, stride=1, padding=1), nn.MaxPool2d(2, stride=2),
            nn.Conv2d(8, 4, 2, stride=1, padding=1), nn.MaxPool2d(2, stride=2)
            
#             nn.Conv2d(1, 16, 2, stride=2, padding=1)
#             nn.Conv2d(1, 16, 2, stride=3, padding=1), nn.ReLU(True), nn.MaxPool2d(2, stride=2),
#             nn.Conv2d(16, 8, 3, stride=2, padding=1), nn.ReLU(True), nn.MaxPool2d(2, stride=1)
        )
        self.decoder = nn.Sequential(
            nn.MaxUnpool2d(2, stride=2), nn.ConvTranspose2d(4, 8, 2, stride=1),
            nn.MaxUnpool2d(2, stride=2), nn.ConvTranspose2d(8, 16, 5, stride=1),
            nn.MaxUnpool2d(2, stride=2), nn.ConvTranspose2d(16, 1, 5, stride=1),
            nn.Tanh()
            
#             nn.ConvTranspose2d(8, 16, 3, stride=2), nn.ReLU(True),
#             nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), nn.ReLU(True),
#             nn.ConvTranspose2d(16, 1, 2, stride=2, padding=1), nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Getting our model
model = autoencoder()

# Defining optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.00001)
criterion = nn.MSELoss()

# Starting Training
for epoch in range(num_epochs):
    for data in trainloader:
        img, _ = data
        img = Variable(img)
        
        # Forward Propagation
        output = model(img)
        loss = criterion(output, img)
        
        # Back Propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # Results
    print('epoch {} = loss:{:.4f}'
          .format(epoch+1, loss.data))
    # Save images every 10 epochs
    if epoch % 10 == 0:
        pic = to_img(output.cpu().data)
        save_image(pic, './dc_img/image_{}.png'.format(epoch))

torch.save(model.state_dict(), './conv_autoencoder.pth')

KeyboardInterrupt: 