| **Name**  | **Surname** | **Student ID** | **Email**                                |
|-----------|-------------|----------------|------------------------------------------|
| Mattia    | Buzzoni     | 0001145667     | mattia.buzzoni@studio.unibo.it           |


# Project Work Deep Learning

## Project Goal

Build a flexible, end-to-end 2D-->3D reconstruction pipeline that transforms multiple rendered views of an object into a high‑fidelity 3D volumetric model. Leveraging the ShapeNet dataset’s paired 2D images and 3D voxel ground truth, this project will:

- **Encode** each view with a deep CNN (ResNet) to extract rich visual features.  
- **Fuse** features from all views at multiple spatial scales, integrating cross-view cues into a unified 3D latent representation.  
- **Decode** the latent code into a dense 32³ voxel grid, then apply a lightweight refinement module to sharpen object surfaces.  
- **Visualize** results interactively: render raw voxels and smooth meshes with Plotly for both qualitative inspection and quantitative evaluation.

The pipeline is category‑agnostic but configurable via `get_config()`. In this experiment, we restrict to the **Chair** category to demonstrate its capability on a single object class, with the option to swap in any other ShapeNet category through one simple parameter change.  


# 1. Dataset Acquisition and Preparation

 Download and extract the ShapeNetRendering images and corresponding ShapeNetVox32 voxel grids. These provide multi-view 2D renders and ground-truth 3D occupancy grids.



*   **Source:** ShapeNet, an extensive repository of 3D models across hundreds of object categories.
*   **2D Views:** Pre-rendered images (ShapeNetRendering) provide 24 evenly spaced viewpoints per model, stored as compressed tar archives (~5.4 GB).
*   **3D Voxels:** Corresponding voxel grids (ShapeNetVox32) give binary occupancy in a 32×32×32 grid, stored as .binvox files (~1.2 GB).

**Actions:**


1.   Download archives via `wget` from Stanford CVGL servers.
2.   Extract locally or from Google Drive using `tar`.
3.   Organize directories: `/content/ShapeNetRendering/` and `/content/ShapeNetVox32/`.

In [1]:
!pip install torch torchvision

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


if you want to download the shapenet dataset from zero

In [None]:
# Approx sizes (compressed):
# - ShapeNetRendering.tgz: ~5.4 GB
# - ShapeNetVox32.tgz: ~1.2 GB

!wget http://cvgl.stanford.edu/data2/ShapeNetRendering.tgz -O ShapeNetRendering.tgz
!wget http://cvgl.stanford.edu/data2/ShapeNetVox32.tgz -O ShapeNetVox32.tgz

# Unzip/tar the files
!tar -xvzf ShapeNetRendering.tgz
!tar -xvzf ShapeNetVox32.tgz

# After extraction:
#  - /content/ShapeNetRendering/...
#  - /content/ShapeNetVox32/...


In [1]:
# Install necessary packages
!pip install trimesh pillow



In [2]:
!tar -xvzf /content/drive/MyDrive/datasetVox/ShapeNetRendering.tgz
!tar -xvzf /content/drive/MyDrive/datasetVox/ShapeNetVox32.tgz

