In [None]:
import sys, os
from pathlib import Path
import random
from types import SimpleNamespace
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

sys.path.append(os.path.abspath('../src'))
from foe_fingerprint import FOEFingerprint
from foe_patch_dataset import FOEPatchDataset
from foe_conv_net import FOEConvNet
from foe_patch import FOEOrientation
from foe_autoencoder import Encoder, Decoder

keys = []
types = []
confs = []

def make_bool(str_val):
    if str_val == "True": return True
    else: return False

methods = {'str': str, 'int': int, 'float': float, 'bool': make_bool }
with open("args.txt") as f:
    for line in f:
        if not keys: keys = line.split()
        elif not types: types = line.split()
        else:
            vals = line.split()
            d = {}
            for i in range(len(keys)):
                d[keys[i]] = methods[types[i]](vals[i])
            confs.append(SimpleNamespace(**d))

args = confs[0]

base_path = Path(args.base_path)
patch_size = args.patch_size
n_folds = args.n_folds
n_classes = args.n_classes
n_epochs = args.n_epochs
batch_size = args.batch_size
learning_rate = args.learning_rate
use_gpu = not args.cpu

output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
log_dir = output_path.joinpath('log/foe')
model_dir = output_path.joinpath('models')
model_dir.mkdir(parents=True, exist_ok=True)

if not torch.cuda.is_available():
    use_gpu = False
device = 'cuda' if use_gpu else 'cpu'
print("Using {}...".format(device))

VALIDATION_BATCH_SIZE = 256
NUM_WORKERS = 4
FOLD = 0

 # load fingerprint data
dset = FOEFingerprint.load_index_file(base_path.joinpath('Good'),
                                      'index.txt', True)
if not args.without_bad:
    dset.extend(FOEFingerprint.load_index_file(base_path.joinpath('Bad'),
                                               'index.txt', False))
print('Loaded {} fingerprints.'.format(len(dset)))

if not args.no_shuffle:
    random.shuffle(dset)
    print('Randomized splits.')

# create train/val split for fold 0
tset, vset = FOEPatchDataset.trainval_from_fingerprints(dset, n_classes, patch_size, n_folds, FOLD)

# Data augmentation parameters
# tset.set_hflip(True)
# tset.set_delta_r(np.pi / 6.0)

# initialize loader, model, and other training machinary
train_loader = DataLoader(tset,
                          batch_size=batch_size,
                          num_workers=NUM_WORKERS,
                          shuffle=True,
                          pin_memory=use_gpu)
val_loader = DataLoader(vset,
                        batch_size=VALIDATION_BATCH_SIZE,
                        num_workers=NUM_WORKERS,
                        shuffle=True,
                        pin_memory=use_gpu)

n_train = len(train_loader.dataset)
n_val = len(val_loader.dataset)
print("""Number of classes: {}
Patch size: {}
Batch size: {}
Training set size: {}
Validation set size: {}"""
    .format(n_classes, patch_size, batch_size, n_train, n_val))

In [None]:
def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):
    encoder.train()
    decoder.train()
    train_loss = []
    for image_batch, y, theta in dataloader:
        image_batch = image_batch.to(device)
        encoded_data = encoder(image_batch)
        decoded_data = decoder(encoded_data)
        loss = loss_fn(decoded_data, image_batch)
#         print(decoded_data[0].min().item(), decoded_data[0].max().item(), image_batch[0].min().item(), image_batch[0].max().item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.detach().cpu().numpy())
    return np.mean(train_loss)

def val_epoch(encoder, decoder, device, dataloader, loss_fn):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        conc_out = []
        conc_img = []
        for image_batch, y, theta in dataloader:
            image_batch = image_batch.to(device)
            encoded_data = encoder(image_batch)
            decoded_data = decoder(encoded_data)
            conc_out.append(decoded_data.cpu())
            conc_img.append(image_batch.cpu())
        conc_out = torch.cat(conc_out)
        conc_img = torch.cat(conc_img) 
        val_loss = loss_fn(conc_out, conc_img)
    return val_loss.data

