In [1]:
# Imports
import nibabel as nib
import numpy as np
import torch
import os
from CNN1DModel import CNN1DModel
from CVRRegressionModel import CVRRegressionModel
from utils import save_cvr_as_3d_mat
from utils import normalize_fmri_timeseries
from utils import extract_subject_id
from utils import save_fmri_as_4d_mat


In [2]:
# Define Common directories 

BASE_DIR = "/Users/muhammadmahajna/workspace/research/data/cvr_est_project"
OUT_DIR = "./output/"

if not os.path.exists(OUT_DIR):
    os.makedirs(OUT_DIR)

# Subdirectories
TRAIN_INPUT_DIR = os.path.join(BASE_DIR, "func/registered/main_data/training")
VAL_INPUT_DIR = os.path.join(BASE_DIR, "func/registered/main_data/validation")
TEST_INPUT_DIR = os.path.join(BASE_DIR, "func/registered/main_data/testing")

TRAIN_TARGET_DIR = os.path.join(BASE_DIR, "CVR_MAPS/registered/training")
VAL_TARGET_DIR = os.path.join(BASE_DIR, "CVR_MAPS/registered/validation")
TEST_TARGET_DIR = os.path.join(BASE_DIR, "CVR_MAPS/registered/testing")

input_dir = TEST_INPUT_DIR
target_dir = TEST_TARGET_DIR

input_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(('.nii', '.nii.gz'))])
target_files = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir) if f.endswith(('.nii', '.nii.gz'))])


In [None]:
import numpy as np

from scipy.io import savemat
npy_file = "val_inputs.npy"
mat_file = "val_inputs.mat"
data = np.load(npy_file)
variable_name="data"
# Save as .mat file
savemat(mat_file, {variable_name: data})

In [None]:
# Save input and target data as mat files for visualization
for fmri_file_path in input_files:
    subject_id = extract_subject_id(fmri_file_path)
    
    output_file_path = os.path.join(OUT_DIR, subject_id + "_4D.mat")
    print("Saving results as MAT files")
    print(output_file_path)
    
    save_fmri_as_4d_mat(fmri_file_path, output_file_path, remove_time_points=5)
    break

for cvr_file_path in target_files:
    subject_id = extract_subject_id(cvr_file_path)
    
    output_file_path = os.path.join(OUT_DIR, subject_id + "_CVR.mat")
    print("Saving results as MAT files")
    print(output_file_path)
    
    save_cvr_as_3d_mat(cvr_file_path, output_file_path)
    break
    


In [5]:
REMOVE_TIME_POINT = 5  # Remove the first samples from the fMRI data
DATA_THRESHOLD = 1
ZERO_COUNT_THRESHOLD = 43 # 10%x430

def predict_cvr_map(fmri_file_path, model_path, output_file_path, model):
    """
    Predict a CVR map from an fMRI scan using a trained 1D CNN model.

    Args:
        fmri_file_path (str): Path to the input fMRI NIfTI file.
        model_path (str): Path to the trained model state dictionary (.pth).
        output_file_path (str): Path to save the predicted CVR map (NIfTI file).
    """
    # Load the fMRI scan
    print(f"Loading fMRI scan from {fmri_file_path}...")
    fmri_img = nib.load(fmri_file_path)
    fmri_data = fmri_img.get_fdata()  # Shape: (X, Y, Z, T)
    fmri_data = fmri_data[..., REMOVE_TIME_POINT:]  # Remove initial time points

    # Normalize the fMRI data along the time dimension
    print("Normalizing fMRI data...")
    #fmri_data_normalized, _ = normalize_fmri_timeseries(fmri_data)
    fmri_data_normalized = fmri_data
    
    # Prepare model
    input_size = fmri_data.shape[-1]  # Use time dimension as input size
    print(f"Initializing model with input_size={input_size}...")

    device = torch.device('cpu')
    print(f"Loading model weights from {model_path}...")
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    
    # Predict CVR map slice by slice
    print("Predicting CVR map...")
    cvr_map = np.zeros(fmri_data.shape[:3])  # Initialize empty CVR map

    for x in range(fmri_data.shape[0]):
        for y in range(fmri_data.shape[1]):
            for z in range(fmri_data.shape[2]):
                # Extract 1D time-series for each voxel
                voxel_time_series = fmri_data_normalized[x, y, z, :].reshape(1, 1, -1)  # Shape: (batch=1, channel=1, time)

                voxel_time_series[voxel_time_series < DATA_THRESHOLD] = 0

                # Check if the voxel time series is all zeros
                if np.sum(voxel_time_series == 0) > ZERO_COUNT_THRESHOLD:
                    cvr_value = 0.0  # Set CVR value to zero
                else:
                    # Predict CVR value for the voxel
                    input_tensor = torch.tensor(voxel_time_series, dtype=torch.float32)
                    with torch.no_grad():
                        cvr_value = model(input_tensor)[0].item()  # Get scalar output

                cvr_map[x, y, z] = cvr_value

    # Save the predicted CVR map as a NIfTI file
    print(f"Saving predicted CVR map to {output_file_path}...")
    predicted_cvr_img = nib.Nifti1Image(cvr_map, affine=fmri_img.affine)
    nib.save(predicted_cvr_img, output_file_path)

    print(f"Predicted CVR map saved successfully at {output_file_path}")

In [None]:
# Define model
model_path = "best_model_base_cnn_pre_proc.pth"
INPUT_SIZE = 430 # Number of time points
model_cnn = CNN1DModel(input_size=INPUT_SIZE)
best_model_base_cnn = CVRRegressionModel(input_length=INPUT_SIZE)

for fmri_file_path in input_files:
    subject_id = extract_subject_id(fmri_file_path)
    
    output_file_path = os.path.join(OUT_DIR, subject_id + "_PRED.nii")
    print(output_file_path)
    predict_cvr_map(fmri_file_path, model_path, output_file_path, best_model_base_cnn)
    
    
    print("Saving results as MAT files")
    save_cvr_as_3d_mat(output_file_path, output_file_path.replace(".nii", ".mat"))
    matches = [file for file in target_files if subject_id in file]
    if matches:
        ref_cvr_map_file = matches[0]
        save_cvr_as_3d_mat(ref_cvr_map_file, output_file_path.replace("_PRED.nii", "_ref.mat"))
    else:
        print("Could not locate the reference CVR file")
    break
    

./output/SF_01135_PRED.nii
Loading fMRI scan from /Users/muhammadmahajna/workspace/research/data/cvr_est_project/func/registered/main_data/testing/SF_01135_2_T1.nii.gz...
Normalizing fMRI data...
Initializing model with input_size=430...
Loading model weights from best_model_base_cnn_pre_proc.pth...
Predicting CVR map...


  state_dict = torch.load(model_path, map_location=device)
