# Project imports

In [1]:
"""
All needed imports included here
"""
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import matplotlib as plt
import trimesh
import torch
import skimage
from torchvision import transforms
import pytorch_lightning as pl

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Data Loading step

In [3]:
"""
Create data loaders and augmentations needed here
"""
from Data.Image2MeshDataLoader import Image2MeshDataLoader
image2mesh_train_dataset = Image2MeshDataLoader(images_path = "E:/UNI/Masters/ML3D/Project/image to mesh/train/images/",
                                meshes_path = "E:/UNI/Masters/ML3D/Project/image to mesh/train/meshes/",
                                image_size= 256, voxel_dims = (32,32,32), sample_rate = 8192)

image2mesh_val_dataset = Image2MeshDataLoader(images_path = "E:/UNI/Masters/ML3D/Project/image to mesh/val/images/",
                                meshes_path = "E:/UNI/Masters/ML3D/Project/image to mesh/val/meshes/",
                                image_size= 256, voxel_dims = (32,32,32), sample_rate = 8192)

In [4]:
print("number of training datapoints is:", len(image2mesh_train_dataset))
print("number of validation datapoints is:", len(image2mesh_val_dataset))

print("Images have shapes:", image2mesh_train_dataset[0][0].shape)
print("Meshes have shapes:",image2mesh_train_dataset[0][1].shape)


number of training datapoints is: 4867
number of validation datapoints is: 50


  spacing=pitch)


Images have shapes: torch.Size([3, 128, 128])
Meshes have shapes: torch.Size([32, 32, 32])


# Reconstruction Networks

In [5]:
"""
creation, training, and testing of the image2mesh reconstruction networks
"""
from Networks.Image2Mesh import Image2Voxel

model = Image2Voxel()

In [6]:
"""
Test out forward pass and ensure output sizes
"""
X = torch.rand(256,3,128,128)
pred = model(X)
pred.shape


torch.Size([256, 32, 32, 32])

In [7]:
"""
Viewing some of the dataset datapoints to see what the model is training on
"""
# ToDo

'\nViewing some of the dataset datapoints to see what the model is training on\n'

In [8]:
def image2meshScoreFunction(preds,labels,thresh=0.5):
    predicted_vals = preds.clone()
    with torch.no_grad():
        predicted_vals[predicted_vals<thresh] = 0
        predicted_vals[predicted_vals>=thresh] = 1

        acc = (labels == predicted_vals).float().mean()
    return acc.item()

In [None]:
from Networks.Trainer import Trainer

path_prefix = "Assets/Models/Image2Mesh/"

try:
    os.mkdir(path_prefix)
except:
    pass

trainer = Trainer( 
        model = torch.load(f"{path_prefix}/image2mesh.model"),#model,
        model_save_path = f"{path_prefix}/image2mesh.model",
        loss_function = torch.nn.BCELoss(),#torch.nn.L1Loss(), 
        optimizer = torch.optim.Adam,
        batch_size = 8,
        device = device,
        training_dataset = image2mesh_train_dataset,
        validation_dataset = image2mesh_val_dataset,
        score_function = image2meshScoreFunction
    )
try:
    trainer.fit(epochs=10,learning_rate=0.01)
except KeyboardInterrupt:
    print("\nStopped by user saving last file")
    model.save(f"{path_prefix}/Keyboard_interrupt_temp.model")

# Purifying predicted Meshes

In [None]:
"""
Code to purify meshes predicted by the previous networks to be used in the retrieval step
"""

# Mesh Encoding

In [None]:
"""
AutoEncoder Models and/or different techniques used to encode the mesh to a smaller dimensions
"""

# Mesh Retreival Networks

In [None]:
"""
Models/Techniques to use the previous encoding steps to retreive objects from a specified database
"""

# Inference and Full Testing

In [None]:
"""
Testing the entire pipeline implemented with added visualizations and discussions.
"""

# Citations

[1].....