In [1]:
import scipy.io
import jax.numpy as np
from scipy.io import loadmat 
import numpy as onp
import nibabel as nib
import matplotlib.pyplot as plt 
import os
import shutil
from PIL import Image
from inrmri.bart import bart_acquisition_from_arrays
from inrmri.data_harvard import get_csmaps_and_mask, get_reference_reco
import pandas as pd
from jax import jit, vmap, random
from inrmri.dip import TimeDependant_DIP_Net, helix_generator, circle_generator
from inrmri.new_radon import ForwardRadonOperator
from inrmri.fourier import fastshiftfourier, get_freqs
from inrmri.basic_nn import weighted_loss 
from inrmri.utils import to_complex, is_inside_of_radial_lim, meshgrid_from_subdiv_autolims, total_variation_batch_complex, save_matrix_and_dict_in_zpy, load_matrix_and_dict_from_zpy    
from inrmri.utils import create_exp_file_name, total_variation_complex
import optax 
from inrmri.metrics_rd import mean_psnr, mean_ssim, mean_artifact_power
from inrmri.utils_rdls import seconds_to_min_sec_format, filter_and_get_columns, apply_transform

from inrmri.image_processor import BeforeLinRegNormalizer
from inrmri.basic_plotting import full_halph_FOV_space_time 
from inrmri.image_processor import reduce_FOV 

from inrmri.utils_rdls import safe_normalize, get_center

from inrmri.utils_rdls import get_info_volunteer, read_segmentation, read_ref_dataset

## General

In [2]:
total_slices          = 8
num_frames            = 30
base_path             = '/mnt/workspace/datasets/pulseqCINE/'
target_columns        = ['training_name', 'psnr', 'ssim', 'it', 'duration [min]', 'duration [s]']

## Volunteer

In [3]:
dataset        = 'DATA_0.55T'
volunteer      = 'AA'
# --- PATH ---
base_folder                     = base_path + dataset + '/' + volunteer + '/'
train_data_folder               = base_folder + 'traindata/'
segmentation_folder             = base_folder + 'segmentations/'
segmentation_endocardium_folder = segmentation_folder + 'endocardium/'
segmentation_septum_folder      = segmentation_folder + 'septum/'

In [4]:
ef_folder = base_folder + 'ef/'
if not os.path.exists(ef_folder):
    os.makedirs(ef_folder)

## DIP Model reconstructions

In [5]:
stdip_parameters = {
    'experiment_name': 'iter_lr_init_value', 
    'training_params': {
        'iter':2000,
        'lr_init_value': 1e-3,
        'slice':None,
    }
}

tddip_parameters = {
    'experiment_name': 'iter_lr_init_value', 
    'training_params': {
        'iter':2000,
        'lr_init_value': 1e-3,
        'slice':None,
    }
}

In [6]:
stdip_parameters['path']               = base_folder + 'stDIP'
stdip_parameters['csv_path']           = stdip_parameters['path'] +  '/' + stdip_parameters['experiment_name'] + ".csv"
stdip_parameters['csv_path_summary']   = stdip_parameters['path'] +  '/' + stdip_parameters['experiment_name'] + "_summary.csv"

In [7]:
tddip_parameters['path']               = base_folder + 'tdDIP'
tddip_parameters['csv_path']           = tddip_parameters['path'] +  '/' + tddip_parameters['experiment_name'] + ".csv"
tddip_parameters['csv_path_summary']   = tddip_parameters['path'] +  '/' + tddip_parameters['experiment_name'] + "_summary.csv"

## Data

In [8]:
import os

all_slices_data = []

