In [22]:
import numpy as np
import nibabel as nib
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table

In [23]:
subject_id = "100206"
subject_path = f"diffusion_data/{subject_id}/T1w/Diffusion"
dwi_img = nib.load(f'{subject_path}/data.nii.gz')
mask_img = nib.load(f'{subject_path}/nodif_brain_mask.nii.gz')

In [24]:
dwi_data = dwi_img.get_fdata()
original_mask = mask_img.get_fdata() > 0

In [25]:
# store original indices before any voxel filtering 
original_idx = np.where(original_mask)

In [26]:
# Map (x,y,z) to original linear index
coord_to_gtt = {}
for i in range(len(original_idx[0])):
    x, y, z = original_idx[0][i], original_idx[1][i], original_idx[2][i]
    coord_to_gtt[(x,y,z)] = i  # i is the index in ground_truth_tensors

In [27]:
print(f"DWI data shape: {dwi_data.shape}")  # (X, Y, Z, num_volumes)

DWI data shape: (145, 174, 145, 288)


In [28]:
bvals, bvecs = read_bvals_bvecs(f'{subject_path}/bvals', 
                               f'{subject_path}/bvecs')
gtab = gradient_table(bvals, bvecs)

In [29]:
# get b0 data and b0_avg
b0_mask = gtab.b0s_mask
b0_data = dwi_data[..., b0_mask]
print(f"B0 data shape: {b0_data.shape}")
b0_avg = np.mean(b0_data, axis=-1)

# b1000 images mask
b1000_mask = (bvals >= 990) & (bvals <= 1010)
print(f"Number of b=1000 volumes: {np.sum(b1000_mask)}")

B0 data shape: (145, 174, 145, 18)
Number of b=1000 volumes: 90


In [30]:
# Get all b1000 scans
dwi_vols = dwi_data[..., b1000_mask]
print(f"DWI volumes shape: {dwi_vols.shape}")

DWI volumes shape: (145, 174, 145, 90)


In [31]:
# mask out voxels with very low b0 signal
b0_threshold = 250
valid_b0_mask = b0_avg > b0_threshold
mask = original_mask & valid_b0_mask  # Combine with brain mask

In [32]:
# Find valid voxels using the brain mask
valid_idx = np.where(mask) # shape: 3 x num_valid_voxels
print(f"\nNumber of valid voxels in mask: {len(valid_idx[0])}")


Number of valid voxels in mask: 926671


In [33]:
# Sample random voxels for training
n_samples = 150000
sample_idx = np.random.choice(len(valid_idx[0]), 
                            min(n_samples, len(valid_idx[0])), 
                            replace=False)

In [34]:
# Extract features (signal intensities) from sampled voxels
features = []
gtt_indices = [] # store ground truth tensor indices. gtt_indices[i] gives index of ith feature in original number of voxels (93XXXX)

for idx in sample_idx:
    x, y, z = valid_idx[0][idx], valid_idx[1][idx], valid_idx[2][idx]
    signal = dwi_vols[x, y, z, :]
    b0_ref = b0_avg[x, y, z]
    normalized_signal = signal / (b0_ref)
    features.append(normalized_signal)
    gtt_indices.append(coord_to_gtt[(x, y, z)])

# gtt_indices shape = 100000 x 1
# features shape = 100000 x 90 x 1


In [35]:
features = np.array(features)
gtt_indices = np.array(gtt_indices)
gradient_directions = bvecs[b1000_mask]  # Only keep directions for DWI volumes, shape = 90 x 3

In [36]:
print("Loading ground truth data...")
gt_data = np.load('ground_truth_v2.npz')
ground_truth_tensors = gt_data['tensors'] # shape 936256(all voxels) x 6

Loading ground truth data...


In [37]:
ground_truth_tensors.shape

(936256, 6)

In [38]:
n_directions = 21

