### Libraries

In [None]:
import laspy, os, sys, random, torch, hydra, gc
from pathlib import Path
import numpy as np
server = os.path.expanduser("~")
from torch_geometric.nn.pool.consecutive import consecutive_cluster

In [None]:
PATH = os.path.dirname(os.path.abspath('')) 
sys.path.append(PATH)

from src.datasets.dales import CLASS_NAMES as DALES_CLASS_NAMES
from src.datasets.dales import CLASS_COLORS as DALES_CLASS_COLORS
from src.datasets.kitti360 import CLASS_NAMES as KITTI_CLASS_NAMES
from src.datasets.kitti360 import CLASS_COLORS as KITTI_CLASS_COLORS
from src.data import Data, InstanceData
from src.transforms import instantiate_datamodule_transforms
from src.transforms import NAGRemoveKeys
from src.utils import init_config

### `Data` reader

#### Generate Dynamic Classes

In [None]:
def generate_class_mapping(filepath, default_class_id=7):
    """
    Genera las variables de clasificación dinámicamente con base en un archivo LAS/LAZ.
    """
    # Leer el archivo LAS/LAZ
    las = laspy.read(filepath)

    classifications = las['classification']
    unique_classes, counts = np.unique(classifications, return_counts=True)
    max_class_id = max(unique_classes) if len(unique_classes) > 0 else 0
    ID2TRAINID = np.full(max_class_id + 1, default_class_id, dtype=np.int64)

    class_index_mapping = {cls: idx for idx, cls in enumerate(sorted(unique_classes))}
    for cls, mapped_idx in class_index_mapping.items():
        ID2TRAINID[cls] = mapped_idx

    CLASS_NAMES = [f"Class_{cls}" for cls in sorted(unique_classes)]
    CLASS_COLORS = np.array([[random.randint(0, 255) for _ in range(3)] for _ in unique_classes])

    def map_classification(y):
        """
        Mapea la clasificación original a la clase definida en ID2TRAINID.
        Si la clasificación no está en el rango, se asigna a 'Default'.
        """
        y_mapped = torch.from_numpy(ID2TRAINID)[y]
        y_mapped[y >= len(ID2TRAINID)] = default_class_id
        return y_mapped

    print("Clases encontradas y asignadas:")
    for cls, idx in class_index_mapping.items():
        print(f"Clase {cls} -> Índice {idx} - Color: {CLASS_COLORS[idx]}")
    return ID2TRAINID, CLASS_NAMES, CLASS_COLORS, map_classification

#### CTD Classes

In [None]:
CTD_NUM_CLASSES = 7 
ID2TRAINID = np.full(65, 0, dtype=np.int64)
ID2TRAINID[7] = 1  
ID2TRAINID[11] = 1  
ID2TRAINID[4] = 2
ID2TRAINID[20] = 3
ID2TRAINID[8] = 4
ID2TRAINID[60] = 5   
ID2TRAINID[30] = 6   

ID2TRAINID = np.array([0, 1, 2, 3, 4, 5, 6], dtype=np.int64)

CTD_CLASS_NAMES = [
    'Default',    
    'Ground',    
    'Building',    
    'Pole',    
    'Wires',    
    'Vegetation',    
    'Car',    
]

CTD_CLASS_COLORS = np.asarray([
    [0, 0, 0],   # White
    [139, 87, 42],     # Brown
    [74, 144, 226],    # Blue
    [245, 166, 35],    # Orange
    [208, 2, 27],      # Red
    [126, 211, 33],    # Green
    [248, 231, 28]     # Yellow
])

In [None]:
class_names = {"kitti360": KITTI_CLASS_NAMES, "dales": DALES_CLASS_NAMES, "ctd": CTD_CLASS_NAMES}
class_colors = {"kitti360": KITTI_CLASS_COLORS, "dales": DALES_CLASS_COLORS, "ctd": CTD_CLASS_COLORS}

#### Las reader

