In [None]:
import json
import SimpleITK as sitk
import numpy as np
import nibabel as nib

greedy_path = '/Applications/greedy'
data_path = '../data/COMULIS3DCLEM/'

def affine_displacement_field(affine, shape):
    coords = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]), indexing='ij')
    coords = np.array(coords).reshape(3, -1)
    homogeneous_coords = np.vstack([coords, np.ones((1, coords.shape[1]))])
    displaced_coords = np.dot(affine, homogeneous_coords)
    displacement_field = displaced_coords[:3] - coords
    return displacement_field.reshape(3, *shape).transpose(1, 2, 3, 0)

!rm -r submission
!mkdir submission

with open(data_path+'COMULIS3DCLEM_dataset.json') as f:
    dataset_json = json.load(f)
    for val_pair in dataset_json['registration_val']:
        fixed_img_path = data_path + val_pair['fixed']
        moving_img_path = data_path + val_pair['moving']

        fixed_img = sitk.GetArrayFromImage(sitk.ReadImage(fixed_img_path)).transpose()
        moving_img = sitk.GetArrayFromImage(sitk.ReadImage(moving_img_path))[0,0].transpose()

        fixed_img -= fixed_img.mean()
        fixed_img /= fixed_img.std()

        moving_img -= moving_img.mean()
        moving_img /= moving_img.std()
        
        sitk.WriteImage(sitk.GetImageFromArray(fixed_img), 'fix.nii.gz')
        sitk.WriteImage(sitk.GetImageFromArray(moving_img), 'mov.nii.gz')

        !{greedy_path} -d 3 -a -dof 12 -jitter 0 -search 1000 0 10 -m MI -n 5 -i fix.nii.gz mov.nii.gz -o affine.mat > /dev/null
        
        affine = np.loadtxt('affine.mat')
        affine[:3, 3] = np.flip(affine[:3, 3], 0) * np.array([1, -1, -1])
        affine[:3, :3] = np.flip(np.flip(affine[:3, :3], 0), 1) * np.array([[1, -1, -1], [-1, 1, 1], [-1, 1, 1]])

        disp = affine_displacement_field(affine, moving_img.shape)
        
        nib.save(nib.Nifti1Image(disp.astype(np.float32), np.eye(4)), f"submission/disp_{fixed_img_path.split('_')[1]}_{fixed_img_path.split('_')[1]}.nii.gz")

!zip -r submission.zip submission/ > /dev/null
