In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
from tqdm import tqdm
import h5py
import numpy as np
import torch
import os
from torch.utils.data import Dataset, DataLoader, Subset
from scipy.spatial import cKDTree
from google.colab import drive
import shutil
import time
from datetime import datetime
import gc
import zipfile
from pathlib import Path
import struct

In [None]:
def normalize_to_sphere(pc):
    centroid = np.mean(pc, axis=0)
    pc_centered = pc - centroid
    max_distance = np.max(np.linalg.norm(pc_centered, axis=1))
    if max_distance == 0:
        return pc_centered
    pc_normalized = pc_centered / max_distance
    return pc_normalized

In [None]:
# DATASET

class ModelNet40(Dataset):
    def __init__(self, root_dir, file_list, mode='train', corruption_rate=0.5, num_points=2048,
                 use_curriculum=True, use_augmentation=True):
        self.mode = mode
        self.num_points = num_points
        self.corruption_rate = corruption_rate
        self.use_curriculum = use_curriculum
        self.use_augmentation = use_augmentation
        self.current_epoch = 0
        self.data = []

        with open(os.path.join(root_dir, file_list), 'r') as f:
            h5_files = [os.path.join(root_dir, line.strip()) for line in f]

        for h5_file in h5_files:
            with h5py.File(h5_file, 'r') as f:
                pcds = f['data'][:].astype('float32')
                for pcd in pcds:
                    pcd_normalized = normalize_to_sphere(pcd)
                    self.data.append(pcd_normalized)

        self.corrupted_dir = os.path.join(root_dir, "corrupted_dataset")
        os.makedirs(self.corrupted_dir, exist_ok=True)


    def __len__(self):
        return len(self.data)


    def set_epoch(self, epoch):
        self.current_epoch = epoch


    def get_current_corruption_rate(self):
        if not self.use_curriculum:
            return self.corruption_rate

        initial_rate = 0.3
        final_rate = 0.9
        transition_epochs = 100

        if self.current_epoch < transition_epochs:
            rate = initial_rate + (final_rate - initial_rate) * (self.current_epoch / transition_epochs)
        else:
            rate = final_rate

        return rate


    def augment_pointcloud(self, pc):
        if not self.use_augmentation or self.mode != 'encoder':
            return pc

        pc = pc.copy()

        # Rotação
        if np.random.rand() < 0.7:
            theta = np.random.uniform(0, 2 * np.pi)
            cos_theta = np.cos(theta)
            sin_theta = np.sin(theta)
            rot_matrix = np.array([
                [cos_theta, -sin_theta, 0],
                [sin_theta, cos_theta, 0],
                [0, 0, 1]
            ], dtype=np.float32)
            pc = pc @ rot_matrix.T

        if np.random.rand() < 0.3:
            angle_x = np.random.uniform(-np.pi/12, np.pi/12)
            angle_y = np.random.uniform(-np.pi/12, np.pi/12)

            rot_x = np.array([
                [1, 0, 0],
                [0, np.cos(angle_x), -np.sin(angle_x)],
                [0, np.sin(angle_x), np.cos(angle_x)]
            ], dtype=np.float32)
            rot_y = np.array([
                [np.cos(angle_y), 0, np.sin(angle_y)],
                [0, 1, 0],
                [-np.sin(angle_y), 0, np.cos(angle_y)]
            ], dtype=np.float32)

            pc = pc @ rot_x.T @ rot_y.T

        # Ruído, scaling e flipping
        if np.random.rand() < 0.5:
            jitter = np.random.normal(0, 0.01, pc.shape).astype(np.float32)
            pc += jitter

        if np.random.rand() < 0.3:
            scale = np.random.uniform(0.95, 1.05)
            pc *= scale

        if np.random.rand() < 0.2:
            pc[:, 0] *= -1

        return pc


    def corrupt_pointcloud(self, pc, idx, corruption_rate=None):
        if corruption_rate is None:
            corruption_rate = self.get_current_corruption_rate()

        seed = idx + self.current_epoch * 100000
        rng = np.random.default_rng(seed)
        num_points = pc.shape[0]
        num_to_remove = int(corruption_rate * num_points)

        if num_to_remove == 0:
            return pc

        tree = cKDTree(pc)
        mask = np.ones(num_points, dtype=bool)
        available_indices = np.arange(num_points)
        removed = 0

        while removed < num_to_remove and len(available_indices) > 0:
            seed_idx = rng.choice(available_indices)
            K = rng.integers(5, 64)
            distances, neighbor_indices = tree.query(pc[seed_idx], k=min(K, len(available_indices)))

            if not isinstance(neighbor_indices, np.ndarray):
                neighbor_indices = [neighbor_indices]

            to_remove = []
            for n_idx in neighbor_indices:
                if mask[n_idx]:
                    to_remove.append(n_idx)
                    removed += 1
                    if removed >= num_to_remove:
                        break

            mask[to_remove] = False
            available_indices = np.where(mask)[0]

        corrupted_pc = pc[mask]
        return corrupted_pc


    # FUNÇÃO BIG IMPORTANTE
    def __getitem__(self, idx):
        original_pc = self.data[idx]

        if self.use_augmentation and self.mode == 'encoder':
            original_pc = self.augment_pointcloud(original_pc)

        if self.mode == 'encoder':
            corrupted_pc = self.corrupt_pointcloud(original_pc, idx)

            padded_pc = np.zeros((self.num_points, 3), dtype=np.float32)
            padded_pc[:corrupted_pc.shape[0]] = corrupted_pc
            return torch.tensor(padded_pc).float(), torch.tensor(original_pc).float()
        else:
            return torch.tensor(original_pc).float()

