In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu
import pandas as pd
import sys
sys.path.append("/home/ws/oc9627/cerebral_organoid_quant_mri")

from scripts.utils.constants import MRI_ORG_LOC_GT_DIR, MRI_CYST_LOC_GT_DIR
from scripts.utils.segmentation_performance_extractor import SegmentationPerformanceExtractor
from scripts.utils.segmentation_overlay_visualizer import SegmentationOverlayVisualizer
from scripts.utils.io_utils import get_orig_imgs, get_masks
from scripts.utils.global_cyst_classification import get_compactness
from scripts.utils.metrics import calculate_roc_auc
from scripts.utils.dw_mri import get_metrics_global_cyst_seg_dw_mri

from scripts.utils.visualization import plot_compactness, plot_examples_lq_hq_organoids, plot_trace_lq_hq_mean_org_int, plot_correlation_compactness_cysticity, plot_organoid_growth_over_time
from scripts.utils.global_cyst_classification import get_all_otsu_masks, get_org_mean_and_compactness
from scripts.utils.constants import CSV_GLOBAL_CYST_ANNOT

wdir = '/home/ws/oc9627/cerebral_organoid_quant_mri'
os.chdir(wdir)

# Organoid segmentation

#### extract results on test set

In [None]:
pred_dir = 'results/organoid_segmentation/checkpoint_dirs_trained_previously/all_predictions_on_test_sets'
gt_dir = MRI_ORG_LOC_GT_DIR

In [None]:
spe = SegmentationPerformanceExtractor(pred_dir = pred_dir,
                                       gt_dir = gt_dir)
spe.extract_test_performance()
spe.plot_test_performance(save_to='results/organoid_segmentation/plots/organoid_seg_performance.png')
spe.print_test_dice_mean_sd()

In [None]:
spe = SegmentationPerformanceExtractor(pred_dir = pred_dir,
                                       gt_dir = gt_dir)
spe.extract_test_performance()
spe.plot_test_performance(save_to='results/organoid_segmentation/plots/organoid_seg_performance.svg')
spe.print_test_dice_mean_sd()

#### plot weakest, intermediate and best prediction

In [None]:
imgs = get_orig_imgs(kind='mri')

In [None]:
# weakest prediction
for pred_color in ['green', 'orange']:
    sov = SegmentationOverlayVisualizer(imgs,
                                    org_id = 'org3_0609', 
                                    org_loc = [15, 65, 65, 108], 
                                    gt_dir = MRI_ORG_LOC_GT_DIR,
                                    pred_dir = pred_dir,
                                    px_threshold=0,
                                    planes='sagittal',
                                    keep_planes=[44, 47, 50],
                                    rot_img=270,
                                    pred_color=pred_color,
                                    save_to=f'results/organoid_segmentation/plots/organoid_seg_overlay_org3_0609_pred-{pred_color}.png')
    sov.plot_segmentation_overlays()
display(sov.plot_segmentation_overlays())

In [None]:
# intermediate prediction
for pred_color in ['green', 'orange']:
    sov = SegmentationOverlayVisualizer(imgs,
                                 org_id = 'org2_0615', 
                                 org_loc = [20, 75, 30, 75], 
                                 gt_dir = MRI_ORG_LOC_GT_DIR,
                                 pred_dir = pred_dir,
                                 px_threshold=0,
                                 planes='sagittal',
                                 keep_planes=[40, 50, 58],
                                 rot_img=270,
                                 pred_color=pred_color,
                                 save_to=f'results/organoid_segmentation/plots/organoid_seg_overlay_org2_0615_pred-{pred_color}.png')
    sov.plot_segmentation_overlays()
display(sov.plot_segmentation_overlays())

