## Data Visualization

In [None]:
import os
import shutil
from PIL import Image
import numpy as np
from scipy import io
import h5py
import pprint
import random

In [None]:
nyud_file_path, splits_file_path = './data/nyu_depth_v2_labeled.mat','./data/splits.mat'

In [None]:
def get_dataset(source_dir, target_dir):
    print("Loading dataset: NYU Depth V2")
    nyud_dict = h5py.File(nyud_file_path, 'r')
    splits_dict = io.loadmat(splits_file_path)
    return nyud_dict, splits_dict

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
target_dir = '/content/nyu_depth_v2/'
nyud_dict, splits_dict = get_dataset(nyud_file_path,target_dir)
pprint.pprint(nyud_dict.keys())

In [None]:
images = np.asarray(nyud_dict['images'])
images = images.swapaxes(2, 3)
images.shape

In [None]:
depths = np.asarray(nyud_dict['depths'])
depths = depths.swapaxes(1, 2)
depths = np.expand_dims(depths, 1)
depths.shape

In [None]:
train_indices = splits_dict['trainNdxs'][:, 0] - 1
print("Training Data Size: ", len(train_indices))
test_indices = splits_dict['testNdxs'][:, 0] - 1
print("Testing Data Size: ", len(test_indices))

In [None]:
#train_images = np.take(images, train_indices, axis=0)
#test_images = np.take(images, test_indices, axis=0)

#print(train_images.shape)

#train_depths = np.take(depths, train_indices, axis=0)
#test_depths = np.take(depths, test_indices, axis=0)

#print(train_depths.shape)

**Training Data Visualization**

In [None]:
print(len(images))
fig = plt.figure(figsize=(20,20))
k=1
for sample_idx in range(7):
    plt.subplot(5,4, k)
    plt.imshow(images[sample_idx].transpose((1,2,0)), interpolation='none')
    k+=1
    plt.xticks([])
    plt.yticks([])
    plt.subplot(5,4, k)
    plt.imshow(depths[sample_idx].transpose((1,2,0)), cmap='plasma',interpolation='none')
    k+=1
    plt.xticks([])
    plt.yticks([])
    
fig.tight_layout()
fig.show()

## Data Loaders

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import hflip
from torchvision import datasets

In [None]:
class NYUDepthDataset(torch.utils.data.Dataset):
    def __init__(self, images, maps, transform=None):
        self.images = images
        self.maps = maps
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image = torch.from_numpy(self.images[index]).float().div(255)
        dmap = torch.from_numpy(self.maps[index]).float()#.div(255) #* 1000
        #dmap = torch.clamp(dmap, 10, 1000)
        
        if self.transform:
            image = self.transform(image)
            dmap = self.transform(dmap)
        
        if random.random() > 0.5:
            image = hflip(image)
            image = image[[2, 1, 0], :, :]
            dmap = hflip(dmap)
        
        return image, dmap

In [None]:
# train_data = NYUDepthDataset(np.take(images, train_indices, axis=0), np.take(depths, train_indices, axis=0))
# val_data = NYUDepthDataset(np.take(images, test_indices, axis=0), np.take(depths, test_indices, axis=0))

train_data = NYUDepthDataset(np.take(images, np.arange(0, 1000), axis=0), np.take(depths, np.arange(0, 1000), axis=0))
#val_data = NYUDepthDataset(np.take(images, np.arange(1000, 1449), axis=0), np.take(depths, np.arange(1000, 1449), axis=0))

In [None]:
img, dmap = next(iter(train_data))
img, dmap = img.numpy(), dmap.numpy()

In [None]:
img.shape, dmap.shape

In [None]:
fig = plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img.transpose((1,2,0)), interpolation='none')
plt.xticks([])
plt.yticks([])
plt.subplot(1,2,2)
plt.imshow(dmap.transpose((1,2,0)), cmap='plasma', interpolation='none')
plt.xticks([])
plt.yticks([])
fig.tight_layout()
fig.show()

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=2,
                                           shuffle=True)

In [None]:
img, dmap = next(iter(train_loader))
img, dmap = img.numpy(), dmap.numpy()

In [None]:
img.shape, dmap.shape

