In [1]:
import os
import time
import cv2
import scipy.io as sio
import numpy as np
import pandas as pd

from cest_mrf.metrics.dot_product import dot_prod_matching
from cest_mrf.metrics.euc_dist import euc_dist_matching

from my_funcs.new_plot_functions import mask_roi_finder
from my_funcs.new_plot_functions import mask_check_plot
from my_funcs.new_plot_functions import compare_txt_method
from my_funcs.new_plot_functions import t1_t2_pixel_reader
from my_funcs.new_plot_functions import real_t1_t2
from my_funcs.new_plot_functions import real_t1_t2_bg
from my_funcs.new_plot_functions import dict_t1_t2
from my_funcs.new_plot_functions import dict_fs_ksw
from my_funcs.new_plot_functions import plot_norm_sig
from my_funcs.new_plot_functions import multi_b_fit_plot
from my_funcs.new_plot_functions import z_mtr_plot

from my_funcs.cest_functions import bruker_dataset_creator
from my_funcs.cest_functions import dicom_data_arranger

from my_funcs.path_functions import make_folder


## 1.1. Enter Data: ##

In [2]:
# Root stats:
general_fn = os.path.abspath(os.curdir)
txt_file_name = 'labarchive_notes.txt'
save_name = '24_03_13'
fp_prtcl_names = ['107a']  # to be looped later with all options
f_const = 3 / 110000
dict_ranges =  {'fs_0': (-1*f_const, 30*f_const),
                'ksw_0': (5000, 12000),
                't1w': (3.3, 4.2),
                't2w': (1.45, 1.9)}

subject_dicts = [
    {'scan_name': '23_12_08_glu_phantom_37deg',
     'sub_name': '1_glu_phantom_12_16_20mM_ph7_37deg',
     'tags': ['a', 'c', 'b'],
     'concs': [20, 12, 16],
     'phs': [7, 7, 7],
     'month': 'dec',
     'save_name': save_name,
     'temp': 37,
     'tag_x_locs': [0,-4,0],
     'tag_y_locs': [-15,15,15],
     'z_b1s': [0.7, 2, 4, 6],
     'z_b1s_names': ['0p7uT', '2uT', '4uT', '6uT']
     },
    # {'scan_name': '24_02_12_glu_phantom_vardeg',
    #  'sub_name': '2_glu_phantom_20_16_12mM_ph7_dec_37deg',
    #  'tags': ['b', 'a', 'c'],
    #  'concs': [16, 20, 12],
    #  'phs': [7, 7, 7],
    #  'month': 'dec',
    #  'save_name': save_name,
    #  'temp': 37,
    #  'tag_x_locs': [-2,1,-2],
    #  'tag_y_locs': [-15,-15,14],
    #  'z_b1s': [4],
    #  'z_b1s_names': ['4uT']
    #  },
        {'scan_name': '24_01_16_glu_phantom_37deg',
     'sub_name': '2_glu_phantom_5_10_15mM_ph7_37deg',
     'tags': ['b', 'c', 'a'],
     'concs': [10, 5, 15],
     'phs': [7, 7, 7],
     'month': 'jan',
     'save_name': save_name,
     'temp': 37,
     'tag_x_locs': [-2,3,-2],
     'tag_y_locs': [-15,14,14],
     'z_b1s': [4, 6],
     'z_b1s_names': ['4uT', '_6uT']
     },
        {'scan_name': '24_01_16_glu_phantom_37deg',
     'sub_name': '3_glu_phantom_20mM_ph6p8_7_7p2_37deg',
     'tags': ['b', 'a', 'c'],
     'concs': [20, 20, 20],
     'phs': [7, 7.2, 6.8],
     'month': 'jan',
     'save_name': save_name,
     'temp': 37,
     'tag_x_locs': [3, -3, 5],
     'tag_y_locs': [-15,15,15],
     'z_b1s': [4, 6],
     'z_b1s_names': ['4uT', '_6uT']
     },
    {'scan_name': '24_02_12_glu_phantom_vardeg',
     'sub_name': '5_glu_phantom_20_16_12mM_ph7_feb_37deg',
     'tags': ['c', 'a', 'b'],
     'concs': [12, 20, 16],
     'phs': [7, 7, 7],
     'month': 'feb',
     'save_name': save_name,
     'temp': 37,
     'tag_x_locs': [0,-3,0],
     'tag_y_locs': [-15,14,14],
     'z_b1s': [4],
     'z_b1s_names': ['4uT']
     },
    {'scan_name': '24_03_03_glu_phantom_37deg',
     'sub_name': '1_glu_phantom_6_9_12mM_ph7_37deg',
     'tags': ['b', 'a', 'c'],
     'concs': [9, 12, 6],
     'phs': [7, 7, 7],
     'month': 'mar',
     'save_name': save_name,
     'temp': 37,
     'tag_x_locs': [0,0,0],
     'tag_y_locs': [-15,15,15],
          'z_b1s': [0.7, 2, 4, 6],
     'z_b1s_names': ['0p7uT', '2uT', '4uT', '6uT']
     },
    {'scan_name': '24_03_03_glu_phantom_37deg',
     'sub_name': '2_glu_phantom_16mM_ph6p5_6p75_7_37deg',
     'tags': ['a', 'b', 'c'],
     'concs': [16, 16, 16],
     'phs': [7, 6.75, 6.5],
     'month': 'mar',
     'save_name': save_name,
     'temp': 37,
     'tag_x_locs': [2,-5,0],
     'tag_y_locs': [-14.5,-14,14],
     'z_b1s': [0.7, 2, 4, 6],
     'z_b1s_names': ['0p7uT', '2uT', '4uT', '6uT']
     }
 ]


