In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')
sys.path.append('..')

# Maths
import numpy as np

# Plotting
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, ConnectionPatch
from matplotlib.patches import Rectangle
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize, BoundaryNorm
from matplotlib.colors import ListedColormap

# ML
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.pipeline import Pipeline
from sklearn.linear_model import Ridge
from sklearn.compose import TransformedTargetRegressor
from sklearn.neighbors import KernelDensity
from errors import MAE
from kernels import sqeuclidean_distances

# Atoms
from ase.io import read
from soap import extract_species_pair_groups

# Utilities
import h5py
from tools import load_json
from tqdm.notebook import tqdm
import project_utils as utils

# Import COSMO style toolkit
import cosmoplot.colorbars as cosmocbars
import cosmoplot.utils as cosmoutils
import cosmoplot.style as cosmostyle

cosmostyle.set_style('article')
color_list = cosmostyle.color_cycle

# Analysis setup

In [3]:
# Load SOAP cutoffs
soap_hyperparameters = load_json('../../Processed_Data/soap_hyperparameters.json')   
cutoffs = soap_hyperparameters['interaction_cutoff']

In [4]:
# Load train and test indices
deem_train_idxs = np.loadtxt('../../Processed_Data/DEEM_330k/svm_train.idxs', dtype=int)
deem_test_idxs = np.loadtxt('../../Processed_Data/DEEM_330k/svm_test.idxs', dtype=int)

iza_train_idxs = np.loadtxt('../../Processed_Data/IZA_230/svm_train.idxs', dtype=int)
iza_test_idxs = np.loadtxt('../../Processed_Data/IZA_230/svm_test.idxs', dtype=int)

In [5]:
# Load indices of DEEM 10k set in 330k
idxs_deem_10k = np.loadtxt('../../Processed_Data/DEEM_330k/deem_10k.idxs', dtype=int)

In [6]:
# Load cantons for IZA and Deem
iza_cantons = np.loadtxt('../../Raw_Data/IZA_230/cantons.dat', usecols=1, dtype=int)
deem_cantons_2 = np.loadtxt('../../Processed_Data/DEEM_330k/Data/cantons_2-class.dat', dtype=int)
deem_cantons_4 = np.loadtxt('../../Processed_Data/DEEM_330k/Data/cantons_4-class.dat', dtype=int)
n_deem = len(deem_cantons_2)

In [7]:
# Build set of "master" canton labels for the train set
train_cantons = {}

train_cantons[4] = np.concatenate((
    iza_cantons[iza_train_idxs], 
    deem_cantons_4[deem_train_idxs]
))

train_cantons[2] = np.concatenate((
    np.ones(len(iza_train_idxs), dtype=int),
    deem_cantons_2[deem_train_idxs]
))

train_class_weights = {
    n_cantons: utils.balanced_class_weights(train_cantons[n_cantons]) for n_cantons in (2, 4)
}

# Build set of "master" canton labels for the test set
test_cantons = {}

test_cantons[4] = np.concatenate((
    iza_cantons[iza_test_idxs], 
    deem_cantons_4[deem_test_idxs]
))

test_cantons[2] = np.concatenate((
    np.ones(len(iza_test_idxs), dtype=int),
    deem_cantons_2[deem_test_idxs]
))

test_class_weights = {
    n_cantons: utils.balanced_class_weights(test_cantons[n_cantons]) for n_cantons in (2, 4)
}

In [8]:
# Load dummy Deem cantons to test the "null" case
dummy_cantons = {}
dummy_cantons[2] = np.loadtxt('../../Processed_Data/DEEM_330k/Data/dummy_cantons_2-class.dat', dtype=int)
dummy_cantons[4] = np.loadtxt('../../Processed_Data/DEEM_330k/Data/dummy_cantons_4-class.dat', dtype=int)

In [9]:
# Master set of dummy labels for the train set
dummy_train_cantons = {}
dummy_train_cantons[2] = dummy_cantons[2][deem_train_idxs]
dummy_train_cantons[4] = dummy_cantons[4][deem_train_idxs]

dummy_train_class_weights = {
    n_cantons: utils.balanced_class_weights(dummy_train_cantons[n_cantons]) for n_cantons in (2, 4)
}

# Master set of dummy labels for the test set
dummy_test_cantons = {}
dummy_test_cantons[2] = dummy_cantons[2][deem_test_idxs]
dummy_test_cantons[4] = dummy_cantons[4][deem_test_idxs]

dummy_test_class_weights = {
    n_cantons: utils.balanced_class_weights(dummy_test_cantons[n_cantons]) for n_cantons in (2, 4)
}

In [10]:
class_names = {
    2: ['IZA', 'DEEM'],
    4: ['IZA1', 'IZA2', 'IZA3', 'DEEM']

}

ticklabels = {
    2: ['IZA', 'DEEM'],
    4: ['IZA1', 'IZA2', 'IZA3', 'DEEM']
}

dummy_ticklabels = {
    2: ['DEEM1', 'DEEM2'],
    4: ['DEEM1', 'DEEM2', 'DEEM3', 'DEEM4']
}

In [11]:
# Linear model setup
n_species = 2
group_names = {'power': ['OO', 'OSi', 'SiSi', 
                         'OO+OSi', 'OO+SiSi', 'OSi+SiSi',
                         'OO+OSi+SiSi'], 
               'radial': ['O', 'Si', 'O+Si']}

In [12]:
deem_name = 'DEEM_330k'
iza_name = 'IZA_230'
deem_dir = f'../../Processed_Data/{deem_name}/Data'
iza_dir = f'../../Processed_Data/{iza_name}/Data'

# LR check

In [13]:
batch_size = 100000

