In [None]:
import pandas as pd
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import map_coordinates

In [None]:
import csv
import os
import nrrd
import nibabel as nib
from os import path
import matplotlib.pyplot as plt
from utils.img_operations import MR_normalize,CT_normalize
pair_list = []
dataset_dir = 'sample_dataset_dir'
with open(path.join(dataset_dir, 'task1_diff_time_same_seq.csv'), 'r') as file:
    reader = csv.reader(file)
    for row in reader:
        pair_list.append(row)
def warp_arr(arr, disp):
    D, H, W = arr.shape
    identity = np.meshgrid(np.arange(D), np.arange(H), np.arange(W), indexing='ij')
    warped_arr = map_coordinates(arr, identity + disp, order=0, mode='nearest')
    # warped_arr = np.where(warped_arr > 0.5, 1, 0)
    return warped_arr

def dice_coefficient(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    size1 = mask1.sum()
    size2 = mask2.sum()
    if size1 + size2 == 0:
        return 1.0  # Both masks are empty
    return 2.0 * intersection / (size1 + size2)

def show_overlay(ax, image, mask, target_classes=[1], title='', alpha=0.4):
    ax.imshow(image, cmap='gray')
    mask = mask.astype(int)
    
    # Generate binary mask for specified classes
    if isinstance(target_classes, int):
        target_classes = [target_classes]
    binary_mask = np.isin(mask, target_classes)

    # Overlay in red
    ax.imshow(np.ma.masked_where(~binary_mask, binary_mask), alpha=alpha)
    
    ax.axis('off')
    ax.set_title(title)

# vis one case with mask overlay

In [None]:
pair = pair_list[3]
moving_fn = pair[0] #template is to be applied to other cases
fixed_fn = pair[1]

fixed_basename = os.path.basename(fixed_fn)
fixed_basename = fixed_basename.split('.')[0]
moving_basename = os.path.basename(moving_fn)
moving_basename = moving_basename.split('.')[0]

source = 'img'
img_fixed = nib.load(path.join(dataset_dir, f'{source}_processed', fixed_fn))
img_moving = nib.load(path.join(dataset_dir, f'{source}_processed', moving_fn))


arr_fixed = img_fixed.get_fdata()
arr_moving = img_moving.get_fdata()
z_fixed = 128
z_moving = 128
print('image shape',arr_fixed.shape)

target = 'dt_seg_processed' # where you save the target mask

aff_mov = img_moving.affine
seg_fixed = nib.load(path.join(dataset_dir, target, fixed_fn)).get_fdata()
seg_moving = nib.load(path.join(dataset_dir, target, moving_fn)).get_fdata()

In [None]:
import os
import nibabel as nib

exp_note = 'breast_same-seq-multi-time'
model = 'sam'
fs = 64
output_dir = f'output/foundReg-model-{model}-2smooth-1000iter-itersmoothK7R5-lr3-fmd1-fmsize{fs}-noconvex/' + exp_note 


# Load warped image and displacement field
warped_image = nib.load(os.path.join(output_dir, '{}_to_{}_warped_{}.nii.gz'.format(
    moving_basename, fixed_basename, exp_note)))
disp_img = nib.load(os.path.join(output_dir, '{}_to_{}_disp_{}.nii.gz'.format(
    moving_basename, fixed_basename, exp_note)))

arr_warped = warped_image.get_fdata()
arr_disp = disp_img.get_fdata()
arr_disp = np.moveaxis(arr_disp, 3, 0)
seg_warped = warp_arr(seg_moving, arr_disp)

In [None]:
# Plotting
fig, axs = plt.subplots(1, 3, figsize=(18, 6))
target_classes = [1]
z_warped= 128
show_overlay(axs[0], arr_fixed[:, :, z_fixed], seg_fixed[:, :, z_fixed],
             target_classes=target_classes)
show_overlay(axs[1], arr_moving[:, :, z_moving], seg_moving[:, :, z_moving],
             target_classes=target_classes)
show_overlay(axs[2], arr_warped[:, :, z_warped], seg_warped[:, :, z_warped],
             target_classes=target_classes)

plt.tight_layout()
plt.show()

## calculate dsc over the entire set

In [None]:
import pandas as pd
import nibabel as nib
import numpy as np
import os
import pandas as pd

exp_note = 'breast_same-seq-multi-time'
model = 'MIND'
fs = 32
output_dir = f'output/foundReg-model-{model}-2smooth-1000iter-itersmoothK7R5-lr3-fmd1-fmsize{fs}-noconvex/' + exp_note 
source = 'img'
task_name = 'task1_diff_time_same_seq'
task = pd.read_csv(path.join(dataset_dir, task_name+'.csv'))

dice_list = []
mov_pixel_list = []
warp_pixel_list = []
fix_pixel_list = []

def dice_coefficient(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    size1 = mask1.sum()
    size2 = mask2.sum()
    if size1 + size2 == 0:
        return 1.0  # Both masks are empty
    return 2.0 * intersection / (size1 + size2)

for i in range(len(task)):
    moving_fn = task['mov_volume'].iloc[i]
    fixed_fn = task['fix_volume'].iloc[i]
    
    fixed_basename = os.path.basename(fixed_fn)
    fixed_basename = fixed_basename.split('.')[0]
    moving_basename = os.path.basename(moving_fn)
    moving_basename = moving_basename.split('.')[0]
    
    img_fixed = nib.load(path.join(dataset_dir, f'{source}_processed', fixed_fn))
    img_moving = nib.load(path.join(dataset_dir, f'{source}_processed', moving_fn))
    
    
    arr_fixed = img_fixed.get_fdata()
    arr_moving = img_moving.get_fdata()

    print('image shape',arr_fixed.shape)
    
    
    aff_mov = img_moving.affine
    seg_fixed = nib.load(path.join(dataset_dir, 'full_breast_seg_rotate', fixed_fn)).get_fdata()
    seg_moving = nib.load(path.join(dataset_dir, 'full_breast_seg_rotate', moving_fn)).get_fdata()


    # Load warped image and displacement field
    warped_image = nib.load(os.path.join(output_dir, '{}_to_{}_warped_{}.nii.gz'.format(
        moving_basename, fixed_basename, exp_note)))
    disp_img = nib.load(os.path.join(output_dir, '{}_to_{}_disp_{}.nii.gz'.format(
        moving_basename, fixed_basename, exp_note)))
    
    arr_warped = warped_image.get_fdata()
    arr_disp = disp_img.get_fdata()
    arr_disp = np.moveaxis(arr_disp, 3, 0)
    seg_warped = warp_arr(seg_moving, arr_disp)
    
    raw_dice = dice_coefficient(seg_fixed, seg_moving)
    dice = dice_coefficient(seg_fixed, seg_warped)
    print(f'dice before reg:{raw_dice}, dice after reg: {dice}')

    mov_pixels = seg_moving.sum()
    warp_pixels = seg_warped.sum()
    fix_pixels = seg_fixed.sum()

    dice_list.append(dice)
    mov_pixel_list.append(mov_pixels)
    warp_pixel_list.append(warp_pixels)
    fix_pixel_list.append(fix_pixels)
    
task[f'{model} DICE Breast'] = dice_list
task[f'{model} moving num. pixel'] = mov_pixel_list
task[f'{model} warp num. pixel'] = warp_pixel_list
task[f'{model} fix num. pixel'] = fix_pixel_list

task.to_csv(model + '_' + task_name + '_' +'breast' + '.csv')