In [None]:
def read_ctd360_tile(
        filepath, 
        load_xyz=True, 
        load_rgb=True, 
        load_intensity=True, 
        load_semantic=True, 
        load_instance=True,
        remap_labels=True, 
        max_intensity=600):
    """Read a CTD360 tile saved as LAS or LAZ."""
    
    data = Data()

    if filepath.lower().endswith('.laz'):
        las = laspy.read(filepath, laz_backend=laspy.LazBackend.LazrsParallel)
    else:
        las = laspy.read(filepath)

    if load_xyz:
        pos = torch.stack([
            torch.tensor(np.copy(las[axis]), dtype=torch.float32)  
            for axis in ["X", "Y", "Z"]], dim=-1)
        scale = torch.tensor(las.header.scale, dtype=torch.float32)  
        pos *= scale
        pos_offset = pos[0].clone()  
        pos -= pos_offset
        data['pos'] = pos

    if load_rgb:
        if all(axis in las.point_format.dimension_names for axis in ['red', 'green', 'blue']):
            data.rgb = torch.stack([
                torch.FloatTensor(las[axis].astype('float32') / 65535.0)
                for axis in ['red', 'green', 'blue']], dim=-1)
        else:
            print("El archivo LAS no contiene información RGB.")

    if load_intensity:
        data.intensity = torch.FloatTensor(
            las['intensity'].astype('float32')
        ).clip(min=0, max=max_intensity) / max_intensity

    if load_semantic:
        y = torch.LongTensor(las['classification'])
        data.y = torch.from_numpy(ID2TRAINID)[y] if remap_labels else y

    # Cargar etiquetas de instancias desde 'user_data'
    if load_instance and hasattr(las, 'user_data'):
        idx = torch.arange(len(las['user_data']))
        obj = torch.LongTensor(las['user_data'])
        obj, _ = consecutive_cluster(obj)
        count = torch.ones_like(obj)  
        if load_semantic:
            y = data.y
        else:
            y = torch.zeros(len(obj), dtype=torch.long)
        data.obj = InstanceData(idx, obj, count, y, dense=True)

    return data,  pos_offset, scale

### `Data` visualization

In [None]:
# file_path = "/home/binahlab/AI-Labs/clever-data/electrical-elements/data/icadel/ibague/ut/raw/pointclouds/UT_CLAS.las"
# ID2TRAINID, ENEL_CLASS_NAMES, ENEL_CLASS_COLORS, map_classification = generate_class_mapping(file_path)
# print("\nENEL_CLASS_NAMES:", ENEL_CLASS_NAMES)
# print("\nID2TRAINID:\n", ID2TRAINID)
# data, _, _ = read_ctd360_tile(file_path, load_semantic=True, load_instance=False)

In [None]:
# data.y.unique(return_counts=True)

In [None]:
# data.show(class_names=ENEL_CLASS_NAMES, class_colors=ENEL_CLASS_COLORS)

In [None]:
# file_path = "/home/binahlab/AI-Labs/clever-data/electrical-elements/data/celsia/Tolima-2024_08_28/2024_07_12/proccesed/labeling/supervisely/CML-5FFC8A-2024-07-12-16-09-29/ML-5FFC8A-2024-07-12-16-09-29.las"
# data, _, _ = read_ctd360_tile(file_path, load_semantic=True, load_instance=False)
# data.show(class_names=CTD_CLASS_NAMES, class_colors=CTD_CLASS_COLORS)

### Single `Inference`

In [None]:
# file_path = "/home/binahlab/AI-Labs/clever-data/electrical-elements/data/enel/ABC_TRACK_A/raw/pointclouds/ABC_Track_A_FirstProfiler_1.las"
# data, _, _ = read_ctd360_tile(file_path, load_semantic=False, load_instance=False)

In [None]:
# dataset = "kitti360" 
# experiment = "panoptic"
# exp = "spt-2" if experiment == "semantic" else "supercluster"
# ckpt_path = f"{server}/AI-Labs/superpoint_transformer/ckpt/{exp}_{dataset}.ckpt"
# cfg = init_config(overrides=[f"experiment={experiment}/{dataset}"])
# transforms_dict = instantiate_datamodule_transforms(cfg.datamodule)
# model = hydra.utils.instantiate(cfg.model)._load_from_checkpoint(ckpt_path).eval()

# nag = transforms_dict['pre_transform'](data)
# nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag)
# nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag)