In [None]:
# strongest prediction
for pred_color in ['green', 'orange']:
    sov = SegmentationOverlayVisualizer(imgs,
                                     org_id = 'org5_0530', 
                                     org_loc = [20, 65, 38, 75], 
                                     gt_dir = MRI_ORG_LOC_GT_DIR,
                                     pred_dir = pred_dir,
                                     px_threshold=0,
                                     planes='sagittal',
                                     keep_planes=[34, 40, 52],
                                     rot_img=270,
                                     pred_color=pred_color,
                                     save_to=f'results/organoid_segmentation/plots/organoid_seg_overlay_org5_0530_pred-{pred_color}.png')
img = sov.plot_segmentation_overlays()
display(img)

# Global cyst classification

#### Calculate mean organoid intensities and compactness

In [None]:
df = get_compactness()

#### Calculate ROC AUC and visualize class separation

In [None]:
roc_auc_moi = calculate_roc_auc(df, col='mean_organoid_intensity')
roc_auc_compactness = calculate_roc_auc(df, col='compactness')

print(f'ROC AUC for mean organoid intensity:\t{roc_auc_moi}')
print(f'ROC AUC for Compactness:\t\t{roc_auc_compactness}')

plot_compactness(df, 
                 roc_auc_compactness, 
                 save_to='results/global_cyst_classification/compactness_separates_lq_hq_organoids.svg')

#### Visualize examples of LQ and HQ organoids

In [None]:
org_ids = ['org6_0518', 'org2_0518', 'org1_0615', 'org4_0609']
compactnesses = [31.3, 34.3, 14.7, 11.2]
# two selected sagittal planes for each organoid
# other ranges to crop images to EV boundaries for better visibility (manually extracted) 
org_locs_planes = [(['88-117', '59', '32-67'], ['88-117', '60', '32-67']),
                   (['40-69', '41', '40-75'], ['40-69', '45', '40-75']),
                   (['5-55', '58', '15-90'], ['5-55', '61', '15-90']),
                   (['17-67', '36', '15-85'], ['17-67', '53', '15-85'])]
plot_examples_lq_hq_organoids(org_ids, 
                              compactnesses,
                              org_locs_planes,
                              save_to='results/global_cyst_classification/examples_lq_hq_organoids.png')

#### DW-MRI to probe tissue characteristics

In [None]:
plot_trace_lq_hq_mean_org_int(save_to='results/global_cyst_classification/trace_mean_org_intensities.svg')

# Local cyst classification

#### extract results

In [None]:
pred_dir_local_cysts = 'results/local_cyst_segmentation/checkpoint_dirs_trained_previously/all_predictions_on_test_sets'
gt_dir = MRI_CYST_LOC_GT_DIR

In [None]:
spe = SegmentationPerformanceExtractor(pred_dir = pred_dir_local_cysts,
                                       gt_dir = gt_dir)
spe.extract_test_performance()
spe.plot_test_performance(save_to='results/local_cyst_segmentation/plots/local_cyst_seg_performance.svg')
spe.print_test_dice_mean_sd()

#### visualize weakest, intermediate and strongest prediction

In [None]:
imgs = get_orig_imgs(kind='mri')

In [None]:
# weakest prediction
sov = SegmentationOverlayVisualizer(imgs,
                 org_id = 'org1_0615', 
                 org_loc = [10, 75, 10, 60],
                 gt_dir = gt_dir,
                 pred_dir = pred_dir_local_cysts,
                 planes='sagittal',
                 keep_planes=[51, 55, 60],
                 rot_img=270,
                 pred_color='orange',
                 save_to='results/local_cyst_segmentation/plots/cyst_seg_overlay_org1_0615_pred-orange.png')
display(sov.plot_segmentation_overlays())

In [None]:
# intermediate prediction
sov = SegmentationOverlayVisualizer(imgs,
                 org_id = 'org4_0609', 
                 org_loc=[20, 70, 25, 65], 
                 gt_dir = gt_dir,
                 pred_dir = pred_dir_local_cysts,
                 planes='sagittal',
                 px_threshold=0,
                 keep_planes=[42, 49, 52],
                 pred_color='orange',
                 rot_img=270,
                 save_to='results/local_cyst_segmentation/plots/cyst_seg_overlay_org4_0609_pred-orange.png')