[1;30;43mOutput streaming troncato alle ultime 5000 righe.[0m
ShapeNetVox32/02958343/92c882d35dfca864acee48fc4abca0f4/model.binvox
ShapeNetVox32/02958343/b098f1db2f190a71d61b6a34f3fd808c/
ShapeNetVox32/02958343/b098f1db2f190a71d61b6a34f3fd808c/model.binvox
ShapeNetVox32/02958343/3174a11023f0a6cdd9ebe2632c1ec249/
ShapeNetVox32/02958343/3174a11023f0a6cdd9ebe2632c1ec249/model.binvox
ShapeNetVox32/02958343/c83458f94ae8752f63ee8a34069b7c5/
ShapeNetVox32/02958343/c83458f94ae8752f63ee8a34069b7c5/model.binvox
ShapeNetVox32/02958343/6c6254a92c485787f1ca7626ddabf47/
ShapeNetVox32/02958343/6c6254a92c485787f1ca7626ddabf47/model.binvox
ShapeNetVox32/02958343/8997065fe94841771ef06e9b490109d8/
ShapeNetVox32/02958343/8997065fe94841771ef06e9b490109d8/model.binvox
ShapeNetVox32/02958343/55181c34dadb4f032db09455d18fca0/
ShapeNetVox32/02958343/55181c34dadb4f032db09455d18fca0/model.binvox
ShapeNetVox32/02958343/9901e5d7520df242c61cbe35764dfac1/
ShapeNetVox32/02958343/9901e5d7520df242c61cbe35764dfac1/mode

In [3]:
!pip install plotly



In [4]:
!pip install tensorboard



In [5]:
!pip install tqdm



In [6]:
!pip install trimesh tqdm scikit-image scipy plotly



In [7]:
!pip install trimesh scikit-image scipy plotly matplotlib torch-tb-profiler

Collecting torch-tb-profiler
  Downloading torch_tb_profiler-0.4.3-py3-none-any.whl.metadata (1.4 kB)
Downloading torch_tb_profiler-0.4.3-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-tb-profiler
Successfully installed torch-tb-profiler-0.4.3


## Imports

In [16]:
###############################################
# 0. Imports
###############################################
import os, copy
import numpy as np
from PIL import Image
import trimesh
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from skimage import measure

# 2. Reconstruction Network

## 2.1 Configuration Management

All path settings, model parameters, augmentation transforms, and training hyperparameters are centralized in one place:

- **Data paths**  
  - `image_root`: Root directory for ShapeNet 2D renderings (e.g. `/content/ShapeNetRendering`)  
  - `voxel_root`: Root directory for ShapeNet voxel grids (e.g. `/content/ShapeNetVox32`)  
  - `category`: ShapeNet category ID (`"03001627"` for chairs)  
  - `model_ids`: List of specific object IDs to include in the subset

- **Data augmentation** (`transform`)  
  - Random resizing & cropping to 224×224  
  - Horizontal flips, small rotations, and color jitter  
  - Normalization to ImageNet mean/std

- **Training hyperparameters**  
  - `epochs`: Number of training epochs (e.g. 250)  
  - `batch_size`: Batch size per GPU (e.g. 32)  
  - `lr`: Initial learning rate (e.g. 5e‑4)  
  - `patience`: Early‑stopping patience on validation loss  
  - `num_workers`: Number of background workers for data loading

- **Voxel & visualization settings**  
  - `voxel_size`: Resolution of predicted occupancy grid (32³)  
  - `threshold`: Probability cutoff for marching cubes (e.g. 0.8)  
  - `refine_mesh`: Whether to run Laplacian smoothing on the extracted mesh  
  - `viz_margin`: Padding (in px) around Plotly figures

- **Device & logging**  
  - `device`: Automatically selects CUDA if available  
  - `log_dir`: Directory for TensorBoard logs (e.g. `./runs/3drecon`)  
  - Scheduler warm‑restart period `T_0`  

- **Camera parameters**  
  - `camera_eye` and `camera_up`: Default Plotly 3D camera position for consistent visualizations

By defining these in **one function**, you ensure that every experiment run uses the same settings, making it easy to reproduce results or tweak individual parameters without hunting through the notebook.  


In [9]:
###############################################
# 1. Configuration
###############################################
def get_config():
    return {
        # data
        "image_root":    "/content/ShapeNetRendering",
        "voxel_root":    "/content/ShapeNetVox32",
        "category":      "03001627",  # Chair Category
        "model_ids":     [
            "1006be65e7bc937e9141f9b58470d646",
            "184c07776b38d92e35836c728d324152",
            "114f72b38dcabdf0823f29d871e57676",
            "11358c94662a68117e66b3e5c11f24d4",
            "123b44b99e48eccb6960dc525de2f934",
            "1015e71a0d21b127de03ab2a27ba7531",
            "1016f4debe988507589aae130c1f06fb",
            "10e523060bb5b51f9ee9f382b1dfb770",
            "113016635d554d5171fb733891076ecf",
            "18005751014e6ee9747c474f2e537e26",
            "34dc6c4e022b380cf91663a74ccd2338",
            "3503fedf43c99f0afb63ac4eaa5569d8"
        ],
        # augmentation
        "transform": transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(0.2,0.2,0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],
                                 [0.229,0.224,0.225]),
        ]),
        # training
        "epochs":       250,
        "batch_size":   32,
        "lr":           5e-4,
        "patience":     20,
        "num_workers":  4,
        # voxel / viz
        "voxel_size":   32,
        "threshold":    0.8,
        "refine_mesh":  True,
        # Add default margin for Plotly visualization padding (in px)
        "viz_margin":   30,
        # device
        "device":       torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        # logging / scheduler
        "log_dir":      "./runs/3drecon",
        "T_0":          10,
        # camera (front view, z up)
        "camera_eye":   dict(x=0,   y=-2.5, z=1.5),
        "camera_up":    dict(x=0,   y=0,    z=1.0),
    }

## 2.2 Voxel Data Loading and Caching

Implement `read_binvox()`:
*   Parses `.binvox` binary format into a NumPy boolean array.
*   Optionally corrects coordinate axes if flips are needed.
*   Uses an LRU-style cache (Python dictionary) to avoid re-reading the same file multiple times during training.


In [10]:
###############################################
# 2. Binvox loader & cache
###############################################

def read_binvox(path, fix_coords=True):
    """Read a .binvox file and optionally swap axes for correct orientation."""
    with open(path, "rb") as f:
        header = f.readline().strip()
        if not header.startswith(b"#binvox"):
            raise ValueError("Not a binvox file")
        dims = None
        while True:
            line = f.readline().strip()
            if line.startswith(b"dim"):
                dims = list(map(int, line.split()[1:]))
            elif line.startswith(b"data"):
                break
        raw = f.read()
    vals, idx = [], 0
    while idx < len(raw):
        v, c = raw[idx], raw[idx+1]
        vals.extend([v]*c)
        idx += 2
    arr = np.array(vals, np.uint8).reshape(dims)
    if fix_coords:
        arr = np.transpose(arr, (2,1,0))
    return arr

class ChairSubsetDataset(Dataset):
    """A tiny ShapeNet chair subset for rapid experimentation."""
    def __init__(self, cfg):
        self.tfm       = cfg["transform"]
        img_root       = cfg["image_root"]
        vox_root       = cfg["voxel_root"]
        cat, mids      = cfg["category"], cfg["model_ids"]
        self.vox_cache = {}
        self.samples   = []

        for mid in mids:
            vox_fp  = os.path.join(vox_root, cat, mid, "model.binvox")
            img_dir = os.path.join(img_root, cat, mid, "rendering")
            if not os.path.isdir(img_dir) or not os.path.exists(vox_fp):
                print(f"⚠ skip missing {mid}")
                continue
            vox = read_binvox(vox_fp)
            self.vox_cache[mid] = torch.from_numpy(vox).unsqueeze(0).float()
            for fn in sorted(os.listdir(img_dir)):
                if fn.endswith(".png"):
                    self.samples.append((mid, os.path.join(img_dir, fn)))

        if not self.samples:
            raise RuntimeError("No samples found—check paths & model_ids")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mid, img_fp = self.samples[idx]
        img         = Image.open(img_fp).convert("RGB")
        x           = self.tfm(img)
        y           = self.vox_cache[mid]
        return x, y

## 2.3 Model Architecture

This module implements a 2D‑to‑3D reconstruction network with four key stages:

1. **2D Feature Encoder (ResNet50 Backbone)**  
   - Loads a pre‑trained `resnet50` (`ResNet50_Weights.DEFAULT`).  
   - Reuses the initial layers:
     - `conv1`, `bn1`, `relu`, `maxpool` for low‑level features  
     - `layer1` → C=256 feature maps  
     - `layer2` → C=512  
     - `layer3` → C=1024  
     - `layer4` → C=2048  

