In [3]:
import sys,os
import torch
import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt
from src.visualization.vis_utils import plot_point_cloud_3d, visualize_predictions
from src.cls.prediction import create_dataloader, predict_phases
from src.cls.lightning_module import PointNetClassifier
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
from warnings import filterwarnings
filterwarnings("ignore")


In [4]:
with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name="Al_classification.yaml")
print(OmegaConf.to_yaml(cfg))

training:
  learning_rate: 0.002
  batch_size: 256
  epochs: 100
  decay_rate: 0.002
  gpu: true
  log_every_n_steps: 5
  num_workers: 4
data:
  cube_size: 16
  radius: 8.0
  sample_type: regular
  sample_shape: spheric
  n_samples: 8000
  num_points: 128
project_name: pointnet
experiment_name: setup



In [51]:
#change config if needed
# cfg.data.sample_shape = 'spheric'
# cfg.data.num_points = 128
cfg.data.radius = 7.9
file_path = 'datasets/Al/inherent_configurations_off/166ps.off'
dataloader = create_dataloader(cfg, file_path)

Read 1048576 points
Size of space: [267.0368   267.035942 267.03684 ]
Min coords: [-8.0e-04  5.8e-05 -8.4e-04]
Max coords: [267.036 267.036 267.036]
Avg added 14.79 points, avg dropped 0.0 points
Number of samples in spheric dataset: 4913


In [52]:
import random

points_batch, coords_batch = random.choice(list(dataloader))
points = points_batch[0].numpy()
coords = np.array(coords_batch)[:, 0]
print(f"Points shape: {points.shape}")
print(f"Center coodinates: {coords}")

fig = plot_point_cloud_3d(points, n_connections=4)
fig.show()


Points shape: (128, 3)
Center coodinates: [  7.8992   244.900058  23.69916 ]


In [53]:
model_path = 'output/2024-12-03/17-31-38/pointnet-epoch=86-val_acc=0.99.ckpt'
model = PointNetClassifier.load_from_checkpoint(model_path)
prediction = predict_phases(model, dataloader, device='cpu')

In [54]:
visualize_predictions(prediction)

In [55]:
model = PointNetClassifier.load_from_checkpoint(model_path)
predictions_proba = predict_phases(model, dataloader, device='cpu',  return_probablitity=True)

In [56]:
from src.visualization.vis_utils import visualize_probabilities

def visualize_probabilities(points_and_probs, title: str = "Phase Probabilities"):
    """Visualize 3D points with probability values using plotly.
    
    Args:
        points_and_probs: Nx4 array where each row is (x, y, z, probability)
        title: Plot title
    """
    fig = go.Figure()
    probs = points_and_probs[:, 3]
    min_size = 3
    max_size = 5
    marker_sizes = min_size + (max_size - min_size) * (1 - probs)

    fig.add_trace(go.Scatter3d(
        x=points_and_probs[:, 0],
        y=points_and_probs[:, 1],
        z=points_and_probs[:, 2],
        mode='markers',
        marker=dict(
            size=marker_sizes,
            colorscale='Viridis',
            color= points_and_probs[:, 3],
            colorbar=dict(title='Probability'),
            opacity=0.8
        ),
        name='Points'
    ))
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        width=800,
        height=800,
    )
    
    fig.show()


In [57]:
visualize_probabilities(predictions_proba)