In [None]:
import sys

import nibabel as nib
import numpy as np
import torch

from mniscnn import (
    compartment_model_simulation,
    isft,
    l_max,
    n_coeffs,
    sft,
    sh,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Define acquisition protocol

bvals = torch.tensor(np.loadtxt("dwi.bval") / 1e3)
bvecs = torch.tensor(np.loadtxt("dwi.bvec").T)

idx = bvals > 0
bvals = bvals[idx]
bvecs = bvecs[idx]

bs = torch.unique(bvals)
n_shells = len(bs)
shell_idx_list = [np.where(bvals == b)[0] for b in bs]
bvecs_list = [bvecs[bvals == b] for b in bs]

bvecs_sft_list = []
bvecs_isft_list = []
for bvecs in bvecs_list:
    thetas = np.arccos(bvecs[:, 2])
    phis = np.arctan2(bvecs[:, 1], bvecs[:, 0]) + np.pi
    bvecs_isft = np.zeros((len(bvecs), n_coeffs))
    for l in range(0, l_max + 1, 2):
        for m in range(-l, l + 1):
            bvecs_isft[:, int(0.5 * l * (l + 1) + m)] = sh(l, m, thetas, phis)
    bvecs_sft = np.linalg.inv(bvecs_isft.T @ bvecs_isft) @ bvecs_isft.T
    bvecs_sft_list.append(torch.tensor(bvecs_sft).float())
    bvecs_isft_list.append(torch.tensor(bvecs_isft).float())

In [None]:
# Define test dataset

SNR = 30
n_test = int(1e6)
torch.random.manual_seed(666)
test_ds = torch.rand(n_test) * 3
test_fs = torch.rand(n_test)
test_odfs_sh = torch.tensor(np.loadtxt("odfs_sh.txt")).float()
np.random.seed(666)
test_odfs_sh = test_odfs_sh[np.random.choice(len(test_odfs_sh), n_test)]
test_targets = torch.vstack((test_ds / 3, test_fs)).T
np.savetxt("test_targets.txt", test_targets)

In [None]:
# Simulate signals

signals = torch.zeros(n_test, n_shells, 60)
for i in range(0, n_test, int(1e4)):

    sys.stdout.write(f"\r{int(100 * (i + 1) / n_test)}%")
    sys.stdout.flush()

    idx = torch.arange(i, i + int(1e4))

    batch_ads = torch.vstack((test_ds[idx], test_ds[idx])).T
    batch_rds = torch.vstack(
        (
            torch.zeros(len(idx)),
            (1 - test_fs[idx]) * test_ds[idx],
        )
    ).T
    batch_fs = torch.vstack((test_fs[idx], 1 - test_fs[idx])).T
    batch_odfs = test_odfs_sh[idx]

    for j, b in enumerate(bs):
        signals[idx, j, :] = (
            compartment_model_simulation(
                b,
                bvecs_isft_list[j],
                batch_ads,
                batch_rds,
                batch_fs,
                batch_odfs,
                "linear",
                device,
            )
            .cpu()
            .squeeze(-1)
        )

In [None]:
# Reorganize data and save in a nifti file

bvals = torch.tensor(np.loadtxt("dwi.bval") / 1e3)
bvecs = torch.tensor(np.loadtxt("dwi.bvec").T)

test_signals = torch.zeros(n_test, len(bvals))
test_signals[:, np.where(bvals == 0)[0]] = 1
test_signals[:, np.where(bvals == 1)[0]] = signals[:, 0, :]
test_signals[:, np.where(bvals == 2.2)[0]] = signals[:, 1, ::]

test_signals += torch.normal(
    mean=torch.zeros(test_signals.size()),
    std=torch.ones(test_signals.size()) / SNR,
)

nib.save(
    nib.Nifti1Image(test_signals.numpy().reshape(100, 100, 100, len(bvals)), np.eye(4)),
    "test_signals.nii.gz",
)