In [None]:
# X3Py CT Volume Registration Test

import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
import h5py
import os

from x3py.x3reg import ITKregistration
from x3py.config import ProInit
from x3py.utils import writeh5REG

# --- Load volumes ---
def load_volume(path):
    with h5py.File(path, 'r') as f:
        if 'exchange/data' in f:
            vol = f['exchange/data'][()]
        else:
            raise KeyError("Expected dataset 'exchange/data' not found in HDF5 file.")
    return vol.astype(np.float32)

ref_path = '/data/2023-03-Xu_rec/NNO2_1458_rec.h5'
mov_path = '/data/2023-03-Xu_rec/NNO2_1488_rec.h5'

ref_vol = load_volume(ref_path)
mov_vol = load_volume(mov_path)

# --- Show middle slice before registration ---
mz = ref_vol.shape[0] // 2
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.imshow(ref_vol[mz], cmap='gray')
plt.title('Reference Volume Slice')
plt.subplot(1,2,2)
plt.imshow(mov_vol[mz], cmap='gray')
plt.title('Moving Volume Slice')
plt.show()

# --- Prepare parameters ---
params = {
    'CorrectionPlane': 'XYZ',  # Full volume
    'CROP': False,
    'sz': 0, 'ez': ref_vol.shape[0],
    'sy': 0, 'ey': ref_vol.shape[1],
    'sx': 0, 'ex': ref_vol.shape[2],
    'FDreg': 1e-6,
    'ITKcustom': True,  # Will use ReadCustomITKReg
    'Start_energy': 1.0,
    'EndEnergy': 1.0,
    'epoints': 2,
    'DataPrefix': '/path/to/',  # unused here
}

# --- Define a minimal registration function ---
def run_custom_registration(fixed, moving, par):
    R = sitk.ImageRegistrationMethod()
    R.SetMetricAsMeanSquares()
    R.SetOptimizerAsRegularStepGradientDescent(learningRate=2.0,
                                               minStep=1e-4,
                                               numberOfIterations=100)
    R.SetInterpolator(sitk.sitkLinear)
    tx = sitk.TranslationTransform(fixed.GetDimension())
    R.SetInitialTransform(tx)
    outTx = R.Execute(fixed, moving)

    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(outTx)

    out = resampler.Execute(moving)
    return sitk.GetArrayFromImage(out), outTx

# --- Wrap volumes and run ---
fixed = sitk.GetImageFromArray(ref_vol)
moving = sitk.GetImageFromArray(mov_vol)

itkr = ITKregistration([ref_path, 'ITKCustom'], params)
itkr.ReadCustomITKReg(run_custom_registration, {})

out, _ = run_custom_registration(fixed, moving, params)

# --- Save output ---
out_path = mov_path.replace('.h5', '_registered.h5')
writeh5REG(out_path, out, energy=1.0)

# --- Show result ---
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.imshow(ref_vol[mz], cmap='gray')
plt.title('Reference')
plt.subplot(1,2,2)
plt.imshow(out[mz], cmap='gray')
plt.title('Aligned')
plt.show()

print("Done.")
