In [1]:
import torch.nn as nn
import time
import torch.optim as optim
from autoencoder_model_segnet_rev1 import SegNet
import torch
from torchvision import models
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from address_gram_dataset import AddressGramDataset
from torch.utils.data import DataLoader
from utils import *

# Load training/val dataset for one font only

In [2]:

# arial_bold_test_dataset = FontDataset('./address_grams/arialbold/test')

# arial_italic_train_dataset = FontDataset('./address_grams/arialitalic/train')
# arial_italic_test_dataset = FontDataset('./address_grams/arialitalic/test')

# vera_train_dataset = FontDataset('./address_grams/vera/train')
# vera_test_dataset = FontDataset('./address_grams/vera/test')

# arial_train_dataset = FontDataset('./address_grams/arial/train')
# arial_test_dataset = FontDataset('./address_grams/arial/test')

In [3]:

# arial_bold_test_dataloader = DataLoader(arial_bold_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# arial_italic_train_dataloader = DataLoader(arial_italic_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# arial_italic_test_dataloader = DataLoader(arial_italic_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# vera_train_dataloader = DataLoader(vera_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# vera_test_dataloader = DataLoader(vera_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# arial_train_dataloader = DataLoader(vera_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# arial_test_dataloader = DataLoader(vera_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define train and val loops

In [4]:
def train(epoch, train_loader, model, optimizer):
    # Ensure dropout layers are in train mode
    model.train()

    batch_time = ExpoAverageMeter()  # forward prop. + back prop. time
    losses = ExpoAverageMeter()  # loss (per word decoded)

    start = time.time()

    # Batches
    for i_batch, (x, y) in enumerate(train_loader):
        # Set device options
        x = x.to(device)
        y = y.to(device)

        # print('x.size(): ' + str(x.size())) # [32, 3, 224, 224]
        # print('y.size(): ' + str(y.size())) # [32, 3, 224, 224]

        # Zero gradients
        optimizer.zero_grad()

        y_hat = model(x)
        # print('y_hat.size(): ' + str(y_hat.size())) # [32, 3, 224, 224]

        loss = torch.sqrt((y_hat - y).pow(2).mean())
        loss.backward()

        optimizer.step()

        # Keep track of metrics
        losses.update(loss.item())
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i_batch % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i_batch, len(train_loader),
                                                                  batch_time=batch_time,
                                                                  loss=losses))
            
        

In [5]:
def valid(val_loader, model):
    model.eval()  # eval mode (no dropout or batchnorm)

    batch_time = ExpoAverageMeter()  # forward prop. + back prop. time
    losses = ExpoAverageMeter()  # loss (per word decoded)

    start = time.time()

    with torch.no_grad():
        # Batches
        for i_batch, (x, y) in enumerate(val_loader):
            # Set device options
            x = x.to(device)
            y = y.to(device)

            y_hat = model(x)

            loss = torch.sqrt((y_hat - y).pow(2).mean())

            # Keep track of metrics
            losses.update(loss.item())
            batch_time.update(time.time() - start)

            start = time.time()

            # Print status
            if i_batch % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i_batch, len(val_loader),
                                                                      batch_time=batch_time,
                                                                      loss=losses))

    return losses.avg

# Load training/val dataset for one font only

In [10]:
# create dataset, pass in path to dir of images
arial_bold_train_dataset = AddressGramDataset('./address_grams/arialbold/train')
arial_bold_val_dataset = AddressGramDataset('./address_grams/arialbold/val')

In [11]:
BATCH_SIZE = 32

# create dataloader
arial_bold_train_dataloader = DataLoader(arial_bold_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
arial_bold_val_dataloader = DataLoader(arial_bold_val_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [12]:
def main():
    
    train_loader = arial_bold_train_dataloader
    val_loader = arial_bold_val_dataloader

    # Create SegNet model
    label_nbr = 3
    model = SegNet(label_nbr)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [40, xxx] -> [10, ...], [10, ...], [10, ...], [10, ...] on 4 GPUs
        model = nn.DataParallel(model)
    # Use appropriate device
    model = model.to(device)
    # print(model)

    # define the optimizer
    # optimizer = optim.LBFGS(model.parameters(), lr=0.8)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_loss = 100000
    epochs_since_improvement = 0

    # Epochs
    for epoch in range(start_epoch, epochs):
        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(optimizer, 0.8)

        # One epoch's training
        train(epoch, train_loader, model, optimizer)

        # One epoch's validation
        val_loss = valid(val_loader, model)
        print('\n * LOSS - {loss:.3f}\n'.format(loss=val_loss))

        # Check if there was an improvement
        is_best = val_loss < best_loss
        best_loss = min(best_loss, val_loss)

        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, model, optimizer, val_loss, is_best)

In [13]:
if __name__ == '__main__':
    main()

Epoch: [0][0/125]	Batch Time 111.191 (11.119)	Loss 0.9227 (0.0923)	


KeyboardInterrupt: 