<a href="https://colab.research.google.com/github/fraco03/6D_pose/blob/pose_rgb/notebooks/pose_rgb/pose_rgb_pointnet_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import sys

# Clone or pull part
repo_url = "https://github.com/fraco03/6D_pose.git"
repo_dir = "/content/6D_pose"
branch = "pose_rgb"

if not os.path.exists(repo_dir):
    !git clone -b {branch} {repo_url}
    print(f"Cloned {repo_url}")
else:
    %cd {repo_dir}
    !git fetch origin
    !git checkout {branch}
    !git reset --hard origin/{branch}
    %cd ..
    print(f"Updated {repo_url}")

if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)

%cd 6D_pose

Cloning into '6D_pose'...
remote: Enumerating objects: 349, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 349 (delta 1), reused 4 (delta 1), pack-reused 334 (from 1)[K
Receiving objects: 100% (349/349), 5.57 MiB | 29.39 MiB/s, done.
Resolving deltas: 100% (167/167), done.
Cloned https://github.com/fraco03/6D_pose.git
/content/6D_pose


In [2]:
from google.colab import drive
from utils.load_data import mount_drive

mount_drive()

dataset_root = "/content/drive/MyDrive/Linemod_preprocessed"
print(f"\n‚úÖ Setup complete!")
print(f"üìÅ Dataset path: {dataset_root}")

Mounted at /content/drive
‚úÖ Drive mounted at /content/drive

‚úÖ Setup complete!
üìÅ Dataset path: /content/drive/MyDrive/Linemod_preprocessed


In [3]:
!pip install plyfile

Collecting plyfile
  Downloading plyfile-1.1.3-py3-none-any.whl.metadata (43 kB)
[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/43.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m43.3/43.3 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Downloading plyfile-1.1.3-py3-none-any.whl (36 kB)
Installing collected packages: plyfile
Successfully installed plyfile-1.1.3


In [4]:
from src.pose_rgb.pointcloud_dataset import LineModPointCloudDataset
from src.pose_rgb.pointnet_model import PointNetPose
from src.pose_rgb.loss import AutomaticWeightedLoss
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"üî• Using device: {DEVICE}")

üî• Using device: cuda


## üìù Note sulla Loss Function

**Differenza tra PointNet e RGB approach:**

- **RGB (ResNet + TranslationNet)**:
  - Predice `[dx, dy, log(z)]` ‚Üí poi pinhole projection ‚Üí `[X, Y, Z]`
  - Usa `DisentangledTranslationLoss`: separa XY da Z
  - Ha senso perch√© XY dipendono dalla geometria pinhole, Z √® indipendente

- **PointNet**:
  - Predice **direttamente** `[X, Y, Z]` dalla point cloud
  - Usa loss **unificata** per translation: tratta X, Y, Z simmetricamente
  - Non c'√® pinhole projection, quindi non ha senso separare XY da Z

## ‚ö° Performance Optimizations

**Ottimizzazioni applicate per velocizzare il training:**

1. **Riduzione punti**: 1024 ‚Üí 512 punti per point cloud (~2x speed-up)
2. **Cached YAML files**: `linemod_config` cacha `info.yml` e `gt.yml` invece di aprirli ogni volta (~10-20x speed-up!)
   - **Critico su Google Drive**: I/O latency √® molto alta, caching essenziale
3. **Torch sampling**: `torch.randperm` invece di `np.random.choice` (~1.5x speed-up)
4. **Mixed precision**: FP16/FP32 automatico con `torch.cuda.amp` (~1.5x speed-up)
5. **Batch size**: Aumentato da 32 ‚Üí 64 per migliore GPU utilization
6. **DataLoader**: `num_workers=4` + `pin_memory=True` per I/O parallelo

**Speed-up totale stimato: ~20-30x** rispetto alla versione iniziale! üöÄ

**Nota**: Il primo batch pu√≤ richiedere pi√π tempo per caricare e cachare tutti i YAML files, poi diventa molto pi√π veloce.

In [5]:
# Crea dataset con point clouds
# use_rgb=True -> point cloud con 6 canali [x,y,z,r,g,b]
# use_rgb=False -> point cloud con 3 canali [x,y,z]

train_dataset = LineModPointCloudDataset(
    root_dir=dataset_root,
    split='train',
    num_points=512,  # Ridotto per speed-up (era 1024)
    use_rgb=True      # Include RGB
)

test_dataset = LineModPointCloudDataset(
    root_dir=dataset_root,
    split='test',
    num_points=512,
    use_rgb=True
)

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

‚úÖ LineModConfig initialized: /content/drive/.shortcut-targets-by-id/1sm0RTmd3Q00Uw99X5cErD8SxrnPrGS6v/Linemod_preprocessed
üìä Loaded 2373 samples for train split
üìä Loaded 13407 samples for test split
Train samples: 2373
Test samples: 13407


In [6]:
# Visualizza un sample
sample = train_dataset[0]

print("Sample keys:", sample.keys())
print(f"Point cloud shape: {sample['point_cloud'].shape}")  # (1024, 6)
print(f"Rotation shape: {sample['rotation'].shape}")        # (4,)
print(f"Translation shape: {sample['translation'].shape}")  # (3,)
print(f"\nRotation (quat): {sample['rotation']}")
print(f"Translation (m): {sample['translation']}")

Sample keys: dict_keys(['point_cloud', 'rotation', 'translation', 'object_id', 'img_id', 'cam_K', 'bbox'])
Point cloud shape: torch.Size([512, 6])
Rotation shape: (4,)
Translation shape: torch.Size([3])

Rotation (quat): [ 0.33261785  0.64730227  0.6364495  -0.2555329 ]
Translation (m): tensor([-0.1036, -0.0498,  1.0251])


In [7]:
# DataLoaders - Ottimizzati per performance
train_loader = DataLoader(
    train_dataset,
    batch_size=64,  # Aumentato da 32 per migliore GPU utilization
    shuffle=True,
    num_workers=2,  # Aumentato da 2
    pin_memory=True  # Velocizza transfer CPU->GPU
)

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

In [8]:
import os
import json
from datetime import datetime
from itertools import islice
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler

# ==========================================
# HYPERPARAMETERS
# ==========================================
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
USE_MIXED_PRECISION = True  # Mixed precision for speed-up

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
CHECKPOINT_DIR = f'/content/drive/MyDrive/runs/pointnet_{timestamp}'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Initialize PointNet Model
# input_channels=6 because we use [x,y,z,r,g,b]
model = PointNetPose(input_channels=6, use_batch_norm=True).to(DEVICE)

# Loss for PointNet
criterion = AutomaticWeightedLoss(use_disentangled=False).to(DEVICE)

# Optimizer
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(criterion.parameters()),
    lr=LEARNING_RATE
)

