In [1]:
# since pix2mesh reconstructs meshes in the pose to the input view, this realignes the models to a class-canonical alignment.

In [1]:
%load_ext autoreload
%autoreload 2
notebook_fixed_dir = False

In [2]:
# this cell can only be called once
import os
if not notebook_fixed_dir:
    os.chdir('..')
    notebook_fixed_dir = True
print(os.getcwd())

/home/svcl-oowl/brandon/research/sil_consistent_at_inference


In [3]:
import pprint
import pickle
import glob
import random
from pathlib import Path
import os

import torch
from tqdm.autonotebook import tqdm
from PIL import Image
import numpy as np
from pytorch3d.renderer import look_at_view_transform
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import pytorch3d
from pytorch3d.io import load_obj
from pytorch3d.io import save_obj
from pytorch3d.renderer.cameras import look_at_view_transform
from pytorch3d.renderer import Textures

from utils import utils

In [6]:
# test pytorch3d 
#test_mesh_path = "data/refinements/shapenet_occnet_refinements/gt_pose/02691156/batch_1_of_2/d18592d9615b01bbbc0909d98a1ff2b4.obj"
#test_mesh = utils.load_untextured_mesh(test_mesh_path, device)
#test_dist = 1
#test_elev = 10
#test_azim = 40
#test_R, test_T = look_at_view_transform(dist=test_dist, elev=test_elev, azim=test_azim)
#m = torch.tensor([[-1,0,0],
#                  [0,1,0],
#                  [0,0,-1]], dtype=torch.float)
#R_adjusted = test_R[0]@m
#save_obj("test_mesh_out.obj", test_mesh.verts_packed()@R_adjusted@R_adjusted.T, test_mesh.faces_packed())
#test_render = utils.render_mesh(test_mesh, test_R, test_T, device)
#plt.imshow(test_render[0,...,:3].detach().cpu().numpy())

In [4]:
def center_points(points):
    return points - torch.mean(points, 0)

def rot_x(theta, degrees=True):
    if degrees:
        theta = theta * (np.pi/180)
    rot_matrix = np.array([[1, 0, 0],
                           [0, np.cos(theta), -np.sin(theta)],
                           [0, np.sin(theta), np.cos(theta)]
                          ])
    return rot_matrix

def rot_y(theta, degrees=True):
    if degrees:
        theta = theta * (np.pi/180)
    rot_matrix = np.array([[np.cos(theta), 0, np.sin(theta)],
                           [0,1,0],
                           [-np.sin(theta), 0, np.cos(theta)]
                          ])
    return rot_matrix

def rot_z(theta, degrees=True):
    if degrees:
        theta = theta * (np.pi/180)
    rot_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                           [np.sin(theta), np.cos(theta), 0],
                           [0,0,1]
                          ])
    return rot_matrix

def get_mask_iou(render1_mask, render2_mask):
    intersection = torch.logical_and(render1_mask, render2_mask)
    union = torch.logical_or(render1_mask, render2_mask)
    IOU = torch.sum(intersection, dtype=torch.float) / torch.sum(union, dtype=torch.float)
    return IOU.item()

# rotates a mesh around the x-axis, finding the rotation matrix which causes it to be most "flat", and applying that to the input mesh
def make_flat(input_mesh, device, min_rot=-20, max_rot=20, num_rot=20):
    R, T = look_at_view_transform(dist=0.7, elev=0, azim=90)
    verts_rgb = torch.ones_like(input_mesh.verts_packed())[None] # (1, V, 3)
    textures = Textures(verts_rgb=verts_rgb)
    
    highest_iou = None
    flattest_mesh = None
    for rot_amt in np.linspace(min_rot, max_rot, num=num_rot):
        #rotated_verts = input_mesh.verts_packed()@rot_x(rot_amt)
        rotated_verts = input_mesh.verts_packed()@torch.tensor(rot_x(rot_amt), dtype=torch.float)
        rotated_mesh = pytorch3d.structures.Meshes(verts=rotated_verts.unsqueeze(0), faces=mesh.faces_padded(), textures=textures).to(device)
        rotated_render = utils.render_mesh(rotated_mesh, R, T, device)
        #plt.imshow(rotated_render[0,...,:3].detach().cpu().numpy())
        #plt.show()
        #plt.imshow(torch.flip(rotated_render[0,...,:3], [1]).detach().cpu().numpy())
        #plt.show()
        iou = get_mask_iou(rotated_render[0,...,3]>0, torch.flip(rotated_render[0,...,3], [1])>0)
        if highest_iou is None or iou > highest_iou:
            highest_iou = iou
            flattest_mesh = rotated_mesh
        
    return flattest_mesh

