# Project imports

In [1]:
"""
All needed imports included here
"""
%load_ext autoreload
%autoreload 2
from pathlib import Path
import numpy as np
import matplotlib as plt
import torch
import pytorch_lightning as pl

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Data Loading step

In [2]:
"""
Create data loaders and augmentations needed here
"""
from Data.ShapeNetDataLoader import ShapeNetVoxelData
from Utils.visualization import visualize_occupancy

overfit = False

shapenet_core_path = Path("Data/ShapeNetCoreVoxel32")
shapenet_splits_csv_path = Path("Data/shapenet_splits.csv")
voxel_filename = "model_3.binvox"
# load only models from some synsets
synset_id_filter = ["04379243"]  # tables
train_data = ShapeNetVoxelData(shapenet_core_path=shapenet_core_path, shapenet_splits_csv_path=shapenet_splits_csv_path, split="train", 
    overfit=overfit, synset_id_filter=synset_id_filter, voxel_filename=voxel_filename
)
print(f"Train Set Size: {len(train_data)}")
val_data = ShapeNetVoxelData(shapenet_core_path=shapenet_core_path, shapenet_splits_csv_path=shapenet_splits_csv_path, split="val",
    overfit=overfit, synset_id_filter=synset_id_filter, voxel_filename=voxel_filename
)
print(f"Validation Set Size: {len(val_data)}")
test_data = ShapeNetVoxelData(shapenet_core_path=shapenet_core_path, shapenet_splits_csv_path=shapenet_splits_csv_path, split="test",
    overfit=overfit, synset_id_filter=synset_id_filter, voxel_filename=voxel_filename
)
print(f"Test Set Size: {len(test_data)}")

train_sample = train_data[0]
print(f'Voxel Dimensions: {train_sample.shape}')

visualize_occupancy(train_sample.squeeze(), flip_axes=True)

Train Set Size: 1368
Validation Set Size: 236
Test Set Size: 383
Voxel Dimensions: (1, 32, 32, 32)


Output()

# Create Autoencoder

In [3]:
#%env CUDA_LAUNCH_BLOCKING=1
"""
AutoEncoder Models and/or different techniques used to encode the mesh to a smaller dimensions
"""
from Networks.VoxelAutoencoder import VoxelAutoencoder
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# lower kl_divergence_scale -> smoother latent space but worse reconstruction
# however a higher kl_divergence_scale causes too much overlap between latent space distributions which is impractical for retrieval
kl_divergence_scale=0.05
latent_dim = 64 # tried 128 which gives slightly better reconstructions, but worse retrieval
model = VoxelAutoencoder(train_data, val_data, test_data, device, kl_divergence_scale=kl_divergence_scale, latent_dim=latent_dim)

logger = TensorBoardLogger("tb_logs", name="my_model")
model_checkpoint = ModelCheckpoint(
    monitor="val_loss",
    dirpath="Assets/Models/VoxelAutoencoder/",
    filename="voxel-autoencoder-02-{epoch:0004d}-{val_loss:.4f}",
    save_top_k=3,
    every_n_epochs=8,
    mode="min",
)
tqdm_progess_bar = TQDMProgressBar(refresh_rate=1)
early_stopping = EarlyStopping(monitor="val_loss", patience=32, mode="min")
trainer = pl.Trainer(
    max_epochs=1024,
    gpus=1 if torch.cuda.is_available() else None,
    log_every_n_steps=1,
    logger=logger,
    callbacks=[model_checkpoint, tqdm_progess_bar, early_stopping],
    profiler="simple"
)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


# Training

In [None]:

trainer.fit(model)

# Visualize Reconstruction

In [5]:
best_model_path = "Assets/Models/VoxelAutoencoder/voxel-autoencoder-01-epoch=0279-val_loss=0.0346.ckpt"
# best_model_path = model_checkpoint.best_model_path
model = VoxelAutoencoder.load_from_checkpoint(
    best_model_path, train_set=train_data, val_set=val_data, test_set=test_data, device=device, kl_divergence_scale=kl_divergence_scale, latent_dim=latent_dim
)

# visualize reconstruction
train_test_sample = train_data[1]

visualize_occupancy(train_test_sample.squeeze(), flip_axes=True)

sample_tensor = torch.from_numpy(train_test_sample[np.newaxis, :])

model.eval()
with torch.no_grad():
    decoded_test = model(sample_tensor)
    print(model.encode(sample_tensor)[0].shape)

tmp_decoded = decoded_test.clone()
tmp_decoded[decoded_test<0.5] = 0
tmp_decoded[decoded_test>=0.5] = 1

decoded_test_np = tmp_decoded.squeeze().detach().numpy()

visualize_occupancy(decoded_test_np, flip_axes=True)

Output()

torch.Size([64])


Output()

# compute latent vectors of training samples

In [6]:
latent_vectors = {}
model.eval()
with torch.no_grad():
    for train_sample in train_data:
        sample_tensor = torch.from_numpy(train_sample[np.newaxis, :])
        vec = torch.zeros(64)
        for i in range(4):
            vec += model.encode(sample_tensor)[0]
        vec /= 4
        latent_vectors[vec] = train_sample

In [12]:
# compute latent vector of test sample
test_sample = train_data[40]
model.eval()
with torch.no_grad():
    # result is stochastic -> can be run multiple times to get different results
    sample_tensor = torch.from_numpy(test_sample[np.newaxis, :])
    test_latent_vector = torch.zeros(64)
    for i in range(4):
        test_latent_vector += model.encode(sample_tensor)[0]
    test_latent_vector /= 4

print("Test sample:")
visualize_occupancy(test_sample.squeeze(), flip_axes=True)  

# find closest latent vector
min_distance = float('inf')
best_voxel_match_0 = None
best_voxel_match_1 = None
best_voxel_match_2 = None
for train_latent_vector, train_voxel in latent_vectors.items():
    distance = torch.dist(test_latent_vector, train_latent_vector)
    if (distance < min_distance):
        min_distance = distance
        best_voxel_match_2 = best_voxel_match_1
        best_voxel_match_1 = best_voxel_match_0
        best_voxel_match_0 = train_voxel

print("Retrieved object 1:")
visualize_occupancy(best_voxel_match_0.squeeze(), flip_axes=True)
print("Retrieved object 2:")
visualize_occupancy(best_voxel_match_1.squeeze(), flip_axes=True)
print("Retrieved object 3:") 
visualize_occupancy(best_voxel_match_2.squeeze(), flip_axes=True)   

Test sample:


Output()

Retrieved object 1:


Output()

Retrieved object 2:


Output()

Retrieved object 3:


Output()