In [2]:
## Instalaciones

%pip install torch

Collecting torch
  Downloading torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.5.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collec

In [1]:
## Dependencias

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
## T-net
"""
T-net es una 'mini-red' que aprende una matriz de transformación de tamaño
dimxdim que transforma la entrada a una representación 'canónica', la cuál
es invariante a transformaciones rigidas (rotación, translación, reflexión).
"""

class Tnet(nn.Module):
    def __init__(self, dim, num_points):
        super(Tnet, self).__init__()

        self.dim = dim

        # Función de activación
        self.act = F.relu

        # Conv1d es una implementación sencilla de una 'MLP compartida'
        self.shared_mlp1 = nn.Conv1d(dim, 64, kernel_size=1)
        self.shared_mlp2 = nn.Conv1d(64, 128, kernel_size=1)
        self.shared_mlp3 = nn.Conv1d(128, 1024, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.max_pool = nn.MaxPool1d(kernel_size=num_points)

        # MLPs no compartidas
        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 256)
        self.linear3 = nn.Linear(256, dim**2)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
    
    def forward(self, x):
        bs = x.shape[0]

        # Paso a través de las MLPs compartidas
        x = self.bn1(self.act(self.shared_mlp1(x)))
        x = self.bn2(self.act(self.shared_mlp2(x)))
        x = self.bn3(self.act(self.shared_mlp3(x)))

        # Max pool
        x = self.max_pool(x).view(bs, -1)
        
        # Paso a través de las MLPs no compartidas
        x = self.bn4(self.act(self.linear1(x)))
        x = self.bn5(self.act(self.linear2(x)))
        x = self.linear3(x)
        
        # Reshape de 'T-Net(x)' a una matriz
        x = x.view(-1, self.dim, self.dim)
        # Le sumamos la matriz identidad para mayor estabilidad
        iden = torch.eye(self.dim, requires_grad=True).repeat(bs, 1, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x += iden

        return x

In [22]:
## Point-net classifier

class PointnetClassifier(nn.Module):
    def __init__(self, dim, num_points, num_global_feats, num_classes):
        super(PointnetClassifier, self).__init__()

        # Función de activación
        self.act = F.relu

        # T-Net en los puntos de la entrada
        self.input_transform = Tnet(dim, num_points)

        # Primera MLP compartida, transforma los puntos de la entrada en features
        self.shared_mlp1 = nn.Conv1d(3, 64, kernel_size=1)
        self.shared_mlp2 = nn.Conv1d(64, 64, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)

        # T-Net en las features
        self.feature_transform = Tnet(64, num_points)

        # Segunda MLP compartida, determina las features globales
        self.shared_mlp3 = nn.Conv1d(64, 64, kernel_size=1)
        self.shared_mlp4 = nn.Conv1d(64, 128, kernel_size=1)
        self.shared_mlp5 = nn.Conv1d(128, num_global_feats, kernel_size=1)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(num_global_feats)
        # Max pool para extraer las features globales
        # Devolver los indices nos permite ver los indices críticos que determinan las features globales
        self.max_pool = nn.MaxPool1d(kernel_size=num_points, return_indices=True)

        # MLP para clasificación
        self.linear1 = nn.Linear(num_global_feats, 512)
        self.linear2 = nn.Linear(512, 256)
        self.bn_linear1 = nn.BatchNorm1d(512)
        self.bn_linear2 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(p=0.3)

        # Output layer
        self.linear3 = nn.Linear(256, num_classes)
    
    def forward(self, x):

        # Tamaño del batch, es decir cuantos ejemplos hay en el batch
        bs = x.shape[0]

        # Transformación del input
        input_matrix = self.input_transform(x)
        # x = torch.bmm(x.tranpose(2, 1), input_matrix).tranpose(2, 1)
        x = torch.transpose(torch.bmm(torch.transpose(x, 2, 1), input_matrix), 2, 1)

        # Paso a través de las primeras MLPs compartidas
        x = self.bn1(self.act(self.shared_mlp1(x)))
        x = self.bn2(self.act(self.shared_mlp2(x)))

        # Transformación de features
        feature_matrix = self.feature_transform(x)
        # x = torch.bmm(x.tranpose(2, 1), feature_matrix).tranpose(2, 1)
        x = torch.transpose(torch.bmm(torch.transpose(x, 2, 1), feature_matrix), 2, 1)

        # Paso a través de las segundas MLPs compartidas
        x = self.bn3(self.act(self.shared_mlp3(x)))
        x = self.bn4(self.act(self.shared_mlp4(x)))
        x = self.bn5(self.act(self.shared_mlp5(x)))

        global_features, critical_indexes = self.max_pool(x)
        global_features = global_features.view(bs, -1)
        critical_indexes = critical_indexes.view(bs, -1)

        # Clasificación
        x = self.bn_linear1(self.act(self.linear1(global_features)))
        x = self.bn_linear2(self.act(self.linear2(x)))
        x = self.dropout(x)
        x = self.linear3(x)

        # Devolver logits
        return x, critical_indexes, feature_matrix

In [24]:
## Test

# parametros del dataset
batch_size = 32             # número de point clouds
dim = 3                     # número de dimensiones por cada punto
num_points = 1024           # número de puntos por point cloud
num_classes = 2             # número de clases de clasificación

# hiperparametros
num_global_feats = 1024     # número de features globales calculadas

test_data = torch.rand(batch_size, dim, num_points)

classifier = PointnetClassifier(dim, num_points, num_global_feats, num_classes)
out, _, _ = classifier(test_data)
print(f'Class output shape: {out.shape}')
print(f'Class output: {out}')

Class output shape: torch.Size([32, 2])
Class output: tensor([[ 6.0447e-01, -6.1188e-01],
        [ 6.9314e-02,  7.3064e-01],
        [ 2.8404e-01,  4.3951e-01],
        [ 4.2944e-01, -3.8633e-01],
        [-1.0889e-01,  2.9398e-01],
        [ 5.9915e-01,  1.2308e-03],
        [ 4.9978e-01,  6.7860e-01],
        [ 6.9066e-01,  3.9589e-01],
        [ 1.7057e-01,  4.5106e-01],
        [-2.3058e-01,  2.8857e-01],
        [ 1.1676e+00,  2.4708e+00],
        [-7.0644e-01,  4.6825e-02],
        [-5.7438e-01, -9.9266e-01],
        [-9.6524e-01,  3.2507e-01],
        [-5.5060e-01, -1.2496e-03],
        [ 4.7521e-01, -4.5222e-01],
        [-3.6642e-01, -7.8267e-01],
        [ 2.2702e-01, -7.0550e-01],
        [ 4.7587e-01, -1.4030e-01],
        [-6.2072e-01,  7.1639e-01],
        [ 1.0209e-01,  2.0904e+00],
        [ 1.0551e-01, -8.1594e-02],
        [ 2.4185e-01,  3.7723e-01],
        [-1.0678e-02, -1.4720e+00],
        [-6.1999e-01, -5.5066e-01],
        [-1.1785e+00,  1.3708e+00],
        [ 