## 1.2. Filter dict & Run Dot Product ##

In [14]:
# # load un-filtered dict
# for fp_prtcl_name in fp_prtcl_names:
#     dict_fn = os.path.join('exp', save_name, fp_prtcl_name, 'dict.csv')
#     synt_df = pd.read_csv(dict_fn, header=0)
# 
#     # filter dict using dict_range
#     f_dict_fn = os.path.join('exp', save_name, fp_prtcl_name, 'f_dict.csv')
#     df_masks = [synt_df[column].between(min_val, max_val) for column, (min_val, max_val) in dict_ranges.items()]
#     filtered_df = synt_df[np.all(df_masks, axis=0)]
#     filtered_df.to_csv(path_or_buf=f_dict_fn, index=False)  # To save as CSV without row indices
# 
#     for subject_i, subject_dict in enumerate(subject_dicts):
#         start = time.perf_counter()
#         print(f'################################# start of phantom {subject_i+1} #################################')
#         glu_phantom_fn = os.path.join(general_fn, 'scans', subject_dict['scan_name'],
#                                  subject_dict['sub_name'])
#         # val_range = subject_dict['dict_ranges']
# 
#         glu_phantom_dicom_fn, glu_phantom_mrf_files_fn, bruker_dataset = bruker_dataset_creator(glu_phantom_fn, txt_file_name, fp_prtcl_name)
#         glu_acquired_data = dicom_data_arranger(bruker_dataset, glu_phantom_dicom_fn)
# 
#         # create acquired data folder: root->scans->date->subject->E->mrf_files->acquired_data
#         acquired_data_fn = os.path.join(glu_phantom_mrf_files_fn, 'acquired_data.mat')
#         # make_folder(glu_phantom_mrf_files_fn)
#         # 
#         # # save acquired data to: root->scans->date->subject->E->mrf_files->acquired_data
#         # sio.savemat(acquired_data_fn, {'acquired_data': glu_acquired_data})
#         # print(f'Acquired data was saved as {glu_acquired_data.shape} sized array')
# 
#         # start = time.perf_counter()
#         # # load un-filtered dict
#         # dict_fn = os.path.join('exp', fp_prtcl_name, 'dict.csv')
#         # synt_df = pd.read_csv(dict_fn, header=0)
#         # 
#         # # filter dict using dict_range
#         # f_dict_fn = os.path.join('exp', fp_prtcl_name, 'f_dict.csv')
#         # df_masks = [synt_df[column].between(min_val, max_val) for column, (min_val, max_val) in val_range.items()]
#         # filtered_df = synt_df[np.all(df_masks, axis=0)]
#         # filtered_df.to_csv(path_or_buf=f_dict_fn, index=False)  # To save as CSV without row indices
# 
#         quant_maps = dot_prod_matching(dict_fn=f_dict_fn, acquired_data_fn=acquired_data_fn)        
#         # quant_maps = euc_dist_matching(dict_fn=f_dict_fn, acquired_data_fn=acquired_data_fn)
#         end = time.perf_counter()
#         s = (end - start)
#         print(f"Dot product matching took {s:.03f} s.")
# 
#         # save acquired data to: root->scans->date->subject->E->mrf_files->quant_maps.mat
#         quant_maps_fn = os.path.join(glu_phantom_mrf_files_fn, 'quant_maps.mat')
#         sio.savemat(quant_maps_fn, quant_maps)
#         quant_maps_fn = os.path.join('images', save_name, f'subject_{subject_i+1}', 'quant_maps.mat')
#         sio.savemat(quant_maps_fn, quant_maps)
#         print('quant_maps.mat saved')
# 
#         print(f'#################################  end of phantom {subject_i+1}  #################################')


