In [210]:
import numpy as np
import nibabel as nib
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool
import torch.optim as optim
from torch_geometric.data import Data, Batch
from sklearn.model_selection import train_test_split


In [211]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [212]:
subject_id = "100206"
subject_path = f"diffusion_data/{subject_id}/T1w/Diffusion"
dwi_img = nib.load(f'{subject_path}/data.nii.gz')
mask_img = nib.load(f'{subject_path}/nodif_brain_mask.nii.gz')

In [213]:
dwi_data = dwi_img.get_fdata()
original_mask = mask_img.get_fdata() > 0

In [239]:
# store original indices before any voxel filtering 
original_idx = np.where(original_mask)

In [240]:
# Map (x,y,z) to original linear index
coord_to_gtt = {}
for i in range(len(original_idx[0])):
    x, y, z = original_idx[0][i], original_idx[1][i], original_idx[2][i]
    coord_to_gtt[(x,y,z)] = i  # i is the index in ground_truth_tensors

In [241]:
print(f"DWI data shape: {dwi_data.shape}")  # (X, Y, Z, num_volumes)

DWI data shape: (145, 174, 145, 288)


In [242]:
bvals, bvecs = read_bvals_bvecs(f'{subject_path}/bvals', 
                               f'{subject_path}/bvecs')
gtab = gradient_table(bvals, bvecs)

In [243]:
# get b0 data and b0_avg
b0_mask = gtab.b0s_mask
b0_data = dwi_data[..., b0_mask]
print(f"B0 data shape: {b0_data.shape}")
b0_avg = np.mean(b0_data, axis=-1)

# b1000 images mask
b1000_mask = (bvals >= 990) & (bvals <= 1010)
print(f"Number of b=1000 volumes: {np.sum(b1000_mask)}")

B0 data shape: (145, 174, 145, 18)
Number of b=1000 volumes: 90


In [244]:
# Get all b1000 scans
dwi_vols = dwi_data[..., b1000_mask]
print(f"DWI volumes shape: {dwi_vols.shape}")

DWI volumes shape: (145, 174, 145, 90)


In [245]:
# mask out voxels with very low b0 signal
b0_threshold = 250
valid_b0_mask = b0_avg > b0_threshold
mask = original_mask & valid_b0_mask  # Combine with brain mask

In [246]:
# Find valid voxels using the brain mask
valid_idx = np.where(mask) # shape: 3 x num_valid_voxels
print(f"\nNumber of valid voxels in mask: {len(valid_idx[0])}")


Number of valid voxels in mask: 926671


In [247]:
# Sample random voxels for training
n_samples = 150000
sample_idx = np.random.choice(len(valid_idx[0]), 
                            min(n_samples, len(valid_idx[0])), 
                            replace=False)

In [248]:
# Extract features (signal intensities) from sampled voxels
features = []
gtt_indices = [] # store ground truth tensor indices. gtt_indices[i] gives index of ith feature in original number of voxels (93XXXX)

for idx in sample_idx:
    x, y, z = valid_idx[0][idx], valid_idx[1][idx], valid_idx[2][idx]
    signal = dwi_vols[x, y, z, :]
    b0_ref = b0_avg[x, y, z]
    normalized_signal = signal / (b0_ref)
    features.append(normalized_signal)
    gtt_indices.append(coord_to_gtt[(x, y, z)])

# gtt_indices shape = 100000 x 1
# features shape = 100000 x 90 x 1


In [249]:
features = np.array(features)
gtt_indices = np.array(gtt_indices)
gtt_indices_gpu = torch.from_numpy(gtt_indices).to(device)
gradient_directions = bvecs[b1000_mask]  # Only keep directions for DWI volumes, shape = 90 x 3

In [250]:
print("Loading ground truth data...")
gt_data = np.load('ground_truth_v2.npz')
ground_truth_tensors = gt_data['tensors'] # shape 936256(all voxels) x 6

Loading ground truth data...


In [251]:
ground_truth_tensors.shape
n_directions = 21

