# Basic Setup

In [None]:
# Check device running the notebook automatically
import sys
is_on_colab = 'google.colab' in sys.modules
is_on_zerus = 'teampc' in sys.argv[0]
print("Is on colab: ", is_on_colab)
print("Is on zerus:", is_on_zerus)

## Setup for Colab

In [None]:
if is_on_colab:
    # Google Colab setup
    
    # Mount drive
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

    # Retrieve repository and cd into root folder
    from getpass import getpass
    import urllib
    import os
    os.chdir("/content")
    user = input('Github user name: ')
    password = getpass('Github password: ')
    password = urllib.parse.quote(password) # your password is converted into url format
    branch = "" # "-b " + "branch_name"
    cmd_string = 'git clone {0} https://{1}:{2}@github.com/lukasHoel/novel-view-synthesis.git'.format(branch, user, password)
    os.system(cmd_string)
    os.chdir("novel-view-synthesis")

    # Install PyTorch3D libraries (required for pointcloud computations.)
    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git'
    !pwd

## Setup for Local Execution

In [None]:
# ONLY NECESSARY FOR LOCAL EXECUTION (WORKS WITHOUT THIS CELL IN GOOGLE COLAB)
# Setup that is necessary for jupyter notebook to find sibling-directories
# see: https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im


if not is_on_colab:
    
    import os
    import sys
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)


## General Settings

In [None]:
# Imports for this notebook
from models.nvs_model import NovelViewSynthesisModel
from models.synthesis.synt_loss_metric import SynthesisLoss, SceneEditingLoss, SceneEditingAndSynthesisLoss, SynthesisLossRGBandSeg
from util.nvs_solver import NVS_Solver
from util.gan_wrapper_solver import GAN_Wrapper_Solver
from data.nuim_dataloader import ICLNUIMDataset
from data.nuim_dynamics_dataloader import ICLNUIM_Dynamic_Dataset
from data.mp3d_dataloader import MP3D_Habitat_Offline_Dataset
from projection.z_buffer_manipulator import PtsManipulator

from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms
import torch
import torch.nn as nn
import numpy as np

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# Check training on GPU?

cuda = torch.cuda.is_available()

print("Training is on GPU with CUDA: {}".format(cuda))

device = "cuda:0" if cuda else "cpu"

print("Device: {}".format(device))

!nvidia-smi