In [None]:
fig = plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img[0].transpose((1,2,0)), interpolation='none')
plt.xticks([])
plt.yticks([])
plt.subplot(1,2,2)
plt.imshow(dmap[0].transpose((1,2,0)), cmap='plasma', interpolation='none')
plt.xticks([])
plt.yticks([])
fig.tight_layout()
fig.show()

## Network

In [None]:
import torch.nn as nn
from torch.nn import functional as F
from torchvision import models

In [None]:
def get_backbone(name, pretrained=True):
    if name == 'densenet169':
         backbone = models.densenet169(pretrained=True).features
    elif name == 'resnet50':
        backbone = models.resnet50(pretrained=pretrained)
    else:
        raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
    
    if name.startswith('densenet'):
        feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3']
        backbone_output = 'denseblock4'
    elif name.startswith('resnet'):
        feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
        backbone_output = 'layer4'
        
    return backbone, feature_names, backbone_output

In [None]:
class UpsampleBlock(nn.Module):
    def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=True):
        super(UpsampleBlock, self).__init__()
        
        self.parametric = parametric
        ch_out = ch_in/2 if ch_out is None else ch_out
        
        if parametric:
            self.up = nn.ConvTranspose2d(in_channels=ch_in,
                                         out_channels=ch_out,
                                         kernel_size=(4,4),
                                         stride=2,
                                         padding=1,
                                         output_padding=0,
                                         bias=(not use_bn))
        else:
            self.up = None
            ch_in += skip_in
            self.conv1 = nn.Conv2d(in_channels=ch_in,
                                   out_channels=ch_out,
                                   kernel_size=(3,3),
                                   stride=1,
                                   padding=1,
                                   bias=(not use_bn))
            
        self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
        self.relu = nn.ReLU(inplace=True)
        
        conv2_in = ch_out if not parametric else ch_out + skip_in
        self.conv2 = nn.Conv2d(in_channels=conv2_in,
                               out_channels=ch_out,
                               kernel_size=(3,3),
                               stride=1,
                               padding=1,
                               bias=(not use_bn))
        self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
        
    def forward(self, x, skip_connection=None):
        x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', align_corners=None)
        
        if self.parametric:
            x = self.bn1(x) if self.bn1 is not None else x
            x = self.relu(x)
        
        if skip_connection is not None:
            x = torch.cat([x, skip_connection], dim=1)
            
        if not self.parametric:
            x = self.conv1(x)
            x = self.bn1(x) if self.bn1 is not None else x
            x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x) if self.bn2 is not None else x
        x = self.relu(x)
        
        return x

