# Project imports

In [None]:
"""
All needed imports included here
"""
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import matplotlib.pyplot as plt
import trimesh
import torch
import skimage
from torchvision import transforms
import pytorch_lightning as pl
from Utils.Visualize import visualize_mesh, visualize_occupancy, visualize_pointcloud,visualize_sdf

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Data Loading step

Data is available at the following links: <br>
    Images: https://mega.nz/file/bN9iRRTB#kd6-1FDc5f3cxd69Ku0btEXgAjGwtXXHCkwRGHlnHn0 <br>
    Meshes: https://drive.google.com/drive/folders/1BtkAeuMpAC9gKexyoEpu4baRxiga5vqa
    

In [None]:
"""
Create data loaders and augmentations needed here
"""
from Data.Image2MeshDataLoader import Image2MeshDataLoader
image2mesh_train_dataset = Image2MeshDataLoader(images_path = "Assets/Data/image2mesh/train/images/",
                                meshes_path = "Assets/Data/image2mesh/train/meshes/",
                                image_size= 256, voxel_dims = (32,32,32), sample_rate = 8192)

image2mesh_val_dataset = Image2MeshDataLoader(images_path = "Assets/Data/image2mesh/val/images/",
                                meshes_path = "Assets/Data/image2mesh/val/meshes/",
                                image_size= 256, voxel_dims = (32,32,32), sample_rate = 8192)

In [None]:
print("number of training datapoints is:", len(image2mesh_train_dataset))
print("number of validation datapoints is:", len(image2mesh_val_dataset))

print("Images have shapes:", image2mesh_train_dataset[0][0].shape)
print("Meshes have shapes:",image2mesh_train_dataset[0][1].shape)


# Reconstruction Networks

In [None]:
"""
creation, training, and testing of the image2mesh reconstruction networks
"""
from Networks.Image2Mesh import Image2Voxel

model = Image2Voxel()

In [None]:
"""
Test out forward pass and ensure output sizes
"""
X = torch.rand(10,3,128,128)
pred = model(X)
pred.shape


In [None]:
"""
Viewing some of the dataset datapoints to see what the model is training on
"""
for _ in range(3):
    idx = np.random.randint(0,len(image2mesh_train_dataset))
    image = np.array(image2mesh_train_dataset[idx][0].permute(1,2,0))
    mesh = np.array(image2mesh_train_dataset[idx][1])
    # ToDo
    plt.figure()
    plt.imshow(image)
    plt.show()
    visualize_occupancy(mesh)

In [None]:
def image2meshScoreFunction(preds,labels,thresh=0.5):
    predicted_vals = preds.clone()
    with torch.no_grad():
        predicted_vals[predicted_vals<thresh] = 0
        predicted_vals[predicted_vals>=thresh] = 1

        acc = (labels == predicted_vals).float().mean()
    return acc.item()

In [None]:
def image2meshLossFunction(preds,labels,):
        # Give higher weight to False negatives
        filled_fraction_in_batch = (labels.sum() / labels.numel()).item()
        # clamp the fraction, otherwise we start to get many false positives
        filled_fraction_in_batch = max(0.03, filled_fraction_in_batch)
        weights = torch.empty(labels.shape)
        weights[labels < 0.5] = filled_fraction_in_batch
        weights[labels >= 0.5] = 1 - filled_fraction_in_batch
        weights = weights.to(device)

        reconstruction_loss = torch.nn.BCELoss(reduction="none")(preds, labels)
        reconstruction_loss = (reconstruction_loss * weights).mean()

        l1_loss = torch.nn.L1Loss()(preds,labels)

        loss = reconstruction_loss + l1_loss

        return loss

In [None]:
from Networks.Trainer import Trainer
from Networks.Image2Mesh import Image2Voxel

path_prefix = "Assets/Models/Image2Mesh/"

try:
    os.mkdir(path_prefix)
except:
    pass

model = Image2Voxel()
trainer = Trainer( 
        model = model,
        model_save_path = f"{path_prefix}/image2mesh.model",
        loss_function = image2meshLossFunction,
        optimizer = torch.optim.Adam,
        batch_size = 32,
        device = device,
        training_dataset = image2mesh_train_dataset,
        validation_dataset = image2mesh_val_dataset,
        score_function = image2meshScoreFunction
    )
try:
    trainer.fit(epochs=100,learning_rate=0.01)
except KeyboardInterrupt:
    print("\nStopped by user saving last file")
    model.save(f"{path_prefix}/Keyboard_interrupt_temp.model")
model.to('cpu')

# Visualizing Model predictions

In [None]:
model = torch.load(f"{path_prefix}/image2mesh.model")
model.to("cpu")

In [None]:

idx = np.random.randint(0,len(image2mesh_train_dataset))
image = np.array(image2mesh_train_dataset[idx][0].permute(1,2,0))
gt_mesh = np.array(image2mesh_train_dataset[idx][1])
pred_mesh = model(image2mesh_train_dataset[idx][0].unsqueeze(0)).detach().numpy()

pred_thresh = 0.1
pred_mesh[pred_mesh<pred_thresh] = 0
pred_mesh[pred_mesh>=pred_thresh] = 1
# ToDo
plt.figure()
plt.imshow(image)
plt.show()
visualize_occupancy(gt_mesh)
visualize_occupancy(pred_mesh)

# Purifying predicted Meshes

In [None]:
"""
Code to purify meshes predicted by the previous networks to be used in the retrieval step
"""

# Mesh Encoding

In [None]:
"""
AutoEncoder Models and/or different techniques used to encode the mesh to a smaller dimensions
"""

# Mesh Retreival Networks

In [None]:
"""
Models/Techniques to use the previous encoding steps to retreive objects from a specified database
"""

# Inference and Full Testing

In [None]:
"""
Testing the entire pipeline implemented with added visualizations and discussions.
"""

# Citations

[1].....