In [None]:
# Run the notebook as if it's in the PROJECT directory
%bookmark PROJ_ROOT /reg/data/ana03/scratch/cwang31/spi
%cd -b PROJ_ROOT

In [None]:
import numpy as np
import pickle
import os
import torch
from deepprojection.utils import ConfusionMatrix

In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import matplotlib.transforms as transforms
%matplotlib inline

In [None]:
def config_fonts():
    # Where to load external font...
    drc_py    = '.'
    drc_font  = os.path.join("fonts", "Helvetica")
    fl_ttf    = f"Helvetica.ttf"
    path_font = os.path.join(drc_py, drc_font, fl_ttf)
    prop_font = font_manager.FontProperties( fname = path_font )

    # Add Font and configure font properties
    font_manager.fontManager.addfont(path_font)
    prop_font = font_manager.FontProperties(fname = path_font)

    # Specify fonts for pyplot...
    plt.rcParams['font.family'] = prop_font.get_name()
    plt.rcParams['font.size']   = 14

    return None

config_fonts()

### Ensure no common PDB entries in training and test set

In [None]:
path_dat = 'skopi/h5s_mini.sq.train.dat'
pdb_train_list = open(path_dat).readlines()
pdb_train_list = [ pdb.strip() for pdb in pdb_train_list ]

path_dat = 'skopi/h5s_mini.sq.test.corrected.dat'
pdb_test_list = open(path_dat).readlines()
pdb_test_list = [ pdb.strip() for pdb in pdb_test_list ]

set(pdb_train_list).intersection(set(pdb_test_list))

### Load input and collect performance data for each scenario

In [None]:
scenario_to_pdb_to_perf_dict = {}

In [None]:
# # Applied shot to shot fluc before noise...
# fl_pickle_dict = {
#     '01-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_1.min.pickle',
#     '05-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_5.min.pickle',
#     '10-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_10.min.pickle',
#     '15-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_15.min.pickle',
#     '20-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_20.min.pickle',
# }

# # Applied shot to shot fluc before noise...
# fl_pickle_dict = {
#     '01-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_1.mean_dist.pickle',
#     '05-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_5.mean_dist.pickle',
#     '10-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_10.mean_dist.pickle',
#     '15-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_15.mean_dist.pickle',
#     '20-shot' : 'confusion_matrix.2023_0101_0856_44.epoch_71.seed_0.support_20.mean_dist.pickle',
# }


# Applied shot to shot fluc before noise...
fl_pickle_dict = {
    '01-shot' : '2023_0101_0856_44.epoch_71.seed_0.support_1.mean_dist.corrected.pickle',
    '05-shot' : '2023_0101_0856_44.epoch_71.seed_0.support_5.mean_dist.corrected.pickle',
    '10-shot' : '2023_0101_0856_44.epoch_71.seed_0.support_10.mean_dist.corrected.pickle',
    '15-shot' : '2023_0101_0856_44.epoch_71.seed_0.support_15.mean_dist.corrected.pickle',
    '20-shot' : '2023_0101_0856_44.epoch_71.seed_0.support_20.mean_dist.corrected.pickle',
}

drc_pickle = 'confusion_matrix'
for scenario, fl_pickle in fl_pickle_dict.items():
    path_pickle = os.path.join(drc_pickle, fl_pickle)
    with open(path_pickle, 'rb') as handle:
        pdb_to_perf_dict = pickle.load(handle)
    
    scenario_to_pdb_to_perf_dict[scenario] = pdb_to_perf_dict

In [None]:
scenario_to_acc_dict = {}
scenario_to_f1_dict  = {}
for scenario, pdb_to_perf_dict in scenario_to_pdb_to_perf_dict.items():
    acc_list = []
    f1_list  = []
    for pdb in pdb_to_perf_dict.keys():
        photon_list = [ perf[0] for perf in pdb_to_perf_dict[pdb] ]
        cm_list = [ ConfusionMatrix(perf[1]).get_metrics(1) for perf in pdb_to_perf_dict[pdb] ]

        x = photon_list
        acc = [ cm[0] for cm in cm_list ]
        acc_list.append(acc)
        
        f1  = [ cm[4] for cm in cm_list ]
        f1_list.append(f1)
    
    scenario_to_acc_dict[scenario] = acc_list
    scenario_to_f1_dict[scenario] = f1_list

