In [None]:
!pip install -U pandas matplotlib

In [None]:
import plotly
import plotly.express as px

In [None]:
%reload_ext autoreload
%autoreload 2

import h5py
import json
import numpy as np
import pandas as pd
import scipy.stats
from pathlib import Path
from itertools import product

import sklearn.model_selection
import tensorflow as tf

from tqdm.auto import tqdm
tqdm.get_lock().locks = []

from IPython.display import Image, display, HTML, Math, Latex
import ipywidgets as widgets

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.ticker import FormatStrFormatter
import matplotlib
import seaborn as sns
import altair as alt
import dataframe_image as dfi

import plotly
import plotly.express as px
import plotly.offline as ply
import plotly.graph_objs as plygo
import cufflinks as cf

plotly.io.orca.config.executable = '/home/kiran/.local/bin/orca'
ply.init_notebook_mode(connected=False)
cf.set_config_file(offline=True, world_readable=False, theme='white')

In [None]:
import plotly.express as px

In [None]:
from datasets.hsi_dataset import HSIDataset
from sklearn.metrics import precision_recall_fscore_support

In [None]:
datasets = {'Suburban': '/storage/kiran/data/suburban/20170820_Urban_Ref_Reg_Subset.tif',
            'Urban': '/storage/kiran/data/urban/20170820_Urban2_INT_Final.tif',
            'Forest': '/storage/kiran/data/forest/20170820_Forest_Final_INT.tif'        
}

# Collect results from H5 files

In [None]:
h5_files = list(Path('/storage/kiran/results/data/').glob('*.h5'))
metrics = []
confusion_matrices = {}
l = widgets.Label(value="Not started...")
display(l)
for path in tqdm(h5_files[:]):    
    h5_file = h5py.File(path, 'r')    
    attrs = dict(h5_file.attrs.items())
    l.value = f"{path}"    
    predicted = h5_file['predictions'][()]
    targets = h5_file['targets'][()]
    
    dataset_name = attrs['dataset_name']
    input_type = attrs['input_type']
    compression_class = attrs['compression_class']

    n_components = int(attrs['n_components'])
    compression_rate = int(attrs['compression_rate']*100)
    reconstruction_loss = attrs['reconstruction_loss']
    
    execution_times = json.loads(attrs['execution_times'])
    
    l.value = f"{path} -- {dataset_name} ; {attrs['compression_class']} ; {attrs['input_type']} ; {attrs['compression_rate']}"
    
    dataset_file = datasets[dataset_name]
    dataset = HSIDataset(dataset_file, dataset_name)
    labels, _  = dataset.trainingset
    
    labels.pop('undefined',None)    

    # -------------------------------------------------------------------------------- #    
    # Record for categories weighted average 
    precision, recall, fbeta_score, support = precision_recall_fscore_support(y_true=targets, y_pred=predicted, labels=list(labels.values()), average='weighted')    
    record = [dataset_name, input_type, compression_class,
              compression_rate,  n_components, 
              "average_weighted", precision, recall, fbeta_score, support, 
              reconstruction_loss, execution_times
             ]
    metrics.append(record)
        
    # -------------------------------------------------------------------------------- #
    # Record for each category
    precision, recall, fbeta_score, support = precision_recall_fscore_support(y_true=targets, y_pred=predicted, labels=list(labels.values()), average=None)    
    for idx, label in enumerate(labels.keys()):
        record = [dataset_name, input_type, compression_class, 
                  compression_rate, n_components, 
                  label, precision[idx], recall[idx], fbeta_score[idx], support[idx], 
                  reconstruction_loss, execution_times
                 ]
        metrics.append(record)
        
    # -------------------------------------------------------------------------------- #
    # Confusion Matrices
    confusion_matrix = sklearn.metrics.confusion_matrix(y_true=targets, y_pred=predicted, labels=list(labels.values()))            
    confusion_matrix = pd.DataFrame(confusion_matrix, columns=list(labels.keys()))
    confusion_matrix.index = confusion_matrix.columns
    confusion_matrix.apply(func=lambda item: item/item.sum(), axis=1)
    confusion_matrix = confusion_matrix.div(confusion_matrix.sum(axis=1), axis=0)    
    confusion_matrices[(dataset_name, input_type, compression_class, int(compression_rate), )] = confusion_matrix