2. **Multi-Scale Fusion**  
   - **Fusion at scale 3**:  
     1. Upsample `layer4` to match `layer3` spatial size.  
     2. Concatenate along channels (2048 + 1024 → 3072).  
     3. Reduce to 1024 via `Conv2d(1×1) + BatchNorm + ReLU`.  
   - **Fusion at scale 2**:  
     1. Upsample fused output `f3` to match `layer2`.  
     2. Concatenate (1024 + 512 → 1536) and reduce to 512 channels.  
   - **Final conv**:  
     - `Conv2d(512 → 256) + BatchNorm + ReLU` yields a 256‑channel 2D feature map.

3. **Volumetric Latent Pooling**  
   - Apply `AdaptiveAvgPool2d((4,4))` to produce a 256×4×4 tensor.  
   - Reshape and repeat along a new depth dimension to form a 5D tensor:  
     ```
     (batch, 256, 4, 4, 4)
     ```

4. **3D Decoder**  
   - Three transpose‑conv blocks upsample to full resolution (32³):
     1. **256 → 128** channels → output 8×8×8  
     2. **128 → 64**  channels → 16×16×16  
     3. **64 → 1**    channel  → 32×32×32  
   - Each block: `ConvTranspose3d(kernel=4,stride=2,pad=1) → BatchNorm3d → ReLU`.

5. **Refinement Module**  
   - A lightweight residual CNN for boundary sharpening:
     1. `Conv3d(1 → 32, k=3, pad=1) → BatchNorm3d → ReLU`  
     2. `Conv3d(32 → 32, k=3, pad=1) → BatchNorm3d → ReLU`  
     3. `Conv3d(32 → 1,  k=3, pad=1)`  
   - Final output:  
     ```python
     output = coarse_decode + refine_module(coarse_decode)
     ```
   - Enhances the coarse voxel grid with fine detail.

---

