In [1]:
import torch.nn as nn
import time
import torch.optim as optim
from autoencoder_model_segnet_rev1 import SegNet
import torch
from torchvision import models
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from address_gram_dataset import AddressGramDataset
from torch.utils.data import DataLoader
from utils import *
import matplotlib.pyplot as plt

# Load training/val dataset for one font only

In [2]:

# arial_bold_test_dataset = FontDataset('./address_grams/arialbold/test')

# arial_italic_train_dataset = FontDataset('./address_grams/arialitalic/train')
# arial_italic_test_dataset = FontDataset('./address_grams/arialitalic/test')

# vera_train_dataset = FontDataset('./address_grams/vera/train')
# vera_test_dataset = FontDataset('./address_grams/vera/test')

# arial_train_dataset = FontDataset('./address_grams/arial/train')
# arial_test_dataset = FontDataset('./address_grams/arial/test')

In [3]:

# arial_bold_test_dataloader = DataLoader(arial_bold_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# arial_italic_train_dataloader = DataLoader(arial_italic_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# arial_italic_test_dataloader = DataLoader(arial_italic_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# vera_train_dataloader = DataLoader(vera_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# vera_test_dataloader = DataLoader(vera_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# arial_train_dataloader = DataLoader(vera_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# arial_test_dataloader = DataLoader(vera_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define train and val loops

In [4]:
def train(epoch, train_loader, model, optimizer):
    # Ensure dropout layers are in train mode
    model.train()

    batch_time = ExpoAverageMeter()  # forward prop. + back prop. time
    losses = ExpoAverageMeter()  # loss (per word decoded)

    start = time.time()

    # Batches
    for i_batch, (x, y) in enumerate(train_loader):
        # Set device options
        x = x.to(device)
        y = y.to(device)

        # print('x.size(): ' + str(x.size())) # [32, 3, 224, 224]
        # print('y.size(): ' + str(y.size())) # [32, 3, 224, 224]

        # Zero gradients
        optimizer.zero_grad()

        y_hat = model(x)
        # print('y_hat.size(): ' + str(y_hat.size())) # [32, 3, 224, 224]

        loss = torch.sqrt((y_hat - y).pow(2).mean())
        loss.backward()

        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i_batch % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i_batch, len(train_loader),
                                                                  batch_time=batch_time,
                                                                  loss=losses))
            
    return losses.avg

In [5]:
def valid(val_loader, model):
    model.eval()  # eval mode (no dropout or batchnorm)

    batch_time = ExpoAverageMeter()  # forward prop. + back prop. time
    losses = ExpoAverageMeter()  # loss (per word decoded)

    start = time.time()

    with torch.no_grad():
        # Batches
        for i_batch, (x, y) in enumerate(val_loader):
            # Set device options
            x = x.to(device)
            y = y.to(device)

            y_hat = model(x)

            loss = torch.sqrt((y_hat - y).pow(2).mean())

            # Keep track of metrics
            losses.update(loss.item())
            batch_time.update(time.time() - start)

            start = time.time()

            # Print status
            if i_batch % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i_batch, len(val_loader),
                                                                      batch_time=batch_time,
                                                                      loss=losses))

    return losses.avg

In [6]:
def plot_graph(train_losses, val_losses):
    
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.legend(loc='best')
    plt.title('Loss vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (MSE)')
    
    plt.savefig('autoencoder_training_fig.png')

# Load training/val dataset for one font only

In [7]:
# create dataset, pass in path to dir of images
arial_bold_train_dataset = AddressGramDataset('./address_grams/arialbold/train')
arial_bold_val_dataset = AddressGramDataset('./address_grams/arialbold/val')

In [8]:
BATCH_SIZE = 32

# create dataloader
arial_bold_train_dataloader = DataLoader(arial_bold_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
arial_bold_val_dataloader = DataLoader(arial_bold_val_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
def main():
    
    start_time = time.time()
    
    train_loader = arial_bold_train_dataloader
    val_loader = arial_bold_val_dataloader

    # Create SegNet model
    label_nbr = 3
    model = SegNet(label_nbr)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [40, xxx] -> [10, ...], [10, ...], [10, ...], [10, ...] on 4 GPUs
        model = nn.DataParallel(model)
    # Use appropriate device
    model = model.to(device)
    # print(model)

    # define the optimizer
    # optimizer = optim.LBFGS(model.parameters(), lr=0.8)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_loss = 100000
    epochs_since_improvement = 0
    
    train_losses = []
    val_losses = []

    # Epochs
    for epoch in range(start_epoch, epochs):
        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # One epoch's training
        train_loss = train(epoch, train_loader, model, optimizer)

        # One epoch's validation
        val_loss = valid(val_loader, model)
        print('\n * LOSS - {loss:.3f}\n'.format(loss=val_loss))
        
        # append losses
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        # Check if there was an improvement
        is_best = val_loss < best_loss
        best_loss = min(best_loss, val_loss)

        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # only save checkpoint every 5 epochs or if last one
        if epoch % 5 == 0 and is_best and epoch > 60:
            print('Saving checkpoint at epoch:', epoch)
            save_checkpoint(epoch, model, optimizer, val_loss, is_best)
            
    plot_graph(train_losses, val_losses)
    
    # print total training time
    print('total training time (mins): ', (time.time() - start_time) / 60)

In [10]:
if __name__ == '__main__':
    main()

Let's use 4 GPUs!
Epoch: [0][0/125]	Batch Time 5.713 (0.571)	Loss 0.8538 (0.0854)	
Epoch: [0][20/125]	Batch Time 0.238 (0.284)	Loss 0.7523 (0.6837)	
Epoch: [0][40/125]	Batch Time 0.249 (0.250)	Loss 0.7423 (0.7368)	
Epoch: [0][60/125]	Batch Time 0.268 (0.251)	Loss 0.7371 (0.7386)	
Epoch: [0][80/125]	Batch Time 0.240 (0.247)	Loss 0.7336 (0.7356)	
Epoch: [0][100/125]	Batch Time 0.239 (0.246)	Loss 0.7324 (0.7331)	
Epoch: [0][120/125]	Batch Time 0.241 (0.244)	Loss 0.7301 (0.7308)	
Validation: [0/32]	Batch Time 0.123 (0.012)	Loss 0.7285 (0.0728)	
Validation: [20/32]	Batch Time 0.119 (0.107)	Loss 0.7289 (0.6490)	

 * LOSS - 0.704

Saving checkpoint at epoch: 0
Epoch: [1][0/125]	Batch Time 0.244 (0.024)	Loss 0.7297 (0.0730)	
Epoch: [1][20/125]	Batch Time 0.253 (0.226)	Loss 0.7272 (0.6482)	
Epoch: [1][40/125]	Batch Time 0.242 (0.242)	Loss 0.7242 (0.7160)	
Epoch: [1][60/125]	Batch Time 0.261 (0.248)	Loss 0.7220 (0.7226)	
Epoch: [1][80/125]	Batch Time 0.251 (0.254)	Loss 0.7210 (0.7221)	
Epoch: [1

KeyboardInterrupt: 