del h5_files

df_metrics = pd.DataFrame(metrics, columns=['dataset_name','input_type', 'compression_class',
                                            'compression_rate', 'n_components', 
                                            'label','precision', 'recall', 'f1', 'support', 
                                            'reconstruction_loss', 'execution_times'
                                           ]
                         )

In [None]:
# Write
# df_metrics.to_pickle('/storage/kiran/results/df_metrics.pickle')
# import pickle; pickle.dump( confusion_matrices, open( "/storage/kiran/results/confusion_matrices.pickle", "wb" ) )

# Read
df_metrics = pd.read_pickle('/storage/kiran/results/df_metrics.pickle')
import pickle; confusion_matrices = pickle.load(open( "/storage/kiran/results/confusion_matrices.pickle", "rb" ) )

display(confusion_matrices[('Forest', 'HSI', 'AE', 96)])
display(df_metrics.groupby(['dataset_name','input_type','compression_class','label', 'compression_rate','compression_rate']).count())

In [None]:
df_metrics = df_metrics.sort_values(['dataset_name','compression_class','label', 'compression_rate'])
df_metrics.reset_index(inplace=True, drop=True)

datasets = ['Suburban', 'Urban','Forest']
input_types = ['HSI', 'HSI_SG']
compression_classes = ["RGB","PCA","KPCA","ICA", "AE","DAE"]

categories = sorted(df_metrics.label.unique().tolist())

print(df_metrics.shape)
display(df_metrics.sample(5))

In [None]:
df=df_metrics[(df_metrics.label=='average_weighted')&(df_metrics.input_type!='HSI_SG')]
cr=0
df.loc[df.input_type=='RGB', 'compression_class'] = 'RGB'
df.loc[df.input_type=='RGB', 'compression_rate'] = cr
df.loc[df.input_type=='RGB', 'input_type'] = 'HSI'
df.loc[df.compression_class=='NA', 'compression_class'] = 'HSI'
df.rename({'f1':'f1-score'}, axis=1, inplace=True)

df = df[((df.compression_rate>0)&(df.compression_rate<97))|(df.input_type>='RGB')|(df.compression_class>='HSI')]

del df['support']
del df['reconstruction_loss']
del df['execution_times']
del df['n_components']
del df['input_type']
del df['precision']
del df['recall']

df = df.groupby(['dataset_name','compression_class','compression_rate','label']).max().reset_index()

del df['dataset_name']
del df['compression_rate']
del df['label']


df = df.melt(id_vars=['compression_class']).reset_index(drop=True)

fig = px.box(data_frame=df, 
             x='variable', 
             y='value', 
             color='compression_class', 
             template='plotly_white',
             points='suspectedoutliers',
             range_y=[0.8,1],
             boxmode='group',
             notched=False,
             category_orders={"compression_class": ["RGB","HSI","PCA","KPCA","ICA","AE","DAE"]},
            )


# fig.for_each_annotation(lambda a: a.update(text=a.text.replace("input_type=", "")))
# fig.for_each_annotation(lambda a: a.update(text=a.text.replace("dataset_name=", "")))
# fig.for_each_annotation(lambda a: a.update(text=a.text.replace("timer", "")))
# fig.for_each_trace(lambda t: t.update(name=t.name.replace("compression_", "")))
fig.update_layout(legend_title="", font=dict(size=30, color="Black", family='Times New Roman'))
fig.update_yaxes(title_standoff=0, title_font=dict(size=30, family='Times New Roman'))
fig.update_xaxes(title_standoff=0, title_font=dict(size=30, family='Times New Roman'))

