## Objective
Here, we inspect the denoiser performance. we use the stored prediction files to do that.

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import os
DEBUG=False
%run ./nb_core/root_dirs.ipynb
setup_syspath_disentangle(DEBUG)
%run ./nb_core/disentangle_imports.ipynb

In [None]:
from denoisplit.scripts.evaluate import * 
from denoisplit.config_utils import get_configdir_from_saved_predictionfile, load_config
from denoisplit.core.data_split_type import DataSplitType
from denoisplit.core.tiff_reader import load_tiff
from denoisplit.core.data_split_type import get_datasplit_tuples
import ml_collections



# data_dir = '/group/jug/ashesh/data/paper_stats/All_P128_G64_M50_Sk44/'
data_dir = '/group/jug/ashesh/data/paper_stats/All_P128_G64_M50_Sk32'
# data_dir = '/group/jug/ashesh/data/paper_stats/All_P128_G64_M50_Sk0'
denoiser_prediction_fname = "pred_disentangle_2402_D3-M23-S0-L0_11.tif"
channel_idx = 0

# get the prediction. 
pred = load_tiff(os.path.join(data_dir, denoiser_prediction_fname))
_, _ , test_idx = get_datasplit_tuples(0.1, 0.1, pred.shape[0], starting_test = False)
test_pred = pred[test_idx]
denoiser_configdir = get_configdir_from_saved_predictionfile(denoiser_prediction_fname)
print(denoiser_configdir)

# get the highres data
denoiser_config = load_config(denoiser_configdir)
denoiser_config = ml_collections.ConfigDict(denoiser_config)
if denoiser_config.data.data_type == DataType.BioSR_MRC:
    denoiser_input_dir = '/group/jug/ashesh/data/BioSR/'
elif denoiser_config.data.data_type == DataType.OptiMEM100_014:
    denoiser_input_dir = '/group/jug/ashesh/data/microscopy/OptiMEM100x014.tif'
elif denoiser_config.data.data_type == DataType.SeparateTiffData:
    denoiser_input_dir = '/group/jug/ashesh/data/ventura_gigascience/'
    denoiser_config.data.ch1_fname = denoiser_config.data.ch1_fname.replace('lowsnr', 'highsnr')
    denoiser_config.data.ch2_fname = denoiser_config.data.ch2_fname.replace('lowsnr', 'highsnr')
with denoiser_config.unlocked():
    highres_data = get_data_without_synthetic_noise(denoiser_input_dir, denoiser_config, DataSplitType.Test)

h, w = pred.shape[1:3]
highres_data = highres_data[:, :h, :w]
highres_data = highres_data[..., channel_idx].copy()

In [None]:
_,ax = plt.subplots(figsize=(8,4),ncols=2)
ax[0].imshow(test_pred[-1])
ax[1].imshow(highres_data[-1,...])

In [None]:
from denoisplit.core.psnr import RangeInvariantPsnr
print(f'PSNR: {RangeInvariantPsnr(highres_data.astype(np.float32), test_pred.astype(np.float32)).mean().item():.2f}')

In [None]:
hdn_psnr_dict = {
    "2402/D16-M23-S0-L0/93": "39.230",
    "2402/D16-M23-S0-L0/88": "43.930",
    "2402/D16-M23-S0-L0/94": "37.86",
    "2402/D16-M23-S0-L0/89": "42.1",
    "2402/D16-M23-S0-L0/95": "36.68",
    "2402/D16-M23-S0-L0/87": "40.66",
    "2402/D16-M23-S0-L0/92": "33.38",
    "2402/D16-M23-S0-L0/90": "29.39",
    "2402/D16-M23-S0-L0/104": "38.320",
    "2402/D16-M23-S0-L0/96": "36.48",
    "2402/D16-M23-S0-L0/105": "36.78",
    "2402/D16-M23-S0-L0/97":   "34.92",
    "2402/D16-M23-S0-L0/106": "35.43",
    "2402/D16-M23-S0-L0/98": "33.8",
    "2402/D16-M23-S0-L0/107": "31.81",
    "2402/D16-M23-S0-L0/99": "30.32",
    "2402/D16-M23-S0-L0/114": "44.13",
    "2402/D16-M23-S0-L0/101": "37.3",
    "2402/D16-M23-S0-L0/113": "42.21",
    "2402/D16-M23-S0-L0/100": "36.37",
    "2402/D16-M23-S0-L0/117": "40.91",
    "2402/D16-M23-S0-L0/103": "35.18",
    "2402/D16-M23-S0-L0/120": "29.390",
    "2402/D16-M23-S0-L0/102": "32.03",
}

