In [2]:
import os
import nibabel as nib
import numpy as np


In [89]:
import torch
import torch.nn as nn

def get_SDF(direction=['z', 'x', 'y']):
    sdf = []
    for d in direction:
            sdf.append(nib.load(f'DF{d}.nii.gz').get_fdata())
    sdf = np.stack(sdf, axis=-1)
    nib.save(nib.Nifti1Image(df, np.eye(4)), 'DF.nii.gz')
    return sdf

def generate_grid(imgshape):
    grid = np.mgrid[0:imgshape[0], 0:imgshape[1], 0:imgshape[2]].transpose(1, 2, 3, 0)[..., [2, 1, 0]]
    grid = torch.from_numpy(grid).unsqueeze(0).float()
    return grid
    
def transform(x, flow):
    grid = generate_grid(flow.shape[1:])
    grid = grid + flow
    grid[0, :, :, :, 0] = (grid[0, :, :, :, 0] - ((grid.size()[3] - 1) / 2)) / (grid.size()[3] - 1) * 2
    grid[0, :, :, :, 1] = (grid[0, :, :, :, 1] - ((grid.size()[2] - 1) / 2)) / (grid.size()[2] - 1) * 2
    grid[0, :, :, :, 2] = (grid[0, :, :, :, 2] - ((grid.size()[1] - 1) / 2)) / (grid.size()[1] - 1) * 2
    x = torch.nn.functional.grid_sample(x, grid, mode='bilinear', align_corners=True)
    return x

In [86]:
os.system('./bin/dramms --source brain.nii.gz --target MNI152SymNonLinear.nii.gz --outimg src2trg.nii.gz --outdef def_src2trg.nii.gz')

-------------------------------------------------------------------------------------------
DRAMMS: Deformable image Registration via Attribute Matching and Mutual-Saliency weighting
-------------------------------------------------------------------------------------------

Step 1:   Convert images to byte datatype...
Step 2:   Match histograms if necessary...
Step 3:   Affine registration of images by FSL's flirt tool (may take several minutes)...
Step 4:   Skip preprocessing of the initial transformation as none was specified.
Step 5a:  Generate multi-resolution images for the extraction of Gabor attributes...
Step 5b:  Extract Gabor attributes for deformable registration...
Step 6:   Deformably register images via attribute matching and mutual-saliency weighting (be patient, may take tens of minutes)...

Deform3D -b0,0,0 -p -r3 -C0 -n5 -k10 -s0.50 -m1 -f0 -M2 -w1 -g.35555555555555555554 -e0 -F0 -S0 -u2 -a/tmp/dramms-WXiDHI/features/B_level1_mask.nii.gz /tmp/dramms-WXiDHI/features/A

0

In [91]:
sdf = torch.from_numpy(get_SDF()).unsqueeze(dim=0).float()
source = torch.from_numpy(nib.load('brain.nii.gz').get_fdata()).unsqueeze(dim=0).unsqueeze(dim=0).float()
target = torch.from_numpy(nib.load('MNI152SymNonLinear.nii.gz').get_fdata()).unsqueeze(dim=0).unsqueeze(dim=0).float()
s2t = transform(source, sdf).type(torch.int16).squeeze()
nib.save(nib.Nifti1Image(s2t.numpy(), np.eye(4)), f'S2T.nii.gz')