In [1]:
import string

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from scipy.ndimage import binary_erosion
import seaborn as sns
import torch

from scnn.sh import n_coeffs, spherical_harmonic
from scnn.models import MLPModel

sns.set_theme()

if torch.cuda.is_available():
    device = "cuda"
    print(torch.cuda.get_device_name(0))
    torch.cuda.empty_cache()
else:
    raise Exception("CUDA not available")

BATCH_SIZE = int(5e2)

NVIDIA RTX A1000 6GB Laptop GPU


In [2]:
# Data and acquisition protocol

data_img = nib.load("../data/mri/preprocessed/sub-07/dwi.nii.gz")
data = data_img.get_fdata()
affine = data_img.affine
mask = binary_erosion(
    nib.load("../data/mri/preprocessed/sub-07/brain_mask.nii.gz")
    .get_fdata()
    .astype(bool),
    iterations=2,
)

bvals = torch.round(
    torch.tensor(np.loadtxt("../data/mri/preprocessed/sub-07/dwi.bval") / 1e3),
    decimals=1,
).float()
bvecs = torch.tensor(np.loadtxt("../data/mri/preprocessed/sub-07/dwi.bvec").T).float()
bvecs[:, 0] *= -1

data /= np.mean(data[..., np.where(bvals == 0)[0]], axis=-1)[..., np.newaxis]

idx = bvals > 0
data = data[..., idx]
bvals = bvals[idx]
bvecs = bvecs[idx]

bs = torch.unique(bvals)
n_shells = len(bs)
shell_indices = [torch.where(bvals == b)[0] for b in bs]

bvecs_sft_per_shell = []
for idx in shell_indices:
    shell_bvecs = bvecs[idx]
    thetas = torch.arccos(shell_bvecs[:, 2])
    phis = (torch.arctan2(shell_bvecs[:, 1], shell_bvecs[:, 0]) + 2 * np.pi) % (
        2 * np.pi
    )
    bvecs_isft = torch.zeros(len(shell_bvecs), 45)
    for l in range(0, 8 + 1, 2):
        for m in range(-l, l + 1):
            bvecs_isft[:, int(0.5 * l * (l + 1) + m)] = spherical_harmonic(
                l, m, thetas, phis
            )
    bvecs_sft = torch.zeros((45, len(shell_bvecs)), dtype=float)
    bvecs_sft[0:45] = (
        torch.linalg.pinv(bvecs_isft[:, 0:45].T @ bvecs_isft[:, 0:45])
        @ bvecs_isft[:, 0:45].T
    )
    bvecs_sft_per_shell.append(bvecs_sft.float())

  data /= np.mean(data[..., np.where(bvals == 0)[0]], axis=-1)[..., np.newaxis]


In [3]:
# Load model and make the forward passes

model = MLPModel(120, 47).to(device)
model.load_state_dict(torch.load("../mlp_weights_rot.pt"))
model.eval()

signals = torch.tensor(data[mask]).float()

preds = torch.zeros(len(signals), 47)
with torch.no_grad():
    for i in range(0, len(signals), BATCH_SIZE):
        print(f"{int(100 * i / len(signals))}%", end="\r")
        if i + BATCH_SIZE <= len(signals):
            idx = torch.arange(i, i + BATCH_SIZE)
        else:
            idx = torch.arange(i, len(signals))
        preds[idx] = model(signals[idx].to(device)).cpu()
print("100%")

params = np.zeros((mask.shape + (47,)))
params[mask] = preds

nib.save(nib.Nifti1Image(params[..., 45::], affine), "params_mlp.nii.gz")
nib.save(
    nib.Nifti1Image(params[..., 0:45] * params[..., 46:47].clip(0, 1), affine),
    "odfs_mlp.nii.gz",
)

100%