def normalize_pointclouds(points):
    max_vert_values = torch.max(points, 0).values
    min_vert_values = torch.min(points, 0).values
    max_width = torch.max(max_vert_values-min_vert_values)
    normalized_points = points * (1/max_width)

    return normalized_points

In [6]:
class_ids = ["04401088", "04530566"]
#class_ids = ["02691156", "02828884", "02933112", "02958343", "03001627", "03211117", "03636649", "03691459", "04090263", "04256520", "04379243", "04401088", "04530566"]
#class_ids = ["02828884", "02933112", "02958343", "03001627", "03211117", "03636649", "03691459", "04090263", "04256520", "04379243", "04401088", "04530566"]
#class_ids=["test"]
device = torch.device("cuda:0")
cpu_device = torch.device("cpu")
pix2mesh_rec_dir = "/home/svcl-oowl/brandon/research/Pixel2Mesh/rec_files/pytorch3d_shapenet_renders"

#meshes_dir= "/home/svcl-oowl/brandon/research/Pixel2Mesh/rec_files/testt"
#class_pose_dict_path = "/home/svcl-oowl/brandon/research/Pixel2Mesh/rec_files/pytorch3d_shapenet_renders/02691156/rgba/renders_camera_params.pt"

In [7]:
for class_name in class_ids:
    meshes_dir = os.path.join(pix2mesh_rec_dir, class_name)
    class_pose_dict_path = os.path.join(meshes_dir, "rgba", "renders_camera_params.pt")
    mesh_paths = [str(path) for path in Path(meshes_dir).rglob("*.obj")]
    pose_dict = pickle.load(open(class_pose_dict_path, "rb"))
    print(class_name)
    
    for mesh_path in tqdm(mesh_paths):
        if "_aligned" not in mesh_path:
            with torch.no_grad():
                mesh = utils.load_untextured_mesh(mesh_path, cpu_device)
                instance = mesh_path.split('/')[-1].replace(".obj","")
                pose = pose_dict[instance]
                R, T = look_at_view_transform(dist=pose["dist"], elev=pose["elev"], azim=pose["azim"])
                m = torch.tensor([[-1,0,0],
                          [0,1,0],
                          [0,0,-1]], dtype=torch.float)
                R_adjusted = R[0]@m

                partially_aligned_verts = center_points(mesh.verts_packed())
                partially_aligned_verts = partially_aligned_verts @ R_adjusted.T
                partially_aligned_mesh = pytorch3d.structures.Meshes(verts=partially_aligned_verts.unsqueeze(0), faces=mesh.faces_padded())
                aligned_mesh = make_flat(partially_aligned_mesh, device)
                

                #aligned_verts = aligned_verts @ torch.tensor(rot_y(-pose["azim"]), dtype=torch.float, device=device)
                #aligned_verts = aligned_verts @ torch.tensor(rot_x(-pose["elev"]), dtype=torch.float, device=device)
                #aligned_verts = aligned_verts @ torch.tensor(rot_x(45), dtype=torch.float, device=device)

                save_obj(mesh_path, normalize_pointclouds(aligned_mesh.verts_packed()), mesh.faces_packed())

04401088


HBox(children=(FloatProgress(value=0.0, max=210.0), HTML(value='')))


04530566


HBox(children=(FloatProgress(value=0.0, max=387.0), HTML(value='')))


