# Project imports

In [None]:
"""
All needed imports included here
"""
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import matplotlib as plt
import trimesh
import torch
import skimage
from torchvision import transforms
import pytorch_lightning as pl



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Data Loading step

In [None]:
"""
Create data loaders and augmentations needed here
"""
from Data.Image2MeshDataLoader import Image2MeshDataLoader
image2mesh_train_dataset = Image2MeshDataLoader(images_path = "Assets/Data/Image2Mesh/train/images/",
                                meshes_path = "Assets/Data/Image2Mesh/train/meshes/",
                                image_size= 256, voxel_dims = (32,32,32), sample_rate = 8192)

image2mesh_val_dataset = Image2MeshDataLoader(images_path = "Assets/Data/Image2Mesh/val/images/",
                                meshes_path = "Assets/Data/Image2Mesh/val/meshes/",
                                image_size= 256, voxel_dims = (32,32,32), sample_rate = 8192)

In [None]:
print("number of training datapoints is:", len(image2mesh_train_dataset))
print("Images have shapes:", image2mesh_train_dataset[0][0].shape)
print("Meshes have shapes:",image2mesh_train_dataset[0][1].shape)


# Reconstruction Networks

In [None]:
"""
creation, training, and testing of the image2mesh reconstruction networks
"""
from Networks.Image2Mesh import Image2Voxel

model = Image2Voxel()

In [None]:
"""
Test out forward pass and ensure output sizes
"""
X = torch.rand(256,3,128,128)
pred = model(X)
pred.shape


In [None]:
"""
Viewing some of the dataset datapoints to see what the model is training on
"""
# ToDo

In [None]:
def image2meshScoreFunction(preds,labels,thresh=0.5):
    predicted_vals = preds.clone()
    with torch.no_grad():
        predicted_vals[predicted_vals<thresh] = 0
        predicted_vals[predicted_vals>=thresh] = 1

        acc = (labels == predicted_vals).float().mean()
    return acc.item()

In [None]:
from Networks.Trainer import Trainer

try:
    os.mkdir("Assets/Models/Image2Mesh/")
except:
    pass

trainer = Trainer( 
        model = model,
        model_save_path = "Assets/Models/Image2Mesh/image2mesh.model",
        loss_function = torch.nn.L1Loss(),
        optimizer = torch.optim.Adam,
        batch_size = 8,
        device = device,
        training_dataset = image2mesh_train_dataset,
        validation_dataset = image2mesh_val_dataset,
        score_function = image2meshScoreFunction
    )
try:
    trainer.fit(epochs=2,learning_rate=0.001)
except KeyboardInterrupt:   
    print("Stopped by user saving last file")
    model.save("Assets/Models/Image2Mesh/Keyboard_interrupt_temp.model")

# Purifying predicted Meshes

In [None]:
"""
Code to purify meshes predicted by the previous networks to be used in the retrieval step
"""

# Mesh Encoding

In [None]:
"""
AutoEncoder Models and/or different techniques used to encode the mesh to a smaller dimensions
"""

# Mesh Retreival Networks

In [None]:
"""
Models/Techniques to use the previous encoding steps to retreive objects from a specified database
"""

# Inference and Full Testing

In [None]:
"""
Testing the entire pipeline implemented with added visualizations and discussions.
"""

# Citations

[1].....