# Mixed Precision Scaler
scaler = GradScaler('cuda') if USE_MIXED_PRECISION else None

train_losses = []
val_losses = []
best_val_loss = float('inf')
best_epoch = 0

print(f"\nüî• STARTING POINTNET TRAINING on {DEVICE}...")
print(f"üìÅ Checkpoints: {CHECKPOINT_DIR}")
print(f"‚öôÔ∏è  Loss mode: Unified Translation (no XY/Z separation)")
print(f"‚ö° Mixed Precision: {USE_MIXED_PRECISION}")


üî• STARTING POINTNET TRAINING on cuda...
üìÅ Checkpoints: /content/drive/MyDrive/runs/pointnet_20251217_132025
‚öôÔ∏è  Loss mode: Unified Translation (no XY/Z separation)
‚ö° Mixed Precision: True


In [9]:
# ==========================================
# TRAINING LOOP (con Mixed Precision)
# ==========================================
for epoch in range(NUM_EPOCHS):

    # --- TRAIN PHASE ---
    model.train()
    running_train_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")

    for batch in pbar:
        # Move to device
        point_clouds = batch['point_cloud'].to(DEVICE, non_blocking=True)  # (B, N, 6)
        gt_rot = batch['rotation'].to(DEVICE, non_blocking=True)           # (B, 4)
        gt_trans = batch['translation'].to(DEVICE, non_blocking=True)      # (B, 3) in meters

        optimizer.zero_grad()

        # Mixed Precision Forward + Backward
        if USE_MIXED_PRECISION:
            with autocast('cuda'):
                pred_rot, pred_trans = model(point_clouds)
                loss, l_r, l_t, _ = criterion(pred_rot, gt_rot, pred_trans, gt_trans)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            pred_rot, pred_trans = model(point_clouds)
            loss, l_r, l_t, _ = criterion(pred_rot, gt_rot, pred_trans, gt_trans)
            loss.backward()
            optimizer.step()

        running_train_loss += loss.item()

        pbar.set_postfix({
            'L_Tot': f"{loss.item():.2f}",
            'L_Rot': f"{l_r.item():.2f}",
            'L_Trans': f"{l_t.item():.3f}"
        })

    avg_train_loss = running_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # --- VALIDATION PHASE ---
    model.eval()
    running_val_loss = 0.0
    val_batches_limit = 50
    count_batches = 0

    with torch.no_grad():
        val_iterator = islice(test_loader, val_batches_limit)
        val_pbar = tqdm(val_iterator, total=val_batches_limit, desc="Validating")

        for batch in val_pbar:
            point_clouds = batch['point_cloud'].to(DEVICE, non_blocking=True)
            gt_rot = batch['rotation'].to(DEVICE, non_blocking=True)
            gt_trans = batch['translation'].to(DEVICE, non_blocking=True)

            if USE_MIXED_PRECISION:
                with autocast('cuda'):
                    pred_rot, pred_trans = model(point_clouds)
                    loss, _, _, _ = criterion(pred_rot, gt_rot, pred_trans, gt_trans)
            else:
                pred_rot, pred_trans = model(point_clouds)
                loss, _, _, _ = criterion(pred_rot, gt_rot, pred_trans, gt_trans)

            running_val_loss += loss.item()
            count_batches += 1

    avg_val_loss = running_val_loss / count_batches if count_batches > 0 else 0.0
    val_losses.append(avg_val_loss)

    # --- REPORT & SAVE ---
    print(f"üìä Epoch {epoch+1}: Train={avg_train_loss:.4f} | Val={avg_val_loss:.4f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch + 1

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'criterion_state_dict': criterion.state_dict(),
            'val_loss': best_val_loss
        }, os.path.join(CHECKPOINT_DIR, "best_model.pth"))

        print(f"üèÜ New Best Model! (Loss: {best_val_loss:.4f})")

    # Save last checkpoint
    if (epoch + 1) == NUM_EPOCHS:
        torch.save({
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'criterion_state_dict': criterion.state_dict(),
        }, os.path.join(CHECKPOINT_DIR, f"checkpoint_ep{epoch+1}.pth"))