In [None]:
def knn_grouping(xyz, features, k=16):
    B, N, C = features.shape

    dist = torch.cdist(xyz, xyz)
    _, idx = torch.topk(dist, k, dim=-1, largest=False)

    idx_expanded = idx.unsqueeze(-1).expand(-1, -1, -1, C)
    features_expanded = features.unsqueeze(2).expand(-1, -1, k, -1)

    batch_idx = torch.arange(B, device=features.device).view(B, 1, 1, 1).expand(B, N, k, C)
    feat_idx = torch.arange(C, device=features.device).view(1, 1, 1, C).expand(B, N, k, C)

    neighbor_features = features[batch_idx, idx_expanded, feat_idx]

    return neighbor_features, idx

In [None]:
class LocalAggregation(nn.Module):
    def __init__(self, dim=128, k=16):
        super().__init__()
        self.k = k

        # attention
        self.query_conv = nn.Conv1d(dim, dim, 1)
        self.key_conv = nn.Conv1d(dim, dim, 1)
        self.value_conv = nn.Conv1d(dim, dim, 1)

        self.pos_mlp = nn.Sequential(
            nn.Conv2d(3, dim, 1),
            nn.ReLU(),
            nn.Conv2d(dim, dim, 1)
        )

        self.out_conv = nn.Conv1d(dim, dim, 1)


    def forward(self, xyz, features):
        B, N, C = features.shape
        feat = features.transpose(1, 2)
        neighbor_features, idx = knn_grouping(xyz, features, self.k)

        batch_idx = torch.arange(B, device=xyz.device).view(B, 1, 1, 1).expand(B, N, self.k, 3)
        idx_xyz = idx.unsqueeze(-1).expand(-1, -1, -1, 3)
        pos_idx = torch.arange(3, device=xyz.device).view(1, 1, 1, 3).expand(B, N, self.k, 3)
        neighbor_xyz = xyz[batch_idx, idx_xyz, pos_idx]

        pos_rel = (xyz.unsqueeze(2) - neighbor_xyz).permute(0, 3, 1, 2)
        pos_enc = self.pos_mlp(pos_rel)

        q = self.query_conv(feat).unsqueeze(-1)
        k = neighbor_features.permute(0, 3, 1, 2) + pos_enc
        v = neighbor_features.permute(0, 3, 1, 2) + pos_enc

        attn = torch.sum(q * k, dim=1, keepdim=True) / (C ** 0.5)
        attn = F.softmax(attn, dim=-1)

        out = torch.sum(attn * v, dim=-1)
        out = self.out_conv(out)

        return out.transpose(1, 2) + features

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=512):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )

        self.local_agg = LocalAggregation(dim=128, k=16)

        self.conv3 = nn.Sequential(
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
        self.conv4 = nn.Sequential(
            nn.Conv1d(256, latent_dim, 1),
            nn.BatchNorm1d(latent_dim),
            nn.ReLU()
        )


    def forward(self, x):
        xyz = x

        x = self.conv1(x.transpose(1, 2))
        x = self.conv2(x)

        x = self.local_agg(xyz, x.transpose(1, 2))

        x = self.conv3(x.transpose(1, 2))
        point_feat = x.transpose(1, 2)

        x = self.conv4(x)

        global_feat = torch.max(x, dim=2)[0]

        return global_feat, point_feat

In [None]:
class CoarseSeedGenerator(nn.Module):
    def __init__(self, latent_dim=512, num_coarse=256):
        super().__init__()
        self.num_coarse = num_coarse

        self.mlp = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, num_coarse * 3)
        )


    def forward(self, global_feat):
        coarse = self.mlp(global_feat)
        return coarse.view(-1, self.num_coarse, 3)