In [None]:
class UNet(nn.Module):
    def __init__(self,
                 backbone_name='densenet169',
                 pretrained=True,
                 encoder_freeze=True,
                 input_size=(3, 480, 640),
                 classes=1,
                 decoder_filters=(256, 128, 64, 32, 16),
                 parametric_upsampling=True,
                 decoder_batchnorm=True):
        super(UNet, self).__init__()
        
        # encoder
        self.backbone_name = backbone_name
        
        self.input_size = input_size
        
        self.backbone, self.skip_features, self.bb_out_name = get_backbone(backbone_name, pretrained=True)
        
        skip_chs, bb_out_chs = self.infer_skip_channels(input_size)
        
        # decoder
        self.upsample_blocks = nn.ModuleList()
        
        decoder_filters = decoder_filters[:len(self.skip_features)]
        decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
        
        num_blocks = len(self.skip_features)
        
        for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
            self.upsample_blocks.append(UpsampleBlock(filters_in, 
                                                      filters_out,
                                                      skip_in=skip_chs[num_blocks-i-1],
                                                      parametric=parametric_upsampling,
                                                      use_bn=decoder_batchnorm))
            
        self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1,1))
        
        if encoder_freeze:
            self.freeze_encoder()
        
        
    def infer_skip_channels(self, input_size):
        x = torch.unsqueeze(torch.zeros(input_size), 0)
        
        channels = [0]
        
        for name, child in self.backbone.named_children():
            x = child(x)
            if name in self.skip_features:
                channels.append(x.shape[1])
            if name == self.bb_out_name:
                out_channels = x.shape[1]
                break
                
        return channels, out_channels
    
    def freeze_encoder(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
            
    def forward(self, *input):
        x, features = self.forward_backbone(*input)
        
        for skip_name, upsample_block in zip(self.skip_features[::-1], self.upsample_blocks):
            skip_features = features[skip_name]
            x = upsample_block(x, skip_features)
            
        x = self.final_conv(x)
        
        return x
        
    def forward_backbone(self, x):
        features = {None:None} if None in self.skip_features else dict()
        for name, child in self.backbone.named_children():
            x = child(x)
            if name in self.skip_features:
                features[name] = x
            if name == self.bb_out_name:
                break
                
        return x, features

In [None]:
net = UNet(encoder_freeze=False)

In [None]:
net.cuda()

In [None]:
img, dmap = next(iter(train_loader))
output = net(img.cuda())

In [None]:
output.size()

In [None]:
output = 1-output#torch.clip(1000.0/output, 10, 1000) / 1000

In [None]:
dmap.size()

In [None]:
fig = plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(dmap[0].numpy().transpose((1,2,0)), cmap='plasma', interpolation='none')
plt.xticks([])
plt.yticks([])
plt.subplot(1,2,2)
plt.imshow(output[0].detach().cpu().numpy().transpose((1,2,0)), cmap='plasma', interpolation='none')
plt.xticks([])
plt.yticks([])
fig.tight_layout()
fig.show()

## Losses

In [None]:
l1_criterion = nn.L1Loss()

In [None]:
def gradient_loss(gen_frames, gt_frames, alpha=1):

    def gradient(x):
        # idea from tf.image.image_gradients(image)
        # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512
        # x: (b,c,h,w), float32 or float64
        # dx, dy: (b,c,h,w)

        h_x = x.size()[-2]
        w_x = x.size()[-1]
        # gradient step=1
        left = x
        right = F.pad(x, [0, 1, 0, 0])[:, :, :, 1:]
        top = x
        bottom = F.pad(x, [0, 0, 0, 1])[:, :, 1:, :]

        # dx, dy = torch.abs(right - left), torch.abs(bottom - top)
        dx, dy = right - left, bottom - top 
        # dx will always have zeros in the last column, right-left
        # dy will always have zeros in the last row,    bottom-top
        dx[:, :, :, -1] = 0
        dy[:, :, -1, :] = 0

        return dx, dy

    # gradient
    gen_dx, gen_dy = gradient(gen_frames)
    gt_dx, gt_dy = gradient(gt_frames)
    
    grad_diff_x = torch.abs(gt_dx - gen_dx)
    grad_diff_y = torch.abs(gt_dy - gen_dy)

    # condense into one tensor and avg
    return torch.mean(grad_diff_x ** alpha + grad_diff_y ** alpha)

In [None]:
from loss import ssim

In [None]:
lr = 0.01
optimizer = torch.optim.Adam(net.parameters(), lr)

## Training Loop

In [None]:
epochs = 10

In [None]:
losses = 0.0
net.train()

for epoch in range(epochs):
    N = len(train_loader)
    
    for i, (image, depth) in enumerate(train_loader):
        image, depth = image.cuda(), depth.cuda()
        depth_n = 1000.0/depth
        
        optimizer.zero_grad()

        output = net(image)
        
        l_depth = l1_criterion(output, depth)
        l_ssim = torch.clamp(1 - ssim(output, depth_n, val_range=1000.0/10.0) * 0.5, 0, 1)
        l_grad = gradient_loss(output, depth_n)

        loss = (0.1 * l_depth) + (1.0 * l_ssim) + (1.0 * l_grad)

        loss.backward()
        optimizer.step()
        
        losses += (1 / (i + 1)) * (loss.item()/image.size(0) - losses)

        if i % 10 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
            'Loss {4} ({3})'
            .format(epoch+1, i, N, loss, losses))

In [None]:
img, dmap = next(iter(train_loader))
output = net(img.cuda())

In [None]:
fig = plt.figure(figsize=(20,20))
k=1
for sample_idx in range(2):
    plt.subplot(2,2,k)
    plt.imshow(dmap[sample_idx].numpy().transpose((1,2,0)), cmap='plasma', interpolation='none')
    plt.xticks([])
    plt.yticks([])
    k += 1
    plt.subplot(2,2,k)
    plt.imshow(output[sample_idx].detach().cpu().numpy().transpose((1,2,0)), cmap='plasma', interpolation='none')
    plt.xticks([])
    plt.yticks([])
    k += 1
    fig.tight_layout()
    fig.show()