## Loop Subjects ##

In [15]:
# for subject_i, subject_dict in enumerate(subject_dicts):
#     phantom_choice = subject_i+1
#     print(f'################################# start of subject {phantom_choice} #################################')
#     glu_phantom_fn = os.path.join(general_fn, 'scans', subject_dict['scan_name'],
#                               subject_dict['sub_name'])
# 
#     # mask
#     _, _, bruker_dataset_mask = bruker_dataset_creator(glu_phantom_fn, txt_file_name, '107a')  # always takes mask from 107a
#     vial_rois, full_mask, bg_mask = mask_roi_finder(bruker_dataset_mask)
# 
#     subject_dict['vial_rois'] = vial_rois
#     subject_dict['full_mask'] = full_mask
#     subject_dict['bg_mask'] = bg_mask
#     subject_dict['dict_ranges'] = dict_ranges
# 
#     # z-spectra
#     print(f'################################# z-spectra ################################# ')
#     z_mtr_plot(general_fn, txt_file_name, phantom_choice, subject_dict)
# 
#     # multi-B1 fitting
#     print(f'################################# multi-B1 fitting #################################')
#     b1_i_lim = [0, 4]
#     params = {
#     'tp': 3,  #[s]
#     'fb': 0,
#     'kb': 7500,
#     't1w': 4.2,
#     't2w': 1.8,
#     't1s': 1.2,
#     't2s': 0.007
#     }
#     subject_dict['params'] = params
#     sub_df = multi_b_fit_plot(general_fn, txt_file_name, phantom_choice, subject_dict, b1_i_lim)
# 
#     # t1, t2
#     print(f'################################# real t1 t2 #################################')
#     t1_pixels = t1_t2_pixel_reader(glu_phantom_fn=glu_phantom_fn, txt_file_name=txt_file_name, image_idx=3, t_type='t1')
#     t2_pixels = t1_t2_pixel_reader(glu_phantom_fn=glu_phantom_fn, txt_file_name=txt_file_name, image_idx=3, t_type='t2')
# 
#     real_t1_t2(t1_pixels, t2_pixels, phantom_choice, subject_dict)
#     real_t1_t2_bg(t1_pixels, t2_pixels, phantom_choice, subject_dict)
# 
#     for fp_prtcl_name in fp_prtcl_names:
#         print(f'################################# start of protocol {fp_prtcl_name} #################################')
#         glu_phantom_dicom_fn, glu_phantom_mrf_files_fn, bruker_dataset = bruker_dataset_creator(glu_phantom_fn, txt_file_name, fp_prtcl_name)
# 
#         # quant maps
#         quant_data_fn = os.path.join(glu_phantom_mrf_files_fn, 'quant_maps.mat')
#         quant_maps = sio.loadmat(quant_data_fn)
#         fs_array = quant_maps['fs'] * 110e3 / 3
#         ksw_array = quant_maps['ksw']
#         ed_array = quant_maps['dp']
#         t1w_array = quant_maps['t1w'] * 1000
#         t2w_array = quant_maps['t2w'] * 1000
# 
#         # mrf
#         dict_t1_t2(t1w_array, t2w_array, phantom_choice, fp_prtcl_name, subject_dict)
#         mrf_df = dict_fs_ksw(fs_array, ksw_array, ed_array, phantom_choice, fp_prtcl_name, subject_dict)
#         # mrf_df_l.append(mrf_df)
#         sub_df = pd.concat([sub_df, mrf_df], axis=1)
# 
#         # signal close look
#         subject_dict['quant_data_fn'] = quant_data_fn
#         subject_dict['glu_phantom_mrf_files_fn'] = glu_phantom_mrf_files_fn
# 
#         plot_norm_sig(fp_prtcl_name, phantom_choice, subject_dict)
# 
#     subject_dicts[subject_i]['sub_df'] = sub_df
#     print(f'################################# end of subject {phantom_choice} #################################')  