**Forward Pass Summary**  
```python
# x: [B, 3, H, W] input images
feat1 → feat2 → feat3 → feat4 = ResNet50 backbone
f3 = fuse3( upsample(feat4), feat3 )
f2 = fuse2( upsample(f3), feat2 )
f  = convf(f2)                      # 256×H'×W'
p  = pool(f).view(B,256,4,4,1).repeat(1,1,1,1,4)
coarse = dec(p)                     # [B,1,32,32,32]
refine = refine_module(coarse)      # [B,1,32,32,32]
return coarse + refine              # final voxel prediction


In [11]:
###############################################
# 3. Model
###############################################
class FusionRefined3DReconstruction(nn.Module):
    """2D CNN backbone + multi‑scale fusion + 3D decoder with refinement."""
    def __init__(self):
        super().__init__()
        backbone            = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.conv1, self.bn1, self.relu, self.maxpool = (
            backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = (
            backbone.layer1, backbone.layer2,
            backbone.layer3, backbone.layer4
        )
        self.fuse3 = nn.Sequential(
            nn.Conv2d(2048+1024,1024,1), nn.BatchNorm2d(1024), nn.ReLU()
        )
        self.fuse2 = nn.Sequential(
            nn.Conv2d(1024+512,512,1),  nn.BatchNorm2d(512),  nn.ReLU()
        )
        self.convf = nn.Sequential(
            nn.Conv2d(512,256,1), nn.BatchNorm2d(256), nn.ReLU()
        )
        self.pool  = nn.AdaptiveAvgPool2d((4,4))
        self.dec   = nn.Sequential(
            nn.ConvTranspose3d(256,128,4,2,1), nn.BatchNorm3d(128), nn.ReLU(),
            nn.ConvTranspose3d(128,64,4,2,1),  nn.BatchNorm3d(64),  nn.ReLU(),
            nn.ConvTranspose3d(64,1,4,2,1)
        )
        self.refine= nn.Sequential(
            nn.Conv3d(1,32,3,padding=1), nn.BatchNorm3d(32), nn.ReLU(),
            nn.Conv3d(32,32,3,padding=1), nn.BatchNorm3d(32), nn.ReLU(),
            nn.Conv3d(32,1,3,padding=1)
        )

    def forward(self, x):
        # Feature extraction
        x   = self.relu(self.bn1(self.conv1(x)))
        x   = self.maxpool(x)
        x1  = self.layer1(x); x2 = self.layer2(x1)
        x3  = self.layer3(x2); x4 = self.layer4(x3)
        # Top‑down fusion
        u4  = F.interpolate(x4, size=x3.shape[2:],
                            mode="bilinear", align_corners=False)
        f3  = self.fuse3(torch.cat([u4, x3], 1))
        u3  = F.interpolate(f3, size=x2.shape[2:],
                            mode="bilinear", align_corners=False)
        f2  = self.fuse2(torch.cat([u3, x2], 1))
        f   = self.convf(f2)
        # 3D decoding
        p   = self.pool(f).unsqueeze(-1).repeat(1,1,1,1,4)
        coarse = self.dec(p)
        return coarse + self.refine(coarse)

## 2.4 Loss Function Composition

Voxel occupancy is extremely sparse—most voxels are empty—so a single loss term often fails to guide the network toward accurate object boundaries. We therefore combine four complementary objectives:

1. **Binary Cross‑Entropy Loss (BCE)**  
   ```python
   L_{BCE} = \mathrm{BCEWithLogitsLoss}(logits, targets)
   ```  
   - Penalizes each voxel’s predicted probability vs. its true label.  
   - Provides stable gradients across all voxels.

2. **Dice Loss**  
   $$
   \mathrm{Dice} = \frac{2\,\sum p\,t + \varepsilon}{\sum p + \sum t + \varepsilon},
   \quad
   L_{Dice} = 1 - \mathrm{Dice}
   $$  
   - Measures overlap between prediction and ground truth.  
   - Focuses on overall shape agreement rather than individual voxels.

3. **Focal Loss**  
   $$
   L_{Focal}(p_t) = -\,\alpha\,(1 - p_t)^\gamma \,\log(p_t)
   $$  
   where \(p_t\) is the model’s estimated probability for the true class.  
   - Down‑weights “easy” negatives (empty voxels) to emphasize hard positives (occupied voxels).  
   - Hyperparameters \(\alpha\) and \(\gamma\) control the balance and focusing strength.

4. **IoU Loss**  
   $$
   \mathrm{IoU} = \frac{\sum p\,t + \varepsilon}{\sum p + \sum t - \sum p\,t + \varepsilon},
   \quad
   L_{IoU} = 1 - \mathrm{IoU}
   $$  
   - Directly optimizes the Intersection over Union metric.  
   - Encourages correct shape reconstruction at the volume level.

---

#### Final Composite Objective

We weight all four terms equally to balance voxel‑wise accuracy, overlap, and focus on hard examples:

```python
L = 0.25 \bigl(L_{BCE} + L_{Dice} + L_{Focal} + L_{IoU}\bigr)
```

This hybrid loss ensures stable training (BCE), shape‑aware alignment (Dice & IoU), and targeted learning on under‑represented occupied voxels (Focal).  


In [12]:
###############################################
# 4. Composite loss
###############################################
class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=2.):
        super().__init__(); self.alpha, self.gamma = alpha, gamma
    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        p   = torch.sigmoid(logits)
        p_t = targets*p + (1-targets)*(1-p)
        return (self.alpha*(1-p_t)**self.gamma * bce).mean()

class IoULoss(nn.Module):
    def __init__(self, smooth=1.):
        super().__init__(); self.smooth = smooth
    def forward(self, logits, targets):
        preds = torch.sigmoid(logits)
        inter = (preds*targets).sum((1,2,3,4))
        union= preds.sum((1,2,3,4)) + targets.sum((1,2,3,4)) - inter
        return (1 - (inter + self.smooth)/(union + self.smooth)).mean()

class CompositeLoss(nn.Module):
    """BCE + Dice + Focal + IoU (equal weights)."""
    def __init__(self, alpha=1., gamma=2., smooth=1.):
        super().__init__()
        self.bce   = nn.BCEWithLogitsLoss()
        self.focal = FocalLoss(alpha, gamma)
        self.iou   = IoULoss(smooth)
        self.smooth= smooth
    def forward(self, logits, targets):
        b   = self.bce(logits, targets)
        p   = torch.sigmoid(logits)
        inter = (p*targets).sum((1,2,3,4))
        dice  = (2*inter + self.smooth) / (
                   p.sum((1,2,3,4)) + targets.sum((1,2,3,4)) + self.smooth
               )
        d     = 1 - dice.mean()
        f     = self.focal(logits, targets)
        i     = self.iou(logits, targets)
        return 0.25*(b + d + f + i)

## 2.5 Visualization Helpers

* `visualize_voxel(voxel_grid)`: Renders a 3D voxel grid as a Plotly `go.Volume` or `go.Isosurface` for volume rendering.
* `voxel_to_mesh(voxel_grid)`: Runs `skimage.measure.marching_cubes_lewiner()` to extract a triangular mesh.
* Displays interactive 3D mesh with `plotly.graph_objects.Mesh3d`, allowing camera rotation and lighting controls.

In [13]:
###############################################
# 5. Visualization helper
###############################################

def visualize_voxel(
    vol,
    threshold=0.5,
    refine_mesh=False,
    camera_eye=None,
    camera_up=None,
    margin=30,
):
    """Render a volumetric prediction as an interactive 3D mesh with Plotly.

    Args:
        vol (ndarray): 3‑D numpy array of occupancy/probability values.
        threshold (float): Marching‑cubes isovalue.
        refine_mesh (bool): Optional Laplacian smoothing for nicer visuals.
        camera_eye (dict): Plotly camera eye position.
        camera_up (dict): Plotly camera up vector.
        margin (int): Padding (in px) around the figure.
    """
    # pick a robust default threshold if out of range
    if threshold < vol.min() or threshold > vol.max():
        threshold = 0.5*(vol.min() + vol.max())

    # extract surface
    verts, faces, _, _ = measure.marching_cubes(vol, level=threshold)

    if refine_mesh:
        mesh = trimesh.Trimesh(vertices=verts, faces=faces)
        trimesh.smoothing.filter_laplacian(mesh, iterations=5)
        verts, faces = mesh.vertices, mesh.faces

    # swap X<->Z for correct orientation
    x0, y0, z0 = verts.T
    x, y, z    = z0, y0, x0

    i, j, k = faces.T
    mesh3d = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, opacity=0.6)

    fig = go.Figure(data=[mesh3d])
    fig.update_layout(
        scene=dict(
            aspectmode='data',
            camera=dict(
                eye = camera_eye or dict(x=0, y=-2.5, z=1.5),
                up  = camera_up  or dict(x=0, y=0,    z=1.0)
            ),
            xaxis=dict(title='x'),
            yaxis=dict(title='y'),
            zaxis=dict(title='z'),
        ),
        # extra padding all around
        margin=dict(l=margin, r=margin, t=margin, b=margin)
    )
    fig.show()

## 2.6 Training Procedure

### Function: `train_model(cfg)`

1. **Data Preparation**  
   - Instantiate the full `ChairSubsetDataset(cfg)`, then split it into training (80%), validation (10%), and test (10%) sets using a fixed random seed for reproducibility.  
   - Create PyTorch `DataLoader`s for the training and validation subsets, with shuffling, batching, and multi‑worker loading.

2. **Model & Optimization Setup**  
   - Initialize the `FusionRefined3DReconstruction` model on the chosen device.  
   - Configure the optimizer (`AdamW` with weight decay) and a cosine‑annealing‑with‑warm‑restarts learning‑rate scheduler.  
   - Wrap training in mixed‑precision via `GradScaler` and `autocast` for speed and memory efficiency.  
   - Define the composite voxel‑reconstruction loss (`BCE + Focal + Dice + IoU`).

3. **Epoch Loop**  
   For each epoch up to `cfg["epochs"]`:
   - **Training Phase**  
     - Set `model.train()`, zero the gradients, and iterate over batches:  
       1. Forward pass through the network.  
       2. Compute the composite loss.  
       3. Scale and backpropagate gradients.  
       4. Clip gradients to avoid explosion.  
       5. Perform the optimizer step and update the scaler.  
     - Accumulate per‑batch losses to compute the epoch’s average training loss.
   - **Validation Phase**  
     - Switch to `model.eval()` and disable grads: iterate over validation batches to compute the average validation loss.
   - **Learning Rate & Logging**  
     - Record the current learning rate, then step the scheduler.  
     - Print a summary line (`Epoch X/Y | Train Loss: … | Val Loss: … | LR: …`).  
     - Push scalar metrics to TensorBoard via `SummaryWriter`.

4. **Checkpointing & Early Stopping**  
   - After each epoch, if validation loss improves, save the model weights to `best_model.pth` and reset a no‑improvement counter.  
   - If the counter reaches `cfg["patience"]`, break the loop early to avoid overfitting.

5. **Periodic Visualization**  
   - Every N epochs (e.g., every 25), run a forward pass on a fixed validation sample (`fixed_img`) and call `visualize_voxel()` to display an intermediate reconstruction in the notebook.

6. **Finalization**  
   - After early stopping or completing all epochs, load the best weights into the model and save `final_model.pth`.  
   - Close the TensorBoard writer and return:
     ```python
     return train_losses, val_losses, lrs, last_epoch, fixed_img, model, test_ds
     ```
   - This output feeds into downstream plotting and inference steps.


In [18]:
###############################################
# 6. Training loop (prints metrics each epoch)
###############################################

def train_model(cfg):
    ds             = ChairSubsetDataset(cfg)
    total = len(ds)
    train_len = int(0.8 * total)
    val_len   = int(0.1 * total)
    test_len  = total - train_len - val_len
    train_ds, val_ds, test_ds = random_split(
        ds,
        [train_len, val_len, test_len],
        generator=torch.Generator().manual_seed(42)  # for reproducibility
    )

    fixed_img, _   = val_ds[0]

    tl = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True,
                    num_workers=cfg["num_workers"], pin_memory=True)
    vl = DataLoader(val_ds,   batch_size=cfg["batch_size"], shuffle=False,
                    num_workers=cfg["num_workers"], pin_memory=True)

    model     = FusionRefined3DReconstruction().to(cfg["device"])
    optimizer = optim.AdamW(model.parameters(), lr=cfg["lr"], weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                    optimizer, T_0=cfg["T_0"], T_mult=2)
    scaler    = GradScaler()
    loss_fn   = CompositeLoss()

    writer = SummaryWriter(cfg["log_dir"])
    best_val, wait = float("inf"), 0
    train_losses, val_losses, lrs = [], [], []

    for epoch in range(1, cfg["epochs"] + 1):
        # ——— training ———
        model.train(); running = 0.0
        for x, y in tl:
            x, y = x.to(cfg["device"]), y.to(cfg["device"])
            optimizer.zero_grad()
            with autocast(device_type=cfg["device"].type):
                logits = model(x)
                loss   = loss_fn(logits, y)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer); scaler.update()
            running += loss.item() * x.size(0)
        train_loss = running / len(tl.dataset)
        train_losses.append(train_loss)

        # ——— validation ———
        model.eval(); vrunning = 0.0
        with torch.no_grad():
            for x, y in vl:
                x, y = x.to(cfg["device"]), y.to(cfg["device"])
                with autocast(device_type=cfg["device"].type):
                    logits = model(x)
                    vloss  = loss_fn(logits, y)
                vrunning += vloss.item() * x.size(0)
        val_loss = vrunning / len(vl.dataset)
        val_losses.append(val_loss)

        lr = optimizer.param_groups[0]['lr']
        lrs.append(lr)

        # log & console
        print(f"Epoch {epoch}/{cfg['epochs']}  "
              f"Train Loss: {train_loss:.4f}  "
              f"Val Loss:   {val_loss:.4f}  "
              f"LR:         {lr:.6f}")

        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Loss/Val",   val_loss,   epoch)
        writer.add_scalar("LR",         lr,          epoch)

        scheduler.step()

        # early stopping
        if val_loss < best_val:
            best_val = val_loss
            best_wts = copy.deepcopy(model.state_dict())
            torch.save(best_wts, "best_model.pth")
            wait = 0
        else:
            wait += 1
            if wait >= cfg["patience"]:
                print(f"Early stopping @ epoch {epoch}")
                break

        # periodic reconstruction visual
        if epoch % 25 == 0:
            with torch.no_grad():
                recon = torch.sigmoid(
                    model(fixed_img.unsqueeze(0).to(cfg["device"]))
                )[0,0].cpu().numpy()
            print(f"\nReconstruction @ epoch {epoch}")
            visualize_voxel(
                recon,
                threshold=cfg["threshold"],
                refine_mesh=cfg["refine_mesh"],
                camera_eye=cfg["camera_eye"],
                camera_up=cfg["camera_up"],
                margin=cfg["viz_margin"]  # use configured padding
            )

    model.load_state_dict(best_wts)
    torch.save(model.state_dict(), "final_model.pth")
    writer.close()
    return train_losses, val_losses, lrs, epoch, fixed_img, model, test_ds

# 3. Experiment Entry Point

1. **Load Configuration**  
   Retrieve all experiment settings (data paths, hyperparameters, device, etc.) by calling:  
   ```python
   cfg = get_config()
   ```  
   Optionally override these defaults via command‑line arguments if you integrate `argparse`.

2. **Train the Model**  
   Start the end‑to‑end training process with:  
   ```python
   train_losses, val_losses, lrs, last_epoch, fixed_img, model, test_ds = train_model(cfg)
   ```  
   This step will:
   - Instantiate datasets & data loaders  
   - Build the network, optimizer, scheduler, and loss functions  
   - Run the training loop with early stopping and checkpointing  
   - Return training metrics and a held‑out test split

3. **Visualize Training Performance**  
   Use the returned `train_losses`, `val_losses`, and `lrs` to plot:  
   - **Training vs. validation loss** curves over epochs  
   - **Learning rate schedule** on a secondary axis  
   This helps assess convergence behavior and detect overfitting.

4. **Inspect a Fixed Validation Sample**  
   Run a quick forward pass on the fixed validation image (`fixed_img`) to generate a reconstruction:  
   ```python
   with torch.no_grad():
       out_val = model(fixed_img.unsqueeze(0).to(cfg["device"]))
       recon_val = torch.sigmoid(out_val)[0,0].cpu().numpy()
   ```  
   Display the original 2D view next to its 3D reconstruction for qualitative sanity checks.

5. **Evaluate on Held‑Out Test Set**  
   Loop over `test_ds` (the 10% reserved data) to:
   - Compute per‑sample **Intersection over Union (IoU)** between prediction and ground truth  
   - Render side‑by‑side subplots of:
     - **Input view**  
     - **Ground‑truth mesh** (wireframe)  
     - **Predicted mesh** (solid, with IoU in the title)  
   This provides both quantitative and visual evidence of generalization.

6. **Wrap Up & Export**  
   Optionally save final figures to disk or export interactive HTML widgets for sharing.  
   Cleanly exit or return from `main()` once all visualizations are complete.

In [19]:
###############################################
# 7. Main & Enhanced Visualization
###############################################
def main():
    cfg = get_config()
    train_losses, val_losses, lrs, last_epoch, fixed_img, model, test_ds = train_model(cfg)

    # 1) Training Curves with Clear Titles & Axes
    epochs = list(range(1, last_epoch + 1))
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    fig.add_trace(
        go.Scatter(x=epochs, y=train_losses, name="Train Loss", mode="lines+markers"),
        secondary_y=False,
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=val_losses, name="Val Loss", mode="lines+markers"),
        secondary_y=False,
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=lrs, name="Learning Rate", mode="lines"),
        secondary_y=True,
    )
    fig.update_layout(
        title="Training vs. Validation Loss & Learning Rate",
        xaxis_title="Epoch",
        legend_title="Metrics",
        margin=dict(l=50, r=50, t=80, b=50)
    )
    fig.update_yaxes(title_text="Loss", secondary_y=False)
    fig.update_yaxes(title_text="LR", secondary_y=True)
    fig.show()

    # 2) Fixed Validation Sample: Input → Prediction
    #    Show the fixed validation image alongside its reconstruction
    with torch.no_grad():
        out_val = model(fixed_img.unsqueeze(0).to(cfg["device"]))
        recon_val = torch.sigmoid(out_val)[0,0].detach().cpu().numpy()

    fig2 = make_subplots(
        rows=1, cols=2,
        subplot_titles=("Fixed Val Input Image", "Reconstruction at Last Epoch"),
        specs=[[{"type": "image"}, {"type": "scene"}]]
    )
    # Left: input image
    inv = transforms.Normalize(
        mean=[-m/s for m,s in zip([0.485,0.456,0.406],[0.229,0.224,0.225])],
        std=[1/s for s in [0.229,0.224,0.225]]
    )
    im = inv(fixed_img).clamp(0,1).permute(1,2,0).numpy()
    fig2.add_trace(
        go.Image(z=(im * 255).astype("uint8")),
        row=1, col=1
    )
    # Right: 3D mesh
    verts, faces, _, _ = measure.marching_cubes(recon_val, level=cfg["threshold"])
    x0,y0,z0 = verts.T; x, y, z = z0, y0, x0
    i,j,k = faces.T
    mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, opacity=0.6)
    fig2.add_trace(mesh, row=1, col=2)
    fig2.update_layout(
        title="Fixed Validation Example",
        scene2=dict(
            camera=dict(eye=cfg["camera_eye"], up=cfg["camera_up"]),
            xaxis_title="X", yaxis_title="Y", zaxis_title="Z",
            aspectmode="data"
        ),
        margin=dict(l=50, r=50, t=80, b=50)
    )
    fig2.show()

    # 3) Test‐Set Reconstructions: Input / Ground‑Truth / Prediction
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=True)
    for idx, (x_test, y_true) in enumerate(test_loader):
        if idx >= 5:
            break

        x_test = x_test.to(cfg["device"])
        with torch.no_grad(), autocast(device_type=cfg["device"].type):
            logits = model(x_test)
        vox_pred = torch.sigmoid(logits)[0,0].cpu().numpy()
        vox_true = y_true[0,0].numpy()

        # Build 1×3 subplot: [input image | GT mesh | Pred mesh]
        fig3 = make_subplots(
            rows=1, cols=3,
            subplot_titles=("Input View", "Ground‑Truth Mesh", "Predicted Mesh"),
            specs=[[{"type":"image"}, {"type":"scene"}, {"type":"scene"}]]
        )

        # Input image
        imt = inv(x_test[0]).clamp(0,1).permute(1,2,0).cpu().numpy()
        fig3.add_trace(go.Image(z=(imt*255).astype("uint8")), row=1, col=1)

        # Ground‑truth mesh
        verts_t, faces_t, _, _ = measure.marching_cubes(vox_true, level=0.5)
        xt0, yt0, zt0 = verts_t.T; xt, yt, zt = zt0, yt0, xt0
        it, jt, kt = faces_t.T
        fig3.add_trace(
            go.Mesh3d(x=xt, y=yt, z=zt, i=it, j=jt, k=kt, opacity=0.6),
            row=1, col=2
        )

        # Predicted mesh
        verts_p, faces_p, _, _ = measure.marching_cubes(vox_pred, level=cfg["threshold"])
        xp0, yp0, zp0 = verts_p.T; xp, yp, zp = zp0, yp0, xp0
        ip, jp, kp = faces_p.T
        fig3.add_trace(
            go.Mesh3d(x=xp, y=yp, z=zp, i=ip, j=jp, k=kp, opacity=0.6),
            row=1, col=3
        )

        fig3.update_layout(
            title=f"Test Sample #{idx+1}",
            scene2=dict(camera=dict(eye=cfg["camera_eye"], up=cfg["camera_up"]), aspectmode="data"),
            scene3=dict(camera=dict(eye=cfg["camera_eye"], up=cfg["camera_up"]), aspectmode="data"),
            margin=dict(l=40, r=40, t=80, b=40)
        )
        fig3.show()

if __name__ == "__main__":
    main()


Epoch 1/250  Train Loss: 0.6590  Val Loss:   0.6738  LR:         0.000500
Epoch 2/250  Train Loss: 0.4951  Val Loss:   0.5292  LR:         0.000488
Epoch 3/250  Train Loss: 0.4469  Val Loss:   0.4558  LR:         0.000452
Epoch 4/250  Train Loss: 0.4142  Val Loss:   0.4201  LR:         0.000397
Epoch 5/250  Train Loss: 0.3909  Val Loss:   0.3875  LR:         0.000327
Epoch 6/250  Train Loss: 0.3730  Val Loss:   0.3876  LR:         0.000250
Epoch 7/250  Train Loss: 0.3589  Val Loss:   0.3790  LR:         0.000173
Epoch 8/250  Train Loss: 0.3434  Val Loss:   0.3420  LR:         0.000103
Epoch 9/250  Train Loss: 0.3339  Val Loss:   0.3357  LR:         0.000048
Epoch 10/250  Train Loss: 0.3295  Val Loss:   0.3291  LR:         0.000012
Epoch 11/250  Train Loss: 0.3227  Val Loss:   0.3696  LR:         0.000500
Epoch 12/250  Train Loss: 0.2989  Val Loss:   0.3032  LR:         0.000497
Epoch 13/250  Train Loss: 0.2809  Val Loss:   0.3334  LR:         0.000488
Epoch 14/250  Train Loss: 0.2627  

Epoch 26/250  Train Loss: 0.1434  Val Loss:   0.1304  LR:         0.000073
Epoch 27/250  Train Loss: 0.1404  Val Loss:   0.1248  LR:         0.000048
Epoch 28/250  Train Loss: 0.1391  Val Loss:   0.1302  LR:         0.000027
Epoch 29/250  Train Loss: 0.1386  Val Loss:   0.1225  LR:         0.000012
Epoch 30/250  Train Loss: 0.1377  Val Loss:   0.1226  LR:         0.000003
Epoch 31/250  Train Loss: 0.1398  Val Loss:   0.1494  LR:         0.000500
Epoch 32/250  Train Loss: 0.1500  Val Loss:   0.1581  LR:         0.000499
Epoch 33/250  Train Loss: 0.1479  Val Loss:   0.1852  LR:         0.000497
Epoch 34/250  Train Loss: 0.1391  Val Loss:   0.1354  LR:         0.000493
Epoch 35/250  Train Loss: 0.1457  Val Loss:   0.1456  LR:         0.000488
Epoch 36/250  Train Loss: 0.1378  Val Loss:   0.1771  LR:         0.000481
Epoch 37/250  Train Loss: 0.1348  Val Loss:   0.1392  LR:         0.000473
Epoch 38/250  Train Loss: 0.1312  Val Loss:   0.1206  LR:         0.000463
Epoch 39/250  Train Loss:

Epoch 51/250  Train Loss: 0.1021  Val Loss:   0.0921  LR:         0.000250
Epoch 52/250  Train Loss: 0.0977  Val Loss:   0.0892  LR:         0.000230
Epoch 53/250  Train Loss: 0.0949  Val Loss:   0.0875  LR:         0.000211
Epoch 54/250  Train Loss: 0.0944  Val Loss:   0.0917  LR:         0.000192
Epoch 55/250  Train Loss: 0.0936  Val Loss:   0.0854  LR:         0.000173
Epoch 56/250  Train Loss: 0.0923  Val Loss:   0.0824  LR:         0.000154
Epoch 57/250  Train Loss: 0.0922  Val Loss:   0.0833  LR:         0.000137
Epoch 58/250  Train Loss: 0.0913  Val Loss:   0.0826  LR:         0.000119
Epoch 59/250  Train Loss: 0.0910  Val Loss:   0.0876  LR:         0.000103
Epoch 60/250  Train Loss: 0.0908  Val Loss:   0.0832  LR:         0.000088
Epoch 61/250  Train Loss: 0.0905  Val Loss:   0.0824  LR:         0.000073
Epoch 62/250  Train Loss: 0.0910  Val Loss:   0.0820  LR:         0.000060
Epoch 63/250  Train Loss: 0.0904  Val Loss:   0.0880  LR:         0.000048
Epoch 64/250  Train Loss:

Epoch 76/250  Train Loss: 0.0938  Val Loss:   0.0872  LR:         0.000495
Epoch 77/250  Train Loss: 0.1043  Val Loss:   0.0876  LR:         0.000493
Epoch 78/250  Train Loss: 0.0945  Val Loss:   0.0886  LR:         0.000491
Epoch 79/250  Train Loss: 0.0943  Val Loss:   0.0918  LR:         0.000488
Epoch 80/250  Train Loss: 0.0926  Val Loss:   0.0838  LR:         0.000485
Epoch 81/250  Train Loss: 0.0899  Val Loss:   0.0829  LR:         0.000481
Epoch 82/250  Train Loss: 0.0992  Val Loss:   0.0798  LR:         0.000477
Epoch 83/250  Train Loss: 0.0920  Val Loss:   0.0818  LR:         0.000473
Epoch 84/250  Train Loss: 0.0897  Val Loss:   0.0834  LR:         0.000468
Epoch 85/250  Train Loss: 0.0907  Val Loss:   0.0826  LR:         0.000463
Epoch 86/250  Train Loss: 0.0901  Val Loss:   0.0818  LR:         0.000458
Early stopping @ epoch 86


# 4. Inference and Final Results

1. **Load configuration & model**  
   - Call `cfg = get_config()` and set `device = cfg["device"]`.  
   - Instantiate `FusionRefined3DReconstruction()` on the device and load `"final_model.pth"`.

2. **Recreate held‑out test split**  
   - Build the full `ChairSubsetDataset(cfg)`.  
   - Compute `train/val/test` lengths (80%/10%/10%) and call `random_split(..., seed=42)` to extract `test_ds`.

3. **Prepare test DataLoader**  
   - Wrap `test_ds` in a `DataLoader(batch_size=1, shuffle=True, ...)` for visualization.

4. **Define visualization helpers**  
   - **Inverse-normalization** transform to map tensor back to displayable RGB.  
   - **Camera settings**: reuse `cfg["camera_eye"]` and `cfg["camera_up"]`.  
   - **IoU function** to compute overlap between binarized prediction and ground truth.

5. **Inference & plotting loop** (for N samples)  
   - Run `with torch.no_grad(), autocast(...)` to get `logits = model(x_test)` and then `vox_pred = sigmoid(logits)`.  
   - Extract the corresponding `vox_true` from `y_true`.  
   - Compute **per-sample IoU**: `(inter + smooth)/(union + smooth)`.  
   - **Invert** your normalization to get back the 2D input image.  
   - **Marching cubes** on `vox_true` (level=0.5) and `vox_pred` (level=`cfg["threshold"]`) to get meshes.  
   - Render a **1×3 subplot**:  
     - **Left**: Input RGB view.  
     - **Center**: GT mesh as a light wireframe.  
     - **Right**: Predicted mesh as a solid surface, with the IoU in the subplot title.

6. **Review & compare**  
   - Repeat for the first N test samples to qualitatively and quantitatively assess reconstruction performance.  


In [21]:
# -----------------------------
# Inference on Held‑Out Test Set (Refactored)
# -----------------------------

# 1) Load config & model
cfg    = get_config()
device = cfg["device"]

model = FusionRefined3DReconstruction().to(device)
model.load_state_dict(torch.load("final_model.pth", map_location=device))
model.eval()

# 2) Rebuild & split dataset exactly as in training
ds = ChairSubsetDataset(cfg)
total      = len(ds)
train_len  = int(0.8 * total)
val_len    = int(0.1 * total)
test_len   = total - train_len - val_len
_, _, test_ds = random_split(
    ds, [train_len, val_len, test_len],
    generator=torch.Generator().manual_seed(42)
)

# 3) Test loader
test_loader = DataLoader(
    test_ds, batch_size=1, shuffle=True,
    num_workers=cfg["num_workers"], pin_memory=True
)

# 4) Inverse-normalization (for display) & camera
inv_norm = transforms.Normalize(
    mean=[-m/s for m,s in zip([0.485,0.456,0.406],[0.229,0.224,0.225])],
    std=[1/s    for s     in [0.229,0.224,0.225]]
)
camera   = dict(eye=cfg["camera_eye"], up=cfg["camera_up"])

# 5) Helper: compute IoU
def compute_iou(pred, target, thresh=0.5, smooth=1.):
    p = (pred >= thresh).astype(float)
    t = target.astype(float)
    inter = (p * t).sum()
    union = p.sum() + t.sum() - inter
    return (inter + smooth) / (union + smooth)

# 6) Loop & visualize
N = 5
for idx, (x_test, y_true) in enumerate(test_loader):
    if idx >= N: break

    x_test = x_test.to(device)
    with torch.no_grad(), torch.amp.autocast(device_type=device.type):
        logits    = model(x_test)
        vox_pred  = torch.sigmoid(logits)[0,0].cpu().numpy()
        vox_true  = y_true[0,0].numpy()

    # compute IoU
    iou = compute_iou(vox_pred, vox_true, thresh=cfg["threshold"])

    # prepare the input image for display
    img_vis = inv_norm(x_test[0]).clamp(0,1).permute(1,2,0).cpu().numpy()
    img_uint = (img_vis * 255).astype("uint8")

    # extract meshes
    def extract_mesh(vox, level):
        verts, faces, _, _ = measure.marching_cubes(vox, level=level)
        x0,y0,z0 = verts.T; return (z0, y0, x0), faces.T

    (xt, yt, zt), (it, jt, kt) = extract_mesh(vox_true, level=0.5)
    (xp, yp, zp), (ip, jp, kp) = extract_mesh(vox_pred, level=cfg["threshold"])

    # 1×3 subplot: Input | GT (wireframe) | Pred (solid)
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=(
            "Input View",
            f"GT Mesh (wireframe)",
            f"Pred Mesh (solid) — IoU: {iou:.2f}"
        ),
        specs=[[{"type":"image"}, {"type":"scene"}, {"type":"scene"}]]
    )
    # input
    fig.add_trace(go.Image(z=img_uint), row=1, col=1)

    # GT wireframe
    fig.add_trace(
        go.Mesh3d(
            x=xt, y=yt, z=zt, i=it, j=jt, k=kt,
            opacity=0.2, color="grey",
            flatshading=True, showscale=False
        ), row=1, col=2
    )

    # Prediction solid
    fig.add_trace(
        go.Mesh3d(
            x=xp, y=yp, z=zp, i=ip, j=jp, k=kp,
            opacity=0.7, color="steelblue",
            flatshading=True, showscale=False
        ), row=1, col=3
    )

    # layout
    fig.update_layout(
        title=f"Test Sample #{idx+1}",
        margin=dict(l=30, r=30, t=80, b=30),
        scene2=dict(camera=camera, aspectmode="data"),
        scene3=dict(camera=camera, aspectmode="data")
    )
    fig.show()