fig.layout.xaxis.title.text = ""
fig.layout.yaxis.title.text  = ""
# fig.layout.update(showlegend=False)

fig.update_layout(boxgroupgap=0.2, boxgap=0.2)

# fig.update_layout(shapes=[dict(type= 'line',
#                                yref= 'paper', y0= 0, y1= 0.9,
#                                xref= 'x', x0= 0.5, x1= 0.5
#                               ),
#                           dict(type= 'line',
#                                yref= 'paper', y0= 0, y1= 0.9,
#                                xref= 'x', x0= 1.5, x1= 1.5
#                               )
# ])


fig.update_layout(
    legend=dict(
        x=1,
        y=1,
        title="",
        # traceorder="reversed",
        #title_font_family="Times New Roman",
        font=dict(
            family="Times New Roman",
            size=20,
            color="black"
        ),
        bgcolor='rgba(255, 255, 255, 0)',
        # bordercolor="Black",
        borderwidth=0,
        orientation='v'
    )
)

filename= f'/storage/kiran/results/charts/overall_scores_boxplot.png'
print(f"Saving: {filename}")
fig.write_image(filename, scale=1, width=1200, height=500)
Image(filename)

# Reconstruction Error (MSE)

In [None]:
filt  = (df_metrics.compression_rate>0) & (df_metrics.compression_rate<100)
filt &= (df_metrics.label=='average_weighted') 
#filt &= (df_metrics.input_type=='HSI') 
#filt &= (df_metrics.compression_class!='DAE') 
df =  df_metrics[filt]
compression_classes =["PCA","KPCA","ICA","AE","DAE"]

df.drop(['n_components','precision','recall', 'f1','support'], axis=1, inplace=True)
sns.set(context='paper',font="Times New Roman", style="whitegrid") # font_scale=2 

# plt.xkcd()
# fig, axs = plt.subplots(nrows=5, ncols=2, figsize=(16,20), sharex=True, sharey=True)
#axs = axs.reshape(1,-1)
ls = ['-','--','=']

pbar = tqdm(enumerate(input_types))

for _col, input_type in pbar:
    for _row, compression_class in enumerate(compression_classes):
        fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(8,5), sharex=True, sharey=True)
        axs = [axs]

        for dataset in datasets:
            pbar.set_description(f"{dataset} {input_type} {compression_class} ")
            chart_data = df[(df.input_type==input_type) & (df.compression_class==compression_class) & (df.dataset_name==dataset)]
            sns.lineplot(ax=axs[0], data=chart_data, x='compression_rate', y='reconstruction_loss', label=dataset, dashes=True, markers='o', linewidth=3)
            
            axs[0].set_title(f"${input_type.replace('SG','{SG}')}$", fontdict={'fontsize': 24, 'fontfamily': 'Times New Roman'})
            
            # axs[0].set_ylim(ymin=-1E-12, ymax=1E-3)
            axs[0].set_yscale("log", nonposy='clip')
            axs[0].set_ylabel("")
            # axs[0].yaxis.set_major_locator(plt.MaxNLocator(4))
            axs[0].yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{x:.1e}'))
            axs[0].tick_params(axis='both', labelsize=28)
            
            axs[0].set_xlabel("compression rate (%)", fontsize=24, fontfamily="Times New Roman")            
            axs[0].legend(ncol=1, loc='best', prop={'family':'Times New Roman', 'size':20}).set_visible(True)
            axs[0].grid(True)
            
            
            axs[0].set_ylabel(f"MSE - {compression_class}", fontsize=24, fontfamily="Times New Roman")

            # fig.subplots_adjust(top=0.9, left=-2, right=1, bottom=0.3)  # create some space below the plots by increasing the bottom-value
            # axs.ravel()[0].legend(ncol=1, fontsize=26, loc='best')

        

        filename= f'/storage/kiran/results/charts/mse_reconstruction_{compression_class}_{input_type}.pdf'
        print(f"Saving: {filename}")
        plt.tight_layout()
        #plt.savefig(filename, bbox='tight', dpi=300, transparent=True)
        plt.show()

