In [32]:
import numpy as np
import nibabel as nib
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table

In [3]:
subject_id = "100206"
subject_path = f"diffusion_data/{subject_id}/T1w/Diffusion"
dwi_img = nib.load(f'{subject_path}/data.nii.gz')
mask_img = nib.load(f'{subject_path}/nodif_brain_mask.nii.gz')

In [4]:
# Convert to numpy arrays for processing
dwi_data = dwi_img.get_fdata()
mask = mask_img.get_fdata()

In [5]:
print(f"DWI data shape: {dwi_data.shape}")  # Should be (X, Y, Z, num_volumes)
print(f"Mask shape: {mask.shape}")          # Should be (X, Y, Z)

DWI data shape: (145, 174, 145, 288)
Mask shape: (145, 174, 145)


In [6]:
# Load gradient information (bvals and bvecs)
print("\nLoading gradient information...")
bvals, bvecs = read_bvals_bvecs(f'{subject_path}/bvals', 
                               f'{subject_path}/bvecs')
print(f"Number of gradient directions: {len(bvals)}")
print(f"bvals shape: {bvals.shape}")     # Should match number of volumes
print(f"bvecs shape: {bvecs.shape}")     # Should be (num_volumes, 3)


Loading gradient information...
Number of gradient directions: 288
bvals shape: (288,)
bvecs shape: (288, 3)


In [7]:
# Create gradient table for DIPY
gtab = gradient_table(bvals, bvecs)

In [8]:
# Identify and extract B0 (non-diffusion weighted) volumes
b0_mask = gtab.b0s_mask
b0_data = dwi_data[..., b0_mask]
print(f"\nNumber of B0 volumes: {np.sum(b0_mask)}")
print(f"B0 data shape: {b0_data.shape}")
# Average B0 volumes to get a single reference image
b0_avg = np.mean(b0_data, axis=-1)
print(f"Average B0 shape: {b0_avg.shape}")


Number of B0 volumes: 18
B0 data shape: (145, 174, 145, 18)
Average B0 shape: (145, 174, 145)


In [9]:
# Extract and normalize diffusion weighted volumes
dwi_mask = ~b0_mask  # Mask for diffusion weighted volumes
dwi_vols = dwi_data[..., dwi_mask]
print(f"\nNumber of DWI volumes: {np.sum(dwi_mask)}")
print(f"DWI volumes shape: {dwi_vols.shape}")


Number of DWI volumes: 270
DWI volumes shape: (145, 174, 145, 270)


In [10]:
# Normalize DWI volumes by B0 (avoid division by zero with small epsilon)
dwi_norm = dwi_vols / (b0_avg[..., None] + 1e-6)
print(f"Normalized DWI shape: {dwi_norm.shape}")

Normalized DWI shape: (145, 174, 145, 270)


In [11]:
# Find valid voxels using the brain mask
valid_idx = np.where(mask > 0)
print(f"\nNumber of valid voxels in mask: {len(valid_idx[0])}")


Number of valid voxels in mask: 936256


In [25]:
# Sample random voxels for training
n_samples = 50000  # Adjust this number based on your needs
sample_idx = np.random.choice(len(valid_idx[0]), 
                            min(n_samples, len(valid_idx[0])), 
                            replace=False)

In [26]:
# Extract features (signal intensities) from sampled voxels
features = []
for idx in sample_idx:
    x, y, z = valid_idx[0][idx], valid_idx[1][idx], valid_idx[2][idx]
    features.append(dwi_norm[x, y, z, :])

In [27]:
features = np.array(features)
gradient_directions = bvecs[dwi_mask]  # Only keep directions for DWI volumes

In [28]:
print("\nFinal data shapes:")
print(f"Features shape: {features.shape}")           # Should be (n_samples, n_directions)
print(f"Gradient directions shape: {gradient_directions.shape}")  # Should be (n_directions, 3)

# Basic sanity checks
print("\nSanity checks:")
print(f"Max normalized value: {np.max(features)}")
print(f"Min normalized value: {np.min(features)}")
print(f"Gradient directions magnitude close to 1: {np.allclose(np.linalg.norm(gradient_directions, axis=1), 1, atol=1e-3)}")


Final data shapes:
Features shape: (50000, 270)
Gradient directions shape: (270, 3)

Sanity checks:
Max normalized value: 474281555.17578125
Min normalized value: 0.0
Gradient directions magnitude close to 1: True


In [29]:
print("Loading ground truth data...")
gt_data = np.load('ground_truth.npz')
ground_truth_tensors = gt_data['tensors']
valid_coordinates = gt_data['coordinates']
print(f"Ground truth tensors shape: {ground_truth_tensors.shape}")
print(f"Valid coordinates shape: {valid_coordinates.shape}")

Loading ground truth data...
Ground truth tensors shape: (936256, 6)
Valid coordinates shape: (936256, 3)