In [None]:
class Upsampler(nn.Module):
    def __init__(self, feat_dim=512, up_factor=2):
        super().__init__()
        self.up_factor = up_factor

        self.feat_mlp = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.ReLU()
        )

        self.combine = nn.Sequential(
            nn.Conv1d(128 + feat_dim, 256, 1),
            nn.ReLU(),
            nn.Conv1d(256, 128, 1),
            nn.ReLU()
        )

        self.local_agg = LocalAggregation(dim=128, k=8)

        self.up_mlp = nn.Sequential(
            nn.Conv1d(128, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 3 * up_factor, 1)
        )


    def forward(self, xyz, global_feat):
        B, N, _ = xyz.shape
        feat = self.feat_mlp(xyz.transpose(1, 2))

        global_expanded = global_feat.unsqueeze(2).expand(-1, -1, N)

        combined = self.combine(torch.cat([feat, global_expanded], dim=1))

        refined = self.local_agg(xyz, combined.transpose(1, 2))

        offsets = self.up_mlp(refined.transpose(1, 2))
        offsets = offsets.view(B, 3, self.up_factor, N).permute(0, 3, 2, 1)

        xyz_expanded = xyz.unsqueeze(2).expand(-1, -1, self.up_factor, -1)
        new_xyz = xyz_expanded + 0.1 * torch.tanh(offsets)

        return new_xyz.reshape(B, N * self.up_factor, 3)

In [None]:
class MiniCRAPCN(nn.Module):
    def __init__(self, num_points=2048):
        super().__init__()

        self.encoder = Encoder(latent_dim=512)

        self.seed_generator = CoarseSeedGenerator(latent_dim=512, num_coarse=256)

        self.up1 = Upsampler(feat_dim=512, up_factor=2)
        self.up2 = Upsampler(feat_dim=512, up_factor=2)
        self.up3 = Upsampler(feat_dim=512, up_factor=2)


    def forward(self, partial):
        global_feat, _ = self.encoder(partial)

        coarse = self.seed_generator(global_feat)

        fine1 = self.up1(coarse, global_feat)
        fine2 = self.up2(fine1, global_feat)
        fine3 = self.up3(fine2, global_feat)

        return [coarse, fine1, fine2, fine3]

In [None]:
def remove_padding(padded_pc):
    batch_list = []

    for pc in padded_pc:
        mask = torch.sum(torch.abs(pc), dim=1) > 1e-6
        valid_pc = pc[mask]
        if len(valid_pc) == 0:
            valid_pc = pc[:1]
        batch_list.append(valid_pc)
    return batch_list


def collate_fn_remove_padding(batch):
    padded_partials, completes = zip(*batch)

    completes = torch.stack(completes, dim=0)
    partials_list = remove_padding(torch.stack(padded_partials, dim=0))

    return partials_list, completes