def create_graph_data(batch_indices, direction_indices, batch_size=32, threshold_angle=60):
    """
    Optimized version of create_graph_data that processes the batch in parallel
    using vectorized operations.
    """
    # Pre-compute directions once for the whole batch
    current_directions = gradient_directions[direction_indices]  # shape = 21 x 3
    directions_norm = current_directions / np.linalg.norm(current_directions, axis=1, keepdims=True)
    
    # Pre-compute angle matrix once
    cos_sim = np.dot(directions_norm, directions_norm.T)
    angles = np.arccos(np.clip(cos_sim, -1.0, 1.0)) * 180/np.pi
    src, dst = np.where(angles < threshold_angle)
    mask = src != dst
    edge_template = np.column_stack([src[mask], dst[mask]])
    edge_weights_template = cos_sim[src[mask], dst[mask]]
    
    # Process all signals at once
    all_signals = features[batch_indices][:, direction_indices]  # shape = batch_size x 21 x 1
    
    # Pre-allocate arrays
    batch_nodes = np.zeros((len(batch_indices), n_directions, 4))
    batch_edges = [edge_template for _ in range(len(batch_indices))]
    batch_edge_weights = [edge_weights_template for _ in range(len(batch_indices))]
    
    # Vectorized node creation
    batch_nodes[..., 0] = all_signals.reshape(len(batch_indices), -1)
    batch_nodes[..., 1:4] = current_directions[None, :, :]  # Broadcasting
    
    # Get tensors in one operation
    batch_tensors = ground_truth_tensors[gtt_indices[batch_indices]]
    
    return batch_nodes, batch_edges, batch_edge_weights, batch_tensors

In [39]:
def select_diverse_directions(all_directions, n_select=21):
    selected = [0]  # Start with first direction
    while len(selected) < n_select:
        # Calculate angles with all selected directions
        angles = []
        for i in range(len(all_directions)):
            if i in selected:
                continue
            min_angle = float('inf')
            for s in selected:
                # Cosine similarity
                angle = np.arccos(np.clip(
                    np.dot(all_directions[i], all_directions[s]), -1.0, 1.0))
                min_angle = min(min_angle, angle)
            angles.append((i, min_angle))
        # Select direction with largest minimum angle
        next_idx = max(angles, key=lambda x: x[1])[0]
        selected.append(next_idx)
    return selected

In [40]:
def initialize_direction_sets(all_directions, n_sets=3, n_directions=21):
    base_sets = []
    n_total = len(all_directions)
    
    for i in range(n_sets):
        # Create a shuffled index array for this set
        shuffled_indices = np.random.permutation(n_total)
        # Use select_diverse_directions on the shuffled indices
        selected = select_diverse_directions(all_directions[shuffled_indices], n_directions)
        # Map back to original indices
        original_indices = shuffled_indices[selected]
        base_sets.append(original_indices)
    
    return base_sets

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool
import torch.optim as optim
from torch_geometric.data import Data, Batch
from sklearn.model_selection import train_test_split
import time

In [42]:
class DiffusionGNN(torch.nn.Module):
    def __init__(self, node_features = 4, hidden_dim = 128):
        super(DiffusionGNN, self).__init__()
    
        # Layers
        self.conv1 = DiffusionConv(node_features, hidden_dim)
        self.conv2 = DiffusionConv(hidden_dim, hidden_dim)

        # MLP for tensor prediction
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 6),
        )

        self.register_buffer('loss_weights',
            torch.tensor([10, 500, 10, 500, 500, 10]))

    
    def forward(self, x, edge_index, edge_attr, batch):
        """
        x: node_features
        edge_index = graph conn info [2, num_edges]
        batch: batch assignment for nodes
        """
        # print("pre conv", x.shape)
        # x shape = 672 (batch_size*num_directions) x 4
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        # x shape = 672 x hidden_dim(64)
        # print("after conv", x.shape)

        # combine node features for graph
        x = global_mean_pool(x, batch)  # [batch_size, hidden_dim]. 1 representation for each voxel/graph
        # print("after pool", x.shape)

        # raw predictions
        out = self.mlp(x) # shape = 32 x 6

        diag_idx = [0, 2, 5]
        offdiag_idx = [1, 3, 4]

        # Diagonal elements: [0,1] using sigmoid
        diag = torch.sigmoid(out[:, diag_idx])
        
        # Off-diagonal elements: [-1,1] using tanh
        offdiag = torch.tanh(out[:, offdiag_idx])

        out_reordered = torch.zeros_like(out)
        out_reordered[:, diag_idx] = diag
        out_reordered[:, offdiag_idx] = offdiag

        return out_reordered # shape = 32 x 6
    
    def weighted_mse_loss(self, pred, target):
        squared_diff = (pred - target) ** 2  # [batch_size, 6]
        weighted_diff = squared_diff * self.loss_weights
        return torch.mean(weighted_diff)


class DiffusionConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(DiffusionConv, self).__init__(aggr="mean")
        
        # MLP to process messages
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index, edge_weight):
        return self.propagate(edge_index, x=x, edge_weight = edge_weight)
    
    def message(self, x_i, x_j, edge_weight):
        """
        x_i: features of target nodes
        x_j: features of source nodes
        Returns: messages to be aggregated
        """
        tmp = torch.cat([x_i, x_j], dim = 1)
        return self.mlp(tmp) * edge_weight.view(-1,1)

In [43]:
def convert_to_torch_geometric(nodes, edges, edge_weights, tensors):
    data_list = []
    for i in range(len(nodes)):
        # Ensure tensor is properly shaped [6] not flattened
        tensor = tensors[i].reshape(-1, 6) if len(tensors[i].shape) == 1 else tensors[i]
        
        data = Data(
            x=torch.FloatTensor(nodes[i]),          # [n_nodes]
            edge_index=torch.LongTensor(edges[i].T), # [2, n_edges]
            edge_attr=torch.FloatTensor(edge_weights[i]),
            y=torch.FloatTensor(tensor)             # [1, 6]
        )
        data_list.append(data)
    return data_list

In [44]:
train_idx, test_idx = train_test_split(np.arange(len(features)), test_size=0.2, random_state=42)
print(f"Training with {len(train_idx)} voxels, testing with {len(test_idx)} voxels")

Training with 120000 voxels, testing with 30000 voxels


In [45]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [46]:
def evaluate_model(pred, gt, epoch, print_results=True):
    """
    Compute and optionally print evaluation metrics
    Args:
        pred: predictions array (n_samples x 6)
        gt: ground truth array (n_samples x 6)
        epoch: current epoch number
        print_results: whether to print metrics
    Returns:
        Dictionary of metrics
    """
    metrics = {}
    
    # Component-wise statistics and relative errors
    for i in range(6):
        # Relative error
        rel_error = np.mean(np.abs(pred[:,i] - gt[:,i]) / (np.abs(gt[:,i]) + 1e-10))
        metrics[f'rel_error_{i}'] = rel_error
        
        # Statistics for monitoring distribution
        metrics[f'pred_mean_{i}'] = np.mean(pred[:,i])
        metrics[f'pred_std_{i}'] = np.std(pred[:,i])
        metrics[f'gt_mean_{i}'] = np.mean(gt[:,i])
        metrics[f'gt_std_{i}'] = np.std(gt[:,i])
    
    # Mean Diffusivity error
    pred_md = np.mean([pred[:,0], pred[:,2], pred[:,5]], axis=0)  # Dxx, Dyy, Dzz
    gt_md = np.mean([gt[:,0], gt[:,2], gt[:,5]], axis=0)
    metrics['md_rel_error'] = np.mean(np.abs(pred_md - gt_md) / (np.abs(gt_md) + 1e-10))

    if print_results:
        print(f"\nEpoch {epoch+1} Statistics:")
        print("\nRelative Errors per component:")
        for i in range(6):
            print(f"Component {i}: {metrics[f'rel_error_{i}']:.2e}")
        print(f"Mean Diffusivity Error: {metrics['md_rel_error']:.2e}")
        
        print("\nPrediction Statistics:")
        for i in range(6):
            print(f"Component {i}: mean={metrics[f'pred_mean_{i}']:.2e}, std={metrics[f'pred_std_{i}']:.2e}")
        
        print("\nGround Truth Statistics:")
        for i in range(6):
            print(f"Component {i}: mean={metrics[f'gt_mean_{i}']:.2e}, std={metrics[f'gt_std_{i}']:.2e}")
    
    return metrics

In [49]:
n_sets = 3
direction_sets = initialize_direction_sets(gradient_directions, n_sets)
batch_size = 32


In [None]:
# # Cache creation before training
# print("Pre-computing graph data for all direction sets...")
# cached_samples = {}  # Dictionary: direction_set_idx -> list of preprocessed samples