In [None]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_soaps = utils.load_hdf5(iza_file)
        
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        f = h5py.File(deem_file, 'r')
        deem_330k_dataset = f['0']
        
        train_soaps = np.vstack((iza_soaps[iza_train_idxs], deem_330k_dataset[deem_train_idxs]))
        
        n_features = train_soaps.shape[1]
        feature_groups = extract_species_pair_groups(
            n_features, n_species, 
            spectrum_type=spectrum_type,
            combinations=True
        )
        
        # Prepare batches for LR
        n_samples_330k = len(deem_test_idxs)
        n_batches = n_samples_330k // batch_size
        if n_samples_330k % batch_size > 0:
            n_batches += 1
        
        for species_pairing, feature_idxs in zip(
            tqdm(group_names[spectrum_type], desc='Species', leave=False),
            feature_groups
        ):
                        
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):              
                
                # Load decision functions
                df_dir = f'LSVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                
                # Load decision functions
                iza_dfs = np.loadtxt(f'{iza_dir}/{cutoff}/{df_dir}/svc_structure_dfs.dat')
                deem_dfs = np.loadtxt(f'{deem_dir}/{cutoff}/{df_dir}/svc_structure_dfs.dat')
                
                train_dfs = np.concatenate((iza_dfs[iza_train_idxs], deem_dfs[deem_train_idxs]))
                test_dfs = np.concatenate((iza_dfs[iza_test_idxs], deem_dfs[deem_test_idxs]))
                     
                # Need to preprocess y (the decision functions) by hand since TransformedTargetRegressor
                # doesn't pass any fit_params to the transformer
                df_scaler = utils.StandardNormScaler(featurewise=True)
                scaled_train_dfs = df_scaler.fit_transform(
                    train_dfs, sample_weight=train_class_weights[n_cantons]
                )
                
                pipeline = Pipeline(
                    [
                        ('norm_scaler', utils.StandardNormScaler()),
                        ('ridge', Ridge(alpha=1.0E-10)),
                    ],
                )
                fit_params = {
                    'norm_scaler__sample_weight': train_class_weights[n_cantons], 
                    'ridge__sample_weight': train_class_weights[n_cantons],
                }
                pipeline.fit(train_soaps, scaled_train_dfs, **fit_params)
                
                predicted_iza_dfs = pipeline.predict(iza_soaps[iza_test_idxs])
                predicted_iza_dfs = df_scaler.inverse_transform(predicted_iza_dfs)
                
                predicted_deem_dfs = np.zeros(deem_dfs[deem_test_idxs].shape)       
                for i in tqdm(range(0, n_batches), desc='Batch', leave=False):
                    batch_slice = slice(i * batch_size, (i + 1) * batch_size)
                    batch_idxs = deem_test_idxs[batch_slice]
                    
                    deem_330k_batch = deem_330k_dataset[batch_idxs]
                    predicted_deem_dfs[batch_slice] = pipeline.predict(deem_330k_batch)
                
                predicted_deem_dfs = df_scaler.inverse_transform(predicted_deem_dfs)
                predicted_dfs = np.concatenate((predicted_iza_dfs, predicted_deem_dfs))
                
                mae = mean_absolute_error(
                    test_dfs, predicted_dfs, 
                    sample_weight=test_class_weights[n_cantons],
                    multioutput='raw_values'
                )
                rmse = mean_squared_error(
                    test_dfs, predicted_dfs, 
                    sample_weight=test_class_weights[n_cantons], 
                    squared=False, multioutput='raw_values'
                )
                
                print(f'-----{n_cantons}-Class {spectrum_name} {species_pairing}-----')
                print(f'Test MAE = {mae}')
                print(f'Test RMSE = {rmse}')
                print()
        
        f.close()

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power OO-----
Test MAE = [4.95218246e-05]
Test RMSE = [1.81994999e-07]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Power OO-----
Test MAE = [0.5839351  0.51671835 0.54157188 0.53575357]
Test RMSE = [0.70163627 0.47943542 0.57531916 0.62499404]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power OSi-----
Test MAE = [5.28591177e-05]
Test RMSE = [2.29463063e-07]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Power OSi-----
Test MAE = [0.47891326 0.43398279 0.50963214 0.44010401]
Test RMSE = [0.55430154 0.34768403 0.42489182 0.47536034]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power SiSi-----
Test MAE = [7.20133986e-05]
Test RMSE = [6.38545645e-07]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Power SiSi-----
Test MAE = [0.59779817 0.52095342 0.54290055 0.50965192]
Test RMSE = [0.91105665 0.7223323  0.92986195 0.56433875]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power OO+OSi-----
Test MAE = [1.24052998e-05]
Test RMSE = [9.58247396e-09]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Power OO+OSi-----
Test MAE = [0.51710506 0.50197257 0.47432752 0.44738544]
Test RMSE = [0.57930212 0.4886385  0.38938215 0.46959039]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power OO+SiSi-----
Test MAE = [2.4777398e-05]
Test RMSE = [6.09424421e-08]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Power OO+SiSi-----
Test MAE = [0.56100085 0.38218911 0.47916496 0.45169838]
Test RMSE = [0.56882029 0.2930045  0.44293592 0.40111341]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power OSi+SiSi-----
Test MAE = [8.64377032e-06]
Test RMSE = [9.11085442e-09]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Power OSi+SiSi-----
Test MAE = [0.54650014 0.42719268 0.51282793 0.4911126 ]
Test RMSE = [0.57976929 0.348686   0.47916032 0.49042846]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power OO+OSi+SiSi-----
Test MAE = [1.13314665e-06]
Test RMSE = [4.78613572e-12]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Power OO+OSi+SiSi-----
Test MAE = [0.53428721 0.4846022  0.52423521 0.47372373]
Test RMSE = [0.52152799 0.43955855 0.55443007 0.40467837]



HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Radial O-----
Test MAE = [7.04695339e-06]
Test RMSE = [3.14211573e-10]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Radial O-----
Test MAE = [0.62039702 0.32056319 0.39975763 0.53804778]
Test RMSE = [0.61065689 0.1968055  0.24524467 0.59880995]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Radial Si-----
Test MAE = [2.16314653e-07]
Test RMSE = [1.3912309e-13]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Radial Si-----
Test MAE = [0.51962695 0.28747358 0.41000526 0.83969791]
Test RMSE = [0.42255181 0.13844271 0.29631771 1.04239756]



HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Radial O+Si-----
Test MAE = [2.16532343e-06]
Test RMSE = [1.29452923e-11]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----4-Class Radial O+Si-----
Test MAE = [0.61501565 0.56879673 0.39724608 0.54736124]
Test RMSE = [0.6080658  0.51544423 0.31035967 0.56616107]



HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

-----2-Class Power OO-----
Test MAE = [2.35203474e-05]
Test RMSE = [5.07686514e-09]



HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

# Classification statistics

In [None]:
for cutoff in cutoffs:
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        
        for species_pairing in group_names[spectrum_type]:

            for n_cantons in (2, 4):
                
                # Prepare outputs
                data_dir = f'LSVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
            
                print(f'===== {cutoff} | {n_cantons}-Class | {spectrum_name} | {species_pairing} =====')

                predicted_cantons_iza = np.loadtxt(
                    f'{iza_dir}/{cutoff}/{data_dir}/svc_structure_cantons.dat', dtype=int
                )
                predicted_cantons_deem = np.loadtxt(
                    f'{deem_dir}/{cutoff}/{data_dir}/svc_structure_cantons.dat', dtype=int
                )

                predicted_cantons_train = np.concatenate((
                    predicted_cantons_iza[iza_train_idxs],
                    predicted_cantons_deem[deem_train_idxs]
                ))
                predicted_cantons_test = np.concatenate((
                    predicted_cantons_iza[iza_test_idxs],
                    predicted_cantons_deem[deem_test_idxs]
                ))

                # TODO: weighted confusion matrices
                matrix_train = confusion_matrix(cantons_train[n_cantons], predicted_cantons_train)
                matrix_test = confusion_matrix(cantons_test[n_cantons], predicted_cantons_test)

                print('----- Train -----')
                print(
                    classification_report(
                        cantons_train[n_cantons], predicted_cantons_train, zero_division=0
                    )
                )
                # TODO: do balanced accuracy
                print('Train accuracy:', np.sum(np.diag(matrix_train)) / np.sum(matrix_train))
                print()
                print(matrix_train)
                print()

                print('----- Test -----')
                print(
                    classification_report(
                        cantons_test[n_cantons], predicted_cantons_test, zero_division=0
                    )
                )
                print('Test accuracy:', np.sum(np.diag(matrix_test)) / np.sum(matrix_test))
                print()
                print(matrix_test)
                print()
                
                # Extract structure indices for the misclassified DEEM 330k
                key = f'{cutoff}-{spectrum_type}-{species_pairing}-{n_cantons}'
                misclassified_deem_330k[key] = {}
                for i in range(1, n_cantons):
                    misclassified_deem_330k[key][i] = np.nonzero(predicted_cantons_deem == i)[0]