# Variability on AE and DAE for several trained models

In [None]:
filt = (df_metrics.compression_class.isin(['AE','DAE']))
filt &= (df_metrics.compression_rate>0) & (df_metrics.compression_rate<100)
filt &= (df_metrics.input_type=='HSI') 
filt &= (df_metrics.label!='average_weighted') 
# filt &= (df_metrics.compression_rate.isin(range(90,100)))

df = df_metrics[filt]
df.drop(['n_components','support','precision','recall','f1','execution_times','label'], axis=1, inplace=True)
compression_classes = sorted(df.compression_class.unique().tolist())
categories = sorted(df_metrics.label.unique().tolist())
display(df.groupby(['dataset_name','input_type','compression_class','compression_rate']).min())

fig = px.box(data_frame=df, 
             x='dataset_name', 
             y='reconstruction_loss', 
             color='compression_class',              
             # facet_col='com', 
             # facet_row='input_type',         
             category_orders={'dataset_name':['Suburban','Urban','Forest'],
                              "compression_class": ["RGB","PCA","ICA","KPCA","AE","DAE"]
                             },
             template='plotly_white',
             orientation='v',
            )

fig.for_each_annotation(lambda a: a.update(text=a.text.replace("input_type=", "")))
fig.for_each_annotation(lambda a: a.update(text=a.text.replace("label=", "")))
fig.for_each_annotation(lambda a: a.update(text=a.text.replace("compression_class=", "")))
fig.for_each_annotation(lambda a: a.update(text=a.text.replace("dataset_name=", "")))
fig.update_layout(legend_title="",font=dict(size=24, color="Black"))
fig.layout.xaxis.title.text = "" # fig.layout.xaxis2.title.text = fig.layout.xaxis3.title.text = ""
fig.layout.yaxis.title.text  = "mse"
fig.update_layout(legend_title="", font=dict(size=24, color="Black", family='Times New Roman'))
fig.update_yaxes(nticks=6, title_standoff=0, title_font=dict(size=24, family='Times New Roman'))
fig.update_xaxes(nticks=10, title_standoff=0, title_font=dict(size=24, family='Times New Roman'))
#fig.layout.update(showlegend=False)

fig.update_layout(
    legend=dict(
        x=0,
        y=1,
        # traceorder="reversed",
        title_font_family="Times New Roman",
        font=dict(
            family="Times New Roman",
            size=18,
            color="black"
        ),
        bgcolor='rgba(255, 255, 255, 0)',
        # bordercolor="Black",
        borderwidth=0,
        orientation='v'
    )
)

# fig.for_each_trace(lambda t: t.update(name=t.name.replace("=", "")))
# fig.show()

filename= f'/storage/kiran/results/charts/AE_DAE_variability.png'
print(f"Saving: {filename}")
fig.write_image(filename, scale=1.5, width=800, height=400)
Image(filename)

# Precision x Compression Rate

In [None]:
filt = (df_metrics.compression_rate>0) & (df_metrics.compression_rate<100)
filt &= (df_metrics.label!='average_weighted') 
#filt &= (df_metrics.compression_class!='DAE')
df = df_metrics[filt]

compression_classes = ["PCA","KPCA","ICA","AE","DAE"]
categories = sorted(df_metrics.label.unique().tolist())
df.drop(['n_components','support'], axis=1, inplace=True)

