# GHD Fitting Demo

This notebook demonstrates how to call the `ghd_fit.py` script from Python so you can iterate quickly inside a Jupyter environment.

### 1  Environment setup
Install any required third‑party libraries *if they’re not already available* in your runtime.

Install from the provided `requirements.txt` file

In [10]:
## If running on a fresh environment you may need:
# !pip install torch torchvision torchaudio
# !pip install numpy trimesh nibabel tqdm

import sys, os, json, torch, pathlib
print(f'Python {sys.version.split()[0]}  |  Torch {torch.__version__}')

Python 3.10.13  |  Torch 2.1.1


### 2  Dataset Structure & Naming Convention

This notebook fits parametric meshes to fetal cardiac segmentation masks using **GHD**.  
Before running, make sure your data follows the structure below.

---

####  Folder Structure 

Each **case** is stored as a sub-folder under `--data_root`.

####  File Naming Rules

- Filenames must include a **case ID** and a **time tag**.
  - Example: `MITEA_005_scan1_ED_true_mask.nii.gz`
  - Example: `FeEcho4D_017_time003_true_mask.nii.gz`
- Time tag must be one of:
  - `ED`, `ES` → End-Diastole / End-Systole (2-frame mode)
  - `timeXYZ` → Frame number in a 3D+T sequence (full cycle)

####  Time Selection

- `--times` accepts tags like:
  - `ED ES` → process only ED and ES frames
  - `time001-010` → process time001 to time010
- If not specified, **all time points** are used.

### 3  Import helper and set parameters

For more detailed parameter customization, please modify them directly in `ghd_fit.py`, including optimization settings, loss weights, and canonical mesh paths.

In [None]:
from ghd_fit import run     # Adjust this import if the script lives elsewhere
# --- Paths ------------------------------------------------------------------
DATA_ROOT = pathlib.Path('data_example')       # folder with case sub-dirs
MESH_OUT  = pathlib.Path("meshes_out_demo")    # save location for OBJ meshes
MESH_OUT.mkdir(exist_ok=True)

# --- Case / time selection --------------------------------------------------
CASES = ["MITEA_005_scan1"]  # [] → all cases in DATA_ROOT
TIMES = ["ED"]               
# [] → all frames; "ED","ES","time001",also supports ranges like "time001-005"

# --- Device -----------------------------------------------------------------
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# --- Core hyper-parameters ---------------------------------------------------
MYO_IDX      = 1       # myocardium entry in point_list (FeEcho4D=2, MITEA=1)
NUM_ITER     = 600     # morphing iterations
MESH_SAMPLES = 4000    # points for CPD registration
LR_START     = 5e-3    # initial learning-rate

# Loss weights
LOSS_OCC   = 1.0
LOSS_NORM  = 0.01
LOSS_LAP   = 0.01
LOSS_THICK = 0.01

PC_SPACING = 200       # µm spacing used by point_cloud_extractor

### 4  Launch fitting

In [None]:
# --- Build CLI argument list for fit_ghd.py ---------------------------------
args = [
    "--data_root",         str(DATA_ROOT),
    "--mesh_out",          str(MESH_OUT),
    "--device",            DEVICE,
    "--myo_idx",           str(MYO_IDX),
    "--num_iter",          str(NUM_ITER),
    "--mesh_samples",      str(MESH_SAMPLES),
    "--lr_start",          str(LR_START),
    "--loss_occupancy",          str(LOSS_OCC),
    "--loss_normal_consistency", str(LOSS_NORM),
    "--loss_laplacian",          str(LOSS_LAP),
    "--loss_thickness",          str(LOSS_THICK),
    "--pc_spacing",        str(PC_SPACING),
]
if CASES:
    args += ["--cases", *CASES]
if TIMES:
    args += ["--times", *TIMES]

print("Argument list:\n", " ".join(args))
run(args)

### 5  Inspect results
OBJ meshes are written to the `meshes_out_demo/` folder.