In [252]:
def select_diverse_directions(all_directions, n_select=21):
    selected = [0]  # Start with first direction
    while len(selected) < n_select:
        # Calculate angles with all selected directions
        angles = []
        for i in range(len(all_directions)):
            if i in selected:
                continue
            min_angle = float('inf')
            for s in selected:
                # Cosine similarity
                angle = np.arccos(np.clip(
                    np.dot(all_directions[i], all_directions[s]), -1.0, 1.0))
                min_angle = min(min_angle, angle)
            angles.append((i, min_angle))
        # Select direction with largest minimum angle
        next_idx = max(angles, key=lambda x: x[1])[0]
        selected.append(next_idx)
    return selected

In [253]:
def initialize_direction_sets(all_directions, n_sets=3, n_directions=21):
    base_sets = []
    n_total = len(all_directions)
    
    for i in range(n_sets):
        # Create a shuffled index array for this set
        shuffled_indices = np.random.permutation(n_total)
        # Use select_diverse_directions on the shuffled indices
        selected = select_diverse_directions(all_directions[shuffled_indices], n_directions)
        # Map back to original indices
        original_indices = shuffled_indices[selected]
        base_sets.append(original_indices)
    
    return base_sets

In [254]:
class DiffusionGNN(torch.nn.Module):
    def __init__(self, node_features = 4, hidden_dim = 128):
        super(DiffusionGNN, self).__init__()
    
        # Layers
        self.conv1 = DiffusionConv(node_features, hidden_dim)
        self.conv2 = DiffusionConv(hidden_dim, hidden_dim)

        # MLP for tensor prediction
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 6),
        )

        self.register_buffer('loss_weights',
            torch.tensor([10, 500, 10, 500, 500, 10]))

    
    def forward(self, x, edge_index, edge_attr, batch):
        """
        x: node_features
        edge_index = graph conn info [2, num_edges]
        batch: batch assignment for nodes
        """
        # print("pre conv", x.shape)
        # x shape = 672 (batch_size*num_directions) x 4
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        # x shape = 672 x hidden_dim(64)
        # print("after conv", x.shape)

        # combine node features for graph
        x = global_mean_pool(x, batch)  # [batch_size, hidden_dim]. 1 representation for each voxel/graph
        # print("after pool", x.shape)

        # raw predictions
        out = self.mlp(x) # shape = 32 x 6

        diag_idx = [0, 2, 5]
        offdiag_idx = [1, 3, 4]

        # Diagonal elements: [0,1] using sigmoid
        diag = torch.sigmoid(out[:, diag_idx])
        
        # Off-diagonal elements: [-1,1] using tanh
        offdiag = torch.tanh(out[:, offdiag_idx])

        out_reordered = torch.zeros_like(out)
        out_reordered[:, diag_idx] = diag
        out_reordered[:, offdiag_idx] = offdiag

        return out_reordered # shape = 32 x 6
    
    def weighted_mse_loss(self, pred, target):
        squared_diff = (pred - target) ** 2  # [batch_size, 6]
        weighted_diff = squared_diff * self.loss_weights
        return torch.mean(weighted_diff)


class DiffusionConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(DiffusionConv, self).__init__(aggr="mean")
        
        # MLP to process messages
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
    
    def forward(self, x, edge_index, edge_weight):
        return self.propagate(edge_index, x=x, edge_weight = edge_weight)
    
    def message(self, x_i, x_j, edge_weight):
        """
        x_i: features of target nodes
        x_j: features of source nodes
        Returns: messages to be aggregated
        """
        tmp = torch.cat([x_i, x_j], dim = 1)
        return self.mlp(tmp) * edge_weight.view(-1,1)

In [255]:
def convert_to_torch_geometric(nodes, edges, edge_weights, tensors):
    data_list = []
    device = nodes.device
    
    for i in range(len(nodes)):
        tensor = tensors[i].reshape(-1, 6) if len(tensors[i].shape) == 1 else tensors[i]
        data = Data(
            x=nodes[i],  # Already GPU tensor
            edge_index = edges[i].clone().detach().to(dtype=torch.long, device=device).T,
            edge_attr=edge_weights[i].clone().detach().to(dtype=torch.float, device=device),
            y=tensor  # Already GPU tensor
        )
        data_list.append(data)
    
    return data_list