input_types = ['HSI']
for row, input_type in enumerate(input_types):
    sns.set(context='paper',font="Times New Roman", style="whitegrid") # font_scale=2    
    # plt.xkcd()
    fig, axs = plt.subplots(nrows=5, ncols=3, figsize=(15,15), sharex=True, sharey=True)

    for col, dataset in tqdm( enumerate(datasets)):
        for row, compression_class in enumerate(compression_classes):
            line_styles=['-','--','-.',':'][::-1]            
            for curve, category in enumerate(categories):
                
                chart_data = df[ (df.dataset_name==dataset) & (df.compression_class==compression_class) & (df.label==category) & (df.input_type==input_type)]
                
                chart_data = chart_data.groupby(['dataset_name','input_type','compression_class','compression_rate'])['f1'].max().reset_index()                
                chart_data = chart_data.sort_values(['compression_rate'])
                
                # display(chart_data)
                if len(chart_data)==0:
                    continue            

                sns.lineplot(ax=axs[row,col], data=chart_data, x='compression_rate', y='f1', label=category, linewidth=2)
                axs[row,col].lines[-1].set_linestyle(line_styles.pop())
                
                
             
                axs[row,col].legend(ncol=2, prop={'size':16}, framealpha=1).get_frame().set_facecolor('white')
                axs[row,col].get_legend().set_visible(False)
                
                axs[row,col].set_title(f"{dataset}, {compression_class}", fontsize=24, fontfamily="Times New Roman")

        axs[row,col].set_ylabel('f1 score', fontsize=24, fontfamily="Times New Roman")
        axs[0,col].get_legend().set_visible(True)

        
    for ax in axs.ravel():
        ax.grid("on")
        ax.tick_params(axis='both', labelsize=24)
        ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{x:.2f}'))
        ax.set_xlabel("compression rate (%)", fontsize=24, fontfamily="Times New Roman")
        ax.set_ylim(ymin=0.8, ymax=1)
        ax.set_xlim(xmin=0, xmax=99)
        ax.set_ylabel("f1 score", fontsize=24, fontfamily="Times New Roman" )
        
        
    fig.subplots_adjust(top=0.9, left=-1, right=1, bottom=0.30)  # create some space below the plots by increasing the bottom-value
    plt.tight_layout()
    filename= f'/storage/kiran/results/charts/f1_score_curves_{input_type}.png'
    print(f"Saving: {filename}")
    plt.savefig(filename, bbox='tight', dpi=300, transparent=True) 
    plt.show()
    plt.close()

# Confusion Matrices for compressed

In [None]:
datasets = ['Forest','Suburban','Urban']
compression_classes = ["PCA","KPCA","ICA","AE","DAE"]
compression_rates = [0, 10, 50, 90]
input_type = ['HSI', 'HSI_SG']

items = product(datasets, input_type, compression_classes,compression_rates)

for dataset, input_type, compression_class, compression_rate in tqdm(items):
    # print(dataset, compression_class, compression_rate)
    confusion_matrix = confusion_matrices.get((dataset, input_type, compression_class, compression_rate), None)
    if confusion_matrix is None:
        print(f"Confusion Matrix not found: {(dataset, input_type, compression_class, compression_rate)}")
        continue
    # display(confusion_matrix) 
    confusion_matrix.columns = [f"{c[0]}" for c in confusion_matrix.columns]
    confusion_matrix.index = [f"{c[0]}" for c in confusion_matrix.index]

    
    sns.set(style="whitegrid", font_scale=1.3,  context="paper")
    size=4
    if len(confusion_matrix)>2:
        size = 4
    
    fig, ax = plt.subplots(figsize=(size,size))
    sns.heatmap(data=confusion_matrix, 
                annot=True, 
                cmap='Pastel2_r', 
                linewidths=0.1, 
                linecolor='white', 
                cbar=False, 
                fmt='.3f', 
                ax=ax, 
                annot_kws={'size':14, 'weight':'medium'}
               )
    
    plt.title(f"{dataset}, {input_type}", fontsize=18)
    plt.tight_layout()
    filename= f'/storage/kiran/results/charts/confusion_matrix_{dataset}_{compression_class}_{compression_rate}.png'
    print(f"Saving: {filename}")
    plt.savefig(filename, bbox='tight', dpi=300, transparent=True)
    ax.tick_params(axis='x', labelsize=20, labelrotation=0)
    ax.tick_params(axis='y', labelsize=20, labelrotation=0)
    # plt.show()
    plt.close()

