In [75]:
import torch

USE_CUDA = torch.cuda.is_available()

BATCH_SIZE = 128

LATENT_SPACE_SIZE = 4

LEARNING_RATE = 1e-2

EPOCHS = 100

In [64]:
# Some utils
import matplotlib.pyplot as plt
from torchvision.utils import save_image

def plot_image(image):
    plt.imshow(image, cmap='gray')
    plt.show()
    
def save_sample(img, rec, epoch):
    img = img.view(-1, 28)
    rec = rec.view(-1, 28)
    
    pair = torch.cat((img, rec), 1)
    
    save_image(pair, './samples/dense/sample-epoch-{}.png'.format(epoch))

def log_training(loss, epoch, batch, n_epochs, n_batches, overwrite=False):
    print("Epoch {}/{} | Batch {}/{} | Loss: {}".format(epoch,n_epochs, batch, n_batches, round(loss, 6)), end='\r' if overwrite else '\n')

In [3]:
# Read data

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

dataset = MNIST('./mnist-data', transform=Compose([ToTensor()]))

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [5]:
# Build model

from torch import nn

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, LATENT_SPACE_SIZE)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(LATENT_SPACE_SIZE, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        
        return x
    


In [73]:
model = Autoencoder()

if USE_CUDA:
    model.cuda()

In [78]:
# Training


criterion = torch.nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    for batch, (img, _) in enumerate(dataloader):
        img = img.view(-1, 28 * 28)
        
        if USE_CUDA:
            img = img.cuda()
        
        rec = model(img)
        
        loss = criterion(rec, img)
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        log_training(loss.item(), epoch, batch * BATCH_SIZE, EPOCHS, len(dataset), overwrite=True) 
    
    if epoch % 10 == 0:
        print("Saving sample...", end='\r')
        save_sample(img.view(-1, 1, 28, 28), rec.view(-1, 1, 28, 28), epoch)

Epoch 1/100 | Batch 20352/60000 | Loss: 0.040945

KeyboardInterrupt: 