## 0. Install and Import modules

https://towardsdatascience.com/how-to-render-3d-files-using-pytorch3d-ef9de72483f8

https://github.com/codingforpleasure/collection_of_pytorch_helpful_stuff

https://docs.pyvista.org/examples/01-filter/voxelize.html

Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:

In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
try:
    import pytorch3d
except ModuleNotFoundError:
    print("pytorch3d missing")

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, load_obj

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    TexturesUV,
    TexturesVertex
)

sys.path.append(os.path.abspath(''))
from utils import image_grid

### 1. Load a mesh and texture file

Load an `.obj` file and its associated `.mtl` file and create a **Textures** and **Meshes** object. 

**Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes. 

**TexturesUV** is an auxiliary datastructure for storing vertex uv and texture maps for meshes. 

**Meshes** has several class methods which are used throughout the rendering pipeline.

In [None]:
# Setup
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Set paths
DATA_DIR = "./data"
obj_filename = os.path.join(DATA_DIR, "cow_mesh/cow.obj")

# Load obj file
mesh = load_objs_as_meshes([obj_filename], device=device)

#### Let's visualize the texture map

## 2. Mesh visualization 
If you only want to visualize a mesh, you don't really need to use a differentiable renderer - instead we support plotting of Meshes with plotly. For these Meshes, we use TexturesVertex to define a texture for the rendering.
`plot_meshes` creates a Plotly figure with a trace for each Meshes object. 

In [None]:
verts, faces_idx, _ = load_obj(obj_filename)
faces = faces_idx.verts_idx

# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = TexturesVertex(verts_features=verts_rgb.to(device))

# Create a Meshes object
mesh = Meshes(
    verts=[verts.to(device)],   
    faces=[faces.to(device)],
)
#  Shows camera positions
R, T = look_at_view_transform(2.7, 0, [0,90,180,270]) # 2 camera angles, front and back
# Any instance of CamerasBase works, here we use FoVPerspectiveCameras
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
fig = plot_scene({
    "subplot1_title": {
        "mesh_trace_title": mesh,
        "cameras0": cameras[0],
        "cameras1": cameras[1],
        "cameras2": cameras[2],
        "cameras3": cameras[3],
    },
})
fig.show()

In [None]:
from pytorch3d.io import IO

# create a batch of meshes, and offset one to prevent overlap
mesh_batch = Meshes(
    verts=[verts.to(device), (5*verts+5).to(device)],   
    faces=[faces.to(device), (faces).to(device)]
)
print(mesh_batch.verts_padded().reshape([1,2*2930, 3]).shape)
print(mesh.verts_padded().shape)

# print(mesh_batch.faces_padded().reshape([2*5856, 3]))

mesh_merged = Meshes(
    verts= mesh_batch.verts_padded().reshape([1,2*2930, 3]),   
    faces=mesh_batch.faces_padded().reshape([1,2*5856, 3])
)

# IO().save_mesh(mesh_merged, "moo.obj")

# Render the plotly figure
# plot batch of meshes in different traces
fig = plot_scene({
    "cow_plot1": {
        "cow_mesh1": mesh_batch[0],
        "cow_mesh2": mesh_batch[1]
    }
},
    xaxis={"backgroundcolor":"rgb(200, 200, 230)"},
    yaxis={"backgroundcolor":"rgb(230, 200, 200)"},
    zaxis={"backgroundcolor":"rgb(200, 230, 200)"}, 
    axis_args=AxisArgs(showgrid=True))

fig.show()


In [None]:
# from utils.voxeliser.voxelizer import read_file_and_reshape_stl, voxelize

# # path to the stl file
# input_path = 'data\cow_mesh\cow.obj'
# # number of voxels used to represent the largest dimension of the 3D model
# resolution = 100

# # read and rescale
# mesh, bounding_box = read_file_and_reshape_stl(input_path, resolution)
# # print(len(mesh[0]))
# # create voxel array
# voxels, bounding_box = voxelize(mesh, bounding_box)

# # print(voxels)

In [None]:
import numpy as np

import pyvista as pv
pv.set_jupyter_backend('pythreejs')
cowmesh = pv.read('data\cow_mesh\cow.obj') 


In [None]:
voxels = pv.voxelize(cowmesh, density=cowmesh.length / 200, check_surface=False)
p = pv.Plotter(notebook=True, window_size=(800, 800))
p.add_mesh(voxels, color=True, show_edges=True, opacity=0.5 )
p.add_mesh(cowmesh, color="lightblue", opacity=0.5)
p.add_mesh(pv.Box(), color="tan", opacity=0.5)
p.show()

In [None]:
import pyvista as pv
import numpy as np
import vtk
from pyvista import examples
pv.set_jupyter_backend('pythreejs')

camera = pv.Camera()
near_range = 0.3
far_range = 0.8
camera.clipping_range = (near_range, far_range)
unit_vector = np.array(camera.direction) / np.linalg.norm(
    np.array([camera.focal_point]) - np.array([camera.position])
)

frustum = camera.view_frustum(1.0)
position = camera.position
focal_point = camera.focal_point

line = pv.Line(position, focal_point)
# model = pv.Box(bounds=(-0.1, 0.1, -0.1, 0.1, -0.1, 0.1))
model = pv.read('data\cow_mesh\cow.obj') 

voxels = pv.voxelize(model, density=model.length / 200, check_surface=False)

xyz = camera.position + unit_vector * 0.6 - np.mean(model.points, axis=0)
model.translate(xyz, inplace=True)
voxels.translate(xyz, inplace=True)
pl = pv.Plotter(window_size=(1200, 800))

pl.add_mesh(voxels, color=True, show_edges=True, opacity=0.5 )
pl.subplot(0, 0)
pl.add_mesh(model, opacity=0.5)
pl.add_mesh(frustum, style="wireframe")
pl.add_mesh(line, color="b")

pl.camera.position = (1.1, 1.5, 0.0)
pl.camera.focal_point = (0.2, 0.3, 0.3)
pl.camera.up = (0.0, 1.0, 0.0)
pl.camera.zoom(0.5)

# Changing pov
# camera.zoom(0.5)
# pl.camera = camera
pl.show()