In [None]:
def count_parameters(model):
    """Given a model return total number of parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Load Data
Load ICL-NUIM dataset or Matterport3D dataset.


In [None]:
#dataset_mode = PtsManipulator.matterport_mode
dataset_mode = PtsManipulator.icl_nuim_mode 

use_dynamics = True

image_size = 128

print("Using dataset: " + dataset_mode)
print("Image Size: " + str(image_size))
print("Using dynamics: " + str(use_dynamics))


In [None]:
# Load dataset from drive or local

if is_on_colab:
    if dataset_mode == PtsManipulator.matterport_mode:
        path = "/content/drive/My Drive/Novel_View_Synthesis/matterport3d"
    elif dataset_mode == PtsManipulator.icl_nuim_mode:
        path = "/content/drive/My Drive/Novel_View_Synthesis/ICL-NUIM/living_room_traj2_loop"
        
elif is_on_zerus:
    if dataset_mode == PtsManipulator.matterport_mode:
        raise ValueError("Path to mp3d on zerus not specified in this notebook!")
    elif dataset_mode == PtsManipulator.icl_nuim_mode and not use_dynamics:
        path = "/mnt/raid/teampc/ICL-NUIM/living_room_traj2_loop"
    elif dataset_mode == PtsManipulator.icl_nuim_mode and use_dynamics:
        path = "/mnt/raid/teampc/ICL-NUIM/custom/seq0001"
        
else:
    if dataset_mode == PtsManipulator.matterport_mode:
        path = "/home/lukas/Desktop/git/synsin/dataset"
    elif dataset_mode == PtsManipulator.icl_nuim_mode and not use_dynamics:
        path = "/home/lukas/Desktop/datasets/ICL-NUIM/prerendered_data/living_room_traj2_loop"
    elif dataset_mode == PtsManipulator.icl_nuim_mode and use_dynamics:
        path = "/home/lukas/Desktop/datasets/ICL-NUIM/custom/seq0001"

transform = torchvision.transforms.Compose([
    #torchvision.transforms.ToPILImage(), # no longer needed: new dataloader now returns PIL Images
    torchvision.transforms.Resize((image_size, image_size)),
    torchvision.transforms.ToTensor()
])
    
data_dict = {
    "mode": dataset_mode,
    "image_size": image_size,
    "use_dynamics": use_dynamics,
    "path": path,
    "sampleOutput": True,
    "inverse_depth": False,
    "cacheItems": False, # Caching will work only if num_workers = 0. Decide what you like more!
}
    
if dataset_mode == PtsManipulator.matterport_mode:
    
    # THIS IS THE HARDCODED IMAGE SIZE THAT WE SET IN THE HABITAT FRAMEWORK WHEN RENDERING MP3D IMAGES
    # THIS DOES NOT CHANGE WHEN WE USE DIFFERENT IMAGE SIZES IN A TRANSFORM OBJECT
    # WHEN CHANGING THE IMAGE SIZE IN TRANSFORM OBJECT, THIS GETS REFLECTED IN THE image_size ATTRIBUTE
    data_dict['mp3d_image_input_size'] = 256
    
    data_dict['train_path'] = path + "/train"
    data_dict['val_path'] = path + "/val"

    train_dataset = MP3D_Habitat_Offline_Dataset(data_dict['train_path'],
                                        in_size=data_dict['mp3d_image_input_size'],
                                        transform=transform,
                                        sampleOutput=data_dict["sampleOutput"],
                                        inverse_depth=data_dict["inverse_depth"],
                                        cacheItems=data_dict["cacheItems"])
    
    print("Loaded following data: {} (samples: {}) with configuration: {}\n".format(data_dict["train_path"], len(train_dataset), data_dict))
    
    val_dataset = MP3D_Habitat_Offline_Dataset(data_dict['val_path'],
                                        in_size=data_dict['mp3d_image_input_size'],
                                        transform=transform,
                                        sampleOutput=data_dict["sampleOutput"],
                                        inverse_depth=data_dict["inverse_depth"],
                                        cacheItems=data_dict["cacheItems"])
    
    print("Loaded following data: {} (samples: {}) with configuration: {}\n".format(data_dict["val_path"], len(val_dataset), data_dict))
        
elif dataset_mode == PtsManipulator.icl_nuim_mode and not use_dynamics:

    data_dict['icl_nuim_output_size'] = image_size
    data_dict['path'] = path
    
    dataset = ICLNUIMDataset(data_dict['path'],
                             transform=transform,
                             sampleOutput=data_dict["sampleOutput"],
                             inverse_depth=data_dict["inverse_depth"],
                             cacheItems=data_dict["cacheItems"], 
                             out_shape=(image_size, image_size))

    print("Loaded following data: {} (samples: {}) with configuration: {}".format(data_dict["path"], len(dataset), data_dict))
    
elif dataset_mode == PtsManipulator.icl_nuim_mode and use_dynamics:
    
    data_dict['icl_nuim_output_size'] = image_size
    data_dict['path'] = path
    data_dict['icl_dynamic_output_from_other_view'] = False
    
    dataset = ICLNUIM_Dynamic_Dataset(data_dict['path'],
                             sampleOutput=True,
                             output_from_other_view=data_dict['icl_dynamic_output_from_other_view'], 
                             inverse_depth=False,
                             cacheItems=False,
                             transform=transform,
                             out_shape=(image_size, image_size))
    
    print("Loaded following data: {} (samples: {}) with configuration: {}".format(data_dict["path"], len(dataset), data_dict))

In [None]:
dataset_args = {
    "batch_size": 2,
    "num_workers": 1, # Dataset Caching will work only if num_workers = 0. Decide what you like more!
    "random_seed": 42, # seed random generation for shuffeling indices to always get same images in train/val
    "shuffle_dataset": True,
    **data_dict
}

if dataset_mode == PtsManipulator.matterport_mode:
    # For mp3d we have separate train/val folders so we can just create different loaders out of the different datasets

    train_len = len(train_dataset)
    train_sampler = SubsetRandomSampler(list(range(train_len)))

    val_len = len(val_dataset)
    val_sampler = SubsetRandomSampler(list(range(val_len)))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=dataset_args["batch_size"], 
                                               #shuffle=dataset_args["shuffle_dataset"],
                                               sampler=train_sampler,
                                               num_workers=dataset_args["num_workers"])
    
    validation_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=dataset_args["batch_size"], 
                                            #shuffle=dataset_args["shuffle_dataset"],
                                            sampler=val_sampler,
                                            num_workers=dataset_args["num_workers"])

elif dataset_mode == PtsManipulator.icl_nuim_mode:
    # Create Train and Val dataset with 80% train and 20% val.
    # from: https://stackoverflow.com/questions/50544730/how-do-i-split-a-custom-dataset-into-training-and-test-datasets

    # For ICL dataset we do not have train/val datasets so we split the existing dataset 80% to 20%
    dataset_args["validation_percentage"] = 0.2

    # Creating data indices for training and validation splits:
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(dataset_args["validation_percentage"] * dataset_size))
    if dataset_args["shuffle_dataset"]:
        np.random.seed(dataset_args["random_seed"])
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # #####################
    # ICL OVERFITTING CASE:
    # #####################
    train_indices = train_indices[:4] # train_indices[0:4] # [train_indices[0]]
    val_indices = val_indices[:2]

    overfit_item = dataset.__getitem__(train_indices[0])
    print("OVERFITTING Input Image: {}, Output Image: {}".format(
        train_indices[0],
        overfit_item["output"]["idx"]))

    input_img = overfit_item["image"].cpu().detach().numpy()
    output_img = overfit_item["output"]["image"].cpu().detach().numpy()
    output_seg = overfit_item["output"]["seg"].cpu().detach().numpy()

    print(torch.min(overfit_item["output"]["image"]))
    print(torch.max(overfit_item["output"]["image"]))
    print(overfit_item["cam"])

    %matplotlib inline

    import matplotlib
    import matplotlib.pyplot as plt

    print("OVERFIT TRAIN INPUT IMAGE")
    plt.imshow(np.moveaxis(input_img, 0, -1))
    plt.show()

    print("OVERFIT TRAIN OUTPUT IMAGE")
    plt.imshow(np.moveaxis(output_img, 0, -1))
    plt.show()
    
    print("OVERFIT TRAIN OUTPUT SEG")
    plt.imshow(np.moveaxis(output_seg, 0, -1))
    plt.show()
    # #########################
    # END ICL OVERFITTING CASE
    # #########################
    
    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=dataset_args["batch_size"], 
                                               sampler=train_sampler,
                                               num_workers=dataset_args["num_workers"])
    validation_loader = torch.utils.data.DataLoader(dataset,
                                                    batch_size=dataset_args["batch_size"],
                                                    sampler=valid_sampler,
                                                    num_workers=dataset_args["num_workers"])

dataset_args["train_len"] = len(train_loader)
dataset_args["val_len"] = len(validation_loader)

print("Dataset parameters: {}".format(dataset_args))


# Model & Loss Init

Instantiate and initialize NovelViewSynthesisModel and a selected flavor of SynthesisLoss.

In [None]:
# TODO: Define more parameters in the dict according to availalbe ones in the model, as soon as they are needed.
# Right now we just use the default parameters for the rest (see outcommented list or the .py file)
    
model_args={
    'imageSize': image_size, # change this now in the first dataloading cell from above!
    
    'use_gt_depth': True,
    'normalize_images': False,
    'use_rgb_features': True,

    'num_depth_filters': 16,
    
    'enc_dims': [3, 8, 8, 16, 16, 32, 32, 64],
    'enc_blk_types': ["id", "id", "id", "id", "id", "id", "id"],
    #'enc_dims': [3, 8, 8, 16, 16, 32, 32, 64, 64, 64],
    #'enc_blk_types': ["id", "id", "id", "id", "id", "id", "id", "id", "id"],
    #'enc_dims': [3, 8, 8],
    #'enc_blk_types': ["id", "id"],
    'enc_noisy_bn': False,
    'enc_spectral_norm': True,
    
    'dec_activation_func': nn.Sigmoid(),
    #'dec_dims': [64, 64, 32, 32, 32, 16, 16, 8, 8, 3],
    #'dec_blk_types': ["id", "id", "id", "id", "id", "id", "id", "id", "id"],
    #'dec_dims': [64, 32, 32, 16, 16, 8, 8, 3],
    #'dec_blk_types': ["id", "id", "id", "id", "id", "id", "id"],
    #'dec_dims': [3, 8, 8, 16, 16, 32, 32, 64, 64, 32, 32, 16, 16, 8, 8, 3],
    #'dec_blk_types': ["id", "id", "id", "id", "id", "id", "id", "id", "id", "id", "id", "id", "id", "id", "id"],
    'dec_dims': [3, 16, 32, 64, 64, 64, 64, 64, 32, 128, 128, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 3],
    'dec_blk_types': ["id", "id", "id", "id", "id", "id", "id", "id", "avg", "avg", "id", "ups", "ups", "id", "id", "id", "id", "id", "id", "id", "id"],
    'seg_dims': [64, 64, 64, 3],
    'seg_blk_types': ["id", "id", "id"],
    'shared_layers': 14,
    'dec_noisy_bn': False,
    'dec_spectral_norm': True,
                      
    'projection_mode': dataset_mode,
    
    # from here attributes for the SynthesisLossRGBandSeg of the nvs_model
    'rgb_l1_loss': '1.0_l1',
    'rgb_content_loss': '10.0_content', # synsin default: 10.0
    'seg_l1_loss': '1.0_l1', # this is only relevant for training without dynamics, will be ignored otherwise
    'seg_content_loss': '10.0_content', # this is only relevant for training without dynamics, will be ignored otherwise
    
    # from here attributes for the SceneEditingLoss of the nvs_model
    'lp_region_lambda': 10.0,
    'lp_region_params': [1.0, 1.0, 5.0],
}

# keep these loss object constant and modify usage of losses by e.g. setting one coefficient to 0
if not use_dynamics:
    
    nvs_loss = SynthesisLossRGBandSeg(rgb_losses=[model_args['rgb_l1_loss'], model_args['rgb_content_loss']],
                                      seg_losses=[model_args['seg_l1_loss'], model_args['seg_content_loss']])
    
else:

    nvs_loss = SceneEditingAndSynthesisLoss(
        synthesis_losses=[model_args['rgb_l1_loss'], model_args['rgb_content_loss']],
        scene_editing_weight=model_args['lp_region_lambda'],
        scene_editing_lpregion_params=model_args['lp_region_params']
    )

model = NovelViewSynthesisModel(imageSize=model_args['imageSize'],
                                
                                max_z=10,
                                min_z=0,
                                num_filters=model_args['num_depth_filters'],
                                
                                enc_dims=model_args['enc_dims'],
                                enc_blk_types=model_args['enc_blk_types'],
                                enc_noisy_bn=model_args['enc_noisy_bn'],
                                enc_spectral_norm=model_args['enc_spectral_norm'],
                                
                                dec_dims=model_args['dec_dims'],
                                dec_blk_types=model_args['dec_blk_types'],
                                seg_dims=model_args['seg_dims'],
                                seg_blk_types=model_args['seg_blk_types'],
                                shared_layers=model_args['shared_layers'],
                                dec_activation_func=model_args['dec_activation_func'],
                                dec_noisy_bn=model_args['dec_noisy_bn'],
                                dec_spectral_norm=model_args['dec_spectral_norm'],
                                
                                projection_mode=model_args['projection_mode'],
                                #points_per_pixel=8,
                                #learn_feature=True,
                                #radius=3.0,
                                #rad_pow=2,
                                #accumulation='alphacomposite',
                                #accumulation_tau=1,
                                
                                use_rgb_features=model_args['use_rgb_features'],
                                use_gt_depth=model_args['use_gt_depth'],
                                #use_inverse_depth=False,
                                normalize_images=model_args['normalize_images'])
model_args["model"] = type(model).__name__

print("Model configuration: {}".format(model_args))

#print("Architecture:", model)
print("Total number of paramaters:", count_parameters(model))
print("Parameters ENCODER:", count_parameters(model.encoder))
print("Parameters DEPTH:", count_parameters(model.pts_regressor))
print("Parameters DECODER:", count_parameters(model.projector))

# Training Visualization

Start Tensorboard for visualization of the upcoming training / validation / test steps.

In [None]:
# Start tensorboard. Might need to make sure, that the correct runs directory is chosen here.
#%load_ext tensorboard
#%tensorboard --logdir "../runs"
#!tensorboard --logdir ../runs

# Training

Start training process.

In [None]:
# This flag decides with solver gets used and where the logs will be logged into (into which directory)
train_with_discriminator = False

In [None]:
# Create unique ID for this training process for saving to disk.

from datetime import datetime
import uuid
now = datetime.now() # current date and time
id = str(uuid.uuid1())
id_suffix = now.strftime("%Y-%b-%d_%H-%M-%S") + "_" + id

if train_with_discriminator:
    log_dir_name = "Full_GAN"
else:
    log_dir_name = "Full_No_GAN"

log_dir = "../runs/" + log_dir_name + "/" + id_suffix # Might need to make sure, that the correct runs directory is chosen here.
print("log_dir:", log_dir)

In [None]:
# Configure solver
extra_args = {
    **model_args,
    **dataset_args,
    'num_D': 3, # number of discriminators, each downsamples by 2
    'size_D': 64, # number of channels each conv in the discriminator has
    'loss_D': 'original', # discriminator loss, options are original(cross-entropy), ls (MSE), hinge, w
    'no_feature_loss': False, # if discriminator should not use feature loss
    'init_weights': True,
    'lr_step': 10, #number of epochs after which the learning rate is mulitplied with gamma
    'lr_gamma': 0.3
}

if train_with_discriminator:
    solver = GAN_Wrapper_Solver(optim_d=torch.optim.Adam,
                                optim_d_args={"lr": 1e-3,
                                              "betas": (0.9, 0.999),
                                              "eps": 1e-8,
                                              "weight_decay": 0.0},# is the l2 regularization parameter, see: https://pytorch.org/docs/stable/optim.html
                                optim_g=torch.optim.Adam,
                                optim_g_args={"lr": 1e-3,
                                              "betas": (0.9, 0.999),
                                              "eps": 1e-8,
                                              "weight_decay": 0.0}, # is the l2 regularization parameter, see: https://pytorch.org/docs/stable/optim.html
                                g_loss_func=nvs_loss,
                                extra_args=extra_args,
                                log_dir=log_dir,
                                num_D=extra_args['num_D'],
                                size_D=extra_args['size_D'],
                                loss_D=extra_args['loss_D'],
                                no_gan_feature_loss=extra_args['no_feature_loss'],
                                init_discriminator_weights=extra_args['init_weights'],
                                lr_step=extra_args['lr_step'],
                                lr_gamma=extra_args['lr_gamma'])
else:
    solver = NVS_Solver(optim=torch.optim.Adam,
                        optim_args={"lr": 1e-3,
                                    "betas": (0.9, 0.999),
                                    "eps": 1e-8,
                                    "weight_decay": 0.0}, # is the l2 regularization parameter, see: https://pytorch.org/docs/stable/optim.html,
                        loss_func=nvs_loss,
                        extra_args=extra_args,
                        tensorboard_writer=None, # let solver create a new instance
                        log_dir=log_dir)

In [None]:
# Start training

num_epochs=10000
log_nth_iter=100
log_nth_epoch=100
tqdm_mode='total'
'''
tqdm_mode:
    'total': tqdm log how long all epochs will take,
    'epoch': tqdm for each epoch how long it will take,
    anything else, e.g. None: do not use tqdm
'''

'''
Use CUDA_VISIBLE_DEVICES=0,1,2 jupyter notebook etc depending on how many gpu you want to use.
The model will then be replicated on each gpu and the batches are split between them. 
So if you use 2 gpus you could make the batch_size twice as large as before.
good for processing a lot of data quickly
'''
if torch.cuda.device_count() > 1:
    print("Using multiple GPUs!!")
    model = nn.DataParallel(model)

# TODO: Add parameters to extra_args dict?
if train_with_discriminator:
    steps = 1 # how many steps of training for discriminator/generator before switching to generator/discriminator
    solver.train(model,
                 train_loader, 
                 validation_loader,
                 num_epochs=num_epochs,
                 log_nth_iter=log_nth_iter,
                 log_nth_epoch=log_nth_epoch,
                 tqdm_mode=tqdm_mode,
                 steps=steps)
else:
    solver.train(model,
                 train_loader,
                 validation_loader,
                 num_epochs=num_epochs,
                 log_nth_iter=log_nth_iter,
                 log_nth_epoch=log_nth_epoch,
                 tqdm_mode=tqdm_mode,
                 verbose=False)

In [None]:
# To download tensorboard runs from Colab

# TODO: Make sure that only new ones are copied --> for tensorboard runs on colab, do not use git repository as "runs" directory?
# TODO: Instead of downloading, directly move it to the git repository that is currently checked out and push changes?
if is_on_colab:
  from google.colab import files
  !zip -r /content/runs.zip /content/runs
  files.download("/content/runs.zip")

# Test

Test with test dataset.
Will load the data and start the training.

Visualizations can be seen in Tensorboard above.

In [None]:
# Load test data
# TODO: Find real test split, for now we load the SAME dataset as for train/val (just that this notebook is complete...)
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
if is_on_zerus:
    test_path = "/mnt/raid/teampc/ICL-NUIM/office_room_traj2_loop"

if is_on_colab:
    test_path = "/content/drive/My Drive/Novel_View_Synthesis/ICL-NUIM/living_room_traj2_loop"

test_dataset = ICLNUIMDataset(test_path, transform=transform) # TODO also use rest of parameters...

test_indices = list(range(len(test_dataset)))
np.random.shuffle(test_indices)

test_sampler = SubsetRandomSampler(test_indices[:len(test_indices)//10])

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=dataset_args["batch_size"], 
                                          sampler=test_sampler,
                                          num_workers=4)

print("Length of test set: {}".format(len(test_loader)))
print("Loaded test set: {}".format(test_path))

In [None]:
# Start testing

solver.test(model, test_loader, test_prefix="icl_test", log_nth=1)

## Evaluate on an ICL dynamics dataset

In [None]:
sequence = "seq0002" # indices [1-3] are available

if is_on_colab:
    test_path = "/content/drive/My Drive/Novel_View_Synthesis/ICL-NUIM/custom/" + sequence
        
elif is_on_zerus:
    test_path = "/mnt/raid/teampc/ICL-NUIM/custom/" + sequence
        
else:
    test_path = "/home/lukas/Desktop/datasets/ICL-NUIM/custom/" + sequence
    
test_dataset = ICLNUIM_Dynamic_Dataset(test_path,
                             sampleOutput=True,
                             output_from_other_view=data_dict['icl_dynamic_output_from_other_view'], 
                             inverse_depth=False,
                             cacheItems=False,
                             transform=transform,
                             out_shape=(image_size, image_size))

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, 
                                          shuffle=False,
                                          num_workers=4)

print("Length of test set: {}".format(len(test_loader)))
print("Loaded test set: {}".format(test_path))

In [None]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import imageio

with torch.no_grad():
    for i, batch in enumerate(tqdm(test_loader)):
        _, out, _ = solver.forward_pass(model, batch)
        pred_img = out["PredImg"].squeeze().permute((1,2,0)).cpu().detach().numpy()
        pred_depth = out["PredDepth"].squeeze().cpu().detach().numpy()
        plt.imshow(pred_img)
        plt.show()
        plt.imshow(pred_depth)
        plt.show()
        
        imageio.imwrite(log_dir + '/pred_img_'+str(i)+'.png', pred_img)
        imageio.imwrite(log_dir + '/pred_depth_'+str(i)+'.png', pred_depth)
        
        

## Generating a Test Time Trajectory

In [None]:
from pprint import pprint
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import math
from util.nvs_solver import to_cuda, default_batch_loader
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt


# Load the model if needed by using the last cell

model.eval().cuda()

# Pick an image for the first frame and extract items related to it
test_item_idx = 0
test_item = test_dataset.__getitem__(test_item_idx)
input_img, K, K_inv, input_RT, input_RT_inv, output_RT, output_RT_inv, gt_img, depth_img = to_cuda(default_batch_loader(test_item))

### Rotation & Translation
Use the sliders to jointly modify every axis and the rotation around them.

In [None]:
# Keep modified RT2 matrices
traj = []

# A function to generate translational trajectory, modifies output_RT, output_RT_inv & gt_img
def modify_frame(x=0,y=0,z=0, rx=0, ry=0, rz=0):
    global traj, input_img, K, K_inv, input_RT, input_RT_inv, output_RT, output_RT_inv, gt_img, depth_img
    R_X = torch.Tensor([
      [ 1.0,  0.0, 0.0, x*255],
      [ 0.0,  math.cos(rx*3.1415/180), -math.sin(rx*3.1415/180)/255, y*255],
      [ 0.0,  math.sin(rx*3.1415/180)/255, math.cos(rx*3.1415/180), z],
      [ 0.0,  0.0, 0.0, 1.0]]).cuda()
    
    R_Y = torch.Tensor([
      [ math.cos(ry*3.1415/180),  0.0, math.sin(ry*3.1415/180)/255, 0.0],
      [ 0.0,  1.0, 0.0, 0.0],
      [ -math.sin(ry*3.1415/180)/255,  0.0, math.cos(ry*3.1415/180), 0.0],
      [ 0.0,  0.0, 0.0, 1.0]]).cuda()
    
    R_Z = torch.Tensor([
      [ math.cos(rz*3.1415/180),  -math.sin(rz*3.1415/180), 0.0, 0.0],
      [ math.sin(rz*3.1415/180),  math.cos(rz*3.1415/180), 0.0, 0.0],
      [ 0.0,  0.0, 1.0, 0.0],
      [ 0.0,  0.0, 0.0, 1.0]]).cuda()
    
    # Translate input_RT by given x,y,z
    output_RT_inv = (R_X@R_Y@R_Z).mm(input_RT_inv)

    # Perform projection to obtain a pseudo GT for the manipulation
    gt_img = model.pts_transformer.forward_justpts(
        input_img.unsqueeze(0),
        depth_img.unsqueeze(0),
        K.unsqueeze(0),
        K_inv.unsqueeze(0),
        input_RT.unsqueeze(0),
        input_RT_inv.unsqueeze(0),
        output_RT.unsqueeze(0),
        output_RT_inv.unsqueeze(0),
    )
    print("Projection with new RT:")
    gt_img_np = gt_img.squeeze(0).cpu().detach().numpy()
    plt.imshow(np.moveaxis(gt_img_np, 0, -1))
    plt.show()
    
    # Store matrices for the new view
    traj.append((output_RT, output_RT_inv, gt_img))

In [None]:
matrix = interact(modify_frame, 
                  x=(-1.0,1.0),
                  y=(-1.0,1.0), 
                  z=(-1.0,1.0), 
                  rx=(-10.0,10.0, 1), 
                  ry=(-10.0,10.0, 1),
                  rz=(-10.0,10.0, 1));

In [None]:
pprint(test_item["cam"]["RT1"])
# traj.pop(0) # Interactive slider sometimes first item has the same RT matrix as the input view (RT1), discard it
pprint(traj)

In [None]:
# It is important to pass this image to loader to apply same transforms that was applied during training.
# We have to make sure that test time images get the same transforms as train time to have meaningful results.
test_sampler = SubsetRandomSampler([test_item_idx])
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                          sampler=test_sampler, num_workers=0)

with torch.no_grad():
    # Get the frame (triggers get_item and transforms)
    test_item = next(iter(test_loader))
    test_item = to_cuda(default_batch_loader(test_item)) # List of contents in dict
    # For each different RT matrix perform a forward pass using GT depth
    for output_RT, output_RT_inv, gt_img in traj:
        out = model(test_item[0], # input_img
                    test_item[1], # K
                    test_item[2], # K_inv
                    test_item[3], # input_RT
                    test_item[4], # input_RT_inv
                    output_RT.unsqueeze(0), 
                    output_RT_inv.unsqueeze(0), 
                    gt_img,                     # Not used
                    test_item[-1]
                   )              # GT depth
        # Visualize prediction by network
        pred = out["PredImg"]
        pred_np = pred.squeeze().cpu().detach().numpy()
        plt.imshow(np.moveaxis(pred_np, 0, -1))
        plt.show()

In [None]:
'''
NN-Search usage: generate trajectories and let NN run. 
generate gif from images.
'''
min_idx = []
for rts in traj:
    norm1 = np.linalg.norm(rts[1].cpu().numpy())
    min_diff = 100
    for i,elem in enumerate(test_loader):
        # adjust these values to limit the range of the nearest neighbour search 
        if i >= 300 and i < 400:
            norm2 = np.linalg.norm(elem['cam']['RT2inv'])
            diff = np.absolute(np.absolute(norm1) - np.absolute(norm2))
            if diff < min_diff:
                min_diff = diff
                min_id = i
    min_idx.append(min_id)
    
for ids in min_idx:
    item = test_dataset.__getitem__(ids)
    img = item['image']
    plt.imshow(np.moveaxis(img, 0, -1))
    plt.show()

# Save the model

Save network with its weights to disk.

See torch.save function: https://pytorch.org/docs/stable/notes/serialization.html#recommend-saving-models 

Load again with `the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))`

In [None]:
def save_model(modelname, model):
    from pathlib import Path
    Path("../saved_models").mkdir(parents=True, exist_ok=True)
    # Might need to make sure, that the correct saved_results directory is chosen here.
    filepath = "../saved_models/" + modelname + ".pt"
    torch.save(model.state_dict(), filepath)

In [None]:
nvs_modelname = "nvs_" + id_suffix
save_model(nvs_modelname, model)

if train_with_discriminator:
    # Also save the discriminator - currently this can only be accessed through the solver (change it!)
    gan_modelname = "gan_" + id_suffix
    save_model(gan_modelname, solver.netD)

In [None]:
# LOAD MODEL AGAIN for verification purposes
# Should print: <All keys matched successfully> per each model if it works

new_model=False
# add a different model name to be loaded here
if new_model:
    nvs_modelname="nvs_2020-May-29_18-44-55_b9c02778-a1cb-11ea-82a9-5542432396e9"
    gan_modelname="gan_2020-May-29_18-44-55_b9c02778-a1cb-11ea-82a9-5542432396e9"
    
nvs_filepath = "../saved_models/" + nvs_modelname + ".pt"
print("NVS_Model loading: ", model.load_state_dict(torch.load(nvs_filepath)))

if train_with_discriminator:
    gan_filepath = "../saved_models/" + gan_modelname + ".pt"
    print("Discriminator loading: ", solver.netD.load_state_dict(torch.load(gan_filepath)))