# ROC curves

In [None]:
# ROC curves
# TODO: evaluate only on the test set
# TODO: 'micro' averaging probably isn't correct. 'macro' or 'weighted' is probably better
# TODO: roc curve with sample weights
fig = plt.figure(figsize=(7.0, 7.0))
axs = {}
axs[3.5] = fig.add_subplot(2, 2, 1)
axs[6.0] = fig.add_subplot(2, 2, 2)
legend_lines = {cutoff: [] for cutoff in cutoffs}

plot_parameters = dict(power={'linestyle': '-'}, radial={'linestyle': '--'})

for cutoff in cutoffs:
    model_dir = f'../../Processed_Data/Models/{cutoff}'
    axs[cutoff].set_aspect('equal')
    axs_inset = axs[cutoff].inset_axes([0.40, 0.10, 0.50, 0.50])
    roc_line_pairs = []
    max_auc = 0.0
    max_auc_idx = 0
    plot_idx = 0
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        plot_parameters['linestyle'] = ''
        for species_pairing in group_names[spectrum_type]:
            data_dir = f'LSVC/2-Class/{spectrum_name}/{species_pairing}'
            
            dfs_iza = np.loadtxt(f'{iza_dir}/{cutoff}/{data_dir}/svc_structure_dfs.dat')
            dfs_deem = np.loadtxt(f'{deem_dir}/{cutoff}/{data_dir}/svc_structure_dfs.dat')
            
            fpr, tpr, thresholds = roc_curve(
                cantons_test[2], 
                np.concatenate((dfs_iza[idxs_iza_test], dfs_deem)), 
                pos_label=2
            )
            auc = roc_auc_score(
                cantons_test[2],
                np.concatenate((dfs_iza[idxs_iza_test], dfs_deem)),
                average='micro'
            )
            if auc > max_auc:
                max_auc = auc
                max_auc_idx = plot_idx
                
            plot_idx += 1
            #print(f'{cutoff}-{spectrum_name}-{species_pairing}: {auc}')
            print(f'{cutoff}-{spectrum_name}-{species_pairing}: {-np.log10(1.0-auc)}')
            
            line_pair = []
            for ax in (axs[cutoff], axs_inset):
                line = ax.plot(fpr, tpr, **plot_parameters[spectrum_type], 
                        label=f'{spectrum_name} {species_pairing}')
                line_pair.extend(line)
            roc_line_pairs.append(line_pair)
    
    for ldx, line_pair in enumerate(roc_line_pairs):
        legend_lines[cutoff].append(Line2D(
            [0], [0], 
            label=line_pair[0].get_label(), 
            color=line_pair[0].get_color(),
            linestyle=line_pair[0].get_linestyle(),
            alpha=0.50
        ))

        if ldx == max_auc_idx:
            color_transparency = ''
            line_scale = 2
        else:
            color_transparency = '40' # Hex for alpha = 0.25
            line_scale = 1
            
        for line in line_pair:
            line.set_color(line.get_color() + color_transparency)
            line.set_linewidth(line.get_linewidth() * line_scale)
                            
    zoom_x = ([-0.02, 0.20])
    zoom_y = ([0.80, 1.02])
    axs_inset.set_xlim(zoom_x)
    axs_inset.set_ylim(zoom_y)    
                
    for ax in (axs[cutoff], axs_inset):        
        ax.axvline(0.0, color=color_list[11], linestyle=':')
        ax.axhline(1.0, color=color_list[11], linestyle=':')
        
    axs[cutoff].indicate_inset_zoom(axs_inset, label=None)

    #axs[cutoff].set_xlim([-0.01, 0.2])
    axs[cutoff].set_xlabel('False Positive Rate')
    
    axs[cutoff].text(0.95, 0.90, f'{cutoff} ' + u'\u00c5',
                     horizontalalignment='right', verticalalignment='top',
                     transform=axs[cutoff].transAxes)
    
    # TODO: save AUC so we can put them in a table

wspace=0.1
axs[3.5].set_ylabel('True Positive Rate')

axs[3.5].legend(handles=legend_lines[3.5], bbox_to_anchor=(0.0, 1.0, 2.0+wspace, 0.5), 
                loc='lower left', bbox_transform=axs[3.5].transAxes,
                ncol=3, mode='expand', borderaxespad=0.0)

axs[6.0].tick_params(axis='y', which='both', labelleft=False)

fig.subplots_adjust(wspace=wspace)

fig.savefig('../../Results/roc_svc.pdf', bbox_inches='tight')

plt.show()