print("\nüéâ TRAINING COMPLETE!")

Epoch 1/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [18:01<00:00, 28.45s/it, L_Tot=0.56, L_Rot=0.34, L_Trans=0.226]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [24:19<00:00, 29.19s/it]


üìä Epoch 1: Train=0.8885 | Val=0.2928
üèÜ New Best Model! (Loss: 0.2928)


Epoch 2/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:55<00:00,  3.04s/it, L_Tot=0.52, L_Rot=0.49, L_Trans=0.040]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:57<00:00,  2.34s/it]


üìä Epoch 2: Train=0.4411 | Val=0.2863
üèÜ New Best Model! (Loss: 0.2863)


Epoch 3/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:54<00:00,  3.01s/it, L_Tot=0.49, L_Rot=0.46, L_Trans=0.040]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:49<00:00,  2.20s/it]


üìä Epoch 3: Train=0.3790 | Val=0.2810
üèÜ New Best Model! (Loss: 0.2810)


Epoch 4/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:53<00:00,  2.98s/it, L_Tot=0.39, L_Rot=0.39, L_Trans=0.025]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:46<00:00,  2.14s/it]


üìä Epoch 4: Train=0.3558 | Val=0.2652
üèÜ New Best Model! (Loss: 0.2652)


Epoch 5/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.22, L_Rot=0.23, L_Trans=0.022]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:46<00:00,  2.12s/it]