In [61]:
import plotly
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import plotly.io as pio
from my_funcs.new_plot_functions import vial_locator
import matplotlib.pyplot as plt

# Create custom Viridis colormap with black for 0 values
custom_viridis = np.array(plotly.colors.sequential.Viridis)
custom_viridis[0] = '#000000'  # Set black for 0 values

# Create custom Plasma colormap with black for 0 values
custom_plasma = np.array(plotly.colors.sequential.Plasma)
custom_plasma[0] = '#000000'  # Set black for 0 values

# Create custom Plasma colormap with black for 0 values
custom_plotly3 = np.array(plotly.colors.sequential.Plotly3)
custom_plotly3[0] = '#000000'  # Set black for 0 values

# Create subplots with 1 row and 3 columns, increased horizontal spacing
fig = make_subplots(rows=2, cols=6, horizontal_spacing=0.04, vertical_spacing=0.01,
                    subplot_titles=(
                    'December<br>20,16,12 mM<br>pH 7',
                    'January<br>15,10,5 mM<br>pH 7',
                    'January<br>20 mM<br>pH 7.2,7,6.8',
                    'February<br>20,16,12 mM<br>pH 7',
                    'March<br>12,9,6 mM<br>pH 7',
                    'March<br>16 mM<br>pH 7,6.75,6.5'))
fp_prtcl_name = '107a'

for subject_i, subject_dict in enumerate(subject_dicts):
    s = 26
    fs_flat = np.zeros([3*s, s])
    ksw_flat = np.zeros([3*s, s])
    phantom_choice = subject_i+1
    print(f'################################# start of subject {phantom_choice} #################################')
    glu_phantom_fn = os.path.join(general_fn, 'scans', subject_dict['scan_name'],
                              subject_dict['sub_name'])

    # mask
    _, _, bruker_dataset_mask = bruker_dataset_creator(glu_phantom_fn, txt_file_name, '107a')  # always takes mask from 107a
    vial_rois, full_mask, bg_mask = mask_roi_finder(bruker_dataset_mask)
    
    subject_dict['vial_rois'] = vial_rois
    subject_dict['full_mask'] = full_mask
    subject_dict['bg_mask'] = bg_mask
    subject_dict['dict_ranges'] = dict_ranges
    
    # plot
    full_mask = subject_dict['full_mask']
    vial_rois = subject_dict['vial_rois']
    tag = subject_dict['tags']
    tag_x_loc = subject_dict['tag_x_locs']
    tag_y_loc = subject_dict['tag_y_locs']
    date = subject_dict['month']
    save_name = subject_dict['save_name']
    temp = subject_dict['temp']
    conc_l = subject_dict['concs']
    ph_l = subject_dict['phs']
    f_lims = subject_dict['dict_ranges']['fs_0']
    k_lims = subject_dict['dict_ranges']['ksw_0']
    tag_id = [tag.index('a'), tag.index('b'), tag.index('c')]

    roi_masks, vial_loc = vial_locator(full_mask, vial_rois)
    
    quant_maps_fn = os.path.join('images', save_name, f'subject_{phantom_choice}', 'quant_maps.mat')
    quant_maps = sio.loadmat(quant_maps_fn)
    fs_array = quant_maps['fs'] * 110e3 / 3
    ksw_array = quant_maps['ksw']
    
    # Loop over masks
    for n, tag_i in enumerate(tag_id):
        mask = roi_masks[tag_i]
        # Calculate placement
        x_loc = int(np.round(vial_loc[tag_i, 1]))
        y_loc = int(np.round(vial_loc[tag_i, 0]))        
        masked_fs = fs_array*mask
        masked_ksw = ksw_array*mask
        
        fs_flat[s*n:s*(n+1), 0:s] = masked_fs[int(y_loc-(s/2)):int(y_loc+(s/2)), int(x_loc-(s/2)):int(x_loc+(s/2))]
        ksw_flat[s*n:s*(n+1), 0:s] = masked_ksw[int(y_loc-(s/2)):int(y_loc+(s/2)), int(x_loc-(s/2)):int(x_loc+(s/2))]
        
    # Add heatmaps for the three arrays
    heatmap_fs = go.Heatmap(z=fs_flat, colorscale=custom_viridis, coloraxis='coloraxis1')
    heatmap_ksw = go.Heatmap(z=ksw_flat, colorscale=custom_plasma, coloraxis='coloraxis2')
    fig.add_trace(heatmap_fs, row=1, col=phantom_choice)
    fig.add_trace(heatmap_ksw, row=2, col=phantom_choice)
    
    # Add individual titles and separate colorbars
    fig.update_xaxes(row=1, col=phantom_choice, showgrid=False, showticklabels=False)
    fig.update_xaxes(row=2, col=phantom_choice, showgrid=False, showticklabels=False)
    fig.update_yaxes(row=1, col=phantom_choice, showgrid=False, showticklabels=False, autorange='reversed')
    fig.update_yaxes(row=2, col=phantom_choice, showgrid=False, showticklabels=False, autorange='reversed')

