In [1]:
import torch

USE_CUDA = torch.cuda.is_available()

BATCH_SIZE = 128


LEARNING_RATE = 1e-2

EPOCHS = 100

In [8]:
# Some utils
import matplotlib.pyplot as plt
from torchvision.utils import save_image

def plot_image(image):
    plt.imshow(image, cmap='gray')
    plt.show()
    
def save_sample(img, rec, epoch):
    img = img.view(-1, 28)
    rec = rec.view(-1, 28)
    
    pair = torch.cat((img, rec), 1)
    
    save_image(pair, './samples/conv/sample-epoch-{}.png'.format(epoch))

def log_training(loss, epoch, batch, n_epochs, n_batches, overwrite=False):
    print("Epoch {}/{} | Batch {}/{} | Loss: {}".format(epoch,n_epochs, batch, n_batches, round(loss, 6)), end='\r' if overwrite else '\n')

In [9]:
# Read data

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

dataset = MNIST('./mnist-data', transform=Compose([ToTensor()]))

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [5]:
# Build model

from torch import nn

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(16, 8, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=1)
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        
        return x
    


In [6]:
model = Autoencoder()

if USE_CUDA:
    model.cuda()

In [16]:
# Training


class timeit():
    from datetime import datetime
    def __enter__(self):
        self.tic = self.datetime.now()
    def __exit__(self, *args, **kwargs):
        print('runtime: {}'.format(self.datetime.now() - self.tic))


criterion = torch.nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    for batch, (img, _) in enumerate(dataloader):
        break
        if USE_CUDA:
            img = img.cuda()

        rec = model(img)

        loss = criterion(rec, img)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

        log_training(loss.item(), epoch, batch * BATCH_SIZE, EPOCHS, len(dataset), overwrite=True) 

    if epoch % 10 == 0:
        print("Saving sample...", end='\r')
        save_sample(img, rec, epoch)
        break

Saving sample...

In [14]:
model.load_state_dict(torch.load("./weights/conv.weights", map_location='cpu'))

In [15]:
model.eval()

Autoencoder(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(3, 3), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(8, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace)
    (2): ConvTranspose2d(16, 8, kernel_size=(5, 5), stride=(3, 3), padding=(1, 1))
    (3): ReLU(inplace)
    (4): ConvTranspose2d(8, 1, kernel_size=(2, 2), stride=(2, 2), padding=(1, 1))
    (5): Tanh()
  )
)