In [None]:
len(scenario_to_acc_dict['01-shot'])

In [None]:
# Define color schemes...
rgb2hex = lambda r, g, b: '#%02x%02x%02x' % (r, g, b)

scenario_to_color_dict = {
    '01-shot' : rgb2hex( 78, 129, 183),
    '05-shot' : rgb2hex(244, 148,  69),
    '10-shot' : rgb2hex( 90, 164,  78),
    '15-shot' : rgb2hex(195,  55,  43),
    '20-shot' : rgb2hex(135,  97, 179),
}

In [None]:
nrows, ncols = 2, 1
h, w = 6.1, 8.5
fig = plt.figure(figsize = (w, h))

gspec = fig.add_gridspec( nrows, ncols,
                          width_ratios  = [1],
                          height_ratios = [1, 1],
                        )
ax_list = [ fig.add_subplot(gspec[i, 0]) for i in range(nrows) ]

# Upper panel - Accuracy...
ax = ax_list[0]
for scenario, acc_list in scenario_to_acc_dict.items():
    
    acc_list = np.asarray(acc_list)
    
    label = scenario.replace("0", "") if scenario.startswith("0") else scenario
    ax.plot(photon_list, np.nanmean(acc_list, axis = 0), '-', color = scenario_to_color_dict[scenario], label = label)
    # ax.set_title('Acc (single-hit)')
    # ax.set_xlabel('Scaling exponent')
    ax.set_ylabel('Accuracy')
    ax.legend()
    #ax.set_ylim((0.4, 1.0))
    ax.set_xlim((-2.0, 2.0))
    ax.set_yticks(np.arange(0.4, 1.0+0.1, 0.1))
    # ax.axvline(x = 0, linestyle='--', color = 'gray', linewidth = 1)
    # ax.axhline(y = 0.5, linestyle='--', color = 'gray', linewidth = 1)
    ax.grid(True, linestyle = '--')
    # ax.set_box_aspect((h/2)/(w))
    
    for ticklabel in ax.xaxis.get_ticklabels():
        ticklabel.set_visible(False)


    
# Lower panel - F-1...
ax = ax_list[1]
for scenario, f1_list in scenario_to_f1_dict.items():
    
    f1_list = np.asarray(f1_list)
    f1_list[f1_list == None] = np.nan
    
    label = scenario.replace("0", "") if scenario.startswith("0") else scenario
    ax.plot(photon_list, np.nanmean(f1_list, axis = 0), '-', color = scenario_to_color_dict[scenario], label = scenario)
    # ax.set_title('F-1 (single-hit)')
    ax.set_xlabel('Scaling exponent')
    ax.set_ylabel('F-1')
    # ax.legend()
    # ax.set_ylim((0.4, 1.0))
    ax.set_xlim((-2.0, 2.0))
    ax.set_yticks(np.arange(0.4, 1.0+0.1, 0.1))
    # ax.axvline(x = 0, linestyle='--', color = 'gray', linewidth = 1)
    # ax.axhline(y = 0.5, linestyle='--', color = 'gray', linewidth = 1)
    ax.grid(True, linestyle = '--')
    # ax.set_box_aspect((h/2)/(w))
    
# fig.subplots_adjust(
# ## top=1-0.049,
# ## bottom=0.049,
# left=0.05,
# right=0.05,
# # hspace=0.1,
# # wspace=0.1
# )

In [None]:
# Define the filename...
filename = 'Figure.perf_to_scaling.mean_dist'

# Set up drc...
DRCPDF         = "pdfs"
drc_cwd        = os.getcwd()
prefixpath_pdf = os.path.join(drc_cwd, DRCPDF)
if not os.path.exists(prefixpath_pdf): os.makedirs(prefixpath_pdf)

# Specify file...
fl_pdf = f"{filename}.pdf"
path_pdf = os.path.join(prefixpath_pdf, fl_pdf)

# Export...
## plt.savefig(path_pdf, dpi = 100, bbox_inches='tight', pad_inches = 0)
fig.savefig(path_pdf, dpi = 300)