# Confusion Matrices for RGB, HSI, HSI_SG (non-compressed)

In [None]:
datasets = ['Forest','Suburban','Urban']
compression_classes = ['NA']
compression_rates = [0]
#input_type = ['RGB','HSI', 'HSI_SG']
input_type = ['RGB','HSI']

items = product(datasets, input_type, compression_classes,compression_rates)

for dataset, input_type, compression_class, compression_rate in tqdm(items):
    # print(dataset, compression_class, compression_rate)
        
    confusion_matrix = confusion_matrices.get((dataset, input_type, compression_class, compression_rate), None)
    if confusion_matrix is None:
        print(f"Confusion Matrix not found: {(dataset, input_type, compression_class, compression_rate)}")
        continue
    display(confusion_matrix) 
    confusion_matrix.columns = [f"{c[0]}" for c in confusion_matrix.columns]
    confusion_matrix.index = [f"{c[0]}" for c in confusion_matrix.index]
    
    sns.set(style="whitegrid", font_scale=1.6,  context="paper",font="Times New Roman")
    size=4
    if len(confusion_matrix)>2:
        size = 4
    fig, ax = plt.subplots(figsize=(size,size))
    
    sns.heatmap(data=confusion_matrix, annot=True, cmap='Pastel2_r', linewidths=0.1, linecolor='white', cbar=False, fmt='.3f', ax=ax, annot_kws={'size':20, 'weight':'medium'})
    plt.title(f"{dataset}, {input_type}", fontsize=18, fontfamily="Times New Roman")
    plt.tight_layout()
    ax.tick_params(axis='x', labelsize=20, labelrotation=0)
    ax.tick_params(axis='y', labelsize=20, labelrotation=0)

    filename= f'/storage/kiran/results/charts/confusion_matrix_baseline_{dataset}_{input_type}.png'
    print(f"Saving: {filename}")
    plt.savefig(filename, bbox='tight', dpi=300, transparent=True)    
    plt.show()
    plt.close()

# Heatmap (consolidated)

In [None]:
df = df_metrics.drop(['reconstruction_loss', 'support', 'n_components'], axis=1)
filt  = (df.compression_rate>0) & (df.compression_rate<100)
filt &= (df.label=='average_weighted')
# filt &= (df.input_type=='HSI')
df = df[filt]
df.rename({'fbeta_score':'f1'}, inplace=1, axis=1)


for i, dataset_name in enumerate(datasets):
    for input_type in input_types:
        sns.set(style='whitegrid', font_scale=1.6)
        #plt.xkcd(False)
        fig, axs = plt.subplots(nrows=len(datasets), ncols=1, figsize=(16,10), sharex=True, sharey=False)

        # [left, bottom, width, height] where all quantities are in fractions of figure width and height
        # cbar_ax = fig.add_axes([.91, .3, .03, .4])
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])

        for i, metric in enumerate(['precision', 'recall','f1']):
            
            data = df[(df.dataset_name==dataset_name) & (df.input_type==input_type)]        
            data = data.groupby(['dataset_name','input_type','compression_class','compression_rate','label']).agg({'precision':'max', 'recall':'max', 'f1':'max'}).reset_index()            
            data = pd.pivot_table(data, index='compression_rate', columns='compression_class', values=metric)
            data = data[["PCA","KPCA","ICA","AE","DAE"]]
            data = data.T
            data.columns = np.around(data.columns.ravel()).astype(int)            

            sns.heatmap(data=data, ax=axs[i], vmin=0.90, vmax=1, cbar_ax=(None if i else cbar_ax), cbar=(i==0), linewidths=0.1, linecolor='lightgray', cmap='jet_r')        
            axs[i].set_ylabel(metric, fontsize=26)

            axs[i].tick_params(axis='x', labelrotation=0, labelsize=20, which='major')
            axs[i].tick_params(axis='y', labelrotation=0, labelsize=20, which='major')
            axs[-1].set_xlabel("compression rate (%)", fontsize=24)
            
            
            
            _ = [a.set_text('') for idx, a in enumerate(axs[i].get_xticklabels()) if ((idx*2)%3!=0)]
            xticklabels = (list(axs[i].get_xticklabels()))
            axs[i].set_xticklabels(xticklabels)
            
            
            
        # plt.tight_layout()
        plt.suptitle(f"{dataset_name}, {input_type}", y=0.95, fontsize=24)
        filename= f'/storage/kiran/results/charts/heatmap_{dataset_name}_{input_type}.png'
        print(f"Saving: {filename}")
        plt.savefig(filename, bbox='tight', dpi=300, transparent=True)
        plt.show()
        plt.close()

