In [1]:
%load_ext autoreload
%autoreload 2
notebook_fixed_dir = False

In [2]:
# this cell can only be called once
import os
if not notebook_fixed_dir:
    os.chdir('..')
    notebook_fixed_dir = True
print(os.getcwd())

/home/svcl-oowl/brandon/research/sil_consistent_at_inference


In [3]:
import pprint
import glob
from pathlib import Path
import pickle
import random
import os
import json

import torch
from PIL import Image
import numpy as np
from pytorch3d.renderer import look_at_view_transform
import matplotlib.pyplot as plt
import trimesh
import pytorch3d.transforms
import PIL
from scipy import ndimage
from tqdm.autonotebook import tqdm

import postprocess_dataset
from utils import utils
from utils import visualization_tools
from utils.eval_utils import eval_metrics
from utils.brute_force_pose_est import brute_force_estimate_pose, brute_force_estimate_dist, brute_force_estimate_dist_cam_pos, rgba_obj_in_frame
#from evaluation import compute_iou_2d, compute_iou_2d_given_pose, compute_iou_3d, compute_chamfer_L1

In [4]:
# reverse engineered from 
# https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/renderer/cameras.html#camera_position_from_spherical_angles
def cart_to_spherical(cart_coords):
    x = cart_coords[0]
    y = cart_coords[1]
    z = cart_coords[2]
    
    dist = np.sqrt(x**2 + y**2 + z**2)
    elev = np.arcsin(y)/dist
    #azim = np.arctan(x/z)
    azim = np.arctan(x/z) + np.pi
    return dist, elev, azim
    
def spherical_to_cart(dist, elev, azim):
    x = dist * np.cos(elev) * np.sin(azim)
    y = dist * np.sin(elev)
    z = dist * np.cos(elev) * np.cos(azim)
    print(x,y,z)

In [5]:
# assumes cam_pos is a vector of numbers
def find_best_fitting_cam_pos(mesh, cam_pos, num_dists, device, batch_size=8):
    # normalizing
    with torch.no_grad():
        eyes = [cam_pos*i for i in np.geomspace(0.005, 2, num_dists)]
        R, T = look_at_view_transform(eye=eyes)
        meshes = mesh.extend(num_dists)
        renders = utils.render_mesh(meshes, R, T, device)
        
        rendered_image_fits = []
        for i in range(renders.shape[0]):
            rendered_image_fits.append(rgba_obj_in_frame(renders[i].cpu().numpy()))

        # choose closest cam_pos, whose rendered image will fit completely in the frame
        i = 0
        while not rendered_image_fits[i]:
            i+=1

        best_cam_pos = eyes[i]
        
    return best_cam_pos

def get_iou(mask1, mask2):
    intersect = mask1 * mask2 # Logical AND
    union = mask1 + mask2 # Logical OR
    IOU = intersect.sum()/float(union.sum())
    return IOU

In [6]:
def process_pix3d_image(curr_info_dict, visualize=False, inplane=True, use_spherical=True):
    #pprint.pprint(curr_info_dict)
    img_path = os.path.join(PIX3D_PATH, curr_info_dict["img"])
    mesh_path = os.path.join(PIX3D_PATH, curr_info_dict["model"])
    mask_path = os.path.join(PIX3D_PATH, curr_info_dict["mask"])
    cam_pos = curr_info_dict["cam_position"]
    theta = curr_info_dict["inplane_rotation"]
    img = Image.open(img_path)
    mesh = utils.load_untextured_mesh(mesh_path, device)

    up_axis = [0,1,0] 
    if inplane:
        theta = curr_info_dict["inplane_rotation"]
        inplane_R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta),0],[0,0,1]])
        up_axis = (inplane_R@np.array([up_axis]).T).T[0]

    # obtaining GT pose in spherical coordinates
    cam_pos = np.array(cam_pos)/np.sqrt(cam_pos[0]**2+cam_pos[1]**2+cam_pos[2]**2)
    dist, elev, azim = cart_to_spherical(cam_pos)
    azim = azim * (180/np.pi) 
    elev = elev * (180/np.pi) 
    R, T = look_at_view_transform(dist,elev,azim, up=[up_axis])
    spherical_based_render = utils.render_mesh(mesh, R, T, device, img_size=img_size)

    # double checking spherical coordinates conversion to see it it matches camera position based pose
    # Note sure why this is necessary
    R, T = look_at_view_transform(eye=[cam_pos], up=[up_axis])
    cam_based_render = utils.render_mesh(mesh, R, T, device, img_size=img_size)
    render_comparision_iou = get_iou(spherical_based_render[0,...,3]>0, cam_based_render[0,...,3]>0)
    flipped=False
    if render_comparision_iou.item() < 0.95:
        azim += 180
        R, T = look_at_view_transform(dist,elev,azim, up=[up_axis])
        spherical_based_render = utils.render_mesh(mesh, R, T, device, img_size=img_size)
        flipped=True
    
    mask = Image.open(mask_path)
    mask_bbox = curr_info_dict["bbox"]
    img_masked_rgba = Image.composite(Image.new("RGBA", curr_info_dict['img_size']), img.convert('RGBA'), PIL.ImageOps.invert(mask))
    img_masked_rgba = img_masked_rgba.crop(mask_bbox)

    objs = ndimage.find_objects(spherical_based_render[0,...,3].detach().cpu().numpy()>0.2)
    # upper left, lower right
    #render_bbox = [objs[0][0].start, objs[0][1].start, objs[0][0].stop, objs[0][1].stop]
    render_bbox = [objs[0][1].start, objs[0][0].start, objs[0][1].stop, objs[0][0].stop]
    render_bbox_width = render_bbox[2] - render_bbox[0]
    render_bbox_height = render_bbox[3] - render_bbox[1]
    img_masked_rgba_resized = img_masked_rgba.resize((render_bbox_width, render_bbox_height))
    processed_img = Image.new("RGBA", (img_size, img_size))
    processed_img.paste(img_masked_rgba_resized, box=render_bbox[:2])

    final_iou = get_iou(spherical_based_render[0, ..., 3].detach().cpu().numpy() > 0, np.array(processed_img)[...,3]>0)
    if visualize:
        plt.imshow(img_masked_rgba)
        plt.show()
        plt.imshow(spherical_based_render[0, ..., :3].detach().cpu().numpy())
        plt.show()
        plt.imshow(processed_img)
        plt.show()
    return processed_img, dist, elev, azim, final_iou