In [None]:
import json
import os
from denoisplit.config_utils import load_config
dir = '/home/ashesh.ashesh/training/disentangle/'
class ConfigInfo:
    def __init__(self, config_path) -> None:
        self._config_path = config_path
        self.cfg = self.get_config_from_path(config_path)

    def get_config_from_path(self, config_path):
        config_fpath = os.path.join(dir, config_path)
        return load_config(config_fpath)

    def get_noise_level(self):
        return self.cfg.data.synthetic_gaussian_scale, self.cfg.data.poisson_noise_factor
    
    def get_channel(self):
        if 'denoise_channel' in self.cfg and self.cfg.model.denoise_channel == 'Ch1':
            return self.cfg.data.ch1_fname
        elif 'denoise_channel' in self.cfg and self.cfg.model.denoise_channel == 'Ch2':
            return self.cfg.data.ch2_fname
        else:
            return [self.cfg.data.ch1_fname, self.cfg.data.ch2_fname]


In [None]:
import pandas as pd
hdn_df = pd.DataFrame([], columns=['Gaus', 'Pois', 'Ch', 'PSNR'])
for key, val in hdn_psnr_dict.items():
    config = ConfigInfo(key)
    hdn_df.loc[key] = [config.get_noise_level()[0], config.get_noise_level()[1], config.get_channel(), float(val)]
    # print(f'{key}: {val} - {config.get_noise_level()} - {config.get_channel()}')

In [None]:
hdn_df[hdn_df.Ch=='ER/GT_all.mrc'].sort_values('Gaus')['PSNR'].plot(marker='o', linestyle='-', label='ER/GT_all.mrc')

In [None]:
denoisplit_dict = {
    "2402/D16-M3-S0-L0/149": "[36.79, 38.93]",
    "2402/D16-M3-S0-L0/143": "[35.36, 37.24]",
    "2402/D16-M3-S0-L0/151": "[33.96, 36.1]",
    "2402/D16-M3-S0-L0/153": "[30.47, 31.92]",
    "2402/D16-M3-S0-L0/150":"[30.2, 29.77]",
    "2402/D16-M3-S0-L0/144":"[29.2, 28.71]",
    "2402/D16-M3-S0-L0/152": "[27.42, 26.65]",
    "2402/D16-M3-S0-L0/155": "[25.19, 24.49]",
    "2402/D16-M3-S0-L0/154": "[39.9, 36.36]",
    "2402/D16-M3-S0-L0/145": "[38.44, 34.85]",
    "2402/D16-M3-S0-L0/156": "[36.82, 33.51]",
    "2402/D16-M3-S0-L0/157": "[32.24, 29.07]"

}
df_denoisplit = pd.DataFrame([], columns=['Gaus', 'Pois', 'Ch', 'PSNR'])
for key, val in denoisplit_dict.items():
    config = ConfigInfo(key)
    val = json.loads(val)
    for ch_idx in [0,1]:
        k = f'{key}_Ch{ch_idx}'
        df_denoisplit.loc[k] = [config.get_noise_level()[0], config.get_noise_level()[1], config.get_channel()[ch_idx], val[ch_idx]]
    # print(f'{key}: {val} - {config.get_noise_level()} - {config.get_channel()}')

In [None]:
df_denoisplit = df_denoisplit.set_index(['Gaus','Pois','Ch'])
df_hdn = hdn_df.set_index(['Gaus','Pois','Ch'])

In [None]:
df = pd.merge(df_denoisplit, df_hdn, left_index=True, right_index=True, suffixes=('_denoisplit', '_hdn'))
df = df.reset_index()
df.Ch = df.Ch.map(lambda x: x.replace('GT_all.mrc','').replace('/',''))

df.head()

In [None]:
df[df.Ch=='ER'].sort_values('Gaus')[['Gaus', 'PSNR_denoisplit', 'PSNR_hdn']]

In [None]:
df[df.Ch=='ER'][df.Gaus.isin([3400, 5100, 6800, 13600])][['PSNR_denoisplit', 'PSNR_hdn']].plot()

In [None]:
df[df.Ch=='ER/GT_all.mrc'][df.Gaus.isin([4450, 6675,8900,17800])][['PSNR_denoisplit', 'PSNR_hdn']].plot()

In [None]:
df[df.Ch=='Microtubules'].sort_values('Gaus')[['Gaus', 'PSNR_denoisplit', 'PSNR_hdn']]

In [None]:
df[df.Ch=='Microtubules'][df.Gaus.isin([4450, 6675,8900,17800])][['PSNR_denoisplit', 'PSNR_hdn']].plot()