In [None]:
# "Best" ROC curve with confusion matrix inset
# TODO: evaluate only on the test set
# TODO: 'micro' averaging probably isn't correct. 'macro' or 'weighted' is probably better
for cutoff in cutoffs:
    fig = plt.figure(figsize=(3.5, 3.5))
    axs = fig.add_subplot(1, 1, 1)
    
    model_dir = f'../../Processed_Data/Models/6.0'
    data_dir = f'LSVC/2-Class/Power/OO+OSi+SiSi'

    dfs_iza = np.loadtxt(f'{iza_dir}/{cutoff}/{data_dir}/svc_structure_dfs.dat')
    dfs_deem = np.loadtxt(f'{all_deem_dir}/{cutoff}/{data_dir}/svc_structure_dfs.dat')

    dfs_test = np.concatenate((
        dfs_iza[idxs_iza_test], 
        np.delete(dfs_deem, idxs_deem_10k[idxs_deem_train])
    ))
    
    predicted_cantons_iza = \
        np.loadtxt(f'{iza_dir}/{cutoff}/{data_dir}/svc_structure_cantons.dat', dtype=int)

    predicted_cantons_deem = \
        np.loadtxt(f'{all_deem_dir}/{cutoff}/{data_dir}/svc_structure_cantons.dat', dtype=int)
        
    predicted_cantons_test = np.concatenate((
        predicted_cantons_iza[idxs_iza_test], 
        np.delete(predicted_cantons_deem, idxs_deem_10k[idxs_deem_train])
    ))
    
    # ROC curve
    fpr, tpr, thresholds = roc_curve(
        cantons_test[2], 
        dfs_test,
        pos_label=2
    )

    auc = roc_auc_score(
        cantons_test[2],
        dfs_test,
        average='micro' # TODO: change
    )
    
    axs.plot(fpr, tpr, color=color_list[4])

    axs.axvline(0.0, color=color_list[11], linestyle=':')
    axs.axhline(1.0, color=color_list[11], linestyle=':')

    #axs.set_xlim([-0.01, 0.20])
    #axs.set_ylim([0.80, 1.0])
    axs.set_xlabel('False Positive Rate')
    axs.set_ylabel('True Positive Rate')

    # Confusion matrix inset
    vmin = 0.0
    vmax = 1.0

    matrix_test = confusion_matrix(
        cantons_test[2], 
        predicted_cantons_test
    )
    
    # See https://en.wikipedia.org/wiki/Receiver_operating_characteristic
    # sklearn confusion matrix format: 
    #     https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
    #     i.e, C[0, 0] = TN, C[0, 1] = FP, C[1, 0] = FN, C[1, 1] = TP
    # Need weighting?
    tpr_matrix = matrix_test[1, 1] / np.sum(matrix_test[1, :])
    fpr_matrix = matrix_test[0, 1] / np.sum(matrix_test[0, :])
        
    axs.scatter(fpr_matrix, tpr_matrix, marker='o', s=20, zorder=3, color=color_list[3])

    matrix_ref = np.zeros((2, 2), dtype=int)
    for i in range(0, 2):
        matrix_ref[i, i] = np.count_nonzero(cantons_test[2] == (i + 1))

    matrix_norm = matrix_test / np.diagonal(matrix_ref)[:, np.newaxis]

    axs_in = axs.inset_axes([0.25, 0.10, 0.40, 0.40])

    cmap = ListedColormap([color_list[1] + '40', color_list[1], color_list[2], color_list[2] + '40'])
    cmap_norm = BoundaryNorm(np.arange(0, cmap.N + 1), cmap.N)
    matshow = axs_in.imshow(np.arange(0, 4).reshape(2, 2), cmap=cmap, norm=cmap_norm)

    text_size = 'medium'

    for i in range(0, 2):
        for j in range(0, 2):
            if i != j:
                text_color = 'w'
            else:
                text_color = 'k'

            if len(str(matrix_test[i, j])) > 5:

                # This isn't a very 'robust' way of doing this,
                # but since we only go up to 6 digits, this should work
                # (we do this shortening only if the number has 6 digits)
                box_number_str = str(round(matrix_test[i, j], -3))
                box_number_str = box_number_str[0:-3] + 'K'
            else:
                box_number_str = f'{matrix_test[i, j]:d}'

            axs_in.text(
                j, i, box_number_str,
                horizontalalignment='center', verticalalignment='center',
                color=text_color, fontsize=text_size
            )

    axs_in.set_xticks(np.arange(0, 2))
    axs_in.set_yticks(np.arange(0, 2))

    axs_in.set_xticklabels([tl + r'$_\mathrm{P}$' for tl in ticklabels[2]])
    axs_in.tick_params(
        axis='x', which='both',
        bottom=False, top=False,
        labelbottom=False, labeltop=True
    )

    axs_in.set_yticklabels([tl + r'$_\mathrm{T}$' for tl in ticklabels[2]], rotation=90, verticalalignment='center')
    axs_in.tick_params(
        axis='y', which='both',
        left=False, right=False,
        labelleft=True, labelright=False
    )
    
    cax = axs.inset_axes([0.675, 0.10, 0.05, 0.40])
    
    cmap_ticklabels = ['TN', 'FP', 'FN', 'TP']
    cmap_ticks = np.arange(0, cmap.N) + 0.5
    cb = fig.colorbar(matshow, cax=cax, ticks=cmap_ticks)
    cb.set_label('Classification\nType')
    cax.tick_params(axis='y', which='both', right=False)
    cax.set_yticklabels(cmap_ticklabels)
    
    xya = (-0.25, 1.25)
    xyb = (1.85, 1.25)    
    adjust = 0.1
    axs.annotate(
        '', 
        xy=(fpr_matrix, tpr_matrix), 
        xytext=xya,
        xycoords='data',
        textcoords=axs_in.transAxes,
        arrowprops=dict(arrowstyle='-', shrinkA=adjust),
    )
    
    axs.annotate(
        '', 
        xy=(fpr_matrix, tpr_matrix), 
        xytext=xyb,
        xycoords='data',
        textcoords=axs_in.transAxes,
        arrowprops=dict(arrowstyle='-', shrinkA=adjust),
    )
    
    box_x, box_y = np.array(
        [[xya[0], xya[0], xyb[0], xyb[0]],
         [xya[1], -0.10, -0.10, xyb[1]]]
    )
    box = Line2D(box_x, box_y, color='k', transform=axs_in.transAxes)
    
    axs.add_line(box)
    
    # Add line to the ROC curve
    # Skip the first point (zero) b/c there is another
    # nearby point that is almost exactly zero
    tpr_gradient = tpr[2::2]
    fpr_gradient = fpr[2::2]
    roc_gradient = np.gradient(tpr_gradient, fpr_gradient)
        
    roc_idx = np.argmin(np.abs(tpr_gradient - tpr_matrix))
    print(roc_gradient[roc_idx])
    
    arrow_start = np.array([fpr_matrix, tpr_matrix])
    arrow_length = 0.10
    arrow_end = arrow_start + np.array([1.0, 1.0 * roc_gradient[roc_idx]])
    arrow_start_ax = axs.transLimits.transform(arrow_start)
    arrow_end_ax = axs.transLimits.transform(arrow_end)
    arrow_vector_ax = arrow_end_ax - arrow_start_ax
    arrow_vector_ax /= np.linalg.norm(arrow_vector_ax)
    arrow_vector_ax *= arrow_length
        
    axs.annotate(
        '',
        xy=(fpr_matrix, tpr_matrix), 
        xytext=arrow_start_ax + arrow_vector_ax,
        xycoords='data',
        textcoords=axs.transAxes,
        arrowprops=dict(arrowstyle='<-', linestyle='--', color=color_list[3])
    )
    
    axs.annotate(
        '',
        xy=(fpr_matrix, tpr_matrix), 
        xytext=arrow_start_ax - arrow_vector_ax,
        xycoords='data',
        textcoords=axs.transAxes,
        arrowprops=dict(arrowstyle='<-', linestyle='--', color=color_list[3])
    )
    
    tprfpr = np.column_stack((fpr, tpr))
    d_tprfpr = np.linalg.norm(tprfpr - np.array([0.0, 1.0]), axis=1)
    pt = np.argmin(d_tprfpr)
    
    fig.savefig(f'../../Results/{cutoff}/roc_svc_confusion_matrix_{cutoff}.pdf', bbox_inches='tight')
    
    plt.show()

In [None]:
# Should be very close
boundary_idx = np.argmin(np.abs(thresholds))
print(np.abs(thresholds[boundary_idx]))
print(fpr[boundary_idx], tpr[boundary_idx])
print(fpr_matrix, tpr_matrix)

# Precision-Recall curves

In [None]:
# Precision-recall curves
fig = plt.figure(figsize=(7.0, 3.5))
axs = {}
axs[3.5] = fig.add_subplot(2, 2, 1)
axs[6.0] = fig.add_subplot(2, 2, 2)

plot_parameters = dict(power={'linestyle': '-'}, radial={'linestyle': '--'})