In [31]:
# Number of directions to use for sparse estimation
n_directions = 21  # We can adjust this number

def create_graph_data(features, gradient_directions, n_sparse_directions, batch_size = 32, threshold_angle = 45):
    """
    Create graph data for a batch of samples, including ground tensor
    """
    voxel_indices = np.random.choice(len(features), size = batch_size, replace = False)
    print(f"Creating graphs for batch of {batch_size} voxels using {n_sparse_directions} directions each")
    
    batch_nodes = []
    batch_edges = []
    batch_tensors = []

    for idx in voxel_indices:
        # Randomly selected directions
        selected_dir_idx = np.random.choice(len(gradient_directions), size=n_sparse_directions, replace=False)

        directions = gradient_directions[selected_dir_idx]
        signals = features[idx, selected_dir_idx]
        
        # Create node features: x, y, z, signal
        nodes = np.column_stack([directions, signals])

        # Create edges acc. to angle
        directions_norm = directions / np.linalg.norm(directions, axis=1, keepdims=True)
        cos_sim = np.dot(directions_norm, directions_norm.T)
        angles = np.arccos(np.clip(cos_sim, -1.0, 1.0)) * 180/np.pi
        
        src, dst = np.where(angles < threshold_angle)
        mask = src != dst
        edges = np.column_stack([src[mask], dst[mask]])

        batch_nodes.append(nodes)
        batch_edges.append(edges)
        batch_tensors.append(ground_truth_tensors[idx])

        if idx == voxel_indices[0]:
            print("\nFirst sample edge check:")
            print(f"Total edges created: {edges.shape[0]}")
            print("Example edges and their angles:")
            for i in range(min(5, len(edges))):
                e1, e2 = edges[i]
                print(f"Edge {i}: {e1}->{e2}, Angle: {angles[e1,e2]:.2f} degrees")
        
    return batch_nodes, batch_edges, batch_tensors

nodes, edges, gt_tensors = create_graph_data(features, gradient_directions, n_directions)
# Print info for verification
print("\nBatch statistics:")
print(f"Number of samples in batch: {len(nodes)}")
print(f"Number of nodes per graph: {nodes[0].shape[0]}")
print(f"Node feature dimensionality: {nodes[0].shape[1]}")
print(f"Ground truth tensors: {len(gt_tensors)}")

# Print example for first sample
print("\nFirst sample in batch:")
print(f"Number of nodes: {nodes[0].shape[0]}")
print(f"Number of edges: {edges[0].shape[0]}")
print(f"Ground truth tensor components: {gt_tensors[0]}")

Creating graphs for batch of 32 voxels using 21 directions each

First sample edge check:
Total edges created: 80
Example edges and their angles:
Edge 0: 0->19, Angle: 24.71 degrees
Edge 1: 1->3, Angle: 24.15 degrees
Edge 2: 1->6, Angle: 20.78 degrees
Edge 3: 1->8, Angle: 36.96 degrees
Edge 4: 1->12, Angle: 37.86 degrees

Batch statistics:
Number of samples in batch: 32
Number of nodes per graph: 21
Node feature dimensionality: 4
Ground truth tensors: 32

First sample in batch:
Number of nodes: 21
Number of edges: 80
Ground truth tensor components: [ 8.39275777e-04 -2.08106019e-05  9.17346774e-04  4.89926518e-05
 -1.59106882e-06  8.41369929e-04]


In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool
import torch.optim as optim
from torch_geometric.data import Data, Batch

In [34]:
class DiffusionGNN(torch.nn.Module):
    def __init__(self, node_features = 4, hidden_dim = 32):
        super(DiffusionGNN, self).__init__()
    
        # Layers
        self.conv1 = DiffusionConv(node_features, hidden_dim)
        self.conv2 = DiffusionConv(hidden_dim, hidden_dim)

        # MLP for tensor prediction
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 6)
        )
    
    def forward(self, x, edge_index, batch):
        """
        x: node_features
        edge_index = graph conn info [2, num_edges]
        batch: batch assignment for nodes
        """
        # pass through gnn layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # combine node features for graph
        x = global_mean_pool(x, batch)  # [batch_size, hidden_dim]

        # raw predictions
        out = self.mlp(x)

        diag = out[:, :3] # Dxx, Dyy, Dzz
        offdiag = out[:, 3:] # Dxy, Dyz, Dxz

        diag = torch.sigmoid(diag)
        offdiag = torch.tanh(offdiag)

        return torch.cat([diag, offdiag], dim = 1)



class DiffusionConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(DiffusionConv, self).__init__(aggr="mean")
        
        # MLP to process messages
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index):
        return self.propogate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        """
        x_i: features of target nodes
        x_j: features of source nodes
        Returns: messages to be aggregated
        """
        tmp = torch.cat([x_i, x_j], dim = 1)
        return self.mlp(tmp)