In [30]:
import numpy as np
import nibabel as nib
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table

In [31]:
subject_id = "100206"
subject_path = f"diffusion_data/{subject_id}/T1w/Diffusion"
dwi_img = nib.load(f'{subject_path}/data.nii.gz')
mask_img = nib.load(f'{subject_path}/nodif_brain_mask.nii.gz')

In [32]:
dwi_data = dwi_img.get_fdata()
original_mask = mask_img.get_fdata() > 0

In [33]:
# store original indices before any voxel filtering 
original_idx = np.where(original_mask)

In [34]:
# Map (x,y,z) to original linear index
coord_to_gtt = {}
for i in range(len(original_idx[0])):
    x, y, z = original_idx[0][i], original_idx[1][i], original_idx[2][i]
    coord_to_gtt[(x,y,z)] = i  # i is the index in ground_truth_tensors

In [35]:
print(f"DWI data shape: {dwi_data.shape}")  # (X, Y, Z, num_volumes)

DWI data shape: (145, 174, 145, 288)


In [36]:
bvals, bvecs = read_bvals_bvecs(f'{subject_path}/bvals', 
                               f'{subject_path}/bvecs')
gtab = gradient_table(bvals, bvecs)

In [37]:
# get b0 data and b0_avg
b0_mask = gtab.b0s_mask
b0_data = dwi_data[..., b0_mask]
print(f"B0 data shape: {b0_data.shape}")
b0_avg = np.mean(b0_data, axis=-1)

# b1000 images mask
b1000_mask = (bvals >= 990) & (bvals <= 1010)
print(f"Number of b=1000 volumes: {np.sum(b1000_mask)}")

B0 data shape: (145, 174, 145, 18)
Number of b=1000 volumes: 90


In [38]:
# Get all b1000 scans
dwi_vols = dwi_data[..., b1000_mask]
print(f"DWI volumes shape: {dwi_vols.shape}")

DWI volumes shape: (145, 174, 145, 90)


In [39]:
# mask out voxels with very low b0 signal
b0_threshold = 250
valid_b0_mask = b0_avg > b0_threshold
mask = original_mask & valid_b0_mask  # Combine with brain mask

In [40]:
# Find valid voxels using the brain mask
valid_idx = np.where(mask) # shape: 3 x num_valid_voxels
print(f"\nNumber of valid voxels in mask: {len(valid_idx[0])}")


Number of valid voxels in mask: 926671


In [41]:
# Sample random voxels for training
n_samples = 100000
sample_idx = np.random.choice(len(valid_idx[0]), 
                            min(n_samples, len(valid_idx[0])), 
                            replace=False)

In [42]:
# Extract features (signal intensities) from sampled voxels
features = []
gtt_indices = [] # store ground truth tensor indices 

for idx in sample_idx:
    x, y, z = valid_idx[0][idx], valid_idx[1][idx], valid_idx[2][idx]
    signal = dwi_vols[x, y, z, :]
    b0_ref = b0_avg[x, y, z]
    normalized_signal = signal / (b0_ref)
    features.append(normalized_signal)
    gtt_indices.append(coord_to_gtt[(x, y, z)])

In [43]:
gtt_indices[:10]

[562793, 201823, 866778, 795546, 16446, 493057, 574195, 883643, 589013, 213797]

In [44]:
features = np.array(features)
gtt_indices = np.array(gtt_indices)
gradient_directions = bvecs[b1000_mask]  # Only keep directions for DWI volumes

In [45]:
# print("\nFinal data shapes:")
# print(f"Features shape: {features.shape}")           # Should be (n_samples, n_directions)
# print(f"Gradient directions shape: {gradient_directions.shape}")  # Should be (n_directions, 3)

# # Basic sanity checks
# print("\nSanity checks:")
# print(f"Max normalized value: {np.max(features)}")
# print(f"Min normalized value: {np.min(features)}")
# print(f"Gradient directions magnitude close to 1: {np.allclose(np.linalg.norm(gradient_directions, axis=1), 1, atol=1e-3)}")