for cutoff in cutoffs:
    model_dir = f'../../Processed_Data/Models/{cutoff}'
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        plot_parameters['linestyle'] = ''
        for species_pairing in group_names[spectrum_type]:
            data_dir = f'LSVC/2-Class/{spectrum_name}/{species_pairing}'
            
            dfs_iza = \
                    np.loadtxt(f'{iza_dir}/{cutoff}/{data_dir}/svc_structure_dfs.dat')
            dfs_deem = \
                    np.loadtxt(f'{all_deem_dir}/{cutoff}/{data_dir}/svc_structure_dfs.dat')
            
            p, r, thresholds = precision_recall_curve(cantons_test[2], 
                                                      np.concatenate((dfs_iza[idxs_iza_test], dfs_deem)), 
                                                      pos_label=2)
            ap = average_precision_score(cantons_test[2],
                                         np.concatenate((dfs_iza[idxs_iza_test], dfs_deem)),
                                         average='micro', pos_label=2)
            print(f'{cutoff}-{spectrum_name}-{species_pairing}: {ap}')
            
            axs[cutoff].plot(r, p, **plot_parameters[spectrum_type], 
                             label=f'{spectrum_name} {species_pairing}')

    axs[cutoff].set_xlabel('Recall')

    axs[cutoff].text(0.95, 0.05, f'{cutoff} ' + u'\u00c5',
                     horizontalalignment='right', verticalalignment='bottom',
                     transform=axs[cutoff].transAxes)
wspace=0.3
axs[3.5].set_ylabel('Precision')
axs[3.5].set_ylim(axs[6.0].get_ylim())

# TODO: custom legend with power and radial spectrum headers
axs[3.5].legend(bbox_to_anchor=(0.0, 1.0, 2.0+wspace, 0.5), 
                loc='lower left', bbox_transform=axs[3.5].transAxes,
                ncol=3, mode='expand', borderaxespad=0.0)

axs[6.0].tick_params(axis='y', which='both', labelleft=False)

fig.subplots_adjust(wspace=wspace)

fig.savefig('../../Results/precision_recall_svc.pdf', bbox_inches='tight')

plt.show()

# Analysis of misclassifications

In [None]:
# Dictionary for holding the misclassified DEEM structures for each knock-out model
misclassified_deem_330k = {}
misclassified_deem_330k_energies = {}
misclassified_deem_330k_volumes = {}

In [None]:
# Population energies and volumes
deem_330k_energies = np.loadtxt('../../Processed_Data/DEEM_330k/structure_energies.dat')
deem_330k_volumes = np.loadtxt('../../Processed_Data/DEEM_330k/structure_volumes.dat')
n_deem_330k = len(deem_330k_energies)

energy_shift = np.mean(deem_330k_energies)

# Load IZA energies and volumes for comparison
iza_energies = np.loadtxt('../../Processed_Data/IZA_230/structure_energies.dat')
iza_volumes = np.loadtxt('../../Processed_Data/IZA_230/structure_volumes.dat')

In [None]:
misclassification_counts = np.zeros(n_deem_330k)

In [None]:
for cutoff in cutoffs:
    for spectrum_type in ('power', 'radial'):
        for species_pairing in group_names[spectrum_type]:
            for n_cantons in (2, 4):
                key = f'{cutoff}-{spectrum_type}-{species_pairing}-{n_cantons}'
                
                misclassified_deem_330k_energies[key] = {}
                misclassified_deem_330k_volumes[key] = {}
                
                for i in range(1, n_cantons):
                    structure_idxs = misclassified_deem_330k[key][i]
                    misclassified_deem_330k_energies[key][i] = deem_330k_energies[structure_idxs]
                    misclassified_deem_330k_volumes[key][i] = deem_330k_volumes[structure_idxs]
                    misclassification_counts[structure_idxs] += 1
                
#                 print(f'===== {key} =====')
#                 print()
#                 for i in range(1, n_cantons):
#                     print(i, misclassified_deem_330k[key][i])
#                 print()

In [None]:
print(misclassification_counts)
print(np.count_nonzero(misclassification_counts < 1))
print(np.argsort(misclassification_counts))
print(np.sort(misclassification_counts))
print(np.count_nonzero(misclassification_counts == 40))
print(np.nonzero(misclassification_counts == 40))

In [None]:
print(np.mean(misclassification_counts))
print(np.std(misclassification_counts))

In [None]:
fig = plt.figure(figsize=(3.5, 3.5))
ax = fig.add_subplot(1, 1, 1)

ax.hist(misclassification_counts, bins=41, density=False, log=True)
ax.set_xlabel('No. Misclassifications')
ax.set_ylabel('Frequency')

#fig.savefig('../../Results/misclassification_number_histogram.pdf', bbox_inches='tight')

plt.show()

# Histograms of decision function values