fig.update_layout(
    template='plotly_white',  # Set the theme to plotly dark
    title_text=f"Dot-product with 2-norm normalization (37°C)",
    title_font=dict(size=24),
    showlegend=False,  # Hide legend
    height=700,
    width=900,  # Set a width based on your preference
    margin=dict(l=10, r=0, t=110, b=20),  # Adjust top and bottom margins
    title=dict(x=0.2, y=0.97),  # Adjust the title position
)

fig.update_yaxes(title_text="Glu [mM]", title_font=dict(size=18), row=1, col=1)
fig.update_yaxes(title_text="ksw [Hz]", title_font=dict(size=18), row=2, col=1)

# Manually add separate colorbars
f_const = 3 / 110000
colorbar_fs = {'colorscale': custom_viridis, 'cmin': 0, 'cmax': 20}
colorbar_ksw = {'colorscale': custom_plasma, 'cmin': k_lims[0], 'cmax': k_lims[1]}

fig.update_layout(
    coloraxis1=colorbar_fs,
    coloraxis2=colorbar_ksw,
    coloraxis_colorbar=dict(x=1.01, y=0.75, len=0.5),
    coloraxis2_colorbar=dict(x=1.01, y=0.23, len=0.5)
    )

# Show the plot
fig.show()

pio.write_image(fig, f'images/{save_name}/summary.jpeg')


################################# start of subject 1 #################################
################################# start of subject 2 #################################
################################# start of subject 3 #################################
################################# start of subject 4 #################################
################################# start of subject 5 #################################
################################# start of subject 6 #################################


## plot comparison ##

In [112]:
import plotly.express as px
from scipy import stats
from sklearn.metrics import mean_squared_error

def calculate_nrmse(y_true, y_pred):
    # Calculate RMSE
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    
    # Calculate the range of the observed data
    data_range = np.max(y_true) - np.min(y_true)
    
    # Calculate NRMSE
    nrmse = rmse / data_range
    
    return nrmse

fig = make_subplots(rows=1, cols=2, vertical_spacing=0.07, horizontal_spacing=0.15,
                    subplot_titles=['ksw [Hz]', 'Glu [mM]'])
    
complete_df = pd.DataFrame()
for subject_i, subject_dict in enumerate(subject_dicts):
    cur_df = subject_dict['sub_df']
    complete_df = pd.concat([complete_df, cur_df])

# Define the discrete colors
colors = px.colors.qualitative.Pastel[:5]