In [46]:
print("Loading ground truth data...")
gt_data = np.load('ground_truth_v2.npz')
ground_truth_tensors = gt_data['tensors']
# ground_truth_tensors = torch.tensor(ground_truth_tensors, dtype=torch.float32)
# valid_coordinates = gt_data['coordinates']
# print(f"Ground truth tensors shape: {ground_truth_tensors.shape}")
# print(f"Valid coordinates shape: {valid_coordinates.shape}")

Loading ground truth data...


In [47]:
# Number of directions to use for sparse estimation
n_directions = 21  # We can adjust this number

def create_graph_data(batch_idx, direction_indices, batch_size = 32, threshold_angle = 60):
    """
    Create graph data for a batch of samples, including ground tensor
    """
    
    batch_nodes = []
    batch_edges = []
    batch_tensors = []
    batch_edge_weights = []  

    current_directions = gradient_directions[direction_indices]
    for idx in batch_idx:
        signals = features[idx, direction_indices]
        # Create edges acc. to angle
        directions_norm = current_directions / np.linalg.norm(current_directions, axis=1, keepdims=True)
        cos_sim = np.dot(directions_norm, directions_norm.T)
        angles = np.arccos(np.clip(cos_sim, -1.0, 1.0)) * 180/np.pi
        
        src, dst = np.where(angles < threshold_angle)
        mask = src != dst
        edges = np.column_stack([src[mask], dst[mask]])

        weights = cos_sim[src[mask], dst[mask]]

        nodes = np.zeros((len(signals), 4))  # Shape: [n_directions, 4]

        for i in range(len(signals)):
            nodes[i, 0] = signals[i]
            nodes[i, 1:4] = current_directions[i]
            # print(f"voxel {idx}, dir {i} = {nodes[i, 1:4]} ")

        batch_nodes.append(nodes)
        batch_edges.append(edges)
        batch_edge_weights.append(weights)
        batch_tensors.append(ground_truth_tensors[gtt_indices[idx]])
    return batch_nodes, batch_edges, batch_edge_weights, batch_tensors

# nodes, edges, gt_tensors = create_graph_data(features, gradient_directions, n_directions)
# # Print info for verification
# print("\nBatch statistics:")
# print(f"Number of samples in batch: {len(nodes)}")
# print(f"Number of nodes per graph: {nodes[0].shape[0]}")
# print(f"Node feature dimensionality: {nodes[0].shape[1]}")
# print(f"Ground truth tensors: {len(gt_tensors)}")

# # Print example for first sample
# print("\nFirst sample in batch:")
# print(f"Number of nodes: {nodes[0].shape[0]}")
# print(f"Number of edges: {edges[0].shape[0]}")
# print(f"Ground truth tensor components: {gt_tensors[0]}")

In [48]:
def select_diverse_directions(all_directions, n_select=21):
    selected = [0]  # Start with first direction
    while len(selected) < n_select:
        # Calculate angles with all selected directions
        angles = []
        for i in range(len(all_directions)):
            if i in selected:
                continue
            min_angle = float('inf')
            for s in selected:
                # Cosine similarity
                angle = np.arccos(np.clip(
                    np.dot(all_directions[i], all_directions[s]), -1.0, 1.0))
                min_angle = min(min_angle, angle)
            angles.append((i, min_angle))
        # Select direction with largest minimum angle
        next_idx = max(angles, key=lambda x: x[1])[0]
        selected.append(next_idx)
    return selected

In [49]:
def initialize_direction_sets(all_directions, n_sets=3, n_directions=21):
    base_sets = []
    n_total = len(all_directions)
    
    for i in range(n_sets):
        # Create a shuffled index array for this set
        shuffled_indices = np.random.permutation(n_total)
        # Use select_diverse_directions on the shuffled indices
        selected = select_diverse_directions(all_directions[shuffled_indices], n_directions)
        # Map back to original indices
        original_indices = shuffled_indices[selected]
        base_sets.append(original_indices)
    
    return base_sets

In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool
import torch.optim as optim
from torch_geometric.data import Data, Batch
from sklearn.model_selection import train_test_split

