In [1]:
import sys,os
import torch
import plotly.graph_objects as go
import numpy as np

from src.data_utils.prediction import create_dataloader, predict_phases
from src.cls.lightning_module import PointNetClassifier


In [2]:
file_path = 'datasets/Al/inherent_configurations_off/170ps.off'
dataloader = create_dataloader(file_path, cube_size=12, batch_size=32)

Read 1048576 points
Size of space: [266.937768 266.942701 266.936487]
Min coords: [ 0.000232 -0.002701  0.000513]
Max coords: [266.938 266.94  266.937]
Number of samples: 12167


In [3]:
model_path = 'output/2024-11-21/23-25-55/pointnet-epoch=36-val_acc=0.93.ckpt'
model = PointNetClassifier.load_from_checkpoint(model_path)
prediction = predict_phases(model, dataloader, device= 'cpu')

In [4]:
len(prediction)

12167

In [5]:
def visualize_predictions(predictions, title: str = "Phase Predictions"):
    """Visualize 3D predictions using plotly.
    
    Args:
        predictions: Nx4 array where each row is (x, y, z, prediction)
        title: Plot title
    """

    fig = go.Figure()
    
    # Create scatter plot for each unique prediction class
    for phase in np.unique(predictions[:, 3]):
        mask = predictions[:, 3] == phase
        points = predictions[mask]
        
        fig.add_trace(go.Scatter3d(
            x=points[:, 0],
            y=points[:, 1],
            z=points[:, 2],
            mode='markers',
            marker=dict(size=2),
            name=f'Phase {int(phase)}'
        ))
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        )
    )
    
    fig.show()

In [6]:
visualize_predictions(prediction)