# Create a DataFrame for easier plotting
for pH_i, pH_value in enumerate(complete_df['pH'].unique()):
    filtered_df = complete_df[complete_df['pH'] == pH_value]
    fig.add_trace(go.Scatter(x=filtered_df.iloc[:, 2], y=filtered_df.iloc[:, 7], mode='markers', marker=dict(color=colors[pH_i], colorscale='Viridis'), name=f'pH {pH_value}', legendgroup=str(pH_value)), row=1, col=1)
    fig.add_trace(go.Scatter(x=filtered_df['fb_mean'], y=filtered_df['fs_mean'], mode='markers', marker=dict(color=colors[pH_i], colorscale='Viridis'), name=f'pH {pH_value}', legendgroup=str(pH_value)), row=1, col=2)

fig.add_trace(go.Scatter(x=[5000, 12100], 
                         y=[5000, 12100], 
                         mode='lines', 
                         line=dict(color='black', dash='dash'),
                         name='x=y line',
                         opacity=0.5), 
               row=1, col=1)
fig.add_trace(go.Scatter(x=[0, 26], 
                         y=[0, 26], 
                         mode='lines', 
                         line=dict(color='black', dash='dash'),
                         name='x=y line',
                         opacity=0.5), 
               row=1, col=2)

fig.update_xaxes(title_text=f'multi-B1 fitted', row=1, col=1)
fig.update_yaxes(title_text=f'CEST-MRF dot-matched', row=1, col=1)
fig.update_xaxes(title_text=f'real', row=1, col=2)
fig.update_yaxes(title_text=f'CEST-MRF dot-matched', row=1, col=2)
                                  
# Update layout
fig.update_layout(template='plotly_white',  # Set the theme to plotly white
                  title_text=f"'107a' - MRF dot-matching / multi-B1-fitting",
                  height=300, width=600+50,
                  title=dict(x=0.02, y=0.97),
                  margin=dict(l=65, r=0, t=50, b=0),
                  showlegend=True
                  )  # Adjust the title position
fig.update_yaxes(tickmode='linear', tick0=0, dtick=1000, range=[5000, 12100], row=1, col=1)
fig.update_yaxes(tickmode='linear', tick0=0, dtick=5, range=[0, 26], row=1, col=2)
fig.update_xaxes(tickmode='linear', tick0=0, dtick=1000, range=[5000, 12100], row=1, col=1)
fig.update_xaxes(tickmode='linear', tick0=0, dtick=5, range=[0, 26], row=1, col=2)

# Calculate Pearson correlation coefficient (r)
r_k, p_val_k = stats.pearsonr(complete_df.iloc[:, 2], complete_df.iloc[:, 7])
nrmse_k = calculate_nrmse(complete_df.iloc[:, 2], complete_df.iloc[:, 7])

fig.add_annotation(
    x=8000,
    y=np.max(complete_df.iloc[:, 7])-500,
    text=f"Pearson r: {r_k:.2f}, p_val: {p_val_k:.2e}",
    showarrow=False,
    font=dict(color="black"),
    row=1,
    col=1
)
fig.add_annotation(
    x=6500,
    y=np.max(complete_df.iloc[:, 7])-1000,
    text=f"nrmse: {nrmse_k*100:.2f}%",
    showarrow=False,
    font=dict(color="black"),
    row=1,
    col=1
)

# Calculate Pearson correlation coefficient (r)
r_f, p_val_f = stats.pearsonr(complete_df['fb_mean'], complete_df['fs_mean'])
nrmse_f = calculate_nrmse(complete_df['fb_mean'], complete_df['fs_mean'])

fig.add_annotation(
    x=11,
    y=23,
    text=f"Pearson r: {r_f:.2f}, p_val: {p_val_f:.2e}",
    showarrow=False,
    font=dict(color="black"),
    row=1,
    col=2
)
fig.add_annotation(
    x=5.5,
    y=21,
    text=f"nrmse: {nrmse_f*100:.2f}%",
    showarrow=False,
    font=dict(color="black"),
    row=1,
    col=2
)


# only show 2 first traces in legend!
for trace_i, trace in enumerate(fig['data']):
    if trace_i%2:
        trace['showlegend'] = False
            
# Show plot
fig.show()
pio.write_image(fig, f'images/{save_name}/ksw_fs.jpeg')
