In [None]:
import torch
import matplotlib.pyplot as plt
import json
import os

In [None]:
torch.set_printoptions(precision=10)
PATH = "."

In [None]:
%run models/rendernet.py
%run data_loaders/scannet_render_loader.py

In [None]:
# Load a trained model using its best weights unless otherwise specified
def load_trained_model(train_id, checkpoint_name='model_best'):
    model_path = os.path.join(PATH, 'saved/models/DNR', train_id)

    # Load config file
    config_file = os.path.join(model_path, "config.json")
    if config_file:
        with open(config_file, 'r') as f:
            config = json.load(f)

    # Load model weights
    checkpoint_path = os.path.join(model_path, checkpoint_name) + '.pth'
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    # Load model with parameters from config file
    model = RenderNet(config['arch']['args']['texture_size'],
                     config['arch']['args']['texture_depth'])
    
    # Assign model weights and set to eval (not train) mode
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    
    return model

In [None]:
# load all models by train id
def load_trained_models(train_ids):
    models = {}
    for train_id in train_ids:
        models[train_id] = load_trained_model(train_id)
    
    return models

In [None]:
# Cereate a libtorch script file containing the model that can be loaded into C++
def create_libtorch_script(model, train_id, checkpoint_name='model_best'):
    sm = torch.jit.script(model)
    model_script_name = 'DNR-{}-{}_model.pt'.format(train_id, checkpoint_name)
    model_script_path = os.path.join(PATH, 'libtorch-models', model_script_name)
    sm.save(model_script_path)
    print(model_script_path)

In [None]:
# Visualize a model prediction
def generate_images(model, test_input, tar):
    prediction = model(test_input)#, training=True)
    plt.figure(figsize=(20,20))
    
    _, h, w, c = test_input.shape
    test_input_color = torch.zeros((h, w, 3))#, dtype=type(test_input))
    test_input_color = test_input[:,:,:, 0]
    tar = tar.permute(0, 2, 3, 1)
    prediction = prediction.detach().permute(0, 2, 3, 1)

    display_list = [test_input_color[0].numpy(), tar[0].numpy(), prediction[0].numpy()]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

In [None]:
def generate_comparison(display_images, title, title_color):
    for i, image in enumerate(display_images):
        display_images[i] = image.permute(0, 2, 3, 1)
        
    # Should assert that rows * cols == len(title) == len(display_images)
    img_per_row = 2 # 3
    rows, cols = np.ceil(len(display_images) / img_per_row), np.min([len(display_images), img_per_row])
    plt.figure(figsize=(35 * cols,30 * rows))
    for i in range(len(display_images)):
        plt.subplot(rows, cols, i+1)
        plt.title(title[i], color=title_color[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_images[i][0].numpy() * 0.5 + 0.5)
        #plt.axis('off')
    plt.show()

In [None]:
##-- Execute code below --##

In [None]:
# List train ids here #
#train_ids = ['0618_155900', '0619_015820', '0618_074752']
#load_trained_models(train_ids)
train_ids = ['0625_181653', '0625_114333', '0626_000812']
#models = [None]*3
models = {}
models[train_ids[0]]  = load_trained_model('0625_181653', 'checkpoint-epoch120')
models[train_ids[1]]  = load_trained_model('0625_114333')
models[train_ids[2]]  = load_trained_model('0626_000812')

In [None]:
# Show a single validation input, ground truth and preducted sample from the first model #
print(len(train_ids))
loader = UVDataLoader('data', 1, True, 6).split_validation(size=(256, 342))
for batch_idx, (data, target) in enumerate(loader):
    for train_id in train_ids:
        print('Train ID:', train_id)
        model = models[train_id]  
        generate_images(model, data, target)
    break

In [None]:
# Show a validation sample prediction for each model #

display_images = []
# Load from the validatiom dataset
loader = UVDataLoader('data', 1, True, 6).split_validation()
for batch_idx, (data, target) in enumerate(loader):
    # Add target image used to generate predictions
    display_images.append(target)

    # Add predictions
    for train_id in train_ids:
        # Get the trained model
        model = models[train_id]
        
        # Make a prediction using the model
        prediction = model(data)
        prediction = prediction.detach()
        display_images.append(prediction)
    break

# Plot results
title = ['Ground Truth', 'Prediction Exp 1: 521 train, 104 val',
         'Prediction Exp 1: 1042 train, 208 val',
         'Prediction Exp 1: 2083 train, 417 val']
title_color = ['black', 'magenta', 'green', 'blue']

generate_comparison(display_images, title, title_color)

In [None]:
# Create a libtorch script file from model #
train_id = '0626_000812'
print(train_id)
create_libtorch_script(models[train_id], train_id)