In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import time
import os
import cv2
import matplotlib.pyplot as plt
from ipdb import set_trace

from model import model
from warp import WarpImageWithFlowAndBrightness

In [4]:
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

In [5]:
def get_filename_info(filename):
    info_dict = {}
    titles_types = {'ID': int, 'T': str, 'N': int, 'F': int, 'V': float, 'H': float, 'img': None, 'target': None}
    info_list = filename.split('_')
    for title, info in zip(titles_types, info_list):
        info_dict[title] = titles_types[title](info[len(title):])
    return info_dict

In [6]:
img_infos = []
input_file_path = os.path.join(os.getcwd(), 'imgs_0_cutouts')
for filename in os.listdir(input_file_path):
    file = os.path.join(input_file_path, filename)
    file_no_type = filename[:filename.rindex('.')]
    if file[-4:] == '.bmp':
        img = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (64, 32))
        img = np.transpose(img, (2, 0, 1))
        info = get_filename_info(file_no_type)
        info['img'] = img
        info['target'] = info['F'] == 1
        img_infos.append(info)
    elif file[-5:] == '.json':
        continue
    else:
        raise ValueError("Invalid file type present in folder")
        break

In [7]:
img_df = pd.DataFrame(img_infos)
img_df.sort_values(['ID', 'F'], ignore_index=True, inplace=True)
img_df.head()

Unnamed: 0,ID,T,N,F,V,H,img,target
0,1,gaussian,1,1,0.1,0.0,"[[[97, 91, 131, 134, 132, 133, 130, 130, 128, ...",True
1,1,gaussian,1,2,-8.26,-9.59,"[[[97, 91, 131, 134, 132, 133, 130, 130, 128, ...",False
2,1,gaussian,1,3,0.61,4.42,"[[[97, 91, 131, 134, 132, 133, 130, 130, 128, ...",False
3,1,gaussian,1,4,8.3,14.35,"[[[97, 91, 131, 134, 132, 133, 130, 130, 128, ...",False
4,1,gaussian,1,5,12.21,-5.26,"[[[97, 91, 131, 134, 132, 133, 130, 130, 128, ...",False


In [8]:
X_img = np.stack(img_df.query('target == False')['img'])
X_angle = img_df.query('target == False')[['H', 'V']].to_numpy()
X_angle = np.tile(X_angle[:, :, np.newaxis, np.newaxis], (1, 1, 32, 64))
y = np.stack(img_df.query('target == True')['img'])
y = np.repeat(y, 39, axis=0)

In [9]:
num_samples = X_img.shape[0]
splits = [0.7, 0.2, 0.1]

X_img_train, X_img_valid, X_img_test = np.split(X_img, [int(num_samples * splits[0]), int(num_samples * (splits[0] + splits[1]))])
X_angle_train, X_angle_valid, X_angle_test = np.split(X_angle, [int(num_samples * splits[0]), int(num_samples * (splits[0] + splits[1]))])
y_train, y_valid, y_test = np.split(y, [int(num_samples * splits[0]), int(num_samples * (splits[0] + splits[1]))])

In [10]:
X_img_train = torch.tensor(X_img_train, dtype=torch.int)
X_angle_train = torch.tensor(X_angle_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.int)

X_img_valid = torch.tensor(X_img_valid, dtype=torch.int)
X_angle_valid = torch.tensor(X_angle_valid, dtype=torch.float32)
y_valid = torch.tensor(y_valid, dtype=torch.int)

X_img_test = torch.tensor(X_img_test, dtype=torch.int)
X_angle_test = torch.tensor(X_angle_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.int)

In [11]:
batch_size = 25
train_dataset = TensorDataset(X_img_train, X_angle_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = TensorDataset(X_img_valid, X_angle_valid, y_valid)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataset = TensorDataset(X_img_test, X_angle_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 100
weight_correction_loss = 0.8
weight_reconstruction_loss = 0.2

printout_freq = 1
start_time = time.time()

warp = WarpImageWithFlowAndBrightness(next(iter(train_loader))[0])

train_losses, valid_losses = [], []
for epoch in range(num_epochs):
    epoch_start_time = time.time()
    model.train()
    train_loss = 0
    for imgs, angles, targets in train_loader:
        imgs, angles, targets = imgs.float().to(device), angles.to(device), targets.float().to(device)
        
        flow_corr, bright_corr = model(imgs, angles)
        img_corr = warp(imgs, flow_corr, bright_corr)
        loss_correction = criterion(img_corr, targets)
        
        inverted_angles = -angles
        flow_reconstruction, bright_reconstruction = model(img_corr, inverted_angles)
        img_reconstruction = warp(img_corr, flow_reconstruction, bright_reconstruction)
        loss_reconstruction = criterion(img_reconstruction, imgs)
        
        loss = (weight_correction_loss * loss_correction) + (weight_reconstruction_loss) * loss_reconstruction
        train_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_losses.append(train_loss / len(train_loader))
    
    model.eval()
    with torch.no_grad():
        valid_loss = 0
        for imgs, angles, targets in valid_loader:
            imgs, angles, targets = imgs.float().to(device), angles.to(device), targets.float().to(device)

            flow_corr, bright_corr = model(imgs, angles)
            img_corr = warp(imgs, flow_corr, bright_corr)
            loss_correction = criterion(img_corr, targets)

            inverted_angles = -angles
            flow_reconstruction, bright_reconstruction = model(img_corr, inverted_angles)
            img_reconstruction = warp(img_corr, flow_reconstruction, bright_reconstruction)
            loss_reconstruction = criterion(img_reconstruction, imgs)

            loss = (weight_correction_loss * loss_correction) + (weight_reconstruction_loss) * loss_reconstruction
            valid_loss += loss.item()
        valid_losses.append(valid_loss / len(valid_loader))
    
    epoch_time = time.time() - epoch_start_time
    overall_time = time.time() - start_time
    num_days = int(overall_time / 86400)
    num_hrs = int((overall_time-(86400*num_days)) / 3600)
    num_mins = int((overall_time-(86400*num_days)-(3600*num_hrs)) / 60)
    num_secs = overall_time-(86400*num_days)-(3600*num_hrs)*(60*num_mins)
    if (epoch + 1) % printout_freq == 0:
        print(f"""Finished epoch {epoch + 1}/{num_epochs} ({(epoch+1)/num_epochs*100:.2f}%)
               total time is {num_days}d:{num_hrs}h:{num_mins}m:{num_secs:.3f}s; time for this epoch is {epoch_time:.2f}s
               training loss was {bcolors.BOLD}{train_loss:.3f}{bcolors.ENDC}, validation_loss was {bcolors.BOLD}{valid_loss:.3f}{bcolors.ENDC}.""")

model.eval()
test_loss = 0
with torch.no_grad():
    for imgs, angles, targets in test_loader:
        imgs, angles, targets = imgs.float().to(device), angles.to(device), targets.float().to(device)

        flow_corr, bright_corr = model(imgs, angles)
        img_corr = warp(imgs, flow_corr, bright_corr)
        loss_correction = criterion(img_corr, targets)
        
        inverted_angles = -angles
        flow_reconstruction, bright_reconstruction = model(img_corr, inverted_angles)
        img_reconstruction = warp(img_corr, flow_reconstruction, bright_reconstruction)
        loss_reconstruction = criterion(img_reconstruction, imgs)
        
        loss = (weight_correction_loss * loss_correction) + (weight_reconstruction_loss) * loss_reconstruction
        test_loss += loss.item()
test_loss /= len(test_loader)
print(f'Test Loss: {test_loss}')