In [1]:
import torch
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import torch.nn.functional as F

from auto_encoder import AutoEncoder

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:

num_epochs = 200
batch_size = 32
learning_rate = 3e-4
decay_factor = 0.99
model_save_path = "./auto_encoder_model.pth"

In [4]:
train_data = datasets.MNIST(root='./data/',
                            train=True,
                            download=True,
                            transform=transforms.ToTensor())

indices = torch.arange(1000)
train_data_1k = data_utils.Subset(train_data, indices)

train_loader = DataLoader(dataset=train_data_1k,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)

In [5]:
model = AutoEncoder().to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                        lr_lambda=(lambda epoch: decay_factor ** epoch))
print(model)

AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.01, inplace=True)
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.01, inplace=True)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): LeakyReLU(negative_slope=0.01, inplace=True)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.01, inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
  )
  (linear1): Linear(in_features=3136, out_features=2, bias=True)
  (linear2): Linear(in_features=2, out_features=3136, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.01, inplace=True)
    (2): ConvTranspose2d(64, 64, kernel_size=(4,

In [6]:
def r_loss(y_true, y_pred):
    return torch.mean(torch.mean(torch.square(y_true - y_pred), axis=[1, 2, 3]))

Train

In [8]:
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            loss = F.mse_loss(outputs, inputs)
            loss.backward()
            optimizer.step()
            
        running_loss += loss.item() * inputs.size(0)
        
    scheduler.step()
    epoch_loss = running_loss / len(train_data_1k)
    print('Epoch {0:03d}\tLoss: {1:0.5f}'.format(epoch, epoch_loss))

torch.save(model.state_dict(), model_save_path)

Epoch 000	Loss: 0.05499
Epoch 001	Loss: 0.05331
Epoch 002	Loss: 0.05161
Epoch 003	Loss: 0.05065
Epoch 004	Loss: 0.04971
Epoch 005	Loss: 0.04893
Epoch 006	Loss: 0.04826
Epoch 007	Loss: 0.04797
Epoch 008	Loss: 0.04744
Epoch 009	Loss: 0.04696
Epoch 010	Loss: 0.04671
Epoch 011	Loss: 0.04617
Epoch 012	Loss: 0.04580
Epoch 013	Loss: 0.04556
Epoch 014	Loss: 0.04513
Epoch 015	Loss: 0.04504
Epoch 016	Loss: 0.04454
Epoch 017	Loss: 0.04438
Epoch 018	Loss: 0.04422
Epoch 019	Loss: 0.04375
Epoch 020	Loss: 0.04348
Epoch 021	Loss: 0.04325
Epoch 022	Loss: 0.04301
Epoch 023	Loss: 0.04277
Epoch 024	Loss: 0.04260
Epoch 025	Loss: 0.04235
Epoch 026	Loss: 0.04208
Epoch 027	Loss: 0.04186
Epoch 028	Loss: 0.04190
Epoch 029	Loss: 0.04189
Epoch 030	Loss: 0.04145
Epoch 031	Loss: 0.04128
Epoch 032	Loss: 0.04115
Epoch 033	Loss: 0.04077
Epoch 034	Loss: 0.04063
Epoch 035	Loss: 0.04057
Epoch 036	Loss: 0.04036
Epoch 037	Loss: 0.04044
Epoch 038	Loss: 0.04007
Epoch 039	Loss: 0.04004
Epoch 040	Loss: 0.03985
Epoch 041	Loss: 