<a href="https://colab.research.google.com/github/lsprietog/public_release/blob/main/inference_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 Apply IVIM-DKI Machine Learning to Your Data

This notebook allows you to use the pre-trained Machine Learning models from the paper **"Exploring the Potential of Machine Learning Algorithms to Improve Diffusion Nuclear Magnetic Resonance Imaging Models Analysis"** on your own NIfTI data.

**What this tool does:**
1. Loads your Diffusion MRI data (NIfTI format).
2. Applies a pre-trained Extra Trees Regressor to estimate $D$, $f$, $D^*$, and $K$.
3. Generates and saves the parameter maps.

**Note:** The pre-trained model included here was trained on a standard set of b-values: `[0, 10, 20, 30, 50, 80, 100, 150, 200, 400, 600, 800, 1000, 1500, 2000]`. For best results on your specific protocol, we recommend retraining the model using the `demo_colab.ipynb` notebook.

In [None]:
# @title 1. Setup & Install
!git clone https://github.com/lsprietog/public_release.git
%cd public_release
!pip install -r requirements.txt

import numpy as np
import nibabel as nib
import joblib
import matplotlib.pyplot as plt
import os
from google.colab import files

In [None]:
# @title 2. Upload Your Data
print("Please upload your 4D DWI NIfTI file (.nii or .nii.gz)...")
uploaded = files.upload()
if uploaded:
    dwi_filename = list(uploaded.keys())[0]
    print(f"Uploaded: {dwi_filename}")
else:
    print("No file uploaded.")

print("\n(Optional) Upload a Mask file (binary mask). If none, press Cancel or skip.")
try:
    uploaded_mask = files.upload()
    if uploaded_mask:
        mask_filename = list(uploaded_mask.keys())[0]
    else:
        mask_filename = None
except:
    mask_filename = None
    print("No mask provided. Processing entire volume (this might take longer).")

In [None]:
# @title 3. Load Model & Data
# Load the pre-trained model
model_path = 'models/ivim_dki_extratrees.joblib'
if not os.path.exists(model_path):
    print("Downloading pre-trained model...")
    # In a real scenario, you might fetch this from a release URL if it's not in the repo
    # But here we assume it's in the repo we cloned
    pass

print(f"Loading model from {model_path}...")
model = joblib.load(model_path)

# Load NIfTI
img = nib.load(dwi_filename)
data = img.get_fdata()
affine = img.affine

if mask_filename:
    mask = nib.load(mask_filename).get_fdata() > 0
else:
    # Create a simple threshold mask to avoid background noise
    # Assuming b0 is the first volume
    mask = data[..., 0] > (np.mean(data[..., 0]) * 0.1)

print(f"Data shape: {data.shape}")
print(f"Mask voxels: {np.sum(mask)}")

In [None]:
# @title 4. Run Prediction
print("Preprocessing data...")

# Reshape for prediction
n_x, n_y, n_z, n_b = data.shape
flat_data = data[mask]

# Normalize signal (S/S0)
# Assuming the first volume is b=0
S0 = flat_data[:, 0][:, np.newaxis]
S0[S0 == 0] = 1 # Avoid division by zero
X_input = flat_data / S0

# Check if b-values match (Basic check)
expected_b_len = 15 # Based on our training script
if n_b != expected_b_len:
    print(f"WARNING: Your data has {n_b} volumes, but the model expects {expected_b_len}.")
    print("We will try to interpolate or truncate, but results may be inaccurate.")
    # Simple truncation or padding for demo purposes
    if n_b > expected_b_len:
        X_input = X_input[:, :expected_b_len]
    else:
        # Pad with zeros (Not ideal, but prevents crash)
        padding = np.zeros((X_input.shape[0], expected_b_len - n_b))
        X_input = np.hstack([X_input, padding])

print("Predicting parameters (this is the fast part!)...")
predictions = model.predict(X_input)

# Reconstruct 3D maps
param_maps = np.zeros((n_x, n_y, n_z, 4)) # D, f, D*, K
param_maps[mask] = predictions

print("Prediction complete!")

In [None]:
# @title 5. Visualize & Download
# Extract maps
D_map = param_maps[..., 0]
f_map = param_maps[..., 1]
Dstar_map = param_maps[..., 2]
K_map = param_maps[..., 3]

# Plot middle slice
z_slice = n_z // 2

fig, axes = plt.subplots(1, 4, figsize=(20, 5))
axes[0].imshow(np.rot90(D_map[..., z_slice]), cmap='gray', vmin=0, vmax=0.003)
axes[0].set_title('Diffusion (D)')
axes[1].imshow(np.rot90(f_map[..., z_slice]), cmap='jet', vmin=0, vmax=0.5)
axes[1].set_title('Perfusion Fraction (f)')
axes[2].imshow(np.rot90(Dstar_map[..., z_slice]), cmap='hot', vmin=0, vmax=0.1)
axes[2].set_title('Pseudo-Diffusion (D*)')
axes[3].imshow(np.rot90(K_map[..., z_slice]), cmap='magma', vmin=0, vmax=2.0)
axes[3].set_title('Kurtosis (K)')
plt.show()

# Save NIfTI files
print("Saving NIfTI files...")
nib.save(nib.Nifti1Image(D_map, affine), 'map_D.nii.gz')
nib.save(nib.Nifti1Image(f_map, affine), 'map_f.nii.gz')
nib.save(nib.Nifti1Image(Dstar_map, affine), 'map_Dstar.nii.gz')
nib.save(nib.Nifti1Image(K_map, affine), 'map_K.nii.gz')

print("Downloading maps...")
files.download('map_D.nii.gz')
files.download('map_f.nii.gz')
files.download('map_Dstar.nii.gz')
files.download('map_K.nii.gz')