In [1]:
%load_ext autoreload
%autoreload 2
notebook_fixed_dir = False

In [2]:
# this cell can only be called once
import os
if not notebook_fixed_dir:
    os.chdir('..')
    notebook_fixed_dir = True
print(os.getcwd())

/home/svcl-oowl/brandon/research/sil_consistent_at_inference


In [11]:
import pprint
import glob
from pathlib import Path
import pickle
import random

import torch
import PIL
from PIL import Image
import numpy as np
from pytorch3d.renderer import look_at_view_transform
import matplotlib.pyplot as plt
import trimesh
from tqdm.autonotebook import tqdm

from utils import utils
from utils.brute_force_pose_est import brute_force_estimate_pose

In [14]:
# given an input directory of meshes, and an input directory of corresponding images,
# predicts the pose, renders the meshes at the predicted pose, and saves it together with the input image 
# side-by-side. The resulting output folder of images is useful for later visual evaluation of pose accuracy.

INPUT_IMG_DIR = "data/img_pix3d_chair/chair"
INPUT_MESH_DIR =  "data/onet_chair_pix3d_no_DA_simplified"
COMPARES_OUTPUT_DIR = "data_prep_tools/pose_est_compare/pix3d_chair_occnet"

# an optional processed folder, with precomputed poses. 
# If set as an empty string, poses will be recomputed from scratch
PROCESSED_MESH_DIR = "data/onet_chair_pix3d_no_DA_simplified_processed"

num_azims = 20
num_elevs = 20
num_dists = 40

device = torch.device("cuda:0")

In [15]:
if PROCESSED_MESH_DIR != "":
    cached_pred_poses = {}
    pred_pose_paths = list(Path(PROCESSED_MESH_DIR).rglob('pred_poses.p'))
    for pred_pose_path in pred_pose_paths:
        curr_cache = pickle.load(open(pred_pose_path, "rb"))
        cached_pred_poses = {**cached_pred_poses, **curr_cache}

In [16]:
def get_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height), color=(0, 255, 255))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

In [20]:
if not os.path.exists(COMPARES_OUTPUT_DIR):
    os.makedirs(COMPARES_OUTPUT_DIR)
instance_names = [path.split('/')[-1].replace(".obj", "") for path in glob.glob(os.path.join(INPUT_MESH_DIR, "*.obj"))]

for instance_name in tqdm(instance_names):
    img_path = os.path.join(INPUT_IMG_DIR, instance_name + ".png")
    mesh_path = os.path.join(INPUT_MESH_DIR, instance_name + ".obj")
    image = Image.open(img_path)
    mask = np.asarray(image)[:,:,3] > 0
    
    with torch.no_grad():
        mesh = utils.load_untextured_mesh(mesh_path, device)
        if PROCESSED_MESH_DIR == "":
            _, _, _, render, _ = brute_force_estimate_pose(mesh, mask, num_azims, num_elevs, num_dists, device, 8)
        else:
            cached_dist = cached_pred_poses[instance_name]["dist"]
            cached_elev = cached_pred_poses[instance_name]["elev"]
            cached_azim = cached_pred_poses[instance_name]["azim"]
            R, T = look_at_view_transform(cached_dist, cached_elev, cached_azim) 
            render = utils.render_mesh(mesh, R, T, device)[0]
    
    render_image = Image.fromarray((render[..., :3].cpu().numpy() * 255).astype(np.uint8)).resize((224, 224))
    input_image_rgb = Image.new("RGB", image.size, (255, 255, 255))
    input_image_rgb.paste(image, mask=image.split()[3])
    pred_pose_compare_img = get_concat_h(input_image_rgb, render_image)
    pred_pose_compare_img.save(os.path.join(COMPARES_OUTPUT_DIR,"{}.png".format(instance_name)))
    
    #plt.imshow(pred_pose_compare_img)
    #plt.show()


HBox(children=(FloatProgress(value=0.0, max=3838.0), HTML(value='')))


