# 06 - Basic Preprocessing: Loading, Masking, and Denoising

This notebook demonstrates several basic but essential preprocessing steps for diffusion MRI (dMRI) data using the `diffusemri` library. Proper preprocessing is crucial for obtaining reliable results from downstream analyses like model fitting and tractography.

We will cover:
1.  **Loading dMRI Data**: For this notebook, we'll create dummy NIfTI data directly. In practice, you would use the I/O utilities shown in previous notebooks (e.g., for DICOM, NRRD, or loading existing NIfTI files).
2.  **Brain Masking**: Isolating the brain tissue from non-brain regions.
3.  **Denoising**: Applying techniques like Marchenko-Pastur PCA (MP-PCA) and Gibbs Ringing correction to improve data quality.

In [None]:
import os
import shutil
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

# Dipy import for gradient table, if creating data from scratch
from dipy.core.gradients import gradient_table 

# diffusemri library imports
from preprocessing.masking import create_brain_mask
from preprocessing.denoising import denoise_mppca_data, correct_gibbs_ringing_dipy

# Setup a temporary directory for example files
TEMP_DIR = "temp_basic_preproc_example"
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
os.makedirs(TEMP_DIR)

print(f"Temporary directory for examples: {os.path.abspath(TEMP_DIR)}")

# Helper function for plotting slices
def show_slice(data, slice_idx=None, title="", vmin=None, vmax=None):
    """Displays a central slice of 2D, 3D, or 4D data (first volume for 4D)."""
    data_to_show = None
    if data.ndim == 4:
        s_idx = slice_idx if slice_idx is not None else data.shape[2] // 2
        data_to_show = data[:, :, s_idx, 0] # Show first volume if 4D
    elif data.ndim == 3:
        s_idx = slice_idx if slice_idx is not None else data.shape[2] // 2
        data_to_show = data[:, :, s_idx]
    elif data.ndim == 2: # Already a 2D slice
        data_to_show = data
    else:
        print(f"Cannot display data with {data.ndim} dimensions.")
        return
    
    plt.imshow(data_to_show.T, cmap='gray', origin='lower', vmin=vmin, vmax=vmax)
    plt.title(title)
    plt.xlabel("X voxel index"); plt.ylabel("Y voxel index")
    plt.colorbar(label="Intensity")
    plt.show()

## Part 1: Loading dMRI Data

For a self-contained example, we'll create dummy 4D DWI NIfTI data directly using NumPy and Nibabel. In a real workflow, you would typically load your existing dMRI data (e.g., from NIfTI, DICOM, or NRRD files) using the I/O utilities from the `data_io` module of this library, as demonstrated in previous notebooks.

In [None]:
# Define shape for dummy 4D DWI data: (X, Y, Z, num_volumes)
shape_4d = (30, 30, 10, 7) # Small dimensions for quick processing
dwi_data_np = np.zeros(shape_4d, dtype=np.float32)

# Create some synthetic signal: make b0s brighter, add some pattern
dwi_data_np[10:20, 10:20, 3:7, :] += 500 # A central block with higher signal
dwi_data_np += np.random.rand(*shape_4d) * 100 # Add some noise

# Define a simple affine matrix (e.g., 2mm isotropic voxels)
affine_np = np.diag([2.0, 2.0, 2.5, 1.0])

# Create dummy b-values and b-vectors
# Example: 1 b0, 6 diffusion-weighted volumes
bvals_np = np.array([0, 1000, 1000, 1000, 1000, 1000, 1000])
bvecs_np = np.random.rand(shape_4d[-1], 3) * 2 - 1 # Random vectors between -1 and 1
bvecs_np[0, :] = 0  # Set b-vector for b0 to zero
# Normalize non-b0 b-vectors
for i in range(1, len(bvecs_np)):
    norm = np.linalg.norm(bvecs_np[i])
    if norm > 1e-6: # Avoid division by zero for potential zero vectors
        bvecs_np[i] /= norm