# nag = nag.cuda()
# nag = transforms_dict['on_device_test_transform'](nag)

# with torch.no_grad():
#     output = model(nag)
#     nag[0].semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)
#     if exp == "panoptic":
#         vox_labels, vox_instance_idx, vox_instance_data = output.voxel_panoptic_pred(super_index=nag[0].super_index)
#         nag[0].semantic_pred = vox_labels
#         nag[0].instance_pred = vox_instance_idx
#         nag[0].instance_data = vox_instance_data

# nag.show(class_names=class_names[dataset], class_colors=class_colors[dataset])
# torch.cuda.empty_cache()
# gc.collect()

### Crop route using images

### Inferences and save Result

In [None]:
def save_las_file(tile_data, file_path, classification=None, pos_offset=None, scale=None):
    """Save the tile data as a .las file."""
    header = laspy.LasHeader(point_format=tile_data.pos.shape[1])
    las = laspy.LasData(header)

    # Mover tile_data.pos a la CPU
    pos_cpu = tile_data.pos.cpu()

    # Restaurar las posiciones originales en la CPU
    if pos_offset is not None and scale is not None:
        pos_restored = (pos_cpu + pos_offset.cpu()) / scale.cpu()
    else:
        pos_restored = pos_cpu

    # Convertir los datos a numpy y escribirlos en el archivo LAS
    las.X = pos_restored[:, 0].numpy()
    las.Y = pos_restored[:, 1].numpy()
    las.Z = pos_restored[:, 2].numpy()

    # Check and save RGB if present
    if tile_data.rgb is not None:
        las.red = (tile_data.rgb[:, 0].cpu() * 65535).to(torch.int16).numpy().astype(np.uint16)
        las.green = (tile_data.rgb[:, 1].cpu() * 65535).to(torch.int16).numpy().astype(np.uint16)
        las.blue = (tile_data.rgb[:, 2].cpu() * 65535).to(torch.int16).numpy().astype(np.uint16)

    # Check and save intensity if present
    if tile_data.intensity is not None:
        las.intensity = (tile_data.intensity.cpu() * 600).to(torch.int16).numpy().astype(np.uint16)

    # Use provided classification (from inference) or fallback to the original classification in tile_data.y
    if classification is not None:
        las.classification = classification.cpu().numpy()
        print(np.unique(classification.cpu().numpy()))
    elif tile_data.y is not None:
        las.classification = tile_data.y.cpu().numpy()

    # Write the LAS file
    las.write(file_path)

def map_kitti_to_ctd(vox_labels):
    """Mapea las clases predichas de KITTI a las clases de ctd."""
    ctd_labels = torch.full_like(vox_labels, fill_value=0)  # Default (Unclassified)
    ctd_labels[vox_labels == -1]= 0   # Noise
    ctd_labels[vox_labels == 0] = 1   # Ground
    ctd_labels[vox_labels == 1] = 1   # Sidewalk
    ctd_labels[vox_labels == 2] = 2   # Building
    ctd_labels[vox_labels == 3] = 2   # Wall
    ctd_labels[vox_labels == 4] = 2   # Fence
    ctd_labels[vox_labels == 5] = 3   # Pole
    ctd_labels[vox_labels == 6] = 0   # Traffic light
    ctd_labels[vox_labels == 7] = 0   # Traffic sign
    ctd_labels[vox_labels == 8] = 5   # Vegetation
    ctd_labels[vox_labels == 9] = 5   # Terrain
    ctd_labels[vox_labels == 10] = 0  # Person
    ctd_labels[vox_labels == 11] = 6  # Car
    ctd_labels[vox_labels == 12] = 6  # Truck
    ctd_labels[vox_labels >= 13] = 0  # Default
    
    return ctd_labels