# Top scores

In [None]:
# filt  = (df_metrics.compression_rate>=95) & (df_metrics.compression_rate<99)
cr = 95
filt  = ((df_metrics.compression_rate==cr) | (df_metrics.input_type=='RGB'))
filt &= (df_metrics.label!='average_weighted') 
df =  df_metrics[filt]
df = df[['dataset_name',
         'label',
         'input_type',
         'compression_class',
         'compression_rate',
         #'n_components',         
         'precision',
         'recall',
         'f1',
         'reconstruction_loss'
        ]]
df['f1-score'] = df.f1
df.loc[df.input_type=='RGB', 'compression_class'] = 'RGB'
df.loc[df.input_type=='RGB', 'compression_rate'] = cr
df.loc[df.input_type=='RGB', 'input_type'] = 'HSI'

# ----------------------------------------------------

def color_max(x):  
    c1 = 'color: green; font-weight: bold'
    c2 = 'color: red; font-style: italic'
    c3 = 'color: gray'    
    ret = pd.DataFrame(c3, columns=x.columns, index=x.index)
    
    m = x.groupby(['label']).agg({'precision':'idxmax','recall':'idxmax','f1-score':'idxmax'})    
    ret.loc[m['precision'],'precision'] = c1
    ret.loc[m['recall'],'recall'] = c1
    ret.loc[m['f1-score'],'f1-score'] = c1    
    
    m = x.groupby(['label']).agg({'precision':'idxmin','recall':'idxmin','f1-score':'idxmin'})
    ret.loc[m['precision'],'precision'] = c2
    ret.loc[m['recall'],'recall'] = c2
    ret.loc[m['f1-score'],'f1-score'] = c2
    
    return ret
    
for t, input_type in enumerate(['HSI']):
    table = []
    for i, dataset_name in enumerate(datasets):
        data = df[(df.dataset_name==dataset_name) & (df.input_type==input_type)]

        # Max here is only for AE, DAE, because of multiple runs. Other should have only one valued for cr=95%.
        table = data.groupby(['dataset_name','input_type','label','compression_class']).agg({'precision':'max','recall':'max','f1-score':'max'})        
        table.rename({'precision':'precision',
                      'recall':'recall',
                      'f1-score':'f1-score'
                     },inplace=True, axis=1)

        table = table.reset_index().drop(['dataset_name','input_type'], axis=1)
        table.rename({'compression_class': 'compression'}, inplace=True, axis=1)
        
        sorter = ["RGB","PCA","KPCA","ICA","AE","DAE"]
        sorterIndex = dict(zip(sorter, range(len(sorter))))
        # the dataframe numerically
        table['order'] = table.compression.map(sorterIndex)
        table = table.sort_values(['label','order'])
        table.set_index(['label','compression'], inplace=True)        
        del table['order']
        styled_table = table.style.apply(color_max, axis=None)            
        display(styled_table)
        
        
        filename= f'/storage/kiran/results/charts/top_classification_scores_{dataset_name}_{input_type}'
        hl = "\(high compression rates\)"
        label = f"table:top_classification_scores_{dataset_name}_{input_type}"
        _c = input_type.replace("HSI_SG", "$HSI_{SG}$")
        caption = f"Top classification scores {dataset_name}, {_c}, compression rate={cr}\%"    
        
        print(f"Saving: {filename}")
        dfi.export(styled_table, f"{filename}.png", fontsize=30)
        display(Image(f"{filename}.png"))        
        
        table.to_excel(f"{filename}.xls", float_format="%.4f")
        table_latex = table.to_latex(buf=None,
                       caption=caption,
                       label=label,
                       header = True,
                       multicolumn=True,
                       multirow=True,
                       bold_rows=True,
                       index=True,
                       float_format="%.3f"
                      )
        
        table_latex = table_latex.replace('begin{table}','begin{table}[H]')
        table_latex = table_latex.replace('centering','centering \\footnotesize')                          
        display(Latex(table_latex))        
        open(f"{filename}.tex",'w+').write(table_latex)
        display(Latex(open(f"{filename}.tex").read()))




