In [1]:
import numpy as np
import nibabel as nib
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table

In [2]:
subject_id = "100206"
subject_path = f"diffusion_data/{subject_id}/T1w/Diffusion"
dwi_img = nib.load(f'{subject_path}/data.nii.gz')
mask_img = nib.load(f'{subject_path}/nodif_brain_mask.nii.gz')

In [3]:
# Convert to numpy arrays for processing
dwi_data = dwi_img.get_fdata()
mask = mask_img.get_fdata()

In [4]:
print(f"DWI data shape: {dwi_data.shape}")  # Should be (X, Y, Z, num_volumes)
print(f"Mask shape: {mask.shape}")          # Should be (X, Y, Z)

DWI data shape: (145, 174, 145, 288)
Mask shape: (145, 174, 145)


In [24]:
# Load gradient information (bvals and bvecs)
print("\nLoading gradient information...")
bvals, bvecs = read_bvals_bvecs(f'{subject_path}/bvals', 
                               f'{subject_path}/bvecs')
print(f"Number of gradient directions: {len(bvals)}")
print(f"bvals shape: {bvals.shape}")     # Should match number of volumes
print(f"bvecs shape: {bvecs.shape}")     # Should be (num_volumes, 3)


Loading gradient information...
Number of gradient directions: 288
bvals shape: (288,)
bvecs shape: (288, 3)


In [25]:
# Create gradient table for DIPY
gtab = gradient_table(bvals, bvecs)

In [26]:
# Identify and extract B0 (non-diffusion weighted) volumes
b0_mask = gtab.b0s_mask
b1000_mask = (bvals >= 990) & (bvals <= 1010)
print(f"Number of b=1000 directions: {np.sum(b1000_mask)}")
b0_data = dwi_data[..., b0_mask]
print(f"\nNumber of B0 volumes: {np.sum(b0_mask)}")
print(f"B0 data shape: {b0_data.shape}")
# Average B0 volumes to get a single reference image
b0_avg = np.mean(b0_data, axis=-1)
print(f"Average B0 shape: {b0_avg.shape}")

Number of b=1000 directions: 90

Number of B0 volumes: 18
B0 data shape: (145, 174, 145, 18)
Average B0 shape: (145, 174, 145)


In [27]:
# Extract and normalize diffusion weighted volumes
dwi_mask = ~b0_mask  # Mask for diffusion weighted volumes
dwi_vols = dwi_data[..., b1000_mask]
print(f"\nNumber of DWI volumes: {np.sum(dwi_mask)}")
print(f"DWI volumes shape: {dwi_vols.shape}")


Number of DWI volumes: 270
DWI volumes shape: (145, 174, 145, 90)


In [28]:
# # Normalize DWI volumes by B0 (avoid division by zero with small epsilon)
# dwi_norm = dwi_vols / (b0_avg[..., None] + 1e-6)
# print(f"Normalized DWI shape: {dwi_norm.shape}")

In [29]:
# Find valid voxels using the brain mask
valid_idx = np.where(mask > 0)
print(f"\nNumber of valid voxels in mask: {len(valid_idx[0])}")


Number of valid voxels in mask: 936256


In [30]:
# Sample random voxels for training
n_samples = 200000  # Adjust this number based on your needs
sample_idx = np.random.choice(len(valid_idx[0]), 
                            min(n_samples, len(valid_idx[0])), 
                            replace=False)

In [31]:
# Extract features (signal intensities) from sampled voxels
features = []
for idx in sample_idx:
    x, y, z = valid_idx[0][idx], valid_idx[1][idx], valid_idx[2][idx]
    features.append(dwi_vols[x, y, z, :])

In [32]:
features = np.array(features)
gradient_directions = bvecs[b1000_mask]  # Only keep directions for DWI volumes

In [33]:
feature_max = np.max(features)
features = features / feature_max

In [34]:
print("\nFinal data shapes:")
print(f"Features shape: {features.shape}")           # Should be (n_samples, n_directions)
print(f"Gradient directions shape: {gradient_directions.shape}")  # Should be (n_directions, 3)

# Basic sanity checks
print("\nSanity checks:")
print(f"Max normalized value: {np.max(features)}")
print(f"Min normalized value: {np.min(features)}")
print(f"Gradient directions magnitude close to 1: {np.allclose(np.linalg.norm(gradient_directions, axis=1), 1, atol=1e-3)}")


Final data shapes:
Features shape: (200000, 90)
Gradient directions shape: (90, 3)

Sanity checks:
Max normalized value: 1.0
Min normalized value: 0.0
Gradient directions magnitude close to 1: True


