In [1]:
import os
import json
import time
from shutil import copyfile

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

from ignite.engine import Engine, _prepare_batch, Events
from ignite.metrics import Loss
from ignite.handlers import ModelCheckpoint, EarlyStopping

from Autoencoder import Autoencoder
from dataloader import dataloader, train_transform
from utils import show_image

In [2]:
# data path
root = 'data/' 

# train settings
num_epochs = 30
batch_size = 8
device = 'cpu'
code_size = 64

# verbose settings
log_interval = 1
save_images = True

In [3]:
model = Autoencoder(code_size).to(device)
criterion = nn.MSELoss()

initial_lr = 0.001
gamma = 0.8
optimizer = torch.optim.Adam(model.parameters(), initial_lr)
lr_scheduler = StepLR(optimizer, step_size=30, gamma=gamma)

train_loader, val_loader, train_eval_loader = dataloader(root=root, 
                                            batch_size=batch_size)

print('Train-size: ', len(train_loader.dataset))
print('Test-size: ', len(val_loader.dataset))
print('Train-eval-size: ', len(train_eval_loader.dataset))

Train-size:  7299
Test-size:  809
Train-eval-size:  809


In [4]:
# There I use ignite framework for wrap all pure python and pytorch training code
# see more https://pytorch.org/ignite/index.html

def create_unsupervised_evaluator(model, metrics={}, device=None):
    if device:
        model.to(device)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, _ = _prepare_batch(batch, device=device)
            x_pred = model(x)
            return x_pred, x

    engine = Engine(_inference)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine


def process_function(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, _ = _prepare_batch(batch, device=device)
    x_pred = model(x)
    loss = criterion(x_pred, x)
    loss.backward()
    optimizer.step()
    return loss.item()

In [5]:
trainer = Engine(process_function)
metrics = {
    'avg_loss': Loss(criterion)
}

train_evaluator = create_unsupervised_evaluator(model, metrics=metrics, device=device)
val_evaluator = create_unsupervised_evaluator(model, metrics=metrics, device=device)

# log train loss after log_interval iterations
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iteration = (engine.state.iteration - 1) % len(train_loader) + 1
    if iteration % log_interval == 0:
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.4f}"
              .format(engine.state.epoch, 
                         iteration, 
                         len(train_loader), 
                         engine.state.output))


# log train and validation loss and learning rate value after epoch
@trainer.on(Events.EPOCH_COMPLETED)
def compute_and_display_val_metrics(engine):
    metrics_val = val_evaluator.run(val_loader).metrics
    metrics_train = train_evaluator.run(train_eval_loader).metrics
    print("Epoch:{} Training Average Loss: {:.4f}, Validation Average Loss: {:.4f}, LR={:.7f}"
          .format(engine.state.epoch, 
                  metrics_val['avg_loss'], 
                  metrics_train['avg_loss'],
                  float(optimizer.param_groups[0]['lr'])))

    
# make step for learning rate scheduler
@trainer.on(Events.EPOCH_STARTED)
def update_lr_scheduler(engine):
    lr_scheduler.step()
    

def score_function(engine):
    return -1*engine.state.metrics['avg_loss']


# save the best model comparing validation loss
model_saver = ModelCheckpoint("best_models",  
                                   filename_prefix="autoencoder",
                                   score_name="loss",  
                                   score_function=score_function,
                                   n_saved=1,
                                   save_as_state_dict=True,
                                   create_dir=True)
val_evaluator.add_event_handler(Events.COMPLETED, 
                                model_saver, 
                                {"model": model})


# also save one model and configuration after each epoch
training_saver = ModelCheckpoint("checkpoint",
                                 filename_prefix="checkpoint",
                                 save_interval=1,
                                 n_saved=1,
                                 save_as_state_dict=True,
                                 create_dir=True)
to_save = {
    "model": model, 
    "optimizer": optimizer, 
    "lr_scheduler": lr_scheduler
} 
trainer.add_event_handler(Events.EPOCH_COMPLETED, training_saver, to_save)


# stop if there is no improvements in validation loss in 10 epochs
early_stopping = EarlyStopping(patience=10,     
                              score_function=score_function, 
                              trainer=trainer)
val_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stopping)

# save 10 original and decoded images after each epoch
if save_images:
    if not os.path.exists('images/'):
        os.mkdir('images/')
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_images(engine):
        model.eval()
        epoch = engine.state.epoch
        np.random.seed(46)
        for i in np.random.random_integers(0, len(val_loader), 10):
            image, _ = val_loader.dataset[i]
            image = image.to(device)
            with torch.no_grad():
                predict_image = model(image[None])[0].detach().cpu()
            plt.subplot(1, 2, 1)
            show_image(np.swapaxes(image.cpu().numpy(), 0, 2))
            plt.subplot(1, 2, 2)
            show_image(np.swapaxes(predict_image.numpy(), 0, 2))
            plt.savefig('images/{}_{}'.format(epoch, i))        

In [6]:
trainer.run(train_loader, num_epochs)

Epoch[1] Iteration[1/913] Loss: 7.6955
Epoch[1] Iteration[2/913] Loss: 63.0387
Epoch[1] Iteration[3/913] Loss: 8.7645
Epoch[1] Iteration[4/913] Loss: 30.6356
Epoch[1] Iteration[5/913] Loss: 17.6803
Epoch[1] Iteration[6/913] Loss: 5.6425
Epoch[1] Iteration[7/913] Loss: 3.2057
Epoch[1] Iteration[8/913] Loss: 6.3335
Epoch[1] Iteration[9/913] Loss: 7.0950
Epoch[1] Iteration[10/913] Loss: 5.5838
Epoch[1] Iteration[11/913] Loss: 3.9486
Epoch[1] Iteration[12/913] Loss: 3.0486
Epoch[1] Iteration[13/913] Loss: 1.5038
Epoch[1] Iteration[14/913] Loss: 1.3292
Epoch[1] Iteration[15/913] Loss: 1.5371
Epoch[1] Iteration[16/913] Loss: 1.5390
Epoch[1] Iteration[17/913] Loss: 1.5695
Epoch[1] Iteration[18/913] Loss: 1.6034
Epoch[1] Iteration[19/913] Loss: 1.8532
Epoch[1] Iteration[20/913] Loss: 1.3660
Epoch[1] Iteration[21/913] Loss: 1.4295
Epoch[1] Iteration[22/913] Loss: 1.3949
Epoch[1] Iteration[23/913] Loss: 1.0575
Epoch[1] Iteration[24/913] Loss: 0.9107
Epoch[1] Iteration[25/913] Loss: 1.0762
Epoch[

KeyboardInterrupt: 