# Create a Dipy GradientTable object
gtab = gradient_table(bvals_np, bvecs_np, b0_threshold=50) # Use a small b0_threshold for safety with dummy bvals

# Make b0 volumes have higher intensity for visual distinction
dwi_data_np[..., gtab.b0s_mask] *= 2.0 

# Save the dummy DWI data to a NIfTI file (needed for Gibbs correction function)
dummy_nifti_dwi_path = os.path.join(TEMP_DIR, "dummy_dwi_for_preproc.nii.gz")
nib.save(nib.Nifti1Image(dwi_data_np.astype(np.float32), affine_np), dummy_nifti_dwi_path)

print(f"Created dummy NIfTI DWI data: {dummy_nifti_dwi_path}")
print(f"  Data shape: {dwi_data_np.shape}, Data type: {dwi_data_np.dtype}")
print(f"  Gradient table: {gtab.gradients.shape[0]} volumes, {np.sum(gtab.b0s_mask)} b0s")
print(f"  b-values: {gtab.bvals}")

# Show a slice of the original b0 volume
b0_slice_idx = dwi_data_np.shape[2] // 2
show_slice(dwi_data_np[..., gtab.b0s_mask], slice_idx=b0_slice_idx, title=f"Original DWI Data (b0, slice {b0_slice_idx})")

## Part 2: Brain Masking

Brain masking aims to create a binary mask that identifies brain voxels, excluding skull, CSF outside the brain, and background. This is often a crucial first step to restrict subsequent analyses to the brain tissue.

The `create_brain_mask` function (from `preprocessing.masking`) typically uses the mean of b0 volumes (or all volumes if no b0s are distinct) and applies a median Otsu method.

In [None]:
# For brain masking, we need the dMRI data (NumPy array) and voxel size.
# Voxel size can be derived from the affine matrix diagonal elements.
voxel_size_dummy = np.abs(np.diag(affine_np)[:3])
print(f"Derived voxel size for masking: {voxel_size_dummy}")

brain_mask_np = None
masked_dwi_data_np = None
try:
    # Using smaller median_radius and numpass for faster processing on dummy data.
    # These parameters might need adjustment for real data.
    brain_mask_np, masked_dwi_data_np = create_brain_mask(
        dwi_data_np, 
        voxel_size=voxel_size_dummy, 
        median_radius=2, # Default is 4
        numpass=2        # Default is 4
    )
    print(f"\nBrain mask created successfully.")
    print(f"  Brain mask shape: {brain_mask_np.shape}, Data type: {brain_mask_np.dtype}")
    print(f"  Masked DWI data shape: {masked_dwi_data_np.shape}")

    # Show a slice of the generated brain mask and the masked b0 volume
    mask_slice_idx = brain_mask_np.shape[2] // 2
    show_slice(masked_dwi_data_np[..., gtab.b0s_mask], slice_idx=mask_slice_idx, title=f"Masked DWI Data (b0, slice {mask_slice_idx})")
    show_slice(brain_mask_np, slice_idx=mask_slice_idx, title=f"Brain Mask (slice {mask_slice_idx})", vmin=0, vmax=1)

except Exception as e:
    print(f"An error occurred during brain masking: {e}")
    print("Using original DWI data for subsequent steps if masking failed.")
    # Fallback to using original data if masking fails for any reason
    masked_dwi_data_np = np.copy(dwi_data_np) 
    brain_mask_np = np.ones(dwi_data_np.shape[:3], dtype=bool) # Dummy mask

## Part 3: Denoising

Denoising techniques aim to reduce noise in the dMRI signal, which can improve the accuracy of model fitting and other analyses.

### MP-PCA Denoising

Marchenko-Pastur Principal Component Analysis (MP-PCA) is a method that denoises data by identifying and removing noise components based on Random Matrix Theory. The `denoise_mppca_data` function (from `preprocessing.denoising`) provides an implementation (wrapping a PyTorch-based core).

In [None]:
denoised_mppca_data_np = None
input_for_mppca = masked_dwi_data_np if masked_dwi_data_np is not None else dwi_data_np