def chamfer_distance(pred, gt):
    B = pred.size(0)
    total_loss = 0

    for i in range(B):
        p = pred[i:i+1]
        g = gt[i:i+1]

        diff = p.unsqueeze(2) - g.unsqueeze(1)
        dist = torch.sum(diff ** 2, dim=3)

        dist1, _ = torch.min(dist, dim=2)

        dist2, _ = torch.min(dist, dim=1)

        total_loss += torch.mean(dist1) + torch.mean(dist2)

    return total_loss / B


def multi_scale_loss(pred_list, gt):
    weights = [2.0, 1.5, 1.0, 0.5]
    total_loss = 0

    for pred, w in zip(pred_list, weights):
        loss = chamfer_distance(pred, gt)
        total_loss += w * loss

    return total_loss

In [None]:
def train_mini_crapcn(
    model,
    train_loader,
    val_loader,
    num_epochs=200,
    learning_rate=0.0001,
    checkpoint_dir='/content/drive/MyDrive/checkpoints',
    device='cuda',
    save_frequency=10,
    val_frequency=5
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    print(f"{'='*60}")
    print(f"Training MiniCRAPCN")
    print(f"Device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    print(f"Checkpoint dir: {checkpoint_dir}")
    print(f"{'='*60}\n")

    for epoch in range(num_epochs):
        if hasattr(train_loader.dataset, 'set_epoch'):
            train_loader.dataset.set_epoch(epoch)
            current_corruption = train_loader.dataset.get_current_corruption_rate()
            print(f"Epoch {epoch+1}/{num_epochs} - Corruption Rate: {current_corruption:.2%}")

        model.train()
        epoch_loss = 0
        epoch_start = time.time()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, (partials_list, completes) in enumerate(pbar):
            completes = completes.to(device)
            batch_size = len(partials_list)
            batch_loss = 0

            for i in range(batch_size):
                partial = partials_list[i].unsqueeze(0).to(device)
                complete = completes[i:i+1]
                pred_list = model(partial)

                loss = multi_scale_loss(pred_list, complete)
                batch_loss += loss

            batch_loss = batch_loss / batch_size

            optimizer.zero_grad()
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            epoch_loss += batch_loss.item()

            pbar.set_postfix({
                'loss': f'{batch_loss.item():.6f}',
                'avg_loss': f'{epoch_loss/(batch_idx+1):.6f}'
            })

        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        epoch_time = time.time() - epoch_start

        print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.6f} - Time: {epoch_time:.1f}s - LR: {scheduler.get_last_lr()[0]:.6f}")

        if (epoch + 1) % val_frequency == 0 and val_loader is not None:
            val_loss = validate(model, val_loader, device)
            val_losses.append(val_loss)
            print(f"Validation Loss: {val_loss:.6f}")

            # best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_path = os.path.join(checkpoint_dir, 'best_model.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'train_loss': avg_train_loss
                }, save_path)
                print(f"✓ Saved best model (val_loss: {val_loss:.6f})")

        if (epoch + 1) % save_frequency == 0:
            save_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'scheduler_state_dict': scheduler.state_dict()
            }, save_path)
            print(f"✓ Saved checkpoint")

        scheduler.step()
        print()

    print(f"{'='*60}")
    print(f"Training completed!")
    print(f"Best validation loss: {best_val_loss:.6f}")
    print(f"{'='*60}")

    return train_losses, val_losses


def validate(model, val_loader, device):
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for partials_list, completes in val_loader:
            completes = completes.to(device)
            batch_size = len(partials_list)

            batch_loss = 0
            for i in range(batch_size):
                partial = partials_list[i].unsqueeze(0).to(device)
                complete = completes[i:i+1]

                pred_list = model(partial)
                loss = multi_scale_loss(pred_list, complete)
                batch_loss += loss

            val_loss += (batch_loss / batch_size).item()

    return val_loss / len(val_loader)

In [None]:
# Paths

drive.mount('/content/drive')

drive_zip_path = "/content/drive/MyDrive/ADONestDataset/data.zip"
local_zip_path = "/content/data.zip"
local_data_root = "/content/data/modelnet40_ply_hdf5_2048"
extraction_path = "/content/"