In [256]:
train_idx, test_idx = train_test_split(np.arange(len(features)), test_size=0.2, random_state=42)
print(f"Training with {len(train_idx)} voxels, testing with {len(test_idx)} voxels")

Training with 120000 voxels, testing with 30000 voxels


In [257]:
def evaluate_model(pred, gt, epoch, print_results=True):
    """
    Compute and optionally print evaluation metrics
    Args:
        pred: predictions array (n_samples x 6)
        gt: ground truth array (n_samples x 6)
        epoch: current epoch number
        print_results: whether to print metrics
    Returns:
        Dictionary of metrics
    """
    metrics = {}
    
    # Component-wise statistics and relative errors
    for i in range(6):
        # Relative error
        rel_error = np.mean(np.abs(pred[:,i] - gt[:,i]) / (np.abs(gt[:,i]) + 1e-10))
        metrics[f'rel_error_{i}'] = rel_error
        
        # Statistics for monitoring distribution
        metrics[f'pred_mean_{i}'] = np.mean(pred[:,i])
        metrics[f'pred_std_{i}'] = np.std(pred[:,i])
        metrics[f'gt_mean_{i}'] = np.mean(gt[:,i])
        metrics[f'gt_std_{i}'] = np.std(gt[:,i])
    
    # Mean Diffusivity error
    pred_md = np.mean([pred[:,0], pred[:,2], pred[:,5]], axis=0)  # Dxx, Dyy, Dzz
    gt_md = np.mean([gt[:,0], gt[:,2], gt[:,5]], axis=0)
    metrics['md_rel_error'] = np.mean(np.abs(pred_md - gt_md) / (np.abs(gt_md) + 1e-10))

    if print_results:
        print(f"\nEpoch {epoch+1} Statistics:")
        print("\nRelative Errors per component:")
        for i in range(6):
            print(f"Component {i}: {metrics[f'rel_error_{i}']:.2e}")
        print(f"Mean Diffusivity Error: {metrics['md_rel_error']:.2e}")
        
        print("\nPrediction Statistics:")
        for i in range(6):
            print(f"Component {i}: mean={metrics[f'pred_mean_{i}']:.2e}, std={metrics[f'pred_std_{i}']:.2e}")
        
        print("\nGround Truth Statistics:")
        for i in range(6):
            print(f"Component {i}: mean={metrics[f'gt_mean_{i}']:.2e}, std={metrics[f'gt_std_{i}']:.2e}")
    
    return metrics

In [258]:
def precompute_direction_structures(direction_indices, threshold_angle=60):
    """
    Precompute direction-based graph structures that remain constant across batches.
    
    Args:
        direction_indices: indices of directions to use, shape (21,)
        
    Returns:
        dict: Precomputed structures for graph creation
    """
    current_directions = gradient_directions[direction_indices]  # Shape (21, 3)
    print(f"cd shape {current_directions.shape}")
    directions_norm = current_directions / np.linalg.norm(current_directions, axis=1, keepdims=True)
    
    cos_sim = np.dot(directions_norm, directions_norm.T)
    angles = np.arccos(np.clip(cos_sim, -1.0, 1.0)) * 180/np.pi
    src, dst = np.where(angles < threshold_angle)
    mask = src != dst
    
    return {
        'edge_template': np.column_stack([src[mask], dst[mask]]),
        'edge_weights': cos_sim[src[mask], dst[mask]],
        'directions': current_directions,
        'direction_indices': direction_indices
    }