def map_dales_to_ctd(vox_labels):
    """Mapea las clases predichas de DALES a las clases de ctd."""
    ctd_labels = torch.full_like(vox_labels, fill_value=0)  # Por defecto: Noise (Unclassified)
    ctd_labels[vox_labels == 0] = 1   # Ground
    ctd_labels[vox_labels == 1] = 5   # Vegetation
    ctd_labels[vox_labels == 2] = 6   # Cars
    ctd_labels[vox_labels == 3] = 6   # Trucks
    ctd_labels[vox_labels == 4] = 4   # Power lines
    ctd_labels[vox_labels == 5] = 2   # Fences
    ctd_labels[vox_labels == 6] = 3   # Poles
    ctd_labels[vox_labels == 7] = 2   # Buildings
    ctd_labels[vox_labels == 8] = 0   # Unknown
    return ctd_labels

In [None]:
def run_inference_and_save_las(
    model,
    data,
    file_path: str,
    project_path: str,
    transforms_dict: dict,
    cfg,
    exp: str,
    dataset: str,
    class_names: dict,
    class_colors: dict,
    pos_offset,
    scale,
    visualize: bool = False
):
    nag = transforms_dict['pre_transform'](data)
    nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag)
    nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag)

    nag = nag.cuda()
    nag = transforms_dict['on_device_test_transform'](nag)

    with torch.no_grad():
        output = model(nag)

        if exp == "panoptic":
            vox_labels, vox_instance_idx, vox_instance_data = output.voxel_panoptic_pred(
                super_index=nag[0].super_index
            )
            inf_y = output.full_res_panoptic_pred(
                super_index_level0_to_level1=nag[0].super_index,
                sub_level0_to_raw=nag[0].sub
            )
        else:
            vox_labels = output.voxel_semantic_pred(super_index=nag[0].super_index)
            inf_y = output.full_res_semantic_pred(
                super_index_level0_to_level1=nag[0].super_index,
                sub_level0_to_raw=nag[0].sub
            )

        if dataset == 'kitti360':
            ctd_labels = map_kitti_to_ctd(inf_y)
            ctd_labels_show = map_kitti_to_ctd(vox_labels)
        elif dataset == 'dales':
            ctd_labels = map_dales_to_ctd(inf_y)
            ctd_labels_show = map_dales_to_ctd(vox_labels)
        else:
            ctd_labels = inf_y
            ctd_labels_show = vox_labels

    las_name = os.path.basename(file_path)
    save_path = os.path.join(project_path, f"processed/classified_{exp}_{dataset}")
    os.makedirs(save_path, exist_ok=True)

    save_las_file(
        data,
        file_path=os.path.join(save_path, las_name),
        classification=ctd_labels,
        pos_offset=pos_offset,
        scale=scale
    )
    if visualize:
        nag[0].semantic_pred = ctd_labels_show
        nag.show(class_names=class_names["ctd"], class_colors=class_colors["ctd"])
    torch.cuda.empty_cache()
    gc.collect()


In [None]:
dataset = "kitti360" 
experiment = "semantic"
exp = "spt-2" if experiment == "semantic" else "supercluster"
ckpt_path = f"{server}/AI-Labs/superpoint_transformer/ckpt/{exp}_{dataset}.ckpt"
cfg = init_config(overrides=[f"experiment={experiment}/{dataset}", f"datamodule.load_full_res_idx={True}"])
transforms_dict = instantiate_datamodule_transforms(cfg.datamodule)
model = hydra.utils.instantiate(cfg.model)._load_from_checkpoint(ckpt_path).eval()
project_path = "/home/binahlab/AI-Labs/clever-data/electrical-elements/data/celsia/Tolima-2024_08_28/2024_07_12"

In [None]:
pointcloud_dir = Path(project_path) / "raw" / "pointclouds"
las_paths = sorted(list(pointcloud_dir.glob("*.las")) + list(pointcloud_dir.glob("*.laz")))

for file_path in las_paths:
    print(f"Procesando: {file_path.name}")
    
    try:
        data, pos_offset, scale = read_ctd360_tile(str(file_path), load_semantic=False, load_instance=False)

        run_inference_and_save_las(
            model=model,
            data=data,
            file_path=str(file_path),
            project_path=project_path,
            transforms_dict=transforms_dict,
            cfg=cfg,
            exp=exp,
            dataset=dataset,
            class_names=class_names,
            class_colors=class_colors,
            pos_offset=pos_offset,
            scale=scale,
            visualize=False
        )

    except Exception as e:
        print(f"Error procesando {file_path.name}: {e}")
