In [None]:
from google.colab import drive
drive.mount("/content/drive")
import os
os.chdir("/content/drive/MyDrive/mniscnn")
!pip install dipy healpy

In [None]:
import sys

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import seaborn as sns
import torch

from mniscnn import (
    compartment_model_simulation,
    isft,
    l_max,
    n_coeffs,
    sft,
    sh,
)

sns.set_theme()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
!nvidia-smi -L

In [None]:
# Define acquisition protocol

bvals = torch.tensor(np.loadtxt("data/train-subject/dwi.bval") / 1e3)
bvecs = torch.tensor(np.loadtxt("data/train-subject/dwi.bvec").T)

idx = bvals > 0
bvals = bvals[idx]
bvecs = bvecs[idx]

bs = torch.unique(bvals)
n_shells = len(bs)
shell_idx_list = [np.where(bvals == b)[0] for b in bs]
bvecs_list = [bvecs[bvals == b] for b in bs]

bvecs_sft_list = []
bvecs_isft_list = []
for bvecs in bvecs_list:
    thetas = np.arccos(bvecs[:, 2])
    phis = np.arctan2(bvecs[:, 1], bvecs[:, 0]) + np.pi
    bvecs_isft = np.zeros((len(bvecs), n_coeffs))
    for l in range(0, l_max + 1, 2):
        for m in range(-l, l + 1):
            bvecs_isft[:, int(0.5 * l * (l + 1) + m)] = sh(l, m, thetas, phis)
    bvecs_sft = np.linalg.inv(bvecs_isft.T @ bvecs_isft) @ bvecs_isft.T
    bvecs_sft_list.append(torch.tensor(bvecs_sft).float())
    bvecs_isft_list.append(torch.tensor(bvecs_isft).float())

In [None]:
# Define model


class MLPModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(2, 256)
        self.bn1 = torch.nn.BatchNorm1d(256)
        self.fc2 = torch.nn.Linear(256, 256)
        self.bn2 = torch.nn.BatchNorm1d(256)
        self.fc3 = torch.nn.Linear(256, 256)
        self.bn3 = torch.nn.BatchNorm1d(256)
        self.fc4 = torch.nn.Linear(256, 2)

    def forward(self, x):
        x = torch.nanmean(x, dim=2)
        x = self.fc1(x)
        x = self.bn1(x)
        x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = torch.nn.functional.relu(x)
        x = self.fc3(x)
        x = self.bn3(x)
        x = torch.nn.functional.relu(x)
        x = self.fc4(x)
        return x


torch.random.manual_seed(123)
model = MLPModel().to(device)

print(f"Number of trainable parameters = {sum(p.numel() for p in model.parameters())}")

In [None]:
# Define validation dataset

SNR = 30
n_val = int(1e5)

torch.random.manual_seed(123)
val_ds = torch.rand(n_val) * 3
val_fs = torch.rand(n_val)
val_odfs_sh = torch.tensor(np.loadtxt("data/val-subject/odfs_sh.txt")).float()
np.random.seed(123)
val_odfs_sh = val_odfs_sh[np.random.choice(len(val_odfs_sh), n_val)]

val_targets = torch.vstack((val_ds / 3, val_fs)).T
val_signals = torch.zeros(n_val, n_shells, 60)

for i in range(0, n_val, int(1e4)):

    idx = torch.arange(i, i + int(1e4))

    batch_ads = torch.vstack((val_ds[idx], val_ds[idx])).T
    batch_rds = torch.vstack(
        (
            torch.zeros(len(idx)),
            (1 - val_fs[idx]) * val_ds[idx],
        )
    ).T
    batch_fs = torch.vstack((val_fs[idx], 1 - val_fs[idx])).T
    batch_odfs = val_odfs_sh[idx]

    for j, b in enumerate(bs):
        signals = compartment_model_simulation(
            b,
            bvecs_isft_list[j],
            batch_ads,
            batch_rds,
            batch_fs,
            batch_odfs,
            "linear",
            device,
        ).cpu()
        signals = torch.abs(
            signals
            + torch.normal(
                mean=torch.zeros(signals.size()),
                std=torch.ones(signals.size()) / SNR,
            )
            + 1j
            * torch.normal(
                mean=torch.zeros(signals.size()),
                std=torch.ones(signals.size()) / SNR,
            )
        )
        val_signals[idx, j, :] = signals.squeeze(-1)