In [None]:
df[df.Ch=='Microtubules'][df.Gaus.isin([3150, 4725,6300,12600])][['PSNR_denoisplit', 'PSNR_hdn']].plot()

In [None]:
df[df.Ch=='CCPs'].sort_values('Gaus')[['Gaus', 'PSNR_denoisplit', 'PSNR_hdn']]

In [None]:
df[df.Ch=='CCPs'][df.Gaus.isin([3150, 4725,6300,12600])][['PSNR_denoisplit', 'PSNR_hdn']].plot()

In [None]:
df[df.Ch=='CCPs'][df.Gaus.isin([3400, 5100, 6800, 13600])][['PSNR_denoisplit', 'PSNR_hdn']].plot(linestyle='-', marker='o')

In [None]:
df[df.Ch == 'ER']

In [None]:
import matplotlib.pyplot as plt
params = {'mathtext.default': 'regular' }          
plt.rcParams.update(params)

_,ax = plt.subplots(figsize=(12,3),ncols=3)
# ER
df[df.Ch == 'ER'].sort_values('Gaus').plot(x='Gaus', y='PSNR_hdn', ax=ax[0], linestyle='-', marker='*', label='HDN')
df[df.Ch=='ER'][df.Gaus.isin([4450, 6675,8900,17800])].plot(x='Gaus', y='PSNR_denoisplit', ax=ax[0], linestyle='-', marker='^', label='ER vs MT')
df[df.Ch=='ER'][df.Gaus.isin([3400, 5100,6800,13600])].plot(x='Gaus', y='PSNR_denoisplit', ax=ax[0], linestyle='-', marker='^', label='CCPs vs ER')

# Microtubules
df[df.Ch == 'Microtubules'].sort_values('Gaus').plot(x='Gaus', y='PSNR_hdn', ax=ax[1], linestyle='-', marker='*', label='HDN')
df[df.Ch=='Microtubules'][df.Gaus.isin([4450, 6675,8900,17800])].plot(x='Gaus', y='PSNR_denoisplit', ax=ax[1], linestyle='-', marker='^', label='ER vs MT')
df[df.Ch=='Microtubules'][df.Gaus.isin([3150, 4725,6300,12600])].plot(x='Gaus', y='PSNR_denoisplit', ax=ax[1], linestyle='-', marker='^', label='CCPs vs MT')

# CCPs
df[df.Ch == 'CCPs'].sort_values('Gaus').plot(x='Gaus', y='PSNR_hdn', ax=ax[2], linestyle='-', marker='*', label='HDN')
df[df.Ch=='CCPs'][df.Gaus.isin([3150, 4725,6300,12600])].plot(x='Gaus', y='PSNR_denoisplit', ax=ax[2], linestyle='-', marker='^', label='CCPs vs MT')
df[df.Ch=='CCPs'][df.Gaus.isin([3400, 5100,6800,13600])].plot(x='Gaus', y='PSNR_denoisplit', ax=ax[2], linestyle='-', marker='^', label='CCPs vs ER')
ax[2].legend(loc='upper right')

ax[0].set_xlabel(f'$Gaussian\ \sigma$')
ax[1].set_xlabel(f'$Gaussian\ \sigma$')
ax[2].set_xlabel(f'$Gaussian\ \sigma$')
ax[0].set_ylabel(f'PSNR')

ax[0].set_ylim(24,44.7)
ax[1].set_ylim(24,44.7)
ax[2].set_ylim(24,44.7)

# ax[0].set_xlim(3000, 18000)
# ax[1].set_xlim(3000, 18000)
# ax[2].set_xlim(3000, 18000)

ax[1].set_yticklabels([])
ax[2].set_yticklabels([])

ax[0].set_title('ER')
ax[1].set_title('Microtubules')
ax[2].set_title('CCPs')

ax[0].yaxis.grid(color='gray', linestyle='dashed')
ax[0].xaxis.grid(color='gray', linestyle='dashed')
ax[0].set_facecolor('xkcd:light grey')

ax[1].yaxis.grid(color='gray', linestyle='dashed')
ax[1].xaxis.grid(color='gray', linestyle='dashed')
ax[1].set_facecolor('xkcd:light grey')

ax[2].yaxis.grid(color='gray', linestyle='dashed')
ax[2].xaxis.grid(color='gray', linestyle='dashed')
ax[2].set_facecolor('xkcd:light grey')
paper_figures_dir = '/group/jug/ashesh/data/paper_figures'
fpath = os.path.join(paper_figures_dir, 'hdn_denoisplit_comparison.png')
plt.savefig(fpath, dpi=200, bbox_inches='tight')
print('Saved to:', fpath)
