In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nibabel as nib
import torch
import sys

from nimosef.data.dataset import NiftiDataset
from nimosef.utils.visualization import get_imgs_experiment

In [None]:
data_path="/media/jaume/DATA/Data/Test_NIMOSEF_Dataset"
splits_filename=f"{data_path}/derivatives/manifests_nimosef/dataset_manifest.json"

mode="train"
results_folder = f"{data_path}/derivatives/nimosef_results"
save_folder_results = os.path.join(results_folder, f"results_{mode}_comparison")
os.makedirs(save_folder_results, exist_ok=True)

In [None]:
dataset = NiftiDataset(splits_filename, mode=mode)
subjects = dataset.patients
subj_id = subjects[0]

im_gt, im_pred, seg_gt, seg_pred, dvol  = get_imgs_experiment(results_folder, subj_id)

In [None]:
overlay = True
t_list = [0,10,20,30,40]
z_list = [3,5,7]

# Concatenate along the time axis (0 -> vertical stacking)
gt_seg_img_t = np.concatenate([seg_gt[..., t] for t in t_list], axis=1)
pred_seg_img_t = np.concatenate([seg_pred[..., t] for t in t_list], axis=1)

# Concatenate along the depth axis (1 -> horizontal stacking)
gt_seg_img = np.concatenate([gt_seg_img_t[..., z] for z in z_list], axis=0)
pred_seg_img = np.concatenate([pred_seg_img_t[..., z] for z in z_list], axis=0)

# Now the intensity images
im_gt_img_t = np.concatenate([im_gt[..., t] for t in t_list], axis=1)
im_gt_plot = np.concatenate([im_gt_img_t[..., z] for z in z_list], axis=0)

im_pred_im_t = np.concatenate([im_pred[..., t] for t in t_list], axis=1)
im_pred_plot = np.concatenate([im_pred_im_t[..., z] for z in z_list], axis=0)

# Path of gt
save_gt_path = os.path.join(save_folder_results, f"train_{subj_id}_gt_plot.png")
save_pred_path = os.path.join(save_folder_results, f"train_{subj_id}_baseline_plot.png")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 20))
ax.imshow(im_gt_plot, cmap='gray', origin='lower', interpolation='none')
ax.imshow(gt_seg_img, cmap='jet', origin='lower', interpolation='none', alpha=0.25)
ax.set_title('GT Segmentation')
fig.tight_layout()
fig.savefig(save_gt_path, dpi=300, bbox_inches='tight', pad_inches=0)

fig, ax = plt.subplots(1, 1, figsize=(20, 20))
ax.imshow(im_pred_plot, cmap='gray', origin='lower', interpolation='none')
ax.imshow(pred_seg_img, cmap='jet', origin='lower', interpolation='none', alpha=0.25)
ax.set_title('Predicted Segmentation')
fig.tight_layout()
fig.savefig(save_pred_path, dpi=300, bbox_inches='tight', pad_inches=0)