üìä Epoch 5: Train=0.3371 | Val=0.2618
üèÜ New Best Model! (Loss: 0.2618)


Epoch 6/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.37, L_Rot=0.35, L_Trans=0.058]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:46<00:00,  2.14s/it]


üìä Epoch 6: Train=0.3254 | Val=0.2487
üèÜ New Best Model! (Loss: 0.2487)


Epoch 7/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.54, L_Rot=0.57, L_Trans=0.016]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:47<00:00,  2.15s/it]


üìä Epoch 7: Train=0.3222 | Val=0.2417
üèÜ New Best Model! (Loss: 0.2417)


Epoch 8/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:53<00:00,  2.98s/it, L_Tot=0.60, L_Rot=0.61, L_Trans=0.032]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:42<00:00,  2.05s/it]


üìä Epoch 8: Train=0.3177 | Val=0.2352
üèÜ New Best Model! (Loss: 0.2352)


Epoch 9/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:51<00:00,  2.94s/it, L_Tot=0.37, L_Rot=0.40, L_Trans=0.020]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:41<00:00,  2.03s/it]


üìä Epoch 9: Train=0.2984 | Val=0.2183
üèÜ New Best Model! (Loss: 0.2183)


Epoch 10/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:51<00:00,  2.94s/it, L_Tot=0.11, L_Rot=0.17, L_Trans=0.006]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:42<00:00,  2.05s/it]


üìä Epoch 10: Train=0.2823 | Val=0.2252


Epoch 11/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.36, L_Rot=0.42, L_Trans=0.012]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:42<00:00,  2.05s/it]


üìä Epoch 11: Train=0.2836 | Val=0.1998
üèÜ New Best Model! (Loss: 0.1998)


Epoch 12/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:51<00:00,  2.94s/it, L_Tot=0.15, L_Rot=0.22, L_Trans=0.008]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:41<00:00,  2.02s/it]


üìä Epoch 12: Train=0.2662 | Val=0.1936
üèÜ New Best Model! (Loss: 0.1936)


Epoch 13/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.95s/it, L_Tot=0.31, L_Rot=0.38, L_Trans=0.012]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:42<00:00,  2.06s/it]


üìä Epoch 13: Train=0.2590 | Val=0.1683
üèÜ New Best Model! (Loss: 0.1683)


Epoch 14/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.97s/it, L_Tot=0.55, L_Rot=0.62, L_Trans=0.007]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:42<00:00,  2.05s/it]


üìä Epoch 14: Train=0.2588 | Val=0.1567
üèÜ New Best Model! (Loss: 0.1567)


Epoch 15/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.95s/it, L_Tot=0.16, L_Rot=0.24, L_Trans=0.027]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:43<00:00,  2.08s/it]


üìä Epoch 15: Train=0.2374 | Val=0.1516
üèÜ New Best Model! (Loss: 0.1516)


Epoch 16/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:51<00:00,  2.94s/it, L_Tot=0.26, L_Rot=0.35, L_Trans=0.014]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:41<00:00,  2.03s/it]


üìä Epoch 16: Train=0.2227 | Val=0.1392
üèÜ New Best Model! (Loss: 0.1392)


Epoch 17/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.25, L_Rot=0.35, L_Trans=0.014]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:42<00:00,  2.05s/it]


üìä Epoch 17: Train=0.2234 | Val=0.1587


Epoch 18/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.04, L_Rot=0.17, L_Trans=0.005]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:41<00:00,  2.03s/it]


üìä Epoch 18: Train=0.2004 | Val=0.1101
üèÜ New Best Model! (Loss: 0.1101)


Epoch 19/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.15, L_Rot=0.27, L_Trans=0.008]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:43<00:00,  2.07s/it]


üìä Epoch 19: Train=0.1921 | Val=0.0962
üèÜ New Best Model! (Loss: 0.0962)