In [None]:
for cutoff in cutoffs:
    fig = plt.figure(figsize=(7.0, 3.5), constrained_layout=True)
    gs = fig.add_gridspec(nrows=2, ncols=4, width_ratios=(1, 1, 1, 1), height_ratios=(1, 1))

    axs_2 = fig.add_subplot(gs[0:2, 0:2])
    axs_4 = [fig.add_subplot(gs[i // 2, 2 + (i % 2)]) for i in range(0, 4)]

    dfs_2 = np.loadtxt(f'../../Processed_Data/IZA_226/Data/{cutoff}/LSVC/2-Class/Power/OO+OSi+SiSi/svc_structure_dfs.dat')
    dfs_4 = np.loadtxt(f'../../Processed_Data/IZA_226/Data/{cutoff}/LSVC/4-Class/Power/OO+OSi+SiSi/svc_structure_dfs.dat')
    histogram_parameters = dict(bins=50, density=True, log=True, color=color_list[1], alpha=0.5)

    df_4_min = np.amin(dfs_4)
    df_4_max = np.amax(dfs_4)

    # 2-Class histogram
    axs_2.hist(dfs_2, **histogram_parameters)
    axs_2.set_xlabel('Decision Function')
    axs_2.set_ylabel('Prob. Density')
    axs_2.axvline(0, color=color_list[0], linestyle='-')

    axs_2.text(0.05, 0.95, class_names[2][0], verticalalignment='top', horizontalalignment='left',
               transform=axs_2.transAxes)
    axs_2.text(0.95, 0.95, class_names[2][1], verticalalignment='top', horizontalalignment='right',
               transform=axs_2.transAxes)

    # 4-Class histograms
    for adx, ax in enumerate(axs_4):
        ax.hist(dfs_4[:, adx], **histogram_parameters)
        ax.set_xlabel('Decision Function')
        ax.set_ylabel('Prob. Density')
        ax.axvline(0, color=color_list[0], linestyle='-')
        ax.set_xlim([df_4_min, df_4_max])

        ax.text(0.05, 0.95, 'Rest', verticalalignment='top', horizontalalignment='left',
                transform=ax.transAxes)
        ax.text(0.95, 0.95, class_names[4][adx], verticalalignment='top', horizontalalignment='right',
                transform=ax.transAxes)

    fig.suptitle(f'IZA Classifications, {cutoff} ' + u'\u00c5')

    fig.savefig(f'../../Results/{cutoff}/df_iza_histogram_{cutoff}_power_OO+OSi+SiSi_lsvc.pdf', bbox_inches='tight')
    plt.show()

In [None]:
for cutoff in cutoffs:
    fig = plt.figure(figsize=(7.0, 3.5), constrained_layout=True)
    gs = fig.add_gridspec(nrows=2, ncols=4, width_ratios=(1, 1, 1, 1), height_ratios=(1, 1))

    axs_2 = fig.add_subplot(gs[0:2, 0:2])
    axs_4 = [fig.add_subplot(gs[i // 2, 2 + (i % 2)]) for i in range(0, 4)]

    dfs_2 = np.loadtxt(f'../../Processed_Data/DEEM_330k/Data/{cutoff}/LSVC/2-Class/Power/OO+OSi+SiSi/svc_structure_dfs.dat')
    dfs_4 = np.loadtxt(f'../../Processed_Data/DEEM_330k/Data/{cutoff}/LSVC/4-Class/Power/OO+OSi+SiSi/svc_structure_dfs.dat')
    histogram_parameters = dict(bins=50, density=True, log=True, color=color_list[2], alpha=0.5)

    df_4_min = np.amin(dfs_4)
    df_4_max = np.amax(dfs_4)

    # 2-Class histogram
    axs_2.hist(dfs_2, **histogram_parameters)
    axs_2.set_xlabel('Decision Function')
    axs_2.set_ylabel('Prob. Density')
    axs_2.axvline(0, color=color_list[0], linestyle='-')

    axs_2.text(0.05, 0.95, class_names[2][0], verticalalignment='top', horizontalalignment='left',
               transform=axs_2.transAxes)
    axs_2.text(0.95, 0.95, class_names[2][1], verticalalignment='top', horizontalalignment='right',
               transform=axs_2.transAxes)

    # 4-Class histograms
    for adx, ax in enumerate(axs_4):
        ax.hist(dfs_4[:, adx], **histogram_parameters)
        ax.set_xlabel('Decision Function')
        ax.set_ylabel('Prob. Density')
        ax.axvline(0, color=color_list[0], linestyle='-')
        ax.set_xlim([df_4_min, df_4_max])

        ax.text(0.05, 0.95, 'Rest', verticalalignment='top', horizontalalignment='left',
                transform=ax.transAxes)
        ax.text(0.95, 0.95, class_names[4][adx], verticalalignment='top', horizontalalignment='right',
                transform=ax.transAxes)

    fig.suptitle(f'DEEM 330k Classifications, {cutoff} ' + u'\u00c5')

    fig.savefig(f'../../Results/{cutoff}/df_deem_histogram_{cutoff}_power_OO+OSi+SiSi_lsvc.pdf', bbox_inches='tight')
    plt.show()

In [None]:
# Probability distribution from decision functions
for cutoff in cutoffs:
    fig = plt.figure(figsize=(3.5, 3.5), constrained_layout=True)
    axs = fig.add_subplot(1, 1, 1)

    dfs_deem = np.loadtxt(f'../../Processed_Data/DEEM_330k/Data/{cutoff}/LSVC/2-Class/Power/OO+OSi+SiSi/svc_structure_dfs.dat')
    dfs_iza = np.loadtxt(f'../../Processed_Data/IZA_226/Data/{cutoff}/LSVC/2-Class/Power/OO+OSi+SiSi/svc_structure_dfs.dat')
    histogram_range = (np.amin(np.concatenate((dfs_iza, dfs_deem))), 
                       np.amax(np.concatenate((dfs_iza, dfs_deem))))
    
    histogram_parameters = dict(bins=100, density=True, log=False, range=histogram_range)
    
    # 2-Class histogram
    # We modify the FN and FP bars to be fully opaque. For Deem we look at the right bin edges
    # to find the misclassified; for IZA we look at the left. This leaves the bin that contains
    # zero as a "blend" of the 50% transparency IZA/DEEM. This is easier than making sure
    # a bin edge is at zero
    hist_deem, bin_edges, patches_deem = axs.hist(dfs_deem, **histogram_parameters, color=color_list[2] + '40', label='DEEM')
    for bin_edge, patch in zip(bin_edges[1:], patches_deem):
        if bin_edge < 0.0:
            patch.set_facecolor(color_list[2])
            patch.zorder = 2

    hist_iza, bin_edges, patches_iza = axs.hist(dfs_iza, **histogram_parameters, color=color_list[1] + '40', label='IZA')
    for bin_edge, patch in zip(bin_edges[0:-1], patches_iza):
        if bin_edge >= 0.0:
            patch.set_facecolor(color_list[1])
            patch.zorder = 2
    
    axs.set_xlabel('Decision Function')
    axs.set_ylabel('Prob. Density')
    axs.axvline(0, color=color_list[0], linestyle='-')

    half_length = 0.1
    x_pos = axs.transLimits.transform((0.0, 0.0))[0]
    y_pos = 0.9
    axs.annotate(
        '',
        xy=(x_pos - half_length, y_pos),
        xytext=(x_pos + half_length, y_pos),
        xycoords=axs.transAxes,
        textcoords=axs.transAxes,
        arrowprops=dict(arrowstyle='<->', linestyle='--', color=color_list[3])
    )
    
    axs.text(
        x_pos - half_length / 2, 0.98,
        'IZA (N)',
        horizontalalignment='right', verticalalignment='top',
        transform=axs.transAxes
    )
    axs.text(
        x_pos + half_length / 2, 0.98,
        'Deem (P)',
        horizontalalignment='left', verticalalignment='top',
        transform=axs.transAxes
    )
    
    patch_colors = [color_list[1] + '40', color_list[1], color_list[2], color_list[2] + '40']
    patch_labels = ['TN', 'FP', 'FN', 'TP']
    legend_patches = [Patch(facecolor=pc, label=label) for pc, label in zip(patch_colors, patch_labels)]
        
    axs.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1.0, 1.0))
        
    x_pos_data, y_pos_data = axs.transLimits.inverted().transform((x_pos, y_pos))
    y_pos_data, y_pos_data = axs.transScale.inverted().transform((x_pos_data, y_pos_data))
    axs.scatter(x_pos_data, y_pos_data, s=10, color=color_list[3])
    
    fig.savefig(f'../../Results/{cutoff}/df_histogram_{cutoff}_power_OO+OSi+SiSi_lsvc.pdf', bbox_inches='tight')
    plt.show()

In [None]:
fig = plt.figure(figsize=(7.0, 3.5))
axs_vol = fig.add_subplot(1, 2, 1)
axs_energy = fig.add_subplot(1, 2, 2)

histogram_parameters = dict(bins=50, density=True, alpha=0.5)
energy_range = (np.amin(deem_330k_energies-energy_shift), np.amax(deem_330k_energies-energy_shift))
volume_range = (np.amin(deem_330k_volumes), np.amax(deem_330k_volumes))

plot_key = '6.0-power-OO+OSi+SiSi-4'

axs_vol.hist(deem_330k_volumes, range=volume_range, 
             **histogram_parameters, label='Population')
axs_energy.hist(deem_330k_energies-energy_shift, range=energy_range, 
                **histogram_parameters, label='Population')

for class_key in misclassified_deem_330k[plot_key].keys():
    axs_vol.hist(misclassified_deem_330k_volumes[plot_key][class_key], range=volume_range,
                 **histogram_parameters, label=f'IZA{class_key}*')
    axs_energy.hist(misclassified_deem_330k_energies[plot_key][class_key]-energy_shift, range=energy_range, 
                    **histogram_parameters, label=f'IZA{class_key}*')
    
axs_energy.legend()

axs_vol.set_title('Volume per Si')
axs_vol.set_xlabel(u'Volume (\u00c5' + r'$^3$/Si atom)')
axs_vol.set_ylabel('Prob. Density')
#axs_vol.set_xlim([np.amin(deem_330k_volumes), 80])

axs_energy.set_title('Energy per Si')
axs_energy.set_xlabel('Energy (kJ/mol Si)')
axs_energy.set_ylabel('Prob. Density')


fig.suptitle('DEEM 330k')
fig.subplots_adjust(wspace=0.3)

#fig.savefig('../../Results/deem_property_histogram.pdf', bbox_inches='tight')
    
plt.show()

In [None]:
fig = plt.figure(figsize=(7.0, 3.5))
axs_vol = fig.add_subplot(1, 2, 1)
axs_energy = fig.add_subplot(1, 2, 2)

histogram_parameters = dict(bins=50, density=True, alpha=0.5)
energy_range = (np.amin(deem_330k_energies-energy_shift), np.amax(deem_330k_energies-energy_shift))
volume_range = (np.amin(deem_330k_volumes), np.amax(deem_330k_volumes))

axs_vol.hist(iza_volumes, range=volume_range, **histogram_parameters, label='Population')
axs_energy.hist(iza_energies-energy_shift, range=energy_range, **histogram_parameters, label='Population')

# Skip RWY (canton 4)
for i in np.unique(iza_cantons)[0:-1]:
    canton_idxs = np.nonzero(iza_cantons == i)[0]
    axs_vol.hist(iza_volumes[canton_idxs], range=volume_range,
                 **histogram_parameters, label=f'IZA{class_key}')
    axs_energy.hist(iza_energies[canton_idxs]-energy_shift, range=energy_range,
                    **histogram_parameters, label=f'IZA{class_key}')
    
axs_energy.legend()

axs_vol.set_title('Volume per Si')
axs_vol.set_xlabel(u'Volume (\u00c5' + r'$^3$/Si atom)')
axs_vol.set_ylabel('Prob. Density')
#axs_vol.set_xlim([np.amin(deem_330k_volumes), 80])

axs_energy.set_title('Energy per Si')
axs_energy.set_xlabel('Energy (kJ/mol Si)')
axs_energy.set_ylabel('Prob. Density')

fig.suptitle('IZA')
fig.subplots_adjust(wspace=0.3)

#fig.savefig('../../Results/iza_property_histogram.pdf', bbox_inches='tight')
    
plt.show()

# Confusion matrices

## IZA + Deem

In [None]:
fig_2class = plt.figure(figsize=(12, 2.7), constrained_layout=True)
fig_4class = plt.figure(figsize=(12, 2.7), constrained_layout=True)

n_rows = 2
n_cols = 10

vmin = 0.0
vmax = 1.0

width_ratios = np.ones(n_cols + 1)
width_ratios[-1] = 0.25

for fig, n_cantons in zip((fig_2class, fig_4class), (2, 4)):
    gs = fig.add_gridspec(
        nrows=n_rows, ncols=n_cols+1, 
        width_ratios=width_ratios,
        hspace=0.05, wspace=0.05
    )
    axes = np.array([[fig.add_subplot(gs[i, j]) for j in range(0, n_cols)] for i in range(0, n_rows)])
    row_idx = 0
    for cutoff in cutoffs:
        col_idx = 0
        for spectrum_type in ('power', 'radial'):
            spectrum_name = spectrum_type.capitalize()

            for species_pairing in group_names[spectrum_type]:

                # Prepare outputs
                data_dir = f'LSVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'

                predicted_cantons_iza = np.loadtxt(
                    f'{iza_dir}/{cutoff}/{data_dir}/svc_structure_cantons.dat', dtype=int
                )
                predicted_cantons_deem = np.loadtxt(
                    f'{deem_dir}/{cutoff}/{data_dir}/svc_structure_cantons.dat', dtype=int
                )

                predicted_cantons_train = np.concatenate((
                    predicted_cantons_iza[iza_train_idxs],
                    predicted_cantons_deem[deem_train_idxs]
                ))
                predicted_cantons_test = np.concatenate((
                    predicted_cantons_iza[iza_test_idxs],
                    predicted_cantons_deem[deem_test_idxs]
                ))

                matrix_train = confusion_matrix(cantons_train[n_cantons], predicted_cantons_train)
                matrix_test = confusion_matrix(cantons_test[n_cantons], predicted_cantons_test)

                matrix_ref = np.zeros((n_cantons, n_cantons), dtype=int)
                for i in range(0, n_cantons):
                    matrix_ref[i, i] = np.count_nonzero(cantons_test[n_cantons] == (i + 1))
                    
                matrix_norm = matrix_test / np.diagonal(matrix_ref)[:, np.newaxis]
                
                ax = axes[row_idx, col_idx]
                ax.imshow(matrix_norm, cmap='Purples', vmin=vmin, vmax=vmax)
                
                if n_cantons == 2:
                    text_size = 'medium'
                else:
                    text_size = 'xx-small'
                
                for i in range(0, n_cantons):
                    for j in range(0, n_cantons):
                        if matrix_norm[i, j] > (0.5 * vmax):
                            text_color = 'w'
                        else:
                            text_color = 'k'
                            
                        if len(str(matrix_test[i, j])) > 5:
                            
                            # This isn't a very 'robust' way of doing this,
                            # but since we only go up to 6 digits, this should work
                            # (we do this shortening only if the number has 6 digits)
                            box_number_str = str(round(matrix_test[i, j], -3))
                            box_number_str = box_number_str[0:-3] + 'K'
                        else:
                            box_number_str = f'{matrix_test[i, j]:d}'
                        
                        ax.text(
                            j, i, box_number_str,
                            horizontalalignment='center', verticalalignment='center',
                            color=text_color, fontsize=text_size
                        )
                
                ax.set_xticks(np.arange(0, n_cantons))
                ax.set_yticks(np.arange(0, n_cantons))
                
                ax.tick_params(
                    axis='both', which='both', 
                    left=False, bottom=False, right=False, top=False,
                    labelleft=False, labelbottom=False, labelright=False, labeltop=False
                )
                
                if col_idx == 0:
                    ax.set_ylabel(f'{cutoff} ' + u'\u00c5')
                    ax.set_yticklabels(ticklabels[n_cantons])
                    ax.tick_params(axis='y', which='both', labelleft=True)
                    
                if row_idx == 0:
                    ax.set_title(f'{species_pairing}', fontsize='medium')
                elif row_idx == (n_rows - 1):
                    ax.set_xticklabels(ticklabels[n_cantons], rotation=90)
                    ax.tick_params(axis='x', which='both', labelbottom=True)

                col_idx += 1
        
        row_idx += 1        
    
    cp_xy = (0.5, 1.3)
    cp_power_bar_frac = -0.02
    cp_radial_bar_frac = cp_power_bar_frac * 3
    cp_power = ConnectionPatch(
        xyA=cp_xy, coordsA=axes[0, 0].transAxes,
        xyB=cp_xy, coordsB=axes[0, 6].transAxes,
        connectionstyle=f'bar,fraction={cp_power_bar_frac}'
    )
    cp_radial = ConnectionPatch(
        xyA=cp_xy, coordsA=axes[0, 7].transAxes,
        xyB=cp_xy, coordsB=axes[0, 9].transAxes,
        connectionstyle=f'bar,fraction={cp_radial_bar_frac}'
    )
    
    n_power = len(group_names['power'])
    n_radial = len(group_names['radial'])
    
    power_label_ax = axes[0, n_power // 2]
    if n_power % 2 == 0:
        power_shift = 0.5
    else:
        power_shift = 0.0
        
    radial_label_ax = axes[0, n_radial // 2 + n_power]
    if n_radial % 2 == 0:
        radial_shift = 0.5
    else:
        radial_shift = 0.0
    
    fig.text(
        0.5 - power_shift, 1.65, 'Power Spectrum',
        horizontalalignment='center', verticalalignment='center',
        transform=power_label_ax.transAxes, fontsize='large'
    )
    
    fig.text(
        0.5 - radial_shift, 1.65, 'Radial Spectrum',
        horizontalalignment='center', verticalalignment='center',
        transform=radial_label_ax.transAxes, fontsize='large'
    )
    
    cax = fig.add_subplot(gs[:, -1])
    cb = fig.colorbar(
        ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap='Purples'), cax=cax
    )
    cb.set_label('True Class Proportion')
    
    fig.add_artist(cp_power)
    fig.add_artist(cp_radial)
    
    fig.savefig(f'../../Results/svc_confusion_matrices_{n_cantons}-class.pdf', bbox_inches='tight')
    
plt.show()

## Dummy prediction

In [None]:
fig_2class = plt.figure(figsize=(12, 2.7), constrained_layout=True)
fig_4class = plt.figure(figsize=(12, 2.7), constrained_layout=True)

n_rows = 2
n_cols = 10

vmin = 0.0
vmax = 1.0

width_ratios = np.ones(n_cols + 1)
width_ratios[-1] = 0.25

for fig, n_cantons in zip((fig_2class, fig_4class), (2, 4)):
    gs = fig.add_gridspec(
        nrows=n_rows, ncols=n_cols+1, 
        width_ratios=width_ratios,
        hspace=0.05, wspace=0.05
    )
    axes = np.array([[fig.add_subplot(gs[i, j]) for j in range(0, n_cols)] for i in range(0, n_rows)])
    row_idx = 0
    for cutoff in cutoffs:
        col_idx = 0
        for spectrum_type in ('power', 'radial'):
            spectrum_name = spectrum_type.capitalize()

            for species_pairing in group_names[spectrum_type]:

                # Prepare outputs
                data_dir = f'LSVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'

                predicted_dummy_cantons = np.loadtxt(
                    f'{deem_dir}/{cutoff}/{data_dir}/dummy_svc_structure_cantons.dat', dtype=int
                )
                predicted_dummy_train_cantons = predicted_dummy_cantons[deem_train_idxs]
                predicted_dummy_test_cantons = predicted_dummy_cantons[deem_test_idxs]

                matrix_train = confusion_matrix(
                    dummy_train_cantons[n_cantons], predicted_dummy_train_cantons
                )
                matrix_test = confusion_matrix(
                    dummy_test_cantons[n_cantons], predicted_dummy_test_cantons
                )

                matrix_ref = np.zeros((n_cantons, n_cantons), dtype=int)
                for i in range(0, n_cantons):
                    matrix_ref[i, i] = np.count_nonzero(dummy_test_cantons[n_cantons] == (i + 1))
                    
                matrix_norm = matrix_test / np.diagonal(matrix_ref)[:, np.newaxis]
                
                ax = axes[row_idx, col_idx]
                ax.imshow(matrix_norm, cmap='Purples', vmin=vmin, vmax=vmax)
                
                if n_cantons == 2:
                    text_size = 'medium'
                else:
                    text_size = 'xx-small'
                
                for i in range(0, n_cantons):
                    for j in range(0, n_cantons):
                        if matrix_norm[i, j] > (0.5 * vmax):
                            text_color = 'w'
                        else:
                            text_color = 'k'
                            
                        if len(str(matrix_test[i, j])) > 5:
                            
                            # This isn't a very 'robust' way of doing this,
                            # but since we only go up to 6 digits, this should work
                            # (we do this shortening only if the number has 6 digits)
                            box_number_str = str(round(matrix_test[i, j], -3))
                            box_number_str = box_number_str[0:-3] + 'K'
                        else:
                            box_number_str = f'{matrix_test[i, j]:d}'
                        
                        ax.text(
                            j, i, box_number_str,
                            horizontalalignment='center', verticalalignment='center',
                            color=text_color, fontsize=text_size
                        )
                
                ax.set_xticks(np.arange(0, n_cantons))
                ax.set_yticks(np.arange(0, n_cantons))
                
                ax.tick_params(
                    axis='both', which='both', 
                    left=False, bottom=False, right=False, top=False,
                    labelleft=False, labelbottom=False, labelright=False, labeltop=False
                )
                
                if col_idx == 0:
                    ax.set_ylabel(f'{cutoff} ' + u'\u00c5')
                    ax.set_yticklabels(ticklabels[n_cantons])
                    ax.tick_params(axis='y', which='both', labelleft=True)
                    
                if row_idx == 0:
                    ax.set_title(f'{species_pairing}', fontsize='medium')
                elif row_idx == (n_rows - 1):
                    ax.set_xticklabels(ticklabels[n_cantons], rotation=90)
                    ax.tick_params(axis='x', which='both', labelbottom=True)

                col_idx += 1
        
        row_idx += 1        
    
    cp_xy = (0.5, 1.3)
    cp_power_bar_frac = -0.02
    cp_radial_bar_frac = cp_power_bar_frac * 3
    cp_power = ConnectionPatch(
        xyA=cp_xy, coordsA=axes[0, 0].transAxes,
        xyB=cp_xy, coordsB=axes[0, 6].transAxes,
        connectionstyle=f'bar,fraction={cp_power_bar_frac}'
    )
    cp_radial = ConnectionPatch(
        xyA=cp_xy, coordsA=axes[0, 7].transAxes,
        xyB=cp_xy, coordsB=axes[0, 9].transAxes,
        connectionstyle=f'bar,fraction={cp_radial_bar_frac}'
    )
    
    n_power = len(group_names['power'])
    n_radial = len(group_names['radial'])
    
    power_label_ax = axes[0, n_power // 2]
    if n_power % 2 == 0:
        power_shift = 0.5
    else:
        power_shift = 0.0
        
    radial_label_ax = axes[0, n_radial // 2 + n_power]
    if n_radial % 2 == 0:
        radial_shift = 0.5
    else:
        radial_shift = 0.0
    
    fig.text(
        0.5 - power_shift, 1.65, 'Power Spectrum',
        horizontalalignment='center', verticalalignment='center',
        transform=power_label_ax.transAxes, fontsize='large'
    )
    
    fig.text(
        0.5 - radial_shift, 1.65, 'Radial Spectrum',
        horizontalalignment='center', verticalalignment='center',
        transform=radial_label_ax.transAxes, fontsize='large'
    )
    
    cax = fig.add_subplot(gs[:, -1])
    cb = fig.colorbar(
        ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap='Purples'), cax=cax
    )
    cb.set_label('True Class Proportion')
    
    fig.add_artist(cp_power)
    fig.add_artist(cp_radial)
    
    fig.savefig(f'../../Results/dummy_svc_confusion_matrices_{n_cantons}-class.pdf', bbox_inches='tight')
    
plt.show()