In [35]:
print("Loading ground truth data...")
gt_data = np.load('ground_truth.npz')
ground_truth_tensors = gt_data['tensors']
# ground_truth_tensors = torch.tensor(ground_truth_tensors, dtype=torch.float32)
valid_coordinates = gt_data['coordinates']
print(f"Ground truth tensors shape: {ground_truth_tensors.shape}")
print(f"Valid coordinates shape: {valid_coordinates.shape}")

Loading ground truth data...
Ground truth tensors shape: (936256, 6)
Valid coordinates shape: (936256, 3)


In [36]:
# Number of directions to use for sparse estimation
n_directions = 21  # We can adjust this number

def create_graph_data(features, gradient_directions, n_sparse_directions, batch_size = 32, threshold_angle = 45):
    """
    Create graph data for a batch of samples, including ground tensor
    """
    voxel_indices = np.random.choice(len(features), size = batch_size, replace = False)
    # print(f"Creating graphs for batch of {batch_size} voxels using {n_sparse_directions} directions each")

    batch_nodes = []
    batch_edges = []
    batch_tensors = []

    for idx in voxel_indices:
        # Randomly selected directions
        selected_dir_idx = np.random.choice(len(gradient_directions), size=n_sparse_directions, replace=False)

        directions = gradient_directions[selected_dir_idx]
        signals = features[idx, selected_dir_idx]
        
        # Create node features: x, y, z, signal
        nodes = np.column_stack([directions, signals])

        # Create edges acc. to angle
        directions_norm = directions / np.linalg.norm(directions, axis=1, keepdims=True)
        cos_sim = np.dot(directions_norm, directions_norm.T)
        angles = np.arccos(np.clip(cos_sim, -1.0, 1.0)) * 180/np.pi
        
        src, dst = np.where(angles < threshold_angle)
        mask = src != dst
        edges = np.column_stack([src[mask], dst[mask]])

        batch_nodes.append(nodes)
        batch_edges.append(edges)
        batch_tensors.append(ground_truth_tensors[idx])
        
    return batch_nodes, batch_edges, batch_tensors

# nodes, edges, gt_tensors = create_graph_data(features, gradient_directions, n_directions)
# # Print info for verification
# print("\nBatch statistics:")
# print(f"Number of samples in batch: {len(nodes)}")
# print(f"Number of nodes per graph: {nodes[0].shape[0]}")
# print(f"Node feature dimensionality: {nodes[0].shape[1]}")
# print(f"Ground truth tensors: {len(gt_tensors)}")

# # Print example for first sample
# print("\nFirst sample in batch:")
# print(f"Number of nodes: {nodes[0].shape[0]}")
# print(f"Number of edges: {edges[0].shape[0]}")
# print(f"Ground truth tensor components: {gt_tensors[0]}")

In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool
import torch.optim as optim
from torch_geometric.data import Data, Batch
from sklearn.model_selection import train_test_split

In [46]:
class DiffusionGNN(torch.nn.Module):
    def __init__(self, node_features = 4, hidden_dim = 128):
        super(DiffusionGNN, self).__init__()
    
        # Layers
        self.conv1 = DiffusionConv(node_features, hidden_dim)
        self.conv2 = DiffusionConv(hidden_dim, hidden_dim)

        # MLP for tensor prediction
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 6),
        )

        self.scale_net = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    
    def forward(self, x, edge_index, batch):
        """
        x: node_features
        edge_index = graph conn info [2, num_edges]
        batch: batch assignment for nodes
        """
        # pass through gnn layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # combine node features for graph
        x = global_mean_pool(x, batch)  # [batch_size, hidden_dim]

        # raw predictions
        out = self.mlp(x)

        scale = self.scale_net(x)
        
        diag_idx = [0, 2, 5]
        diag = out[:, diag_idx]

        offdiag_idx = [1, 3, 4]
        offdiag = out[:, offdiag_idx]

        # Diagonal elements: [0,1] using sigmoid
        diag = torch.sigmoid(diag)
        
        # Off-diagonal elements: [-1,1] using tanh
        offdiag = torch.tanh(offdiag)

        out_reordered = torch.zeros_like(out)
        out_reordered[:, diag_idx] = diag      # Put diagonal components in right places
        out_reordered[:, offdiag_idx] = offdiag  # Put off-diagonal components in right places
        
        return out_reordered * scale


class DiffusionConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(DiffusionConv, self).__init__(aggr="mean")
        
        # MLP to process messages
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        """
        x_i: features of target nodes
        x_j: features of source nodes
        Returns: messages to be aggregated
        """
        tmp = torch.cat([x_i, x_j], dim = 1)
        return self.mlp(tmp)