def create_batch_data(batch_indices, direction_structures, features_gpu, ground_truth_gpu):
    """
    Create graph data for a batch using GPU tensors.
    
    Args:
        batch_indices: array of indices for this batch, shape (batch_size,)
        direction_structures: dict of precomputed GPU direction structures
        features_gpu: GPU tensor of all features, shape (n_samples, 90)
        ground_truth_gpu: GPU tensor of ground truth, shape (n_samples, 6)
    
    Returns:
        batch_nodes: [batch_size, n_directions, 4] node features (GPU)
        batch_edges: list of [E, 2] edge indices for each graph
        batch_edge_weights: list of [E] edge weights for each graph
        batch_tensors: [batch_size, 6] ground truth tensors (GPU)
    """
    n_directions = direction_structures['directions'].shape[0]  # Should be 21
    
    # Convert batch_indices to GPU tensor if not already
    if not isinstance(batch_indices, torch.Tensor):
        batch_indices = torch.tensor(batch_indices, device=features_gpu.device)
    
    # Get batch of signals and select directions
    # features_gpu shape: (n_samples, 90)
    # batch_features shape: (batch_size, 90)
    batch_features = features_gpu[batch_indices]
    
    # Select specific directions
    # direction_indices shape: (21,)
    # batch_signals shape: (batch_size, 21)
    direction_indices = torch.tensor(direction_structures['direction_indices'], 
                               device=features_gpu.device)
    
    batch_signals = batch_features[:, direction_indices]
    
    # Create node features on GPU
    # batch_nodes shape: (batch_size, n_directions, 4)
    batch_nodes = torch.zeros((batch_indices.shape[0], n_directions, 4), 
                            device=features_gpu.device)
    batch_nodes[..., 0] = batch_signals
    batch_nodes[..., 1:4] = direction_structures['directions'].unsqueeze(0)
    
    # Edge structures (already on GPU from direction_structures_gpu)
    batch_edges = [direction_structures['edge_template'] 
                  for _ in range(batch_indices.shape[0])]
    batch_edge_weights = [direction_structures['edge_weights'] 
                         for _ in range(batch_indices.shape[0])]
    
    # Get ground truth tensors directly from GPU tensor
    # gtt_indices maps from batch_indices to ground truth indices
    # batch_tensors shape: (batch_size, 6)
    batch_tensors = ground_truth_gpu[gtt_indices_gpu[batch_indices]]
    
    # Verify shapes (comparing torch sizes)
    expected_node_shape = torch.Size([batch_indices.shape[0], n_directions, 4])
    expected_tensor_shape = torch.Size([batch_indices.shape[0], 6])
    
    assert batch_nodes.shape == expected_node_shape, \
        f"Expected nodes shape {expected_node_shape}, got {batch_nodes.shape}"
    assert batch_tensors.shape == expected_tensor_shape, \
        f"Expected tensors shape {expected_tensor_shape}, got {batch_tensors.shape}"
    
    # print("\nDEBUG: create_batch_data shapes:")
    # print(f"Input batch_indices shape: {batch_indices.shape}")         # Expected: (256,)
    # print(f"Features GPU shape: {features_gpu.shape}")                 # Expected: (n_samples, 90)
    # print(f"Ground truth GPU shape: {ground_truth_gpu.shape}")        # Expected: (n_samples, 6)
    # print(f"Batch features shape: {batch_features.shape}")            # Expected: (256, 90)
    # print(f"Batch signals shape: {batch_signals.shape}")              # Expected: (256, 21)
    # print(f"Batch nodes shape: {batch_nodes.shape}")                  # Expected: (256, 21, 4)
    # print(f"Edge template shape: {direction_structures['edge_template'].shape}")  # Expected: (E, 2)
    # print(f"Edge weights shape: {direction_structures['edge_weights'].shape}")    # Expected: (E,)
    # print(f"Batch tensors shape: {batch_tensors.shape}")              # Expected: (256, 6)
    

    return batch_nodes, batch_edges, batch_edge_weights, batch_tensors

In [259]:
n_sets = 3
direction_sets = initialize_direction_sets(gradient_directions, n_sets)
direction_structures = {}
for set_idx, direction_set in enumerate(direction_sets):
    direction_structures[set_idx] = precompute_direction_structures(direction_set)
batch_size = 32


cd shape (21, 3)
cd shape (21, 3)
cd shape (21, 3)


In [260]:
# Before training loop begins
features_gpu = torch.from_numpy(features).to(device)
ground_truth_gpu = torch.from_numpy(ground_truth_tensors).to(device)