print("Verificando o dataset...")
if not os.path.exists(local_data_root):
    print(f"Dataset não encontrado em {local_data_root}.")

    print(f"Copiando {drive_zip_path} para a VM local...")
    shutil.copy(drive_zip_path, local_zip_path)
    print("Cópia do ZIP concluída.")

    print(f"Extraindo {local_zip_path} para {extraction_path}...")
    !unzip -q {local_zip_path} -d {extraction_path}
    print("Extração concluída.")

    print(f"Removendo arquivo ZIP local: {local_zip_path}")
    os.remove(local_zip_path)
    print("Limpeza concluída.")
else:
    print(f"Dataset já existe em {local_data_root}")

train_file_list = "train_files.txt"
val_file_list = "test_files.txt"

In [None]:
# Hiperparâmetros

BATCH_SIZE = 128 + 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(DEVICE)
NUM_WORKERS = 4
CORRUPTION_RATE = 0.5
NUM_EPOCHS = 200
LEARNING_RATE = 0.0001
NUM_POINTS = 2048
CHECKPOINT_DIR = '/content/drive/MyDrive/checkpoints_new_v1_2'

In [None]:
# Instancia datasets

train_dataset = ModelNet40(
    root_dir=local_data_root,
    file_list='train_files.txt',
    mode='encoder',
    corruption_rate=0.5,
    num_points=NUM_POINTS,
    use_curriculum=True,
    use_augmentation=True
)

val_dataset = ModelNet40(
    root_dir=local_data_root,
    file_list='test_files.txt',
    mode='encoder',
    corruption_rate=0.7,
    num_points=NUM_POINTS,
    use_curriculum=False,
    use_augmentation=False
)

val_size = len(train_dataset)
val_size = val_size - test_size

indices = np.random.permutation(len(val_dataset))
val_indices = indices[:val_size]

train_dataset = train_dataset
val_dataset = Subset(val_dataset, val_indices)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Device: {DEVICE}")

In [None]:
# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn_remove_padding,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn_remove_padding,
    pin_memory=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Device: {DEVICE}")

In [None]:
# Clear cache

gc.collect()
torch.cuda.empty_cache()

In [None]:
model = MiniCRAPCN(num_points=NUM_POINTS)

# Train
train_losses, val_losses = train_mini_crapcn(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    checkpoint_dir=CHECKPOINT_DIR,
    device=device,
    save_frequency=2,
    val_frequency=5
)

In [None]:
# Plot results

try:
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    if val_losses:
        val_epochs = [i * 5 for i in range(len(val_losses))]
        plt.plot(val_epochs, val_losses, label='Val Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Progress')
    plt.grid(True)
    plt.savefig(os.path.join(CHECKPOINT_DIR, 'training_curve.png'))
    plt.show()
    print(f"Saved training curve to {CHECKPOINT_DIR}/training_curve.png")
except:
    print("Could not plot (matplotlib not available or not in notebook)")

In [None]:
def save_ply(points: np.ndarray, filepath: str):
    N = points.shape[0]

    header = f"""ply
format binary_little_endian 1.0
element vertex {N}
property float x
property float y
property float z
end_header
"""

    with open(filepath, 'wb') as f:
        f.write(header.encode('ascii'))
        for point in points:
            f.write(struct.pack('fff', point[0], point[1], point[2]))