for i in range(1, total_slices + 1):
    print(f'slice {i}')
    slice_data = {}

    # === Volunteer and Training Parameters ===
    volunteer_params = get_info_volunteer(dataset, volunteer, i)
    stdip_parameters['training_params']['slice'] = i
    tddip_parameters['training_params']['slice'] = i

    # === Reference Dataset Path ===
    dataset_name = f'slice_{i}_{total_slices}_nbins{num_frames}'
    path_dataset = os.path.join(train_data_folder, f'{dataset_name}.npz')

    # === Reference Reconstructions ===
    recon_fs, recon_grasp, recon_sense, time_grasp, time_sense = read_ref_dataset(path_dataset)
    recon_grasp    = get_center(recon_grasp) 
    recon_sense    = get_center(recon_sense)
    recon_fs       = get_center(recon_fs) 
    recon_grasp    = safe_normalize(recon_grasp) 
    recon_sense    = safe_normalize(recon_sense) 
    recon_fs       = safe_normalize(recon_fs) 
    recon_fs = apply_transform(recon_fs, volunteer_params['trans_gt'])
    recon_grasp = apply_transform(recon_grasp, volunteer_params['trans'])
    recon_sense = apply_transform(recon_sense, volunteer_params['trans'])
    slice_data['recon'] = {
        'fs': recon_fs,
        'grasp': recon_grasp,
        'sense': recon_sense
    }

    # === stDIP Reconstruction ===
    stdip_parameters['exp_folder_path'] = os.path.join(stdip_parameters['path'], dataset_name, stdip_parameters['experiment_name'])
    df_stdip = pd.read_csv(stdip_parameters['csv_path'], delimiter=';')
    stDIP_results = filter_and_get_columns(df_stdip, stdip_parameters['training_params'], target_columns)[0]
    stdip_parameters['best_recon_path'] = os.path.join(stdip_parameters['exp_folder_path'], stDIP_results['training_name'], 'best_recon.npz')
    recon_stdip = onp.load(stdip_parameters['best_recon_path'], allow_pickle=True)['best_recon']
    recon_stdip = safe_normalize(recon_stdip)
    recon_stdip = apply_transform(recon_stdip, volunteer_params['trans'])
    slice_data['recon']['stdip'] = recon_stdip

    # === tdDIP Reconstruction ===
    tddip_parameters['exp_folder_path'] = os.path.join(tddip_parameters['path'], dataset_name, tddip_parameters['experiment_name'])
    df_tddip = pd.read_csv(tddip_parameters['csv_path'], delimiter=';')
    tdDIP_results = filter_and_get_columns(df_tddip, tddip_parameters['training_params'], target_columns)[0]
    tddip_parameters['best_recon_path'] = os.path.join(tddip_parameters['exp_folder_path'], tdDIP_results['training_name'], 'best_recon.npz')
    recon_tddip = onp.load(tddip_parameters['best_recon_path'], allow_pickle=True)['best_recon']
    recon_stdip = safe_normalize(recon_tddip)
    recon_tddip = apply_transform(recon_tddip, volunteer_params['trans'])
    slice_data['recon']['tddip'] = recon_tddip
    
    # === FS Segmentations ===
    seg_endo, seg_endo_fill, seg_end_area = read_segmentation(
        os.path.join(segmentation_endocardium_folder, f"{volunteer}_recon_fs_slice{i}.nii"))
    slice_data['segmentation'] = {
        'fs': {
            'endo': seg_endo,
            'endo_fill': seg_endo_fill,
            'area_ed': seg_end_area[volunteer_params['EF_frames']['EDV_gt']],
            'area_es': seg_end_area[volunteer_params['EF_frames']['ESV_gt']]
        }
    }

    # === GRASP Segmentations ===
    seg_endo, seg_endo_fill, seg_end_area = read_segmentation(
        os.path.join(segmentation_endocardium_folder, f"{volunteer}_grasp_slice{i}.nii"))
    seg_sep = read_segmentation(
        os.path.join(segmentation_septum_folder, f"{volunteer}_stdip_slice{i}.nii"), fill=False)
    slice_data['segmentation']['grasp'] = {
        'sep': seg_sep,
        'endo': seg_endo,
        'endo_fill': seg_endo_fill,
        'area_ed': seg_end_area[volunteer_params['EF_frames']['EDV']],
        'area_es': seg_end_area[volunteer_params['EF_frames']['ESV']]
    }

    # === stDIP Segmentations ===
    seg_endo, seg_endo_fill, seg_end_area = read_segmentation(
        os.path.join(segmentation_endocardium_folder, f"{volunteer}_stdip_slice{i}.nii"))
    slice_data['segmentation']['stdip'] = {
        'endo': seg_endo,
        'endo_fill': seg_endo_fill,
        'area_ed': seg_end_area[volunteer_params['EF_frames']['EDV']],
        'area_es': seg_end_area[volunteer_params['EF_frames']['ESV']]
    }

    # === tdDIP Segmentations ===
    seg_endo, seg_endo_fill, seg_end_area = read_segmentation(
        os.path.join(segmentation_endocardium_folder, f"{volunteer}_tddip_slice{i}.nii"))
    seg_sep = read_segmentation(
        os.path.join(segmentation_septum_folder, f"{volunteer}_stdip_slice{i}.nii"), fill=False)
    slice_data['segmentation']['tddip'] = {
        'sep': seg_sep,
        'endo': seg_endo,
        'endo_fill': seg_endo_fill,
        'area_ed': seg_end_area[volunteer_params['EF_frames']['EDV']],
        'area_es': seg_end_area[volunteer_params['EF_frames']['ESV']]
    }
    
    # === Save Slice Info ===
    slice_data['metadata'] = {
        'slice_number': i,
        'dataset_name': dataset_name,
        'volunteer_params': volunteer_params
    }

    all_slices_data.append(slice_data)