In [None]:
# Train

batch_size = int(1e3)
n_batches = int(5e4)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

n_iter = int(1e3)
all_train_odfs_sh = torch.tensor(np.loadtxt("data/train-subject/odfs_sh.txt")).float()
Rs = torch.tensor(
    np.concatenate(
        (
            np.eye(45)[np.newaxis],
            np.load("data/train-subject/Rs.npy").reshape(-1, n_coeffs, n_coeffs),
        ),
        axis=0,
    )
).float()

train_losses = []
val_losses = []

for batch in range(n_batches):

    if batch == int(3e4) or batch == int(4e4):
        scheduler.step()

    train_loss = 0

    torch.random.manual_seed(batch)
    train_ds = torch.rand(batch_size) * 3
    train_fs = torch.rand(batch_size)
    np.random.seed(batch)
    train_odfs_sh = all_train_odfs_sh[
        np.random.choice(len(all_train_odfs_sh), batch_size)
    ]
    train_Rs = Rs[np.random.choice(len(Rs), batch_size)]
    train_odfs_sh = (train_Rs @ train_odfs_sh.unsqueeze(-1)).squeeze(-1)

    for i in range(0, batch_size, n_iter):

        idx = torch.arange(i, i + n_iter)

        batch_ads = torch.vstack((train_ds[idx], train_ds[idx])).T
        batch_rds = torch.vstack(
            (
                torch.zeros(len(idx)),
                (1 - train_fs[idx]) * train_ds[idx],
            )
        ).T
        batch_fs = torch.vstack((train_fs[idx], 1 - train_fs[idx])).T
        batch_odfs = train_odfs_sh[idx]

        batch_targets = torch.vstack((train_ds[idx] / 3, train_fs[idx])).T

        signals = torch.zeros(len(idx), n_shells, 60)
        for j, b in enumerate(bs):
            signals[:, j, :] = (
                compartment_model_simulation(
                    b,
                    bvecs_isft_list[j],
                    batch_ads,
                    batch_rds,
                    batch_fs,
                    batch_odfs,
                    "linear",
                    device,
                )
                .cpu()
                .squeeze(-1)
            )
            signals[:, j, :] = torch.abs(
                signals[:, j, :]
                + torch.normal(
                    mean=torch.zeros(signals[:, j, :].size()),
                    std=torch.ones(signals[:, j, :].size()) / SNR,
                )
                + 1j
                * torch.normal(
                    mean=torch.zeros(signals[:, j, :].size()),
                    std=torch.ones(signals[:, j, :].size()) / SNR,
                )
            )

        y = model(signals.to(device))
        loss = loss_fn(y, batch_targets.to(device))
        loss.backward()
        train_loss += loss.item()

    train_losses.append(train_loss / (batch_size / n_iter))

    sys.stdout.write(f"\rbatch = {batch + 1}, train_loss = {train_losses[-1]}")
    sys.stdout.flush()

    optimizer.step()
    optimizer.zero_grad()

    if batch % 10 == 0 or batch == n_batches - 1:
        model.eval()
        with torch.no_grad():
            val_preds = torch.zeros(val_targets.size()).to(device)
            for i in range(0, n_val, n_iter):
                idx = torch.arange(i, i + n_iter)
                val_preds[idx] = model(val_signals[idx].to(device))
            val_losses.append(loss_fn(val_preds, val_targets.to(device)).item())
        model.train()

sys.stdout.write(f"\rval_loss = {val_losses[-1]}")
sys.stdout.flush()

fig = plt.figure(figsize=(8, 4))
plt.plot(np.linspace(1, n_batches, n_batches), train_losses)
plt.plot(np.linspace(1, n_batches, len(val_losses)), val_losses)
plt.yscale("log")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.legend(["Training", "Validation"])
fig.tight_layout()
plt.show()

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
for i in range(2):
    sc = ax[i].scatter(val_targets[:, i], val_preds[:, i].cpu(), s=1, alpha=0.1)
    sc.set_edgecolor("none")
plt.show()

torch.save(model.state_dict(), "mlp_model_weights.pt")
np.savetxt("mlp_train_losses.txt", train_losses)
np.savetxt("mlp_val_losses.txt", val_losses)