def inference_to_zip(
    checkpoint_path: str,
    test_loader: torch.utils.data.DataLoader,
    output_zip: str = 'predictions_mini_crapcn.zip',
    device: str = 'cuda',
    use_finest: bool = True,
    save_all_scales: bool = False
):

    model = MiniCRAPCN(num_points=2048).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    epoch = checkpoint.get('epoch', 'unknown')
    val_loss = checkpoint.get('val_loss', checkpoint.get('train_loss', 'unknown'))
    print(f" Epoch: {epoch}")
    print(f"  Checkpoint loss: {val_loss}")

    temp_dir = Path('temp_ply_crapcn')
    temp_dir.mkdir(exist_ok=True)

    sample_idx = 0

    print(f"  Saving {'all scales' if save_all_scales else 'finest scale only'}")

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Processing"):
            if isinstance(batch, (tuple, list)):
                if isinstance(batch[0], list):
                    partials_list = batch[0]
                    batch_size = len(partials_list)
                else:
                    partials = batch[0].to(device)
                    batch_size = partials.shape[0]
                    partials_list = [partials[i] for i in range(batch_size)]
            else:
                partials = batch.to(device)
                batch_size = partials.shape[0]
                partials_list = [partials[i] for i in range(batch_size)]

            for i in range(batch_size):
                partial = partials_list[i].unsqueeze(0).to(device)

                pred_list = model(partial)

                if save_all_scales:
                    scale_names = ['coarse_256', 'fine1_512', 'fine2_1024', 'fine3_2048']
                    for scale_idx, (pred, scale_name) in enumerate(zip(pred_list, scale_names)):
                        pred_points = pred[0].cpu().numpy()
                        pred_file = temp_dir / f'prediction_{sample_idx:06d}_{scale_name}.ply'
                        save_ply(pred_points, str(pred_file))
                else:
                    if use_finest:
                        pred_points = pred_list[-1][0].cpu().numpy()
                    else:
                        pred_points = pred_list[0][0].cpu().numpy()

                    pred_file = temp_dir / f'prediction_{sample_idx:06d}.ply'
                    save_ply(pred_points, str(pred_file))

                sample_idx += 1

    with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
        ply_files = sorted(temp_dir.glob('*.ply'))
        for ply_file in tqdm(ply_files, desc="Zipping"):
            zipf.write(ply_file, ply_file.name)

    shutil.rmtree(temp_dir)

    zip_size = Path(output_zip).stat().st_size / (1024**2)
    files_per_sample = 4 if save_all_scales else 1
    print(f"\n Inference complete!")
    print(f"  Samples processed: {sample_idx}")
    print(f"  Files created: {sample_idx * files_per_sample}")
    print(f"  Output: {output_zip}")
    print(f"  Zip size: {zip_size:.1f} MB")

    return sample_idx


def quick_visual_check(
    checkpoint_path: str,
    test_loader: torch.utils.data.DataLoader,
    num_samples: int = 5,
    device: str = 'cuda'
):

    model = MiniCRAPCN(num_points=2048).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(" Quick visual check (first few samples):\n")

    with torch.no_grad():
        batch = next(iter(test_loader))

        if isinstance(batch, (tuple, list)):
            if isinstance(batch[0], list):
                partials_list = batch[0][:num_samples]
            else:
                partials = batch[0].to(device)[:num_samples]
                partials_list = [partials[i] for i in range(min(num_samples, partials.shape[0]))]
        else:
            partials = batch.to(device)[:num_samples]
            partials_list = [partials[i] for i in range(min(num_samples, partials.shape[0]))]

        for i, partial in enumerate(partials_list):
            partial = partial.unsqueeze(0).to(device)
            pred_list = model(partial)

            print(f"Sample {i+1}:")
            print(f"  Input shape: {partial.shape} ({partial.shape[1]} points)")
            print(f"  Coarse (256):  mean={pred_list[0][0].mean():.3f}, std={pred_list[0][0].std():.3f}")
            print(f"  Fine1 (512):   mean={pred_list[1][0].mean():.3f}, std={pred_list[1][0].std():.3f}")
            print(f"  Fine2 (1024):  mean={pred_list[2][0].mean():.3f}, std={pred_list[2][0].std():.3f}")
            print(f"  Fine3 (2048):  mean={pred_list[3][0].mean():.3f}, std={pred_list[3][0].std():.3f}")

            if torch.isnan(pred_list[-1]).any():
                print("WARNING: NaN values detected!")
            if torch.isinf(pred_list[-1]).any():
                print("WARNING: Inf values detected!")

            print()

In [None]:
# Inference

print("\n" + "="*60)
print("STEP 2: Full Inference")
print("="*60)
num_samples = inference_to_zip(
    checkpoint_path='best_model_v1_2.pth',
    test_loader=val_loader,
    output_zip='predictions_mini_crapcn1_2.zip',
    device='cuda',
    use_finest=True,
    save_all_scales=False
)

print(f"\n All done! Processed {num_samples} samples")