In [7]:
pix3d_class = "chair"

PIX3D_PATH = "/home/svcl-oowl/dataset/pix3d"
PROCESSED_PIX3D_PATH = "data/pix3d_images_processed"
device = torch.device("cuda:0")
img_size = 224
blacklist = ["img/table/0045", "img/table/1749"]
recompute=False

processed_class_output_dir = os.path.join(PROCESSED_PIX3D_PATH, pix3d_class)
if not os.path.exists(processed_class_output_dir):
    os.makedirs(processed_class_output_dir)
pose_dict_path = os.path.join(processed_class_output_dir, "renders_camera_params.pt")
iou_dict_path = os.path.join(processed_class_output_dir, "iou_info.pt")

In [8]:
with open(os.path.join(PIX3D_PATH, "pix3d.json")) as f:
    pix3d_data_json = json.loads(f.read())
# convert list of dicts into a dict (keyed by image path) of dicts
pix3d_data_dict = { entry["img"].split('.')[0]:entry for entry in pix3d_data_json}

In [9]:
if os.path.exists(pose_dict_path):
    pose_dict = pickle.load(open(pose_dict_path, "rb"))
else:
    pose_dict = {}
    
if os.path.exists(iou_dict_path):
    iou_dict = pickle.load(open(iou_dict_path, "rb"))
else:
    iou_dict = {}

class_instance_names = [instance_name for instance_name in pix3d_data_dict.keys() if pix3d_class in instance_name]
for instance_name in tqdm(class_instance_names):
    instance_class_id = instance_name.split('/')[-1]
    processed_img_path = os.path.join(processed_class_output_dir, "{}.png".format(instance_class_id))
    if instance_name not in blacklist and (recompute or not os.path.exists(processed_img_path)):
        print(instance_name)
        processed_img, dist, elev, azim, iou = process_pix3d_image(pix3d_data_dict[instance_name], visualize=False)
        iou_dict[instance_class_id] = iou
        pose_dict[instance_class_id] = {"azim": azim, "elev": elev, "dist": dist}
        pickle.dump(iou_dict, open(iou_dict_path, "wb"))
        pickle.dump(pose_dict, open(pose_dict_path, "wb"))
        processed_img.save(processed_img_path)
        

HBox(children=(FloatProgress(value=0.0, max=3839.0), HTML(value='')))

img/chair/1151
img/chair/1152
img/chair/1153
img/chair/1154
img/chair/1155
img/chair/1156
img/chair/1157
img/chair/1158
img/chair/1159
img/chair/1160
img/chair/1161
img/chair/1162
img/chair/1163
img/chair/1164
img/chair/1165
img/chair/1166
img/chair/1167
img/chair/1168
img/chair/1169
img/chair/1170
img/chair/1171
img/chair/1172
img/chair/1173
img/chair/1174
img/chair/1175
img/chair/1176
img/chair/1177
img/chair/1178
img/chair/1179
img/chair/1180
img/chair/1181
img/chair/1182
img/chair/1183
img/chair/1184
img/chair/1185
img/chair/1186
img/chair/1187
img/chair/1188
img/chair/1189
img/chair/1190
img/chair/1191
img/chair/1192
img/chair/1193
img/chair/1194
img/chair/1195
img/chair/1196
img/chair/1197
img/chair/1198
img/chair/1199
img/chair/1200
img/chair/1201
img/chair/1202
img/chair/1203
img/chair/1204
img/chair/1205
img/chair/1206
img/chair/1207
img/chair/1208
img/chair/1209
img/chair/1210
img/chair/1211
img/chair/1212
img/chair/1213
img/chair/1214
img/chair/1215
img/chair/1216
img/chair/

KeyboardInterrupt: 

In [None]:
pprint.pprint(pickle.load(open(iou_dict_path, "rb")))
pprint.pprint(pickle.load(open(pose_dict_path, "rb")))