try:
    # Using a small patch_radius for faster processing. Default is often 2 or 3.
    denoised_mppca_data_np = denoise_mppca_data(input_for_mppca, patch_radius=1)
    print(f"\nMP-PCA denoising completed successfully.")
    print(f"  Denoised data shape: {denoised_mppca_data_np.shape}")

    # Show a slice of the denoised b0 volume
    denoised_slice_idx = denoised_mppca_data_np.shape[2] // 2
    show_slice(denoised_mppca_data_np[..., gtab.b0s_mask], slice_idx=denoised_slice_idx, title=f"MP-PCA Denoised (b0, slice {denoised_slice_idx})")

except Exception as e:
    print(f"An error occurred during MP-PCA denoising: {e}")
    print("Using data before MP-PCA for subsequent steps if denoising failed.")
    # Fallback if MP-PCA fails
    denoised_mppca_data_np = np.copy(input_for_mppca)

### Gibbs Ringing Correction

Gibbs ringing artifacts are truncation artifacts that appear as spurious oscillations near sharp intensity edges (e.g., CSF-tissue boundaries). The `correct_gibbs_ringing_dipy` function (from `preprocessing.denoising`) wraps Dipy's `gibbs_removal` to mitigate these artifacts. This function typically operates on NIfTI file paths.

In [None]:
# Determine which data to use as input for Gibbs correction
# Prefer MP-PCA denoised data if available, otherwise masked data, fallback to original.
if denoised_mppca_data_np is not None:
    data_for_gibbs_np = denoised_mppca_data_np
    print("\nUsing MP-PCA denoised data as input for Gibbs correction.")
elif masked_dwi_data_np is not None:
    data_for_gibbs_np = masked_dwi_data_np
    print("\nUsing masked DWI data as input for Gibbs correction.")
else:
    data_for_gibbs_np = dwi_data_np
    print("\nUsing original DWI data as input for Gibbs correction (masking/MP-PCA might have failed).")

# Save this data to a temporary NIfTI file, as correct_gibbs_ringing_dipy expects file paths
pre_gibbs_nifti_path = os.path.join(TEMP_DIR, "dwi_data_for_gibbs_correction.nii.gz")
nib.save(nib.Nifti1Image(data_for_gibbs_np.astype(np.float32), affine_np), pre_gibbs_nifti_path)
print(f"Saved data for Gibbs input to: {pre_gibbs_nifti_path}")

gibbs_corrected_nifti_path = os.path.join(TEMP_DIR, "dwi_gibbs_corrected_output.nii.gz")
gibbs_corrected_data_np = None

try:
    # Using default parameters (slice_axis=2, n_points=3 for unringing)
    # num_processes can be set for parallel processing if data is large
    output_path_gibbs = correct_gibbs_ringing_dipy(
        input_image_file=pre_gibbs_nifti_path, 
        output_corrected_file=gibbs_corrected_nifti_path,
        num_processes=1 # Can be increased if resources allow
    )
    
    if os.path.exists(output_path_gibbs):
        gibbs_corrected_data_np = nib.load(output_path_gibbs).get_fdata()
        print(f"\nGibbs ringing correction completed successfully.")
        print(f"  Corrected data shape: {gibbs_corrected_data_np.shape}")
        print(f"  Corrected NIfTI file saved to: {output_path_gibbs}")

        # Show a slice of the Gibbs corrected b0 volume
        gibbs_slice_idx = gibbs_corrected_data_np.shape[2] // 2
        show_slice(gibbs_corrected_data_np[..., gtab.b0s_mask], slice_idx=gibbs_slice_idx, title=f"Gibbs Corrected (b0, slice {gibbs_slice_idx})")
    else:
        print("\nGibbs ringing correction function ran, but output file not found.")

except Exception as e:
    print(f"An error occurred during Gibbs ringing correction: {e}")

## Cleanup

Remove the temporary directory and its contents.

In [None]:
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
    print(f"Cleaned up temporary directory: {TEMP_DIR}")