In [None]:
import csv
import os
import nrrd
import nibabel as nib
from os import path
import matplotlib.pyplot as plt
from utils.img_operations import MR_normalize,CT_normalize
pair_list = []
dataset_dir = 'sample_dataset_dir'
with open(path.join(dataset_dir, 'task1_diff_time_same_seq.csv'), 'r') as file:
    reader = csv.reader(file)
    for row in reader:
        pair_list.append(row)

In [None]:
pair_list

## vis mov image fix image

In [None]:
pair = pair_list[1] # 7 
moving_fn = pair[0] #template is to be applied to other cases
fixed_fn = pair[1]
print(moving_fn, fixed_fn)

fixed_basename = os.path.basename(fixed_fn)
fixed_basename = fixed_basename.split('.')[0]
moving_basename = os.path.basename(moving_fn)
moving_basename = moving_basename.split('.')[0]

source = 'img'
img_fixed = nib.load(path.join(dataset_dir, f'{source}_processed', fixed_fn))
img_moving = nib.load(path.join(dataset_dir, f'{source}_processed', moving_fn))


arr_fixed = img_fixed.get_fdata()
arr_moving = img_moving.get_fdata()
z_fixed = 128
z_moving = 128
print('image shape',arr_fixed.shape)


aff_mov = img_moving.affine

fig, axs = plt.subplots(1, 2, figsize=(15, 5))

axs[0].imshow(arr_fixed[:, :, z_fixed], cmap='gray')
axs[0].set_title("Fixed Image (z={})".format(z_fixed))

axs[1].imshow(arr_moving[:, :,100], cmap='gray')
axs[1].set_title("Moving Image (z={})".format(z_moving))

## vis registration results

In [None]:
import os
import nibabel as nib

exp_note = 'breast_same-seq-multi-time'
model = 'MIND'
fs = 32
output_dir = f'output/foundReg-model-{model}-2smooth-1000iter-itersmoothK7R5-lr3-fmd1-fmsize{fs}-noconvex/' + exp_note 


# Load warped image and displacement field
warped_image = nib.load(os.path.join(output_dir, '{}_to_{}_warped_{}.nii.gz'.format(
    moving_basename, fixed_basename, exp_note)))
disp_img = nib.load(os.path.join(output_dir, '{}_to_{}_disp_{}.nii.gz'.format(
    moving_basename, fixed_basename, exp_note)))

arr_warped = warped_image.get_fdata()

# Extract center slice along depth axis
z_fixed = arr_fixed.shape[2] // 2
z_moving = arr_moving.shape[2] // 2
z_warped = arr_warped.shape[2] // 2

#arr_warped  = MR_normalize(arr_warped)

#arr_warped  = CT_normalize(arr_warped,50,400)
# Normalize slices for visualization
def normalize_slice(slice_2d):
    slice_2d = slice_2d - slice_2d.min()
    if slice_2d.max() > 0:
        slice_2d = slice_2d / slice_2d.max()
    return slice_2d

# Plotting
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(normalize_slice(arr_fixed[:, :, z_fixed]), cmap='gray')
axs[0].set_title("Fixed Image (z={})".format(z_fixed))

axs[1].imshow(normalize_slice(arr_moving[:, :, z_moving]), cmap='gray')
axs[1].set_title("Moving Image (z={})".format(z_moving))

axs[2].imshow(normalize_slice(arr_warped[:, :, z_warped]), cmap='gray')
axs[2].set_title("Warped Image (z={})".format(z_warped))

## vis feature maps if you save them

In [None]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

# Load fix and mov volumes
base_name = fixed_basename
fix_vols = [nib.load(os.path.join(output_dir,'vis',f"{base_name}_fix_{i}.nii.gz")).get_fdata() for i in range(1)]
mov_vols = [nib.load(os.path.join(output_dir,'vis',f"{base_name}_mov_{i}.nii.gz")).get_fdata() for i in range(1)]

# Normalize volumes to [0, 1] for RGB display
def normalize(vol):
    vol = vol.astype(np.float32)
    return (vol - np.min(vol)) / (np.max(vol) - np.min(vol) + 1e-8)

fix_vols = [normalize(vol) for vol in fix_vols]
mov_vols = [normalize(vol) for vol in mov_vols]

# Use center slice from axial plane (axis=2)
z_fix = fix_vols[0].shape[2] // 2
z_mov = mov_vols[0].shape[2] // 2

rgb_fix = np.stack([fix_vols[i][:, :, z_fix] for i in range(1)], axis=-1)
rgb_mov = np.stack([mov_vols[i][:, :, z_mov] for i in range(1)], axis=-1)

# Plot side-by-side
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(rgb_fix[:,:,:,0])
plt.title("Fix Volume - Center Slice")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(rgb_mov[:,:,:,0])
plt.title("Mov Volume - Center Slice")
plt.axis('off')

plt.tight_layout()
plt.show()


plt.figure()
plt.imshow(rgb_fix[:,:,:,0])
plt.axis('off')
plt.show()


plt.figure()
plt.imshow(rgb_mov[:,:,:,0])
plt.axis('off')
plt.show()