In [None]:
%load_ext autoreload
%autoreload 2
notebook_fixed_dir = False

In [None]:
# this cell can only be called once
import os
if not notebook_fixed_dir:
    os.chdir('..')
    notebook_fixed_dir = True
print(os.getcwd())

In [None]:
import torch
import pickle
from tqdm import tqdm
from PIL import Image
import numpy as np
from pytorch3d.renderer import (
    look_at_view_transform
)
import matplotlib.pyplot as plt

from pose_est import brute_force_pose_est
import postprocess_dataset
from utils import utils

In [None]:
# displays meshes at the predicted pose
def show_meshes(input_dir_img, input_dir_mesh, meshes_group_name=""):
    pred_poses_path = os.path.join(input_dir_mesh, "pred_poses.p")
    cached_pred_poses = pickle.load(open(pred_poses_path, "rb"))

    for instance_name in cached_pred_poses:
        input_image = Image.open(os.path.join(input_dir_img, instance_name+".png"))
        with torch.no_grad():
            if meshes_group_name == "":
                mesh_filename_suffix = '.obj'
            else:
                mesh_filename_suffix = "_{}.obj".format(meshes_group_name)
            mesh = utils.load_untextured_mesh(os.path.join(input_dir_mesh, instance_name+mesh_filename_suffix), device)

        # rendering mesh at predicted pose
        pred_dist = cached_pred_poses[instance_name]['dist']
        pred_elev = cached_pred_poses[instance_name]['elev']
        pred_azim = cached_pred_poses[instance_name]['azim']
        R, T = look_at_view_transform(pred_dist, pred_elev, pred_azim) 
        mesh_rendered_at_pred_pose = utils.render_mesh(mesh, R, T, device)

        # visualizing
        plt.imshow(input_image)
        plt.show()
        plt.imshow(mesh_rendered_at_pred_pose[0, ..., :3].cpu().numpy())
        plt.show()

In [None]:
input_dir_img = "data/test_dataset/"
input_dir_mesh = "data/test_dataset/"
cfg_path = "configs/default.yaml"
gpu_num = 0
device = torch.device("cuda:"+str(gpu_num))

In [None]:
dataset_loss_info = postprocess_dataset.postprocess_data(input_dir_img, input_dir_mesh, cfg_path, gpu_num)

In [None]:
show_meshes(input_dir_img, input_dir_mesh, "postprocessed")

In [None]:
# show training loss info
for instance_name in dataset_loss_info:
    loss_info = dataset_loss_info[instance_name]
    loss_info.plot.line(x='iter', y='total_loss', title="{}".format(instance_name))
    #loss_info.plot.line(x='iter', y='sil_loss')
    #loss_info.plot.line(x='iter', y='img_sym_loss')
    #loss_info.plot.line(x='iter', y='vertex_sym_loss')
    #loss_info.plot.line(x='iter', y='l2_loss')
    #loss_info.plot.line(x='iter', y='lap_smoothness_loss')
    #loss_info.plot.line(x='iter', y='normal_consistency_loss')

In [None]:
show_meshes(input_dir_img, input_dir_mesh)