In [1]:
import sys, os
import numpy as np
import torch
sys.path.append(os.getcwd())
from src.data_utils.prepare_data import read_off_file
from torch.utils.data import DataLoader
from src.data_utils.data_load import CubeDataset, SphericDataset
from src.autoencoder.autoencoder_module import PointNetAutoencoder
from src.autoencoder.eval_autoencoder import create_autoencoder_dataloader, get_batch_reconstructions
import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt
from src.visualization.vis_utils import plot_point_cloud_3d, visualize_predictions
from omegaconf import DictConfig, OmegaConf
from hydra import compose, initialize

from warnings import filterwarnings
filterwarnings("ignore")


In [2]:
model = PointNetAutoencoder.load_from_checkpoint('output/2025-01-30/01-19-56/pointnet-epoch=59-val_loss=2.15.ckpt')
with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="Al_autoencoder")
file_path = 'datasets/Al/inherent_configurations_off/240ps.off'

dataloader = create_autoencoder_dataloader(cfg, file_path)
points = next(iter(dataloader))[0]
original_points, reconstructed_points = get_batch_reconstructions(model, points, cfg.data.num_points)

Read 1048576 points
Size of space: [263.22167  263.219158 263.220081]
Min coords: [ 3.30e-04  8.42e-04 -8.10e-05]
Max coords: [263.222 263.22  263.22 ]
Avg added 2.89 points, avg dropped 0.22 points
Number of samples in spheric dataset: 50653
Input points shape: torch.Size([3048, 32, 3])
Reconstructed shape: torch.Size([3048, 64, 3])
Original points shape after processing: (3048, 32, 4)
Reconstructed points shape after processing: (3048, 64, 4)


In [3]:
fig_original = plot_point_cloud_3d(
    original_points[0][:, :3],  # Take only XYZ coordinates
    n_connections=3,
    title='Original',
    point_size=3, 
)
fig_original.update_layout(
    width=600,  
    height=400 
)
fig_original.show()

# Visualize reconstructed points (first sample)
fig_reconstructed = plot_point_cloud_3d(
    reconstructed_points[0][:, :3],  # Take only XYZ coordinates
    n_connections=3,
    title='Reconstructed',
    point_size=3,
    color='red'
)
fig_reconstructed.update_layout(
    width=600,  
    height=400
)
fig_reconstructed.show()