In [None]:
# Imports and set torch device
import numpy as np
import meshplot as mp
import torch
import sys
from scripts.helper_functions import segment
import kaolin as kal
import matplotlib.pyplot as plt
from meshseg.models.GLIP.glip import GLIPModel
import igl
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

print('Torch will run on:', device)

# obj_path = 'data/demo/penguin.obj'
# obj_path = 'data/FAUST/scans/tr_scan_000.obj'
# obj_path = 'data/demo/bed.obj'
obj_path = 'data/demo/lamp.obj'

GM = GLIPModel()

In [None]:
# Read mesh
mesh = kal.io.obj.import_mesh(
    obj_path,
    with_normals=True,
    with_materials=False,
)

vertices_tensor = mesh.vertices.to(device)
faces_tensor = mesh.faces.to(device)
# face_areas = kal.ops.mesh.face_areas(
#     vertices_tensor.unsqueeze(0),
#     faces_tensor
# ).view(len(mesh.faces))

vertices = vertices_tensor.detach().cpu().numpy()
faces = faces_tensor.detach().cpu().numpy()

print('Number of vertices: ', vertices.shape[0])
print('Number of faces: ', faces.shape[0])

In [None]:
# Visualize mesh
mp.plot(vertices, faces)
plt.show()

In [None]:
# Show a sample of rendered images
gen = torch.Generator()
gen.seed()
std = 4
center_elev = 0
center_azim = 3.14
elev = torch.randn(1, generator=gen) * np.pi / std + center_elev
azim = torch.randn(1, generator=gen) * 2 * np.pi / std + center_azim
r = 2
x = r * torch.cos(elev) * torch.cos(azim)
y = r * torch.sin(elev)
z = r * torch.cos(elev) * torch.sin(azim)
pos = torch.tensor([x, y, z]).unsqueeze(0).to(device)
center = vertices_tensor.mean(dim = 0).to(device)
look_at = center-pos
# look_at = -pos
direction = torch.tensor([0.0, 1.0, 0.0]).unsqueeze(0).to(device)
camera_transform = kal.render.camera.generate_transformation_matrix(pos, look_at, direction).to(device)
lights=torch.tensor([1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])
lights = lights.unsqueeze(0).to(device)
background = torch.tensor([0.0, 0.0, 0.0]).to(device)

(
    face_vertices_camera,
    face_vertices_image,
    face_normals,
) = kal.render.mesh.prepare_vertices(
    mesh.vertices.to(device),
    mesh.faces.to(device),
    kal.render.camera.generate_perspective_projection(np.pi / 3).to(device),
    camera_transform=camera_transform,
)

face_attributes = kal.ops.mesh.index_vertices_by_faces(
            torch.ones(1, len(mesh.vertices), 3).to(device)
            * torch.tensor([0.5, 0.5, 0.5]).unsqueeze(0).unsqueeze(0).to(device),
            faces_tensor,
        ).to(device)
face_attributes = [
                face_attributes,  # Colors
                torch.ones((1, faces.shape[0], 3, 1), device=device),  # hard seg. mask
            ]
image_features, soft_mask, face_idx = kal.render.mesh.dibr_rasterization(
    1024,
    1024,
    face_vertices_camera[:, :, :, -1],
    face_vertices_image,
    face_attributes,
    face_normals[:, :, -1],
)
image_features, mask = image_features
image_normals = face_normals[:, face_idx].squeeze(0)

image = torch.clamp(image_features, 0.0, 1.0)

image_lighting = kal.render.mesh.spherical_harmonic_lighting(
    image_normals, lights
).unsqueeze(0)

image = image * image_lighting.repeat(1, 3, 1, 1).permute(
    0, 2, 3, 1
).to(device)


background_mask = torch.zeros(image.shape).to(device)
mask = mask.squeeze(-1)
background_idx = torch.where(mask == 0)
assert torch.all(
    image[background_idx] == torch.zeros(3).to(device)
)  # Remove it may be taking a lot of time
background_mask[
    background_idx
] = background  # .repeat(background_idx[0].shape)
image = torch.clamp(image + background_mask, 0.0, 1.0).squeeze()

plt.imshow(image.cpu().numpy())
plt.show()

image = (image * 255).cpu().numpy().astype(np.uint8)

In [None]:
gmPrediction = GM.predict(image, 'body, lamp')
print(gmPrediction[1])
plt.figure(figsize=[8, 6])
plt.imshow(gmPrediction[0])
plt.show()

In [None]:
segment(
    'configs/demo/lamp.yaml',
    'lamp.obj',
    'outputs/demo/ABO/lamp/'
    )

In [None]:
import json
output_file = open('./outputs/demo/ABO/lamp/face_preds.json')
output = np.array(json.load(output_file))
colors = np.unique(output, return_inverse=True)[1]
print('Prompts: ', np.unique(output))
mp.plot(vertices, faces, colors)
plt.show()