In [39]:
def convert_to_torch_geometric(nodes, edges, tensors):
    data_list = []
    for i in range(len(nodes)):
        # Ensure tensor is properly shaped [6] not flattened
        tensor = tensors[i].reshape(-1, 6) if len(tensors[i].shape) == 1 else tensors[i]
        
        data = Data(
            x=torch.FloatTensor(nodes[i]),          # [n_nodes, 4]
            edge_index=torch.LongTensor(edges[i].T), # [2, n_edges]
            y=torch.FloatTensor(tensor)             # [1, 6]
        )
        data_list.append(data)
    return data_list

In [40]:
train_idx, test_idx = train_test_split(np.arange(len(features)), test_size=0.2, random_state=42)
print(f"Training with {len(train_idx)} voxels, testing with {len(test_idx)} voxels")

Training with 160000 voxels, testing with 40000 voxels


In [47]:
# Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = DiffusionGNN(node_features=4, hidden_dim=32).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
batch_size = 32
epochs = 10

all_pred_values = [[] for _ in range(6)]
all_gt_values = [[] for _ in range(6)]

print("Starting training...")
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    n_batches = 0

    # Shuffle training indices
    np.random.shuffle(train_idx)
    
    # Process batches
    for start in range(0, len(train_idx), batch_size):
        batch_idx = train_idx[start:start + batch_size]
        # Create graphs for batch
        nodes, edges, tensors = create_graph_data(
            features[batch_idx], gradient_directions, 21
        )
        
        # Convert to PyG and process
        data_list = convert_to_torch_geometric(nodes, edges, tensors)
        batch_data = Batch.from_data_list(data_list).to(device)
        
        # Training step
        optimizer.zero_grad()
        pred = model(batch_data.x, batch_data.edge_index, batch_data.batch)
        loss = F.mse_loss(pred, batch_data.y) * 1e6
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        n_batches += 1
    
    for i in range(6):
        all_pred_values[i].extend(pred[:, i].detach().cpu().numpy())
        all_gt_values[i].extend(batch_data.y[:, i].cpu().numpy())

    if epoch % 1 == 0:  # Print every epoch
        # Move entire tensors to CPU once
        pred_cpu = pred.detach().cpu().numpy()
        gt_cpu = batch_data.y.cpu().numpy()
        
        print(f"\nEpoch {epoch+1} Statistics:")
        # Component-wise prediction stats
        print("\nPredictions per component:")
        for i in range(6):
            print(f"Component {i}: min={np.min(pred_cpu[:,i]):.2e}, "
                f"max={np.max(pred_cpu[:,i]):.2e}, "
                f"mean={np.mean(pred_cpu[:,i]):.2e}, "
                f"std={np.std(pred_cpu[:,i]):.2e}")
        
        # Component-wise ground truth stats
        print("\nGround Truth per component:")
        for i in range(6):
            print(f"Component {i}: min={np.min(gt_cpu[:,i]):.2e}, "
                f"max={np.max(gt_cpu[:,i]):.2e}, "
                f"mean={np.mean(gt_cpu[:,i]):.2e}, "
                f"std={np.std(gt_cpu[:,i]):.2e}")
        
        # Relative error per component
        print("\nRelative Error per component:")
        for i in range(6):
            rel_error = np.mean(np.abs(pred_cpu[:,i] - gt_cpu[:,i]) / 
                            (np.abs(gt_cpu[:,i]) + 1e-10))
            print(f"Component {i}: {rel_error:.2e}")
    
    avg_loss = epoch_loss / n_batches
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    

print("\nTraining Values Statistics:")
for i in range(6):
    pred_vals = np.array(all_pred_values[i])
    gt_vals = np.array(all_gt_values[i])
    print(f"\nComponent {i}:")
    print("Predictions:")
    print(f"Min: {np.min(pred_vals):.6f}")
    print(f"Max: {np.max(pred_vals):.6f}")
    print(f"Mean: {np.mean(pred_vals):.6f}")
    print(f"1st percentile: {np.percentile(pred_vals, 1):.6f}")
    print(f"99th percentile: {np.percentile(pred_vals, 99):.6f}")
    print("Ground Truth:")
    print(f"Min: {np.min(gt_vals):.6f}")
    print(f"Max: {np.max(gt_vals):.6f}")
    print(f"Mean: {np.mean(gt_vals):.6f}")
    print(f"1st percentile: {np.percentile(gt_vals, 1):.6f}")
    print(f"99th percentile: {np.percentile(gt_vals, 99):.6f}")

# Final evaluation on full test set
print("\nFinal Evaluation...")
model.eval()
test_losses = []
all_test_preds = []
all_test_gts = []

