# CRuDeNet Full Pipeline: DeepCAD → ConvGRU

This notebook demonstrates the complete workflow:
1. Train DeepCAD-RT to generate ground truth (DeepCAD is bundled in crudenet)
2. Run DeepCAD inference to denoise data
3. Train ConvGRU using DeepCAD output as ground truth
4. Export ConvGRU to ONNX for real-time inference


In [1]:
# DeepCAD is now bundled inside crudenet - no separate import needed!
# All DeepCAD functionality is available through crudenet imports

## Setup and Imports


In [2]:
import os
import glob
from pathlib import Path

import numpy as np
import tifffile

# CRuDeNet imports - DeepCAD is now bundled inside crudenet!
from crudenet import (
    train_deepcad,
    test_deepcad,
    get_deepcad_output_path,
    train as train_convgru,
    infer_stream,
    export_onnx,
)



## Configuration


In [None]:
# Paths
raw_data_path = "/scratch.global/iyer0106/KaraLab/3PVisStim"  # Folder containing raw .tif files
raw_tif_file = "3PVisStim.tif"  # Specific file name

# Output directories
output_root = "./outputs"
deepcad_pth_dir = os.path.join(output_root, "deepcad_pth")
deepcad_results_dir = os.path.join(output_root, "deepcad_results")
convgru_output_dir = os.path.join(output_root, "convgru")

os.makedirs(output_root, exist_ok=True)
os.makedirs(deepcad_pth_dir, exist_ok=True)
os.makedirs(deepcad_results_dir, exist_ok=True)
os.makedirs(convgru_output_dir, exist_ok=True)

# Model name (used by DeepCAD)
model_name = "BrainSlice"

# Training parameters
deepcad_epochs = 10
convgru_epochs = 10

print(f"Output root: {output_root}")
print(f"Raw data: {os.path.join(raw_data_path, raw_tif_file)}")


## Step 1: Train DeepCAD-RT


In [None]:
print("=" * 60)
print("Step 1: Training DeepCAD-RT")
print("=" * 60)

deepcad_config = train_deepcad(
    datasets_path=raw_data_path,
    pth_dir=deepcad_pth_dir,
    n_epochs=deepcad_epochs,
    patch_xy=150,
    patch_t=8,
    overlap_factor=0.4,
    train_datasets_size=2500,
    fmap=16,
    GPU="0",
    num_workers=0,  # Set to 0 on Windows
    save_test_images_per_epoch=True,
)

print("\n✅ DeepCAD training completed!")


## Step 2: Run DeepCAD Inference to Generate Ground Truth


In [None]:
print("=" * 60)
print("Step 2: Running DeepCAD inference")
print("=" * 60)

test_config, deepcad_output_path = test_deepcad(
    datasets_path=raw_data_path,
    denoise_model=model_name,
    pth_dir=deepcad_pth_dir,
    output_dir=deepcad_results_dir,
    patch_xy=150,
    patch_t=8,
    overlap_factor=0.4,
    test_datasize=500,
    fmap=16,
    GPU="0",
    num_workers=0,
)

print(f"\nDeepCAD output path: {deepcad_output_path}")


In [None]:
# Get the specific output file path
gt_output_file = get_deepcad_output_path(
    datasets_path=raw_data_path,
    denoise_model=model_name,
    output_dir=deepcad_results_dir,
)

print(f"Ground truth file: {gt_output_file}")

# Verify file exists
if os.path.exists(gt_output_file):
    gt_stack = tifffile.imread(gt_output_file)
    print(f"✅ Loaded GT stack: {gt_stack.shape}")
else:
    print(f"⚠️  Warning: GT file not found at {gt_output_file}")
    # Try to find it manually
    pattern = os.path.join(deepcad_results_dir, "**", "*_output.tif")
    files = glob.glob(pattern, recursive=True)
    if files:
        gt_output_file = sorted(files, key=os.path.getmtime)[-1]
        print(f"Found most recent output: {gt_output_file}")


## Step 3: Train ConvGRU Using DeepCAD Output as Ground Truth


In [None]:
print("=" * 60)
print("Step 3: Training ConvGRU on DeepCAD output")
print("=" * 60)

raw_file_path = os.path.join(raw_data_path, raw_tif_file)

# Train ConvGRU
convgru_ckpt = train_convgru(
    raw_path=raw_file_path,
    gt_path=gt_output_file,
    out_dir=convgru_output_dir,
    epochs=convgru_epochs,
    lr=1e-4,
    batch_size=2,
    patch_t=8,
    patch_xy=128,
    c_hid=32,
    num_layers=3,
    num_workers=0,  # Set to 0 on Windows
)

print(f"\n✅ ConvGRU training completed!")
print(f"Checkpoint saved to: {convgru_ckpt}")


## Step 4: Test ConvGRU Inference


In [None]:
print("=" * 60)
print("Step 4: Running ConvGRU inference")
print("=" * 60)

denoised_output = infer_stream(
    raw_path=raw_file_path,
    ckpt_path=convgru_ckpt,
    out_dir=convgru_output_dir,
    gt_path=gt_output_file,  # Optional: for metrics
    c_hid=32,
    num_layers=3,
)

print(f"\n✅ Inference completed!")
print(f"Denoised output: {denoised_output}")


## Step 5: Export ConvGRU to ONNX


In [None]:
print("=" * 60)
print("Step 5: Exporting ConvGRU to ONNX")
print("=" * 60)

onnx_path = export_onnx(
    ckpt_path=convgru_ckpt,
    out_dir=convgru_output_dir,
    patch_t=8,
    patch_xy=128,
    c_hid=32,
    num_layers=3,
    opset=13,
)

print(f"\n✅ ONNX export completed!")
print(f"ONNX model: {onnx_path}")


## Summary


In [None]:
print("=" * 60)
print("Pipeline Summary")
print("=" * 60)
print(f"\n1. DeepCAD checkpoints: {deepcad_pth_dir}")
print(f"2. DeepCAD GT output: {gt_output_file}")
print(f"3. ConvGRU checkpoint: {convgru_ckpt}")
print(f"4. ConvGRU denoised output: {denoised_output}")
print(f"5. ONNX model: {onnx_path}")
print("\n✅ Full pipeline completed successfully!")