slice 1


2025-08-04 18:00:53.199732: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


slice 2
slice 3
slice 4
slice 5
slice 6
slice 7
slice 8


In [9]:
all_slices_data[0]['metadata']['volunteer_params']['saturation']

0.3

In [10]:
def plot_area_per_slice(all_slices_data, methods=['fs', 'stdip'], save_path=None):
    import matplotlib.pyplot as plt
    import os

    slices = [slice_data['metadata']['slice_number'] for slice_data in all_slices_data]
    color_cycle = plt.get_cmap('tab10').colors
    linestyle_map = {'ED': '-', 'ES': '--'}
    marker_map = {'ED': 'o', 'ES': 'x'}

    plt.figure(figsize=(10, 5))

    for idx, method in enumerate(methods):
        color = color_cycle[idx % len(color_cycle)]
        ed_areas, es_areas = [], []

        for slice_data in all_slices_data:
            segmentation = slice_data['segmentation'].get(method)
            if segmentation:
                ed_areas.append(segmentation['area_ed'])
                es_areas.append(segmentation['area_es'])
            else:
                ed_areas.append(np.nan)
                es_areas.append(np.nan)

        plt.plot(slices, ed_areas, marker_map['ED'] + linestyle_map['ED'], color=color, label=f'{method.upper()} ED')
        plt.plot(slices, es_areas, marker_map['ES'] + linestyle_map['ES'], color=color, label=f'{method.upper()} ES')

    plt.xlabel('Slice Number')
    plt.ylabel('Endocardial Area')
    plt.title('Endocardial Area across Slices')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300)
        print(f"Saved plot to {save_path}")
        plt.close()
    else:
        plt.show()




def sum_areas(methods, all_slices_data, save_path=None):
    import os

    total = {}

    for method in methods:
        total[method] = {'ED': 0.0, 'ES': 0.0}

        for slice_data in all_slices_data:
            total[method]['ED'] += slice_data['segmentation'][method]['area_ed']
            total[method]['ES'] += slice_data['segmentation'][method]['area_es']

        total[method]['S'] = total[method]['ED'] - total[method]['ES']
        total[method]['EF'] = 100 * total[method]['S'] / total[method]['ED']

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'w') as f:
            for method in total:
                for key, val in total[method].items():
                    f.write(f"{method.upper()} {key}: {val:.3f}\n")
        print(f"Saved EF summary to {save_path}")

    return total