Epoch 20/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.97s/it, L_Tot=0.14, L_Rot=0.27, L_Trans=0.003]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:41<00:00,  2.04s/it]


üìä Epoch 20: Train=0.1838 | Val=0.0897
üèÜ New Best Model! (Loss: 0.0897)


Epoch 21/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.09, L_Rot=0.22, L_Trans=0.013]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:43<00:00,  2.06s/it]


üìä Epoch 21: Train=0.1674 | Val=0.0777
üèÜ New Best Model! (Loss: 0.0777)


Epoch 22/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.97s/it, L_Tot=0.11, L_Rot=0.25, L_Trans=0.011]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:41<00:00,  2.04s/it]


üìä Epoch 22: Train=0.1499 | Val=0.1140


Epoch 23/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.14, L_Rot=0.29, L_Trans=0.004]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:43<00:00,  2.07s/it]


üìä Epoch 23: Train=0.1462 | Val=0.0697
üèÜ New Best Model! (Loss: 0.0697)


Epoch 24/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.97s/it, L_Tot=0.22, L_Rot=0.36, L_Trans=0.011]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:41<00:00,  2.04s/it]


üìä Epoch 24: Train=0.1220 | Val=0.0255
üèÜ New Best Model! (Loss: 0.0255)


Epoch 25/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.96s/it, L_Tot=0.31, L_Rot=0.45, L_Trans=0.010]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:43<00:00,  2.07s/it]


üìä Epoch 25: Train=0.1160 | Val=0.0624


Epoch 26/50 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [01:52<00:00,  2.97s/it, L_Tot=0.19, L_Rot=0.35, L_Trans=0.004]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [01:50<00:00,  2.22s/it]


üìä Epoch 26: Train=0.1065 | Val=0.0333


Epoch 27/50 [Train]:   5%|‚ñå         | 2/38 [01:01<18:28, 30.79s/it, L_Tot=0.07, L_Rot=0.24, L_Trans=0.006]


KeyboardInterrupt: 

In [None]:
# Plot training history
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o', alpha=0.7)
plt.plot(val_losses, label='Validation Loss', marker='s', alpha=0.7)
if best_epoch > 0:
    plt.axvline(x=best_epoch-1, color='r', linestyle='--', alpha=0.5,
                label=f'Best Epoch ({best_epoch})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('PointNet Training History')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Train Loss', marker='o', alpha=0.7)
plt.plot(val_losses, label='Validation Loss', marker='s', alpha=0.7)
if best_epoch > 0:
    plt.axvline(x=best_epoch-1, color='r', linestyle='--', alpha=0.5,
                label=f'Best Epoch ({best_epoch})')
plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.yscale('log')
plt.title('PointNet Training History (Log)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, 'training_history.png'), dpi=150)
plt.show()

print(f"\nüìä Training Statistics:")
print(f"   Best epoch: {best_epoch}")
print(f"   Best val loss: {best_val_loss:.6f}")
print(f"   Final train loss: {train_losses[-1]:.6f}")

# Save history
history = {
    'train_losses': [float(x) for x in train_losses],
    'val_losses': [float(x) for x in val_losses],
    'best_epoch': int(best_epoch),
    'best_val_loss': float(best_val_loss),
    'timestamp': timestamp
}

with open(os.path.join(CHECKPOINT_DIR, 'history.json'), 'w') as f:
    json.dump(history, f, indent=2)

## Visualize Point Cloud Sample

Visualizing point cloud from random point in the dataset

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Choose sample
sample = train_dataset[100]
pc = sample['point_cloud'].numpy()  # (512, 6) [x, y, z, r, g, b]

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Plot 3D points with RGB colors
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2],
           c=pc[:, 3:6], s=2, alpha=0.6)

ax.set_xlabel('X (m)')
ax.set_ylabel('Y (m)')
ax.set_zlabel('Z (m)')
ax.set_title(f'Point Cloud - Object {sample["object_id"]}')
plt.show()

print(f"Object ID: {sample['object_id']}")
print(f"Rotation (quat): {sample['rotation']}")
print(f"Translation (m): {sample['translation']}")