In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

from PIL import Image

In [2]:
df = pd.read_csv('df.csv', index_col=0)

In [3]:
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [4]:
from autoencoder import Autoencoder

  warn(f"Failed to load image Python extension: {e}")


In [5]:
a = Autoencoder().to(device)

In [6]:
from data import get_split_df
from torchvision.transforms import Resize, ToTensor

resize = Resize((224))
to_tensor = ToTensor()

transforms = [to_tensor, resize]

train_dataset, test_dataset = get_split_df('df.csv', 
                                           transform=transforms, 
                                           target_transform=transforms, 
                                           stds=[0.1, 0.01])

len(train_dataset) + len(test_dataset) == len(df)

True

In [7]:
from torch.utils.data import DataLoader

bs = 12

dataloader_kwargs = {'batch_size': bs, 
                     'num_workers': 4,
                     'prefetch_factor': 2,
                     'persistent_workers': True,
                     'shuffle': True, 
                     'pin_memory': False}

train_dataloader = DataLoader(train_dataset, **dataloader_kwargs)
test_dataloader = DataLoader(test_dataset, **dataloader_kwargs)

In [8]:
class run_logger():

    def __init__(self,
                **kwargs):
                     
        self._init(**kwargs)
        

    def _init(self,
             run_id, 
             n_epoch, 
             n_iter_train, 
             n_iter_test):
                 
        self.run_id = run_id
        
        self.log_dir = os.path.join(f'logs/{self.run_id}')
        os.makedirs(self.log_dir, exist_ok=False)
        
        self.train_loss_matrix = np.zeros((n_epoch, n_iter_train))
        self.train_losses_outpath = os.path.join(self.log_dir, 'train_losses.npy')
        
        self.test_loss_matrix = np.zeros((n_epoch, n_iter_test))
        self.test_losses_outpath = os.path.join(self.log_dir, 'test_losses.npy')

    def save_losses(self):
        np.save(self.train_losses_outpath, self.train_loss_matrix)
        np.save(self.test_losses_outpath, self.test_loss_matrix)

In [9]:
# Reconstruction losses:
# https://research.nvidia.com/sites/default/files/pubs/2017-03_Loss-Functions-for/NN_ImgProc.pdf
MSE = torch.nn.MSELoss()
MAE = torch.nn.L1Loss()
criterion = MAE

opt = torch.optim.Adam(a.parameters(), lr=3e-4)

In [10]:
from tqdm import tqdm
from uuid import uuid4

In [11]:
n_epoch = 56

run_id = str(uuid4())
logger = run_logger(run_id=run_id, 
                    n_epoch=n_epoch, 
                    n_iter_train=len(train_dataloader), 
                    n_iter_test=len(test_dataloader))

In [12]:
change_strategy_window = 50
lr_rrf = 0.95

for epoch in range(n_epoch):
    
    a = a.train(True)
    
    train_loop = tqdm(train_dataloader)

    # latest_lr = opt.defaults['lr']

    # TRAIN LOOP
    for i, (x, y) in enumerate(train_loop):
        opt.zero_grad()
        
        x = x.to(device)
        y = y.to(device)

        if len(x.shape) > 4:
            
            x = x.reshape((-1, 3, 224, 224))
            y = y.reshape((-1, 3, 224, 224))
            
        out = a(y)
        
        batch_MSE = MSE(out, x)
        batch_loss = criterion(out, x) if criterion == MAE else batch_MSE
        
        logger.train_loss_matrix[epoch][i] = batch_MSE.item()

        # # Change strategy if the loss hasn't dropped within a window of batches
        # if i>change_strategy_window:
        #     if logger.train_loss_matrix[epoch][i-change_strategy_window:i].min() == logger.train_loss_matrix[epoch][i-change_strategy_window]:
        #         opt.defaults['lr'] *= lr_rrf

        batch_loss.backward()
        opt.step()

        train_loop.set_description(f"Epoch [{epoch}/{n_epoch}]")
        train_loop.set_postfix({'MSE_train': np.mean(logger.train_loss_matrix[epoch][logger.train_loss_matrix[epoch]>0]),
                         'criterion': 'MSE' if criterion == MSE else 'MAE',
                         'lr': opt.defaults['lr']})

    
    # TEST LOOP
    a = a.train(False)
    
    test_loop = tqdm(test_dataloader)

    with torch.no_grad():
        for i, (x, y) in enumerate(test_loop):
            
            x = x.to(device)
            y = y.to(device)
            out = a(y)
            
            test_batch_loss = MSE(out, x)
            logger.test_loss_matrix[epoch][i] = test_batch_loss.item()
    
            test_loop.set_description(f"Epoch [{epoch}/{n_epoch}]")
            test_loop.set_postfix({'MSE_test': np.mean(logger.test_loss_matrix[epoch][logger.test_loss_matrix[epoch]>0]),})

    
    epoch_test_loss = np.mean(logger.test_loss_matrix[epoch])

    # Save this model if it had the best mean test loss
    if np.argmin(np.mean(logger.test_loss_matrix[:epoch+1], axis=-1)) == epoch:
        model_out_path = os.path.join(logger.log_dir, 'best.pt')
        torch.save(a.state_dict(), model_out_path)
        print(f'saving best, MSE_test: {epoch_test_loss:.6f}')
    else:
        # opt.defaults['lr'] = latest_lr
        criterion = MSE if criterion == MAE else MAE
        opt.defaults['lr'] *= 0.75

    logger.save_losses()

Epoch [0/56]:   4%|▍         | 16/360 [00:11<04:16,  1.34it/s, MSE_train=0.0823, criterion=MAE, lr=0.0003]


KeyboardInterrupt: 

In [None]:
fig, ax = plt.subplots(nrows=bs, ncols=3, figsize=(12,4*bs))

a.train(False)
a.load_state_dict(torch.load(model_out_path))

x, y = next(iter(test_dataloader))
x = x.to(device)
y = y.to(device)
out = a(y)

for i in range(x.shape[0]):

    _x = x[i].permute(1,2,0).cpu().detach().numpy()
    _out = out[i].permute(1,2,0).cpu().detach().numpy()
    _y = y[i].permute(1,2,0).cpu().detach().numpy()
    
    ax[i][0].imshow(_x)
    ax[i][1].imshow(_out)
    ax[i][2].imshow(_y)
    
    ax[i][0].set_title('x')
    ax[i][1].set_title('x_prime')
    ax[i][2].set_title('y')
    
    ax[i][0].axis('off')
    ax[i][1].axis('off')
    ax[i][2].axis('off')