In [8]:
import nibabel as nib
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
import os
from CNN1DModel import CNN1DModel



In [9]:
def predict_cvr_map(fmri_file_path, model_path, output_file_path):
    """
    Predict a CVR map from an fMRI scan using a trained 1D CNN model.

    Args:
        fmri_file_path (str): Path to the input fMRI NIfTI file.
        model_path (str): Path to the trained model state dictionary (.pth).
        output_file_path (str): Path to save the predicted CVR map (NIfTI file).
    """
    # Load the fMRI scan
    print(f"Loading fMRI scan from {fmri_file_path}...")
    fmri_img = nib.load(fmri_file_path)
    fmri_data = fmri_img.get_fdata()  # Shape: (X, Y, Z, T)

    # Normalize the fMRI data along the time dimension
    print("Normalizing fMRI data...")
    scaler = StandardScaler()
    fmri_data_flat = fmri_data.reshape(-1, fmri_data.shape[-1])  # Flatten spatial dimensions
    fmri_data_normalized = scaler.fit_transform(fmri_data_flat).reshape(fmri_data.shape)

    # Prepare model
    input_size = fmri_data.shape[-1]  # Use time dimension as input size
    print(f"Initializing model with input_size={input_size}...")
    model = CNN1DModel(input_size=input_size)

    print(f"Loading model weights from {model_path}...")
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model.eval()

    # Predict CVR map slice by slice
    print("Predicting CVR map...")
    cvr_map = np.zeros(fmri_data.shape[:3])  # Initialize empty CVR map

    for x in range(fmri_data.shape[0]):
        for y in range(fmri_data.shape[1]):
            for z in range(fmri_data.shape[2]):
                # Extract 1D time-series for each voxel
                voxel_time_series = fmri_data_normalized[x, y, z, :].reshape(1, 1, -1)  # Shape: (batch=1, channel=1, time)
                input_tensor = torch.tensor(voxel_time_series, dtype=torch.float32)

                # Predict CVR value for the voxel
                with torch.no_grad():
                    cvr_value = model(input_tensor).item()  # Get scalar output
                cvr_map[x, y, z] = cvr_value

    # Save the predicted CVR map as a NIfTI file
    print(f"Saving predicted CVR map to {output_file_path}...")
    predicted_cvr_img = nib.Nifti1Image(cvr_map, affine=fmri_img.affine)
    nib.save(predicted_cvr_img, output_file_path)

    print(f"Predicted CVR map saved successfully at {output_file_path}")

In [None]:
# Main function to call the prediction script
if __name__ == "__main__":
    # Define paths
    fmri_file_path = "/Users/muhammadmahajna/workspace/research/data/cvr_est_project/CVR_MAPS/registered/testing/SF_01138_CVR_2_T1.nii.gz"
    
    model_path = "best_model.pth"
    output_file_path = "output_cvr_map.nii"

    # Check if paths exist
    if not os.path.exists(fmri_file_path):
        raise FileNotFoundError(f"Input fMRI file not found at {fmri_file_path}")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Trained model file not found at {model_path}")

    # Predict the CVR map
    predict_cvr_map(fmri_file_path, model_path, output_file_path)


Loading fMRI scan from /Users/muhammadmahajna/workspace/research/data/cvr_est_project/func/registered/main_data/training/SF_01035_2_T1.nii.gz...
Normalizing fMRI data...
Initializing model with input_size=435...
Loading model weights from best_model.pth...
Predicting CVR map...


  state_dict = torch.load(model_path, map_location=torch.device('cpu'))


Saving predicted CVR map to output_cvr_map.nii...
Predicted CVR map saved successfully at output_cvr_map.nii
