In [None]:
import os
import PIL
import argparse
from pathlib import Path

import cv2
import torch
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import MeshRasterizer, RasterizationSettings, MeshRenderer
from pytorch3d.renderer.blending import BlendParams

from colmap.scripts.python.read_write_model import read_model
from utils import SimpleShader, p3d_cam_from_colmap, p3dworld_to_colworld

In [None]:
PIL.Image.MAX_IMAGE_PIXELS = 933120000

obj_mesh_path = Path('/media/clementin/data/Dehazing/2015/textured_refined_dense_mesh.obj')
colmap_model_dir = Path('/media/clementin/data/Dehazing/2015/sparse')
image_name_list = os.listdir('/home/clementin/Documents/CVPR2023/results')
device = 'cpu'

In [None]:
print('Loading OBJ...')
mesh = load_objs_as_meshes([obj_mesh_path], device=device)
mesh._verts_list = [(p3dworld_to_colworld[:3, :3].to(device).T @ mesh._verts_list[0].T).T]

print('Loading COLMAP model...')
colmap_cameras, colmap_images, _ = read_model(colmap_model_dir)

In [None]:
for colmap_image in colmap_images.values():
    if colmap_image.name in image_name_list:
        colmap_camera = colmap_cameras[colmap_image.camera_id]
        p3d_cam = p3d_cam_from_colmap(colmap_image, colmap_camera, device=device)
        rasterizer = MeshRasterizer(
            cameras=p3d_cam,
            raster_settings=RasterizationSettings(
                image_size=(colmap_camera.height, colmap_camera.width),
                blur_radius=0.0,
                faces_per_pixel=1,
                perspective_correct=True
            )
        )
        renderer = MeshRenderer(
            rasterizer=rasterizer,
            shader=SimpleShader(blend_params=BlendParams(background_color=(0.0, 0.0, 0.0)))
        )
        image = renderer(mesh)
        res_name = colmap_image.name.split('.')
        res_name[1] += '_mesh'
        res_name = '.'.join(res_name)
        cv2.imwrite(str(Path('/home/clementin/Documents/CVPR2023/results') / res_name), np.uint8(image[0, :, :, :3].cpu().numpy()[:, :, ::-1] * 255))