import os
import numpy as np
import matplotlib.pyplot as plt

def show_recons(all_slices_data, methods=['fs', 'stdip'], with_segmentation=True, save_folder=None):
    for slice_index, slice_data in enumerate(all_slices_data):
        slice_num = slice_data['metadata']['slice_number']
        volunteer_params = slice_data['metadata']['volunteer_params']
        saturation = volunteer_params['saturation']

        for method in methods:
            # Select correct ED/ES frames based on method
            ef_frames = volunteer_params['EF_frames']
            if method == 'fs':
                frame_ed = ef_frames['EDV_gt']
                frame_es = ef_frames['ESV_gt']
            else:
                frame_ed = ef_frames['EDV']
                frame_es = ef_frames['ESV']

            for frame, label in zip([frame_ed, frame_es], ['ED', 'ES']):
                fig, ax = plt.subplots(figsize=(5, 5))

                recon = slice_data['recon'][method][:, :, frame]
                recon = np.abs(recon)
                vmax = saturation

                ax.imshow(np.clip(recon, 0, vmax), cmap='bone', vmin=0, vmax=vmax)

                if with_segmentation:
                    try:
                        seg = slice_data['segmentation'][method]['endo'][:, :, frame]
                    except KeyError:
                        seg = slice_data['segmentation']['fs']['endo'][:, :, frame]
                    ax.contour(seg, levels=[0.5], colors='red', linewidths=1.5)

                ax.axis('off')

                if save_folder:
                    os.makedirs(save_folder, exist_ok=True)
                    fname = f"slice{slice_num:02d}_{method}_frame{label}"
                    if with_segmentation:
                        fname += "_seg.png"
                    else:
                        fname += ".png"
                    fpath = os.path.join(save_folder, fname)
                    plt.savefig(fpath, dpi=300, bbox_inches='tight', pad_inches=0)
                    print(f"Saved: {fpath}")
                    plt.close()
                else:
                    plt.show()






In [11]:
show_recons(all_slices_data, methods=['fs', 'stdip', 'tddip', 'grasp'], with_segmentation=True, save_folder=ef_folder)
show_recons(all_slices_data, methods=['fs', 'stdip', 'tddip', 'grasp'], with_segmentation=False, save_folder=ef_folder)

Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_fs_frameED_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_fs_frameES_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_stdip_frameED_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_stdip_frameES_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_tddip_frameED_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_tddip_frameES_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_grasp_frameED_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice01_grasp_frameES_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice02_fs_frameED_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice02_fs_frameES_seg.png
Saved: /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/slice02_stdip_frameED_seg.png
Saved: /mnt/workspace/datasets/pu

In [12]:
sum_areas(['fs', 'stdip', 'tddip', 'grasp'], all_slices_data, save_path=ef_folder + 'ef.txt')

Saved EF summary to /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/ef.txt


{'fs': {'ED': 3250.0, 'ES': 1567.0, 'S': 1683.0, 'EF': 51.784615384615385},
 'stdip': {'ED': 3169.0, 'ES': 1506.0, 'S': 1663.0, 'EF': 52.47712212054276},
 'tddip': {'ED': 3102.0, 'ES': 1457.0, 'S': 1645.0, 'EF': 53.03030303030303},
 'grasp': {'ED': 2977.0, 'ES': 1519.0, 'S': 1458.0, 'EF': 48.97547866980182}}

In [13]:
plot_area_per_slice(all_slices_data, methods=['fs', 'stdip', 'tddip', 'grasp'], save_path=ef_folder + 'areas.png')

Saved plot to /mnt/workspace/datasets/pulseqCINE/DATA_0.55T/AA/ef/areas.png
