# Train a Depth Seeding Network

In [None]:
import sys, os
import json
from time import time
import glob

import torch
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.io
import cv2

# My libraries
import src.data_loader as data_loader
import src.segmentation as segmentation
import src.util.utilities as util_
import src.util.flowlib as flowlib

os.environ['CUDA_VISIBLE_DEVICES'] = "0" # TODO: Change this if you have more than 1 GPU

In [None]:
def torch_to_numpy(torch_tensor, is_standardized_image = False):
    """ Converts torch tensor (NCHW) to numpy tensor (NHWC) for plotting
    
        If it's an rgb image, it puts it back in [0,255] range (and undoes ImageNet standardization)
    """
    np_tensor = torch_tensor.cpu().clone().detach().numpy()
    if np_tensor.ndim == 4: # NCHW
        np_tensor = np_tensor.transpose(0,2,3,1)
    if is_standardized_image:
        _mean=[0.485, 0.456, 0.406]; _std=[0.229, 0.224, 0.225]
        for i in range(3):
            np_tensor[...,i] *= _std[i]
            np_tensor[...,i] += _mean[i]
        np_tensor *= 255
            
    return np_tensor

# Example Dataset: TableTop Object Dataset (TOD)

In [None]:
TOD_filepath = '...' # TODO: change this to the dataset you want to train on
data_loading_params = {
    
    # Camera/Frustum parameters
    'img_width' : 640, 
    'img_height' : 480,
    'near' : 0.01,
    'far' : 100,
    'fov' : 45, # vertical field of view in degrees
    
    'use_data_augmentation' : True,

    # Multiplicative noise
    'gamma_shape' : 1000.,
    'gamma_scale' : 0.001,
    
    # Additive noise
    'gaussian_scale' : 0.005, # 5mm standard dev
    'gp_rescale_factor' : 4,
    
    # Random ellipse dropout
    'ellipse_dropout_mean' : 10, 
    'ellipse_gamma_shape' : 5.0, 
    'ellipse_gamma_scale' : 1.0,

    # Random high gradient dropout
    'gradient_dropout_left_mean' : 15, 
    'gradient_dropout_alpha' : 2., 
    'gradient_dropout_beta' : 5.,

    # Random pixel dropout
    'pixel_dropout_alpha' : 1., 
    'pixel_dropout_beta' : 10.,
    
}
dl = data_loader.get_TOD_train_dataloader(TOD_filepath, data_loading_params, batch_size=4, num_workers=8, shuffle=True)

## Train Depth Seeding Network

In [None]:
dsn_params = {
    
    # Sizes
    'feature_dim' : 64,
    
    # algorithm parameters
    'lr' : 1e-2, # learning rate
    'iter_collect' : 20, # Collect results every _ iterations
    'max_iters' : 100000,
    
    # architecture parameters
    'use_coordconv' : False,

    # Loss function parameters
    'lambda_fg' : 1,
    'lambda_direction' : 1.,

    # Hough Voting parameters
    'skip_pixels' : 10, 
    'inlier_threshold' : 0.9, 
    'angle_discretization' : 100,
    'inlier_distance' : 20,
    'percentage_threshold' : 0.5, # this depends on skip_pixels, angle_discretization, inlier_distance. just gotta try it to see if it works
    'object_center_kernel_radius' : 10,

}
depth_seeding_network = segmentation.DepthSeedingNetwork(dsn_params)

In [None]:
# Train the network for 1 epoch
num_epochs = 1
depth_seeding_network.train(num_epochs, dl)

## Plot some losses

In [None]:
%matplotlib inline
fig = plt.figure(1, figsize=(15,3))
total_subplots = 3
starting_epoch = 0
info_items = {k:v for (k,v) in depth_seeding_network.infos.items() if k > starting_epoch}

plt.subplot(1,total_subplots,1)
losses = [x['loss'] for (k,x) in info_items.items()]
plt.plot(info_items.keys(), losses)
plt.xlabel('Iteration')
plt.title('Losses. {0}'.format(losses[-1]))

plt.subplot(1,total_subplots,2)
fg_losses = [x['FG loss'] for (k,x) in info_items.items()]
plt.plot(info_items.keys(), fg_losses)
plt.xlabel('Iteration')
plt.title('Foreground Losses. {0}'.format(fg_losses[-1]))

plt.subplot(1,total_subplots,3)
direction_losses = [x['Direction loss'] for (k,x) in info_items.items()]
plt.plot(info_items.keys(), direction_losses)
plt.xlabel('Iteration')
plt.title('Direction Losses. {0}'.format(direction_losses[-1]))

print("Number of iterations: {0}".format(depth_seeding_network.iter_num))

## Visualize some stuff

Run the network on a single batch, and plot the results

In [None]:
dl = data_loader.get_TOD_test_dataloader(TOD_filepath, data_loading_params, batch_size=8, num_workers=8, shuffle=True)
dl_iter = dl.__iter__()

batch = next(dl_iter)
rgb_imgs = torch_to_numpy(batch['rgb'], is_standardized_image=True) # Shape: [N x H x W x 3]
xyz_imgs = torch_to_numpy(batch['xyz']) # Shape: [N x H x W x 3]
foreground_labels = torch_to_numpy(batch['foreground_labels']) # Shape: [N x H x W]
direction_labels = torch_to_numpy(batch['direction_labels']) # Shape: [N x 2 x H x W]
N, H, W = foreground_labels.shape[:3]

In [None]:
print("Number of images: {0}".format(N))

depth_seeding_network.eval_mode()

### Compute segmentation masks ###
st_time = time()
seg_masks, direction_predictions, object_centers, initial_masks = depth_seeding_network.run_on_batch(batch)
total_time = time() - st_time
print('Total time taken for Segmentation: {0} seconds'.format(round(total_time, 3)))
print('FPS: {0}'.format(round(N / total_time,3)))

# Get results in numpy
seg_masks = seg_masks.cpu().numpy()
direction_predictions = direction_predictions.cpu().numpy().transpose(0,2,3,1)
initial_masks = initial_masks.cpu().numpy()
for i in range(N):
    object_centers[i] = object_centers[i].cpu().numpy()

In [None]:
fig_index = 1
for i in range(N):
    
    fig = plt.figure(fig_index); fig_index += 1
    fig.set_size_inches(20,5)

    # Plot image
    plt.subplot(1,5,1)
    plt.imshow(rgb_imgs[i,...].astype(np.uint8))
    plt.title('Image {0}'.format(i+1))

    # Plot Depth
    plt.subplot(1,5,2)
    plt.imshow(xyz_imgs[i,...,2])
    plt.title('Depth')
    
    # Plot prediction
    plt.subplot(1,5,3)
    plt.imshow(util_.get_color_mask(seg_masks[i,...]))
    plt.title("Predicted Masks")
    
    # Plot Center Direction Predictions
    plt.subplot(1,5,4)
    fg_mask = np.expand_dims(seg_masks[i,...] == 2, axis=-1)
    plt.imshow(flowlib.flow_to_image(direction_predictions[i,...] * fg_mask))
    plt.title("Center Direction Predictions")
    
    # Plot Initial Masks
    plt.subplot(1,5,5)
    plt.imshow(util_.get_color_mask(initial_masks[i,...]))
    plt.title(f"Initial Masks. #objects: {np.unique(initial_masks[i,...]).shape[0]-1}")