In [51]:
class DiffusionGNN(torch.nn.Module):
    def __init__(self, node_features = 4, hidden_dim = 128):
        super(DiffusionGNN, self).__init__()
    
        # Layers
        self.conv1 = DiffusionConv(node_features, hidden_dim)
        self.conv2 = DiffusionConv(hidden_dim, hidden_dim)

        # MLP for tensor prediction
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 6),
        )

        # self.scale_net = nn.Sequential(
        #     nn.Linear(hidden_dim, 1),
        #     nn.Sigmoid()
        # )

        self.scale_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 6),
            nn.Softplus()
        )

        self.register_buffer('loss_weights',
            torch.tensor([10, 100, 10, 100, 100, 10]))

    
    def forward(self, x, edge_index, edge_attr, batch):
        """
        x: node_features
        edge_index = graph conn info [2, num_edges]
        batch: batch assignment for nodes
        """
        # print("pre conv", x.shape)
        # pass through gnn layers
        # print(x.shape)
        # print(x[0:5])
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)

        # print("after conv", x.shape)

        # combine node features for graph
        x = global_mean_pool(x, batch)  # [batch_size, hidden_dim]
        
        # print("after pool", x.shape)

        # raw predictions
        out = self.mlp(x)

        # scales = self.scale_net(x)
        # out = out * scales
        
        diag_idx = [0, 2, 5]
        offdiag_idx = [1, 3, 4]

        # Diagonal elements: [0,1] using sigmoid
        diag = torch.sigmoid(out[:, diag_idx])
        
        # Off-diagonal elements: [-1,1] using tanh
        offdiag = torch.tanh(out[:, offdiag_idx])

        out_reordered = torch.zeros_like(out)
        out_reordered[:, diag_idx] = diag
        out_reordered[:, offdiag_idx] = offdiag
        
        return out_reordered
    
    def weighted_mse_loss(self, pred, target):
        squared_diff = (pred - target) ** 2  # [batch_size, 6]
        weighted_diff = squared_diff * self.loss_weights
        return torch.mean(weighted_diff)


class DiffusionConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(DiffusionConv, self).__init__(aggr="mean")
        
        # MLP to process messages
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index, edge_weight):
        return self.propagate(edge_index, x=x, edge_weight = edge_weight)
    
    def message(self, x_i, x_j, edge_weight):
        """
        x_i: features of target nodes
        x_j: features of source nodes
        Returns: messages to be aggregated
        """
        tmp = torch.cat([x_i, x_j], dim = 1)
        return self.mlp(tmp) * edge_weight.view(-1,1)

In [52]:
def convert_to_torch_geometric(nodes, edges, edge_weights, tensors):
    data_list = []
    for i in range(len(nodes)):
        # Ensure tensor is properly shaped [6] not flattened
        tensor = tensors[i].reshape(-1, 6) if len(tensors[i].shape) == 1 else tensors[i]
        
        data = Data(
            x=torch.FloatTensor(nodes[i]),          # [n_nodes]
            edge_index=torch.LongTensor(edges[i].T), # [2, n_edges]
            edge_attr=torch.FloatTensor(edge_weights[i]),
            y=torch.FloatTensor(tensor)             # [1, 6]
        )
        data_list.append(data)
    return data_list

In [53]:
train_idx, test_idx = train_test_split(np.arange(len(features)), test_size=0.2, random_state=42)
print(f"Training with {len(train_idx)} voxels, testing with {len(test_idx)} voxels")

Training with 80000 voxels, testing with 20000 voxels


In [54]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [55]:
n_sets = 3
direction_sets = initialize_direction_sets(gradient_directions, n_sets)

