In [131]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [164]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import time
import os
import cv2
import matplotlib.pyplot as plt
from ipdb import set_trace

In [169]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        set_trace()
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualConvBlock, self).__init__()
        self.conv1 = DepthwiseSeparableConv(in_channels, out_channels, 3)
        self.conv2 = DepthwiseSeparableConv(out_channels, out_channels, 3)
        self.conv3 = DepthwiseSeparableConv(out_channels, out_channels, 3)

    def forward(self, x):
        set_trace()
        # Residual skip connection, might need to add downsample depending on input and output channels
        residual = self.conv1(x)
        out = self.conv2(residual)
        out += residual
        out = self.conv3(out)
        return out

class ECCNet(nn.Module):
    def __init__(self):
        super(ECCNet, self).__init__()

        self.conv_block1 = ResidualConvBlock(5, 32)
        self.conv_block2 = ResidualConvBlock(32, 64)
        self.conv_block3 = ResidualConvBlock(64, 128)
        self.conv_block4 = ResidualConvBlock(128, 256)

        self.upconv_block3 = ResidualConvBlock(256 + 128, 128)
        self.upconv_block2 = ResidualConvBlock(128 + 64, 64)
        self.upconv_block1 = ResidualConvBlock(64 + 32, 32)

        self.out = nn.Conv2d(32, 3, kernel_size = 3, padding=1)

    def forward(self, img, angle):
        set_trace()
        x = torch.cat([img, angle], dim=-1)

        # Encoder
        x1 = self.conv_block1(x)
        x2 = F.max_pool2d(x1, 2)
        x2 = self.conv_block2(x2)
        x3 = F.max_pool2d(x2, 2)
        x3 = self.conv_block3(x3)
        x4 = F.max_pool2d(x3, 2)
        x4 = self.conv_block4(x4)

        # Decoder with skip connections
        x = F.interpolate(x4, scale_factor=2, mode='nearest')
        x = torch.cat([x, x3], dim=1)
        x = self.upconv_block3(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = torch.cat([x, x2], dim=1)
        x = self.upconv_block2(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = torch.cat([x, x1], dim=1)
        x = self.upconv_block1(x)

        output = self.out(x)
        flow = output[:, :2, :, :]
        brightness_map = torch.sigmoid(output[:, 2, :, :])

        return flow, brightness_map

In [133]:
def get_filename_info(filename):
    info_dict = {}
    titles_types = {'ID': int, 'T': str, 'N': int, 'F': int, 'V': float, 'H': float, 'img': None, 'target': None}
    info_list = filename.split('_')
    for title, info in zip(titles_types, info_list):
        info_dict[title] = titles_types[title](info[len(title):])
    return info_dict

In [134]:
img_infos = []
input_file_path = os.path.join(os.getcwd(), 'imgs_0_cutouts')
for filename in sorted(os.listdir(input_file_path)):
    file = os.path.join(input_file_path, filename)
    file_no_type = filename[:filename.rindex('.')]
    if file[-4:] == '.bmp':
        img = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (64, 32))
        info = get_filename_info(file_no_type)
        info['img'] = img
        info['target'] = info['F'] == 1
        img_infos.append(info)
    elif file[-5:] == '.json':
        continue
    else:
        print('Invalid file present:', filename)
        break

In [135]:
img_df = pd.DataFrame(img_infos)
img_df.sort_values(['ID', 'F'], ignore_index=True, inplace=True)

In [136]:
img_df.head()

Unnamed: 0,ID,T,N,F,V,H,img,target
0,1,gaussian,1,1,0.1,0.0,"[[[97, 97, 97], [91, 98, 105], [131, 103, 85],...",True
1,1,gaussian,1,2,-8.26,-9.59,"[[[97, 97, 97], [91, 98, 105], [131, 103, 85],...",False
2,1,gaussian,1,3,0.61,4.42,"[[[97, 97, 97], [91, 98, 105], [131, 103, 85],...",False
3,1,gaussian,1,4,8.3,14.35,"[[[97, 97, 97], [91, 98, 105], [131, 103, 85],...",False
4,1,gaussian,1,5,12.21,-5.26,"[[[97, 97, 97], [91, 98, 105], [131, 103, 85],...",False


In [137]:
X_img = np.stack(img_df.query('target == False')['img'])
X_angle = img_df.query('target == False')[['H', 'V']].to_numpy()
X_angle = np.tile(X_angle[:, np.newaxis, np.newaxis, :], (1, 32, 64, 1))
y = np.stack(img_df.query('target == True')['img'])
y = np.repeat(y, 39, axis=0)

In [138]:
num_samples = X_img.shape[0]
split_proportions = [0.7, 0.2, 0.1]

X_img_train, X_img_valid, X_img_test = np.split(X_img, [int(num_samples * split_proportions[0]), int(num_samples * (split_proportions[0] + split_proportions[1]))])
X_angle_train, X_angle_valid, X_angle_test = np.split(X_angle, [int(num_samples * split_proportions[0]), int(num_samples * (split_proportions[0] + split_proportions[1]))])
y_train, y_valid, y_test = np.split(y, [int(num_samples * split_proportions[0]), int(num_samples * (split_proportions[0] + split_proportions[1]))])

In [139]:
X_img_train = torch.tensor(X_img_train, dtype=torch.int)
X_angle_train = torch.tensor(X_angle_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.int)

X_img_valid = torch.tensor(X_img_valid, dtype=torch.int)
X_angle_valid = torch.tensor(X_angle_valid, dtype=torch.float32)
y_valid = torch.tensor(y_valid, dtype=torch.int)

X_img_test = torch.tensor(X_img_test, dtype=torch.int)
X_angle_test = torch.tensor(X_angle_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.int)

In [140]:
batch_size = 25
train_dataset = TensorDataset(X_img_train, X_angle_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = TensorDataset(X_img_valid, X_angle_valid, y_valid)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataset = TensorDataset(X_img_test, X_angle_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [170]:
model = ECCNet()

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 100
weight_correction_loss = 0.8
weight_reconstruction_loss = 0.2

printout_freq = 4
start_time = time.time()

train_losses, valid_losses = [], []
for epoch in range(num_epochs):
    epoch_start_time = time.time()
    model.train()
    train_loss = 0
    for imgs, angles, targets in train_loader:
        imgs, angles, targets = imgs.to(device), angles.to(device), targets.to(device)

        outputs_correction = model(imgs, angles)
        loss_correction = criterion(outputs_correction, targets)

        inverted_angles = -angles
        outputs_reconstruction = model(outputs_correction, inverted_angles)
        loss_reconstruction = criterion(outputs_reconstruction, inputs)

        loss = (weight_correction_loss * loss_correction) + (weight_reconstruction_loss) * loss_reconstruction
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_losses.append(train_loss / len(train_loader))

    model.eval()
    with torch.no_grad():
        valid_loss = 0
        for imgs, angles, targets in valid_loader:
            imgs, angles, targets = imgs.to(device), angles.to(device), targets.to(device)

            outputs_correction = model(imgs, angles)
            loss_correction = criterion(outputs_correction, targets)

            inverted_angles = -angles
            outputs_reconstruction = model(outputs_correction, inverted_angles)
            loss_reconstruction = criterion(outputs_reconstruction, inputs)

            loss = (weight_correction_loss * loss_correction) + (weight_reconstruction_loss) * loss_reconstruction
            valid_loss += loss.item()
        valid_losses.append(valid_loss / len(valid_loader))
        
    epoch_time = time.time() - epoch_start_time
    overall_time = time.time() - start_time
    num_days = int(overall_time / 86400)
    num_hrs = int((overall_time-(86400*num_days)) / 3600)
    num_mins = int((overall_time-(86400*num_days)-(3600*num_hrs)) / 60)
    num_secs = overall_time-(86400*num_days)-(3600*num_hrs)*(60*num_mins)
    if (epoch + 1) % printout_freq == 0:
        print(f"""Finished epoch {epoch + 1}/{num_epochs} ({(epoch+1)/num_epochs*100:.2f}%)
               total time is {num_days}d:{num_hrs}h:{num_mins}m:{num_secs:.3f}s; time for this epoch is {epoch_time:.2f}s
               training loss was {bcolors.BOLD}{train_loss:.3f}{bcolors.ENDC}, validation_loss was {bcolors.BOLD}{valid_loss:.3f}{bcolors.ENDC}.""")

model.eval()
test_loss = 0
with torch.no_grad():
    for imgs, angles, targets in test_loader:
        imgs, angles, targets = imgs.to(device), angles.to(device), targets.to(device)

        outputs_correction = model(imgs, angles)
        loss_correction = criterion(outputs_correction, targets)

        inverted_angles = -angles
        outputs_reconstruction = model(outputs_correction, inverted_angles)
        loss_reconstruction = criterion(outputs_reconstruction, inputs)

        loss = (weight_correction_loss * loss_correction) + (weight_reconstruction_loss) * loss_reconstruction
        test_loss += loss.item()
test_loss /= len(test_loader)
print(f'Test Loss: {test_loss}')

> [1;32mc:\users\rickr\appdata\local\temp\ipykernel_23120\4058938542.py[0m(54)[0;36mforward[1;34m()[0m

ipdb> l
[0;32m     49 [0m[1;33m[0m[0m
[0;32m     50 [0m        [0mself[0m[1;33m.[0m[0mout[0m [1;33m=[0m [0mnn[0m[1;33m.[0m[0mConv2d[0m[1;33m([0m[1;36m32[0m[1;33m,[0m [1;36m3[0m[1;33m,[0m [0mkernel_size[0m [1;33m=[0m [1;36m3[0m[1;33m,[0m [0mpadding[0m[1;33m=[0m[1;36m1[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0;32m     51 [0m[1;33m[0m[0m
[0;32m     52 [0m    [1;32mdef[0m [0mforward[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mimg[0m[1;33m,[0m [0mangle[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m
[0;32m     53 [0m        [0mset_trace[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;32m---> 54 [1;33m        [0mx[0m [1;33m=[0m [0mtorch[0m[1;33m.[0m[0mcat[0m[1;33m([0m[1;33m[[0m[0mimg[0m[1;33m,[0m [0mangle[0m[1;33m][0m[1;33m,[0m [0mdim[0m[1;33m=[0m[1;33m-[0m[1;36m1[0m[1;33m

BdbQuit: 

In [154]:
img = np.random.rand(32, 64, 3)
angle = np.random.rand(32, 64, 2)
img, angle = torch.tensor(img), torch.tensor(angle)
torch.cat([img, angle], dim=-1).shape

torch.Size([32, 64, 5])