with torch.no_grad():
    for start in range(0, len(test_idx), batch_size):
        batch_idx = test_idx[start:start + batch_size]
        if len(batch_idx) < batch_size:
            continue
            
        nodes, edges, tensors = create_graph_data(
            features[batch_idx], gradient_directions, 21
        )
        data_list = convert_to_torch_geometric(nodes, edges, tensors)
        batch_data = Batch.from_data_list(data_list).to(device)
        
        pred = model(batch_data.x, batch_data.edge_index, batch_data.batch)
        pred = pred.cpu()
        batch_data = batch_data.cpu()
        
        test_loss = F.mse_loss(pred, batch_data.y) * 1e6
        test_losses.append(test_loss.item())
        
        # Store predictions and ground truths
        all_test_preds.append(pred.detach().numpy())
        all_test_gts.append(batch_data.y.numpy())

# Combine all batches
all_test_preds = np.concatenate(all_test_preds, axis=0)
all_test_gts = np.concatenate(all_test_gts, axis=0)

# Print overall test loss
avg_test_loss = np.mean(test_losses)
print(f"\nFinal Test Results:")
print(f"Average test loss: {avg_test_loss:.6f}")

# Component-wise analysis
print("\nTest Set Component Analysis:")
for i in range(6):
    pred_comp = all_test_preds[:, i]
    gt_comp = all_test_gts[:, i]
    
    # Stats
    print(f"\nComponent {i}:")
    print(f"Predictions - min: {np.min(pred_comp):.2e}, max: {np.max(pred_comp):.2e}, "
          f"mean: {np.mean(pred_comp):.2e}, std: {np.std(pred_comp):.2e}")
    print(f"Ground Truth - min: {np.min(gt_comp):.2e}, max: {np.max(gt_comp):.2e}, "
          f"mean: {np.mean(gt_comp):.2e}, std: {np.std(gt_comp):.2e}")
    
    # Relative Error
    rel_error = np.mean(np.abs(pred_comp - gt_comp) / (np.abs(gt_comp) + 1e-10))
    print(f"Relative Error: {rel_error:.2e}")


# Print some example predictions vs ground truth
print("\nExample predictions vs ground truth:")
for i in range(5):
    print(f"\nSample {i+1}:")
    print(f"Predicted: {pred[i].numpy()}")
    print(f"Actual: {batch_data.y[i].numpy()}")

Using device: cuda
Starting training...

Epoch 1 Statistics:

Predictions per component:
Component 0: min=1.17e-04, max=5.72e-04, mean=2.92e-04, std=8.90e-05
Component 1: min=-3.12e-05, max=-5.01e-06, mean=-1.69e-05, std=7.01e-06
Component 2: min=1.19e-04, max=6.10e-04, mean=3.06e-04, std=9.63e-05
Component 3: min=-2.20e-05, max=3.29e-05, mean=2.28e-05, std=9.56e-06
Component 4: min=-2.08e-04, max=1.05e-05, mean=-5.55e-05, std=3.96e-05
Component 5: min=1.11e-04, max=5.60e-04, mean=2.82e-04, std=8.82e-05

Ground Truth per component:
Component 0: min=1.62e-04, max=5.48e-04, mean=3.04e-04, std=9.81e-05
Component 1: min=-4.25e-05, max=4.70e-05, mean=-1.67e-06, std=2.24e-05
Component 2: min=1.67e-04, max=4.85e-04, mean=3.20e-04, std=9.22e-05
Component 3: min=-1.88e-05, max=4.25e-05, mean=1.49e-05, std=1.61e-05
Component 4: min=-6.59e-05, max=2.80e-05, mean=-1.14e-05, std=2.10e-05
Component 5: min=1.43e-04, max=4.65e-04, mean=2.95e-04, std=8.37e-05

Relative Error per component:
Component 0:

In [42]:
print("Original Ground Truth Tensor Analysis:")
print("\nComponent-wise statistics:")
for i in range(6):
    comp = ground_truth_tensors[:, i]
    print(f"\nComponent {i}:")
    print(f"Min: {np.min(comp):.2e}")
    print(f"Max: {np.max(comp):.2e}")
    print(f"Mean: {np.mean(comp):.2e}")
    print(f"Std: {np.std(comp):.2e}")

Original Ground Truth Tensor Analysis:

Component-wise statistics:

Component 0:
Min: 3.34e-10
Max: 2.67e-01
Mean: 6.67e-04
Std: 4.27e-04

Component 1:
Min: -6.89e-04
Max: 7.34e-03
Mean: -1.48e-06
Std: 8.80e-05

Component 2:
Min: 3.34e-10
Max: 2.72e-01
Mean: 6.93e-04
Std: 4.32e-04

Component 3:
Min: -1.66e-02
Max: 9.90e-04
Mean: 8.52e-08
Std: 9.10e-05

Component 4:
Min: -6.56e-03
Max: 7.76e-04
Mean: -1.75e-05
Std: 9.18e-05

Component 5:
Min: 3.34e-10
Max: 2.74e-01
Mean: 6.75e-04
Std: 4.36e-04
