In [None]:
## Instalaciones

%pip install pytorch
%pip install open3d
%pip install plotly
%pip install "notebook>=7.0" "anywidget>=0.9.13"

In [1]:
## Dependencias

import os
from utils.plotter import notebook_plot_pcd_from_points
from random import randrange
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from modelnet10 import ModelNetClass, ModelNet, DatasetType
from model import PointNetClassifier, PointNetLoss
from utils.transformation import (Normalization,
                                  Rotation, Translation, Reflection, Scale,
                                  DropRandom, DropSphere, Jittering, Noise)
import torch
import open3d as o3d
import numpy as np

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {DEVICE}.")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Using cuda.


In [8]:
## Parametros y constantes

IGNORE_CLASSIFIER = True

ROOT_DIR = os.getcwd()
MODEL_DIR = os.path.join(ROOT_DIR, "checkpoint", "best_model.pth")

classes = [label for label in ModelNetClass]
dim = 3
num_points = 1024
num_classes = len(classes)
num_global_feats = 1024 

t = [Rotation(), Reflection(), Scale(max_ratio=3.0),
    Jittering(max_units=0.005), DropRandom(loss_ratio=0.4), Noise()]
data = ModelNet(classes, DatasetType.TEST, repetitions=3, transformations=t, normalize=True, preserve_original=False)

In [9]:
## Inicialización de objetos y funciones

if not IGNORE_CLASSIFIER:
    classifier = PointNetClassifier(dim, num_points, num_global_feats, num_classes).to(DEVICE)
    classifier.load_state_dict(torch.load(MODEL_DIR))

def map_label(num):
    return classes[num].label

def prettier(string):
    return string.replace("_", " ").capitalize()

def predicted_class(pcd):
    out, _, _ = classifier(pcd)
    pred_choice = torch.softmax(out, dim=1).argmax(dim=1)
    return map_label(pred_choice)

def random_sample():
    index = randrange(data.__len__())
    _return = data.__getitem__(index)
    return _return[0].transpose(1, 0), _return[1]

def pcd_path_to_tensor(pcd_path):
    X = list()
    pcd = o3d.io.read_point_cloud(pcd_path)
    points = np.asarray(pcd.points, dtype=float)
    X.append(points)
    X = np.transpose(X, (0, 2, 1))
    X = torch.tensor(X, dtype=torch.float32)

    return X

In [10]:
# Widgets de output
button = widgets.Button(description="Generar ejemplo")
output_plot = widgets.Output()
output_text = widgets.Output()

def visualize_random_point_cloud(b=None):
    with output_plot:
        clear_output(wait=True)

        # Selección aleatoria
        x, label = random_sample()

        # Visualización
        notebook_plot_pcd_from_points(x, output_size=(1000,400), zoom=1.0)

        # Predicción y formato
        if not IGNORE_CLASSIFIER:
            pred = predicted_class(x)
            match = (pred == label)
            color = "#2ECC40" if match else "#FF4136"
            pred = prettier(pred)
            label = prettier(label)

            with output_text:
                clear_output(wait=True)
                display(HTML(f"""
<div style="font-size: 20px; font-family: Arial, sans-serif; border: 1px solid #ccc; padding: 15px; border-radius: 10px; background-color: #f9f9f9; width: fit-content;">
  <div><strong style="color: #444;">Ground Truth:</strong> <span style="color: {color};">{label}</span></div>
  <div><strong style="color: #444;">Predicted Class:</strong> <span style="color: {color};">{pred}</span></div>
</div>
"""))
        else:
            color = "#00008B"
            label = prettier(map_label(label))

            with output_text:
                clear_output(wait=True)
                display(HTML(f"""
<div style="font-size: 20px; font-family: Arial, sans-serif; border: 1px solid #ccc; padding: 15px; border-radius: 10px; background-color: #f9f9f9; width: fit-content;">
  <div><strong style="color: #444;">Class:</strong> <span style="color: {color};">{label}</span></div>
</div>
"""))

# Inicialización
button.on_click(visualize_random_point_cloud)
display(button, output_plot, output_text)
visualize_random_point_cloud()

Button(description='Generar ejemplo', style=ButtonStyle())

Output()

Output()