# Convert direction structures once
direction_structures_gpu = {}
for idx, struct in direction_structures.items():
    direction_structures_gpu[idx] = {
        'edge_template': torch.from_numpy(struct['edge_template']).to(device),
        'edge_weights': torch.from_numpy(struct['edge_weights']).to(device),
        'directions': torch.from_numpy(struct['directions']).to(device),
        'direction_indices': struct['direction_indices']
    }


In [261]:
import time

In [262]:
# Training
model = DiffusionGNN(hidden_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
epochs = 60
print("Starting training...")
for epoch in range(epochs):
    # Select direction set for this epoch
    direction_set_idx = epoch % len(direction_sets)
    current_structures = direction_structures_gpu[direction_set_idx]
    model.train()
    epoch_loss = 0
    n_batches = 0
    
    # Store epoch predictions and ground truth
    epoch_pred_values = []
    epoch_gt_values = []
    
    # Shuffle training indices
    np.random.shuffle(train_idx)
    
    # Process batches
    for start in range(0, len(train_idx), batch_size):
        batch_indices = train_idx[start:start + batch_size]
        
        # Create graph data using precomputed structures
        nodes, edges, edge_weights, tensors = create_batch_data(
            batch_indices,
            current_structures,
            features_gpu,
            ground_truth_gpu
        )

        data_list = convert_to_torch_geometric(nodes, edges, edge_weights, tensors)
        batch_data = Batch.from_data_list(data_list).to(device)
        
        # Training step
        optimizer.zero_grad()
        pred = model(batch_data.x, batch_data.edge_index, batch_data.edge_attr, batch_data.batch)
        loss = model.weighted_mse_loss(pred, batch_data.y) * 1e6
        loss.backward()
        optimizer.step()
        
        # Store batch results
        epoch_loss += loss.item()
        epoch_pred_values.append(pred.detach().cpu().numpy())
        epoch_gt_values.append(batch_data.y.cpu().numpy())
        n_batches += 1
    
    # Combine all batches
    epoch_pred = np.concatenate(epoch_pred_values, axis=0)
    epoch_gt = np.concatenate(epoch_gt_values, axis=0)
    
    # Print epoch results
    avg_loss = epoch_loss / n_batches
    print("\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    
    # Evaluate epoch performance
    if epoch % 1 == 0:
        _ = evaluate_model(epoch_pred, epoch_gt, epoch)
    
    

Starting training...

xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
Epoch 1/60, Loss: 20834.347618

Epoch 1 Statistics:

Relative Errors per component:
Component 0: 1.59e+04
Component 1: 4.04e+04
Component 2: 1.51e+04
Component 3: 1.35e+04
Component 4: 3.48e+04
Component 5: 1.56e+04
Mean Diffusivity Error: 1.55e+04

Prediction Statistics:
Component 0: mean=1.39e-02, std=5.99e-02
Component 1: mean=1.10e-04, std=3.64e-03
Component 2: mean=1.30e-02, std=5.75e-02
Component 3: mean=-5.05e-05, std=2.52e-03
Component 4: mean=-1.59e-04, std=4.12e-03
Component 5: mean=1.37e-02, std=5.75e-02

Ground Truth Statistics:
Component 0: mean=9.89e-04, std=5.49e-04
Component 1: mean=-2.17e-07, std=1.15e-04
Component 2: mean=1.03e-03, std=5.57e-04
Component 3: mean=3.91e-06, std=1.16e-04
Component 4: mean=-2.21e-05, std=1.20e-04
Component 5: mean=1.00e-03, std=5.63e-04

xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
Epoch 2/60, Loss: 288.476552

Epoch 2 Statistics:

Relative Errors per component:
Component 0: 3.31e+03


In [263]:
print("\nStarting Final Evaluation...")
model.eval()
test_losses = []
test_preds = []
test_gts = []

with torch.no_grad():
    # Process all batches, including last partial batch
    for start in range(0, len(test_idx), batch_size):
        end = min(start + batch_size, len(test_idx))
        batch_indices = test_idx[start:end]
        
        # Create graph data using precomputed structures
        nodes, edges, edge_weights, tensors = create_batch_data(
            batch_indices,
            direction_structures_gpu[direction_set_idx],
            features_gpu,
            ground_truth_gpu
        )

        data_list = convert_to_torch_geometric(nodes, edges, edge_weights, tensors)
        batch_data = Batch.from_data_list(data_list).to(device)
        
        # Get predictions
        pred = model(batch_data.x, batch_data.edge_index, batch_data.edge_attr, batch_data.batch)
        
        # Calculate loss
        test_loss = model.weighted_mse_loss(pred, batch_data.y) * 1e6
        test_losses.append(test_loss.item())
        
        # Store predictions and ground truth
        test_preds.append(pred.cpu().numpy())
        test_gts.append(batch_data.y.cpu().numpy())

# Combine results
all_test_preds = np.concatenate(test_preds, axis=0)
all_test_gts = np.concatenate(test_gts, axis=0)
avg_test_loss = np.mean(test_losses)

print(f"\nFinal Test Results:")
print(f"Average test loss: {avg_test_loss:.6f}")

# Use same evaluation function as training
metrics = evaluate_model(all_test_preds, all_test_gts, epoch=-2)

# Print random examples from full test set
print("\nExample predictions vs ground truth (randomly sampled):")
n_examples = 5
rand_indices = np.random.choice(len(all_test_preds), n_examples, replace=False)
for i, idx in enumerate(rand_indices):
    print(f"\nSample {i+1}:")
    print(f"Predicted: {all_test_preds[idx]}")
    print(f"Actual:    {all_test_gts[idx]}")


Starting Final Evaluation...

Final Test Results:
Average test loss: 0.839526

Epoch -1 Statistics:

Relative Errors per component:
Component 0: 3.11e+02
Component 1: 6.78e+02
Component 2: 4.89e+02
Component 3: 5.74e+02
Component 4: 1.03e+03
Component 5: 3.11e+02
Mean Diffusivity Error: 3.70e+02

Prediction Statistics:
Component 0: mean=1.01e-03, std=5.52e-04
Component 1: mean=4.28e-05, std=1.01e-04
Component 2: mean=1.02e-03, std=5.46e-04
Component 3: mean=4.12e-05, std=1.07e-04
Component 4: mean=-4.13e-05, std=1.08e-04
Component 5: mean=9.92e-04, std=5.39e-04

Ground Truth Statistics:
Component 0: mean=9.88e-04, std=5.53e-04
Component 1: mean=-2.94e-07, std=1.15e-04
Component 2: mean=1.03e-03, std=5.59e-04
Component 3: mean=5.08e-06, std=1.17e-04
Component 4: mean=-2.18e-05, std=1.20e-04
Component 5: mean=1.00e-03, std=5.65e-04

Example predictions vs ground truth (randomly sampled):

Sample 1:
Predicted: [ 5.9884146e-04  1.2237392e-04  1.1475059e-03 -2.4117529e-05
 -3.7674789e-04  

In [264]:
# print("Original Ground Truth Tensor Analysis:")
# print("\nComponent-wise statistics:")
# for i in range(6):
#     comp = ground_truth_tensors[:, i]
#     print(f"\nComponent {i}:")
#     print(f"Min: {np.min(comp):.2e}")
#     print(f"Max: {np.max(comp):.2e}")
#     print(f"Mean: {np.mean(comp):.2e}")
#     print(f"Std: {np.std(comp):.2e}")

In [265]:
# print("Loading ground truth data 2...")
# gt_data_v2 = np.load('ground_truth_v2.npz')
# ground_truth_tensors_v2 = gt_data_v2['tensors']
# print(f"Ground truth tensors shape: {ground_truth_tensors_v2.shape}")

In [266]:
# print("V2 Original Ground Truth Tensor Analysis:")
# print("\nComponent-wise statistics:")
# for i in range(6):
#     comp = ground_truth_tensors_v2[:, i]
#     print(f"\nComponent {i}:")
#     print(f"Min: {np.min(comp):.2e}")
#     print(f"Max: {np.max(comp):.2e}")
#     print(f"Mean: {np.mean(comp):.2e}")
#     print(f"Std: {np.std(comp):.2e}")