# # Pre-compute individual samples for each direction set
# for set_idx, direction_set in enumerate(direction_sets):
#     print(f"Processing direction set {set_idx + 1}/{len(direction_sets)}")
#     cached_samples[set_idx] = []
    
#     # Process each training sample individually
#     for idx in train_idx:  # train_idx shape: (80,000,)
#         # Create graph data for single sample
#         nodes, edges, edge_weights, tensors = create_graph_data(
#             [idx],  # Single sample
#             direction_set  # Shape: (21, 3)
#         )
#         # Convert to PyG data object
#         data = convert_to_torch_geometric(nodes, edges, edge_weights, tensors)[0]  # Take first (only) element
#         # Move to GPU and store
#         cached_samples[set_idx].append(data.to(device))

In [None]:
# Training
model = DiffusionGNN(hidden_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
epochs = 60


print("Starting training...")
for epoch in range(epochs):
    # Select direction set for this epoch
    direction_indices = direction_sets[epoch % len(direction_sets)]
    model.train()
    epoch_loss = 0
    n_batches = 0
    
    # Store epoch predictions and ground truth
    epoch_pred_values = []
    epoch_gt_values = []
    
    # Shuffle training indices
    np.random.shuffle(train_idx)
    
    
    # Process batches
    for start in range(0, len(train_idx), batch_size):
        batch_indices = train_idx[start:start + batch_size]
        
        # Create graph data for batch
        nodes, edges, edge_weights, tensors = create_graph_data(batch_indices, direction_indices)
        data_list = convert_to_torch_geometric(nodes, edges, edge_weights, tensors)
        batch_data = Batch.from_data_list(data_list).to(device)
        
        # Training step
        optimizer.zero_grad()
        pred = model(batch_data.x, batch_data.edge_index, batch_data.edge_attr, batch_data.batch)
        loss = model.weighted_mse_loss(pred, batch_data.y) * 1e6
        loss.backward()
        optimizer.step()
        
        # Store batch results
        epoch_loss += loss.item()
        epoch_pred_values.append(pred.detach().cpu().numpy())
        epoch_gt_values.append(batch_data.y.cpu().numpy())
        n_batches += 1
    
    # Combine all batches
    epoch_pred = np.concatenate(epoch_pred_values, axis=0)
    epoch_gt = np.concatenate(epoch_gt_values, axis=0)
    
    # Print epoch results
    avg_loss = epoch_loss / n_batches
    print("\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    
    # Evaluate epoch performance
    if epoch % 1 == 0:
        _ = evaluate_model(epoch_pred, epoch_gt, epoch)
    
    

Starting training...

xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
Epoch 1/60, Loss: 22311.248545

Epoch 1 Statistics:

Relative Errors per component:
Component 0: 6.21e+03
Component 1: 1.85e+04
Component 2: 7.11e+03
Component 3: 2.25e+04
Component 4: 3.55e+04
Component 5: 7.09e+03
Mean Diffusivity Error: 6.79e+03

Prediction Statistics:
Component 0: mean=1.25e-02, std=5.93e-02
Component 1: mean=1.51e-04, std=3.93e-03
Component 2: mean=1.45e-02, std=6.21e-02
Component 3: mean=-1.06e-04, std=3.59e-03
Component 4: mean=1.20e-04, std=3.97e-03
Component 5: mean=1.43e-02, std=5.76e-02

Ground Truth Statistics:
Component 0: mean=9.90e-04, std=5.52e-04
Component 1: mean=-1.21e-06, std=1.15e-04
Component 2: mean=1.03e-03, std=5.58e-04
Component 3: mean=3.41e-06, std=1.16e-04
Component 4: mean=-2.29e-05, std=1.20e-04
Component 5: mean=1.00e-03, std=5.65e-04

xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
Epoch 2/60, Loss: 220.483912

Epoch 2 Statistics:

Relative Errors per component:
Component 0: 1.97e+03
C

KeyboardInterrupt: 

In [None]:
print("\nStarting Final Evaluation...")
model.eval()
test_losses = []
test_preds = []
test_gts = []

with torch.no_grad():
    # Process all batches, including last partial batch
    for start in range(0, len(test_idx), batch_size):
        end = min(start + batch_size, len(test_idx))
        batch_indices = test_idx[start:end]
        
        # Create and process batch
        nodes, edges, edge_weights, tensors = create_graph_data(
            batch_indices, direction_indices
        )
        data_list = convert_to_torch_geometric(nodes, edges, edge_weights, tensors)
        batch_data = Batch.from_data_list(data_list).to(device)
        
        # Get predictions
        pred = model(batch_data.x, batch_data.edge_index, batch_data.edge_attr, batch_data.batch)
        
        # Calculate loss
        test_loss = model.weighted_mse_loss(pred, batch_data.y) * 1e6
        test_losses.append(test_loss.item())
        
        # Store predictions and ground truth
        test_preds.append(pred.cpu().numpy())
        test_gts.append(batch_data.y.cpu().numpy())

# Combine results
all_test_preds = np.concatenate(test_preds, axis=0)
all_test_gts = np.concatenate(test_gts, axis=0)
avg_test_loss = np.mean(test_losses)

print(f"\nFinal Test Results:")
print(f"Average test loss: {avg_test_loss:.6f}")

# Use same evaluation function as training
metrics = evaluate_model(all_test_preds, all_test_gts, epoch=-2)

# Print random examples from full test set
print("\nExample predictions vs ground truth (randomly sampled):")
n_examples = 5
rand_indices = np.random.choice(len(all_test_preds), n_examples, replace=False)
for i, idx in enumerate(rand_indices):
    print(f"\nSample {i+1}:")
    print(f"Predicted: {all_test_preds[idx]}")
    print(f"Actual:    {all_test_gts[idx]}")


Starting Final Evaluation...

Final Test Results:
Average test loss: 0.412770

Epoch -1 Statistics:

Relative Errors per component:
Component 0: 2.87e+02
Component 1: 1.56e+03
Component 2: 3.06e+02
Component 3: 1.24e+03
Component 4: 5.75e+02
Component 5: 2.95e+02
Mean Diffusivity Error: 2.94e+02

Prediction Statistics:
Component 0: mean=9.82e-04, std=5.25e-04
Component 1: mean=1.02e-04, std=1.01e-04
Component 2: mean=1.02e-03, std=5.33e-04
Component 3: mean=3.12e-05, std=1.00e-04
Component 4: mean=-9.88e-06, std=1.08e-04
Component 5: mean=1.00e-03, std=5.43e-04

Ground Truth Statistics:
Component 0: mean=9.82e-04, std=5.44e-04
Component 1: mean=-5.51e-07, std=1.13e-04
Component 2: mean=1.02e-03, std=5.49e-04
Component 3: mean=2.80e-06, std=1.15e-04
Component 4: mean=-2.30e-05, std=1.20e-04
Component 5: mean=9.99e-04, std=5.54e-04

Example predictions vs ground truth (randomly sampled):

Sample 1:
Predicted: [1.0977978e-03 8.6188316e-05 1.1421236e-03 1.8950552e-05 2.9838644e-05
 1.1174

In [None]:
# print("Original Ground Truth Tensor Analysis:")
# print("\nComponent-wise statistics:")
# for i in range(6):
#     comp = ground_truth_tensors[:, i]
#     print(f"\nComponent {i}:")
#     print(f"Min: {np.min(comp):.2e}")
#     print(f"Max: {np.max(comp):.2e}")
#     print(f"Mean: {np.mean(comp):.2e}")
#     print(f"Std: {np.std(comp):.2e}")

In [None]:
# print("Loading ground truth data 2...")
# gt_data_v2 = np.load('ground_truth_v2.npz')
# ground_truth_tensors_v2 = gt_data_v2['tensors']
# print(f"Ground truth tensors shape: {ground_truth_tensors_v2.shape}")

In [None]:
# print("V2 Original Ground Truth Tensor Analysis:")
# print("\nComponent-wise statistics:")
# for i in range(6):
#     comp = ground_truth_tensors_v2[:, i]
#     print(f"\nComponent {i}:")
#     print(f"Min: {np.min(comp):.2e}")
#     print(f"Max: {np.max(comp):.2e}")
#     print(f"Mean: {np.mean(comp):.2e}")
#     print(f"Std: {np.std(comp):.2e}")