In [56]:
# Training
model = DiffusionGNN(hidden_dim=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
batch_size = 32
epochs = 20

all_pred_values = [[] for _ in range(6)]
all_gt_values = [[] for _ in range(6)]

print("Starting training...")
for epoch in range(epochs):
    direction_indices = direction_sets[epoch % n_sets]
    model.train()
    epoch_loss = 0
    n_batches = 0

    # Shuffle training indices
    np.random.shuffle(train_idx)
    
    # Process batches
    for start in range(0, len(train_idx), batch_size):
        batch_idx = train_idx[start:start + batch_size]
        # Create graphs for batch
        nodes, edges, edge_weights, tensors = create_graph_data(
            batch_idx, direction_indices
        )
        # print("Before PyG:", nodes[0][:5])
        # Convert to PyG and process
        data_list = convert_to_torch_geometric(nodes, edges, edge_weights, tensors)
        # print("After PyG:", data_list[0].x[:5])
        batch_data = Batch.from_data_list(data_list).to(device)
        # print("After batch:", batch_data.x[:5])
        # Training step
        optimizer.zero_grad()
        # print("in train loop", batch_data.x.shape)
        pred = model(batch_data.x, batch_data.edge_index, batch_data.edge_attr, batch_data.batch)
        y_true = batch_data.y

        # loss = F.mse_loss(pred, y_true) * 1e6
        loss = model.weighted_mse_loss(pred, y_true) * 1e6

        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        n_batches += 1
    
    for i in range(6):
        all_pred_values[i].extend(pred[:, i].detach().cpu().numpy())
        all_gt_values[i].extend(batch_data.y[:, i].cpu().numpy())

    if epoch % 1 == 0:  # Print every epoch
        # Move entire tensors to CPU once
        pred_cpu = pred.detach().cpu().numpy()
        gt_cpu = batch_data.y.cpu().numpy()
        
        print(f"\nEpoch {epoch+1} Statistics:")
        # Component-wise prediction stats
        print("\nPredictions per component:")
        for i in range(6):
            print(f"Component {i}: min={np.min(pred_cpu[:,i]):.2e}, "
                f"max={np.max(pred_cpu[:,i]):.2e}, "
                f"mean={np.mean(pred_cpu[:,i]):.2e}, "
                f"std={np.std(pred_cpu[:,i]):.2e}")
        
        # Component-wise ground truth stats
        print("\nGround Truth per component:")
        for i in range(6):
            print(f"Component {i}: min={np.min(gt_cpu[:,i]):.2e}, "
                f"max={np.max(gt_cpu[:,i]):.2e}, "
                f"mean={np.mean(gt_cpu[:,i]):.2e}, "
                f"std={np.std(gt_cpu[:,i]):.2e}")
        
        # Relative error per component
        print("\nRelative Error per component:")
        for i in range(6):
            rel_error = np.mean(np.abs(pred_cpu[:,i] - gt_cpu[:,i]) / 
                            (np.abs(gt_cpu[:,i]) + 1e-10))
            print(f"Component {i}: {rel_error:.2e}")
    
    avg_loss = epoch_loss / n_batches
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    

print("\nTraining Values Statistics:")
for i in range(6):
    pred_vals = np.array(all_pred_values[i])
    gt_vals = np.array(all_gt_values[i])
    print(f"\nComponent {i}:")
    print("Predictions:")
    print(f"Min: {np.min(pred_vals):.6f}")
    print(f"Max: {np.max(pred_vals):.6f}")
    print(f"Mean: {np.mean(pred_vals):.6f}")
    print(f"1st percentile: {np.percentile(pred_vals, 1):.6f}")
    print(f"99th percentile: {np.percentile(pred_vals, 99):.6f}")
    print("Ground Truth:")
    print(f"Min: {np.min(gt_vals):.6f}")
    print(f"Max: {np.max(gt_vals):.6f}")
    print(f"Mean: {np.mean(gt_vals):.6f}")
    print(f"1st percentile: {np.percentile(gt_vals, 1):.6f}")
    print(f"99th percentile: {np.percentile(gt_vals, 99):.6f}")

# Final evaluation on full test set
print("\nFinal Evaluation...")
model.eval()
test_losses = []
all_test_preds = []
all_test_gts = []

with torch.no_grad():
    for start in range(0, len(test_idx), batch_size):
        batch_idx = test_idx[start:start + batch_size]
        if len(batch_idx) < batch_size:
            continue
            
        nodes, edges, edge_weights, tensors = create_graph_data(
            batch_idx, direction_indices
        )
        data_list = convert_to_torch_geometric(nodes, edges, edge_weights, tensors)
        batch_data = Batch.from_data_list(data_list).to(device)
        
        pred = model(batch_data.x, batch_data.edge_index, batch_data.edge_attr, batch_data.batch)

        
        # test_loss = F.mse_loss(pred, batch_data.y) * 1e6
        test_loss = model.weighted_mse_loss(pred, batch_data.y) * 1e6
        test_losses.append(test_loss.item())
        
        
        pred = pred.cpu()
        batch_data = batch_data.cpu()

        # Store predictions and ground truths
        all_test_preds.append(pred.detach().numpy())
        all_test_gts.append(batch_data.y.numpy())

# Combine all batches
all_test_preds = np.concatenate(all_test_preds, axis=0)
all_test_gts = np.concatenate(all_test_gts, axis=0)

# Print overall test loss
avg_test_loss = np.mean(test_losses)
print(f"\nFinal Test Results:")
print(f"Average test loss: {avg_test_loss:.6f}")

# Component-wise analysis
print("\nTest Set Component Analysis:")
for i in range(6):
    pred_comp = all_test_preds[:, i]
    gt_comp = all_test_gts[:, i]
    
    # Stats
    print(f"\nComponent {i}:")
    print(f"Predictions - min: {np.min(pred_comp):.2e}, max: {np.max(pred_comp):.2e}, "
          f"mean: {np.mean(pred_comp):.2e}, std: {np.std(pred_comp):.2e}")
    print(f"Ground Truth - min: {np.min(gt_comp):.2e}, max: {np.max(gt_comp):.2e}, "
          f"mean: {np.mean(gt_comp):.2e}, std: {np.std(gt_comp):.2e}")
    
    # Relative Error
    rel_error = np.mean(np.abs(pred_comp - gt_comp) / (np.abs(gt_comp) + 1e-10))
    print(f"Relative Error: {rel_error:.2e}")


# Print some example predictions vs ground truth
print("\nExample predictions vs ground truth:")
for i in range(5):
    print(f"\nSample {i+1}:")
    print(f"Predicted: {pred[i].numpy()}")
    print(f"Actual: {batch_data.y[i].numpy()}")
    

Starting training...

Epoch 1 Statistics:

Predictions per component:
Component 0: min=1.35e-03, max=1.60e-03, mean=1.50e-03, std=6.24e-05
Component 1: min=3.92e-03, max=4.18e-03, mean=4.00e-03, std=6.05e-05
Component 2: min=2.93e-03, max=3.43e-03, mean=3.22e-03, std=1.22e-04
Component 3: min=-2.00e-03, max=-1.61e-03, mean=-1.67e-03, std=7.95e-05
Component 4: min=2.52e-03, max=3.17e-03, mean=3.03e-03, std=1.30e-04
Component 5: min=1.18e-03, max=1.40e-03, mean=1.31e-03, std=5.55e-05

Ground Truth per component:
Component 0: min=4.91e-04, max=3.18e-03, mean=1.01e-03, std=5.10e-04
Component 1: min=-2.30e-04, max=7.96e-05, mean=-3.11e-05, std=6.57e-05
Component 2: min=3.25e-04, max=3.33e-03, mean=1.03e-03, std=5.64e-04
Component 3: min=-2.05e-04, max=1.68e-04, mean=-1.25e-05, std=6.84e-05
Component 4: min=-5.59e-05, max=2.46e-04, mean=1.42e-05, std=6.93e-05
Component 5: min=5.89e-04, max=3.20e-03, mean=1.03e-03, std=5.42e-04

Relative Error per component:
Component 0: 7.38e-01
Component 1:

In [57]:
# print("Original Ground Truth Tensor Analysis:")
# print("\nComponent-wise statistics:")
# for i in range(6):
#     comp = ground_truth_tensors[:, i]
#     print(f"\nComponent {i}:")
#     print(f"Min: {np.min(comp):.2e}")
#     print(f"Max: {np.max(comp):.2e}")
#     print(f"Mean: {np.mean(comp):.2e}")
#     print(f"Std: {np.std(comp):.2e}")

In [58]:
# print("Loading ground truth data 2...")
# gt_data_v2 = np.load('ground_truth_v2.npz')
# ground_truth_tensors_v2 = gt_data_v2['tensors']
# print(f"Ground truth tensors shape: {ground_truth_tensors_v2.shape}")

In [59]:
# print("V2 Original Ground Truth Tensor Analysis:")
# print("\nComponent-wise statistics:")
# for i in range(6):
#     comp = ground_truth_tensors_v2[:, i]
#     print(f"\nComponent {i}:")
#     print(f"Min: {np.min(comp):.2e}")
#     print(f"Max: {np.max(comp):.2e}")
#     print(f"Mean: {np.mean(comp):.2e}")
#     print(f"Std: {np.std(comp):.2e}")