In [None]:
## Instalaciones

%pip install torch
%pip install open3d
%pip install tabulate

In [None]:
## Dependencias

from typing import List
import torch
import os
import open3d as o3d
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from model import PointNetClassifier, PointNetLoss, PointNetKAN
from modelnet10 import ModelNetClass, ModelNet, DatasetType
from utils.csv import save_loss_dict
from utils.transformation import (Normalization,
                                  Rotation, Translation, Reflection, Scale,
                                  DropRandom, DropSphere, Jittering, Noise)
from trainer import PointNetTrainer


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {DEVICE}.")

In [None]:
# parámetros globales
checkpoint_freq = 25

# parámetros del dataset
classes = [label for label in ModelNetClass]
batch_size = 32
dim = 3
num_points = 1024
num_classes = len(classes)

# hiperparámetros
num_global_feats = 1024     # número de features globales calculadas
learning_rate = 0.001
reg_weight = 0.001
gamma = 2                   # Recomendado por el paper de focal loss

# dataset de entrenamiento
t = [Rotation(), Reflection(), Scale(max_ratio=2.5),
    Jittering(max_units=0.005), DropRandom(loss_ratio=0.4), Noise()]

train_data = ModelNet(classes, DatasetType.TRAIN, repetitions=3, transformations=t, preserve_original=False)
validation_data = ModelNet(classes, DatasetType.VALIDATION, repetitions=3, transformations=t, preserve_original=False)
    
# TODO: Más adelante usar alpha para clases imbalanceadas

In [None]:
# Función de entrenamiento
def train(
        epochs: int,
        name: str,
        num_global_feats: int,
        learning_rate: int,
        use_scheduler: bool,
        alpha: List[int],
        gamma: int,
        reg_weight: int,
        use_kan: bool,
        ignore_Tnet: bool,
):
    if not use_kan:
        classifier = PointNetClassifier(dim, num_points, num_global_feats, num_classes, ignore_Tnet=ignore_Tnet).to(DEVICE)
    else:
        classifier = PointNetKAN(dim, num_points, num_classes, scaling = 2.0).to(DEVICE)
    optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)
    if DEVICE == "cuda" and use_scheduler:
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.01, step_size_up=2000, cycle_momentum=False)
    else:
        scheduler = None
    
    trainer = PointNetTrainer(
        name=name,
        model=classifier,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=PointNetLoss(alpha=alpha, gamma=gamma, reg_weight=reg_weight, size_average=True).to(DEVICE),
        device=DEVICE,
        train_loader=DataLoader(train_data, batch_size=batch_size, shuffle=True),
        val_loader=DataLoader(validation_data, batch_size=batch_size, shuffle=False),
        checkpoint_dir=os.path.join(os.getcwd(), "checkpoint"),
        checkpoint_freq=checkpoint_freq
    )

    loss_dict, best_epoch, best_loss, best_acc = trainer.fit(epochs=epochs)
    #save_loss_dict(loss_dict, os.path.join(os.getcwd(), "csv", f"{name}_loss_dict.csv"))
    print(f"{name} | Best model @ epoch {best_epoch}: loss = {best_loss:.4f}, acc = {best_acc:.4f}")

# Instancias de entrenamiento
EPOCHS=200

In [None]:
train(epochs=EPOCHS, name="base", num_global_feats=num_global_feats, learning_rate=learning_rate,
      use_scheduler=False, alpha=None, gamma=0, reg_weight=reg_weight, use_kan=False, ignore_Tnet=False)

In [None]:
alpha = [3991/106, 3991/515, 3991/889, 3991/200, 3991/200, 3991/465, 3991/200, 3991/680, 3991/392, 3991/344]
train(epochs=EPOCHS, name="mod", num_global_feats=num_global_feats, learning_rate=learning_rate,
      use_scheduler=True, alpha=alpha, gamma=gamma, reg_weight=reg_weight, use_kan=False, ignore_Tnet=False)

In [None]:
train(epochs=EPOCHS, name="no_tnet", num_global_feats=num_global_feats, learning_rate=learning_rate,
      use_scheduler=False, alpha=None, gamma=0, reg_weight=0, use_kan=False, ignore_Tnet=True)

# Dataset Testeo

In [None]:
# dataset de prueba
base_test_data = ModelNet(classes, DatasetType.TEST, repetitions=1, preserve_original=False,
                          transformations=[])
affine_test_data = ModelNet(classes, DatasetType.TEST, repetitions=1, preserve_original=False,
                          transformations=[Rotation(), Reflection(), Scale(max_ratio=2.5)])
complex_test_data = ModelNet(classes, DatasetType.TEST, repetitions=1, preserve_original=False,
                          transformations=[Rotation(), Reflection(), Scale(max_ratio=2.5),
                                          Jittering(max_units=0.005), DropRandom(loss_ratio=0.4), Noise()])

In [None]:
def test_it(classifier_path: str, num_global_feats=num_global_feats, use_kan=False, ignore_Tnet=False):

    for data_name, data in [["base", base_test_data], ["affine", affine_test_data], ["complex", complex_test_data]]:
        data_loader = DataLoader(data, batch_size=batch_size, shuffle=False)
    
        if not use_kan:
            classifier = PointNetClassifier(dim, num_points, num_global_feats, num_classes, ignore_Tnet=ignore_Tnet).to(DEVICE)
        else:
            classifier = PointNetKAN(dim, num_points, num_classes, scaling = 2.0).to(DEVICE)
        classifier.load_state_dict(torch.load(classifier_path))

        with torch.no_grad():
            classifier = classifier.eval()
            correct = 0
            
            for pcds, labels in data_loader:
                pcds = pcds.to(DEVICE)
                labels = labels.squeeze().to(DEVICE)
                
                # Hacer predicciones
                out, _, _ = classifier(pcds)
            
                # Calculamos las elecciones
                pred_choice = torch.softmax(out, dim=1).argmax(dim=1)
                
                # Elecciones correctas, acumuladas
                correct += pred_choice.eq(labels.data).cpu().sum().item()

            test_acc = correct / float(len(data))
            print(f"\tAccuracy on {data_name} dataset:\t", test_acc)

# Tests
_dir = os.path.join(os.getcwd(), "checkpoint", "best_model")
print("Base classifier:")
test_it(os.path.join(_dir, "base_best_model.pth"))
print("Modified classifier with KAN, alpha, gamma, scheduler:")
test_it(os.path.join(_dir, "mod_best_model.pth"), use_kan=True)
print("Base classifier without Tnet:")
test_it(os.path.join(_dir, "no_tnet_best_model.pth"), ignore_Tnet=True)