def find_lr(device, train_loader, loss_fn, init_value = 1e-2, final_value=1e3, beta=0.98, mult=1.1):
    lr = init_value
    avg_loss = 0.
    best_loss = 0.
    losses = []
    log_lrs = []
    num = int(np.log(final_value / init_value) / np.log(mult))+1
    for i in range(1,num+1):
        print(i, num)
        encoded_space_dim = 256
        encoder = Encoder(encoded_space_dim)
        decoder = Decoder(encoded_space_dim)
        params_to_optimize = [
            {'params': encoder.parameters()},
            {'params': decoder.parameters()}
        ]

        # Move both the encoder and the decoder to the selected device
        encoder.to(device)
        decoder.to(device)
        optimizer = optim.Adam(params_to_optimize, lr=lr)
    
        loss = train_epoch(encoder, decoder, device, train_loader, loss_fn, optimizer)
        avg_loss = beta * avg_loss + (1-beta) *loss
        smoothed_loss = avg_loss / (1 - beta**i)
        #Stop if the loss is exploding
        if i > 1 and smoothed_loss > 4 * best_loss:
            return log_lrs, losses
        #Record the best loss
        if smoothed_loss < best_loss or i==1:
            best_loss = smoothed_loss
        #Store the values
        losses.append(loss)
        log_lrs.append(np.log10(lr))
        #Update the lr for the next step
        lr *= mult
        optimizer.param_groups[0]['lr'] = lr
    return log_lrs, losses

def plot_ae_outputs(encoder,decoder,n=5):
    plt.figure(figsize=(10,4.5))
    plt.ion()
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = vset[i][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()

In [None]:
# loss_fn = torch.nn.MSELoss()

# logs,losses = find_lr(device, train_loader,loss_fn)
# plt.figure(figsize=(10,8))
# plt.plot(logs,losses)
# print(logs, losses)

In [None]:
encoded_space_dim = 64
encoder = Encoder(encoded_space_dim).to(device)
decoder = Decoder(encoded_space_dim).to(device)
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optimizer = optim.Adam(params_to_optimize, lr=learning_rate/10)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                           [int(0.50 * n_epochs)],
                                           gamma=0.1)

# scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=n_epochs)
loss_fn = torch.nn.MSELoss()
diz_loss = {'train_loss':[],'val_loss':[]}
for epoch in range(n_epochs):
    train_loss = train_epoch(encoder,decoder,device, train_loader,loss_fn,optimizer)
    val_loss = val_epoch(encoder,decoder,device,val_loader,loss_fn)
    print('EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch + 1, n_epochs,train_loss,val_loss))
    diz_loss['train_loss'].append(train_loss)
    diz_loss['val_loss'].append(val_loss)
    # Plot losses
    if epoch%20 == 0:
        plot_ae_outputs(encoder,decoder,n=5)

In [None]:
plot_ae_outputs(encoder,decoder,n=2)

plt.figure(figsize=(10,8))
plt.semilogy(diz_loss['train_loss'], label='Train')
plt.semilogy(diz_loss['val_loss'], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
#plt.grid()
plt.legend()
#plt.title('loss')
plt.show()# Plot losses
plt.figure(figsize=(10,8))
plt.semilogy(diz_loss['train_loss'], label='Train')
plt.semilogy(diz_loss['val_loss'], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
#plt.grid()
plt.legend()
#plt.title('loss')
plt.show() 

In [None]:
def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    train_loss = 0.0
    train_err_sqr = 0.0
    n_train_batches = 0
    for idx, (x, y, theta) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        x = encoder(x)
        optimizer.zero_grad()
        ye = model(x)
    
        loss = criterion(ye, y.float())
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        theta_estimated = np.arctan2(ye[:,0].cpu().detach().numpy(),ye[:,1].cpu().detach().numpy())
        train_err_sqr += FOEOrientation.estimation_error_sqr(theta, theta_estimated)
        n_train_batches += 1

    return train_loss / n_train_batches, train_err_sqr
def validate_model(model, val_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        val_err_sqr = 0.0
        n_val_batches = 0
        for x, y, theta in val_loader:
            x, y = x.to(device), y.to(device)
            x = encoder(x)
            ye = model(x)
            val_loss += criterion(ye, y).item()
            theta_estimated = np.arctan2(ye[:,0].cpu().detach().numpy(),ye[:,1].cpu().detach().numpy())
            val_err_sqr += FOEOrientation.estimation_error_sqr(theta, theta_estimated)
            
            n_val_batches += 1
        
        return val_loss / n_val_batches, val_err_sqr

In [None]:
from foe_mlp import FOEMLP
model = FOEMLP(encoded_space_dim).to(device)
criterion = torch.nn.MSELoss().to(device)

optimizer = optim.Adam(model.parameters(),
                       lr=0.001)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
#                                            [int(0.25 * n_epochs),
#                                             int(0.50 * n_epochs),
#                                             int(0.75 * n_epochs)],
#                                            gamma=0.1)

# run training + validation for each epoch
train_res = []
val_res = []
for e in range(n_epochs):
    train_loss, train_err_sqr = train_model(model, train_loader, optimizer, criterion, device)
    train_rmse = np.sqrt(train_err_sqr / n_train) * 180.0 / np.pi  # in degrees
    train_res.append([train_loss, train_rmse])

    val_loss, val_err_sqr = validate_model(model, val_loader, criterion, device)
    val_rmse = np.sqrt(val_err_sqr / n_val) * 180.0 / np.pi  # in degrees
    val_res.append([val_loss, val_rmse])

    print('Epoch {}/{}: train loss / rmse = {:.4f} / {:.1f}° validation loss / rmse = {:.4f} / {:.1f}°'
          .format(e+1, n_epochs, train_loss, train_rmse, val_loss, val_rmse))
    
#     scheduler.step()


In [None]:
model = FOEConvNet(patch_size, n_classes).to(device)
# criterion = nn.CrossEntropyLoss().to(device)
criterion = torch.nn.MSELoss().to(device)
# criterion = Orientation_Loss()

optimizer = optim.Adam(model.parameters(),
                       lr=learning_rate)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                           [int(0.25 * n_epochs),
                                            int(0.50 * n_epochs),
                                            int(0.75 * n_epochs)],
                                           gamma=0.1)



class Orientation_Loss(torch.nn.Module):
    def __init__(self):
        super(Orientation_Loss,self).__init__()
        
    def forward(self, gt_in_radians, est_in_radians):
        deltas = gt_in_radians - est_in_radians
        deltas[deltas > np.pi/2.0] = np.pi - deltas[deltas > np.pi/2.0]
        delta_sqr = deltas ** 2
        totloss = torch.mean(delta_sqr)
        return totloss

# run training + validation for each epoch
train_res = []
val_res = []
for e in range(n_epochs):
    train_loss, train_err_sqr = train_model(model, train_loader, optimizer, criterion, device)
    train_rmse = np.sqrt(train_err_sqr / n_train) * 180.0 / np.pi  # in degrees
    train_res.append([train_loss, train_rmse])

    val_loss, val_err_sqr = validate_model(model, val_loader, criterion, device)
    val_rmse = np.sqrt(val_err_sqr / n_val) * 180.0 / np.pi  # in degrees
    val_res.append([val_loss, val_rmse])

    print('Epoch {}/{}: train loss / rmse = {:.4f} / {:.1f}° validation loss / rmse = {:.4f} / {:.1f}°'
          .format(e+1, n_epochs, train_loss, train_rmse, val_loss, val_rmse))
    
    scheduler.step()

    if e % 100 == 99:
        model_path = model_dir.joinpath('foe_conv_net_c{}_b{}_e{:04d}.pt'
                                        .format(n_classes, batch_size, n_epochs))
        torch.save(model.state_dict(), model_path)
        print('Saved model to {}'.format(model_path))

# save model also at the end
model_path = model_dir.joinpath('foe_conv_net_c{}_b{}_e{:04d}.pt'
                                .format(n_classes, batch_size, n_epochs))

torch.save(model.state_dict(), model_path)

In [None]:
train_res = np.array(train_res)
val_res = np.array(val_res)

plt.figure()
plt.plot(train_res[:,0])
plt.plot(val_res[:,0])
plt.figure()
plt.plot(train_res[:,1])
plt.plot(val_res[:,1])

In [None]:
for x, y, theta in val_loader:
#     print(x.shape, y.shape, theta.shape)
    print(y[0])
    print(theta[0]* 180.0 / np.pi)
    plt.figure()
    plt.imshow(x[0,0])