# Precision vs. Recall

In [None]:
filt  = (df_metrics.compression_rate>0) & (df_metrics.compression_rate<100)
#filt &= (df_metrics.input_type=='HSI_SG') 
filt &= (df_metrics.label!='average_weighted') 
df =  df_metrics[filt]
df.drop(['f1','support','reconstruction_loss','execution_times'], inplace=True, axis=1)
# df.compression_rate = 1 - (df.compression_rate / 100)
# df.sort_values('compression_rate', ascending=False, inplace=True)
display(df)

df = df.groupby(['dataset_name','input_type','compression_class','label','compression_rate']).agg({'precision':'max', 'recall':'max'}).reset_index()
# df.rename({'compression_rate': 'compression (%)'},axis=1, inplace=True)

# import retrying
# unwrapped = plotly.io._orca.request_image_with_retrying.__wrapped__
# wrapped = retrying.retry(wait_random_min=1000)(unwrapped)
# plotly.io._orca.request_image_with_retrying = wrapped


fig = px.scatter(df, 
                 y="precision", 
                 x="recall", 
                 color="compression_rate",                 
                 # symbol='compression_class',                 
                 facet_col="label", 
                 facet_col_wrap=5,
                 facet_row='compression_class',
                 # size='compression_rate',
                 category_orders={"compression_class": ["PCA","KPCA","ICA","AE", "DAE"], 
                                  "dataset_name": ["Suburban", "Urban","Forest"], 
                                  "label": ["Asphalt, Rooftop, Grass"]
                                 },
                 #labels={'dataset_name': '', 'compression_class':''},
                 # template="plotly",                                 
                 range_x=[0, 1],
                 range_y=[0, 1],                 
                 size_max=5,
                 opacity=0.8,
                 color_continuous_scale='Bluered_r',
                 #color_continuous_scale=["red", "blue"]
                )

fig.for_each_annotation(lambda a: a.update(text=a.text.replace("label=", "")))
fig.for_each_annotation(lambda a: a.update(text=a.text.replace("compression_class=", "")))
fig.for_each_trace(lambda t: t.update(name=t.name.replace("=", "")))
fig.update_layout(legend_title_text="", legend_title_font=dict(size=1), font=dict(size=18, color="Black", family='Times New Roman'))
fig.update_xaxes(nticks=3, title_standoff=0, title_font=dict(size=14, family='Times New Roman'))
fig.update_yaxes(nticks=3, title_standoff=0, title_font=dict(size=14, family='Times New Roman'), tickfont=dict(size=18))

fig.update_layout(coloraxis_colorbar=dict(title="compression (%)", titlefont=dict(size=14)))


filename= f'/storage/kiran/results/charts/precision_vs_recall.png'
print(f"Saving: {filename}")
fig.write_image(filename, scale=1, width=1200, height=800)
Image(filename)