display(sov.plot_segmentation_overlays())

In [None]:
# strongest prediction
sov = SegmentationOverlayVisualizer(imgs,
                 org_id = 'org7_0530', 
                 org_loc=[22, 65, 20, 55], 
                 gt_dir = MRI_CYST_LOC_GT_DIR,
                 pred_dir = pred_dir_local_cysts,
                 planes='sagittal',
                 px_threshold=0,
                 keep_planes=[49, 56, 63],
                 rot_img=270,
                 pred_color='orange',
                 save_to='results/local_cyst_segmentation/plots/cyst_seg_overlay_org7_0530_pred-orange.png')
display(sov.plot_segmentation_overlays())

# Supplementary

### organoid segmentation

#### plot organoid growth over time

In [None]:
plot_organoid_growth_over_time()
plt.savefig('results/supplement/organoid_growth_over_time.svg')

### global cyst classification

#### correlation of compactness with cysticity from local cyst segmentations

In [None]:
plot_correlation_compactness_cysticity(save_to='results/local_cyst_segmentation/plots/correlation_compactness_cysticity.png')

#### values for predicted organoid locations (3D U-Net)

In [None]:
imgs = get_orig_imgs(kind='mri')
gt = get_masks(kind='predicted_org_loc')
otsu, _ = get_all_otsu_masks(imgs, gt)

df = get_org_mean_and_compactness(imgs, 
                                  org_loc=gt, 
                                  otsu=otsu)
quality_annot = pd.read_csv(CSV_GLOBAL_CYST_ANNOT)
df = df.merge(quality_annot)

roc_auc_moi = calculate_roc_auc(df, col='mean_organoid_intensity')
roc_auc_compactness = calculate_roc_auc(df, col='compactness')

print(f'ROC AUC for mean organoid intensity:\t{roc_auc_moi}')
print(f'ROC AUC for Compactness:\t\t{roc_auc_compactness}')

plot_compactness(df, roc_auc_compactness)

### ROC AUCs for DW-MRI

In [None]:
df_dwmri_metrics, _ = get_metrics_global_cyst_seg_dw_mri()
df_dwmri_metrics['ROC AUC'] = df_dwmri_metrics['ROC AUC'].map('{:,.2f}'.format)
df_dwmri_metrics['P-value'] = df_dwmri_metrics['P-value'].map('{:,.1E}'.format)
df_dwmri_metrics['P-value corr'] = df_dwmri_metrics['P-value corr'].map('{:,.1E}'.format)
df_dwmri_metrics
# df_dwmri_metrics.to_csv('metrics_global_cyst_seg_dwmri.csv', index=False, sep=';')

### show for one plane how automatic ET voxel determination works

In [None]:
imgs = get_orig_imgs(kind='mri')
gts = get_masks(kind='gt_org_loc')

cor_plane = 60
img = imgs['org1_0518'][0:50, 10:58, cor_plane]
gt = gts['org1_0518'][0:50, 10:58, cor_plane]

cmap = plt.cm.get_cmap('Greys').reversed()
fig, axs = plt.subplots(1, 5, figsize=(15, 3))
# otsu, _ = get_all_otsu_masks(imgs, gt)
axs[0].imshow(img, cmap=cmap)
otsu_seg = (img > threshold_otsu(img)).astype('bool')
axs[1].imshow(otsu_seg, cmap=cmap)
axs[2].imshow(gt, cmap=cmap)
# axs[1].imshow(img > threshold_otsu(img), cmap=cmap)
axs[3].imshow(np.where(otsu_seg & ~gt.astype('bool'), 1, 0), cmap=cmap)
axs[4].imshow(np.where(otsu_seg & ~gt.astype('bool'), img, 0), cmap=cmap)

axs[0].set_title
for i in range(5):
    axs[i].axis('off')