In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np

# Plotting
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ML
from regression import PCovR, KPCovR, SparseKPCovR
from regression import LR, KRR
from kernels import build_kernel, linear_kernel, gaussian_kernel
from kernels import center_kernel, center_kernel_fast
from kernels import center_kernel_oos, center_kernel_oos_fast
from soap import compute_soap_density, reshape_soaps
from soap import rrw_neighbors, make_tuples
from soap import extract_species_pair_groups

from sklearn.svm import SVC, LinearSVC
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LogisticRegression

# Utilities
import h5py
import json
import itertools
from tqdm.notebook import tqdm
import project_utils as utils
from tools import load_json

# Import COSMO style toolkit
import cosmoplot.colorbars as cosmocbars
import cosmoplot.utils as cosmoutils
import cosmoplot.style as cosmostyle

cosmostyle.set_style('article')
colorList = cosmostyle.color_cycle

In /home/helfrech/.config/matplotlib/stylelib/cosmo.mplstyle: 
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
In /home/helfrech/.config/matplotlib/stylelib/cosmoLarge.mplstyle: 
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
  self[key] = other[key]


In [3]:
# sys.path.append('/scratch/helfrech/Sync/GDrive/Projects/KPCovR/kernel-tutorials')
# sys.path.append('/scratch/helfrech/Sync/GDrive/Projects/KPCovR/KernelPCovR/analysis/scripts')
# from utilities.sklearn_covr.kpcovr import KernelPCovR as KPCovR2
# from utilities.sklearn_covr.pcovr import PCovR as PCovR2
# from helpers import l_regr, l_kpcovr

# Load train and test splits

In [4]:
# Load SOAP cutoffs
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)
    
cutoffs = soap_hyperparameters['interaction_cutoff']

In [5]:
# Load train and test set indices for Deem
idxs_deem_train = np.loadtxt('../Processed_Data/DEEM_10k/train.idxs', dtype=int)
idxs_deem_test = np.loadtxt('../Processed_Data/DEEM_10k/test.idxs', dtype=int)

# Total number of structures
n_deem_train = idxs_deem_train.size
n_deem_test = idxs_deem_test.size
n_deem = n_deem_train + n_deem_test

print(n_deem_train, n_deem_test)

7750 2250


In [6]:
# Make dummy DEEM cantons
cantons_deem = np.ones(n_deem, dtype=int) * 4

In [7]:
# Load IZA cantons
cantons_iza = np.loadtxt('../Raw_Data/GULP/IZA_226/cantons.txt', usecols=1, dtype=int)
RWY = np.nonzero(cantons_iza == 4)[0][0]

In [8]:
cantons_iza = np.delete(cantons_iza, RWY)
n_iza = len(cantons_iza)

In [9]:
idxs_iza_train_file = '../Processed_Data/IZA_226/train.idxs' 
idxs_iza_test_file = '../Processed_Data/IZA_226/test.idxs'

# Load IZA train and test set indices
try:
    idxs_iza_train = np.loadtxt(idxs_iza_train_file, dtype=int)
    idxs_iza_test = np.loadtxt(idxs_iza_test_file, dtype=int)
    n_iza_train = len(idxs_iza_train)
    n_iza_test = len(idxs_iza_test)
    
    print(n_iza_train, n_iza_test)

# Compute indices if they don't exist
except IOError:

    # Select IZA sample
    n_iza_train = n_iza // 2
    n_iza_test = n_iza - n_iza_train
    idxs_iza = np.arange(0, n_iza)
    np.random.shuffle(idxs_iza)

    idxs_iza_train = idxs_iza[0:n_iza_train]
    idxs_iza_test = idxs_iza[n_iza_train:]
    
    np.savetxt(idxs_iza_train_file, idxs_iza_train, fmt='%d')
    np.savetxt(idxs_iza_test_file, idxs_iza_test, fmt='%d')

112 113


In [10]:
# Build set of "master" canton labels
cantons_train = {}
cantons_test = {}

cantons_train[4] = np.concatenate((cantons_iza[idxs_iza_train], cantons_deem[idxs_deem_train]))
cantons_test[4] = np.concatenate((cantons_iza[idxs_iza_test], cantons_deem[idxs_deem_test]))

cantons_train[2] = np.concatenate((np.ones(len(idxs_iza_train), dtype=int),
                                   np.ones(len(idxs_deem_train), dtype=int) * 2))
cantons_test[2] = np.concatenate((np.ones(len(idxs_iza_test), dtype=int),
                                  np.ones(len(idxs_deem_test), dtype=int) * 2))

# Model setup

In [None]:
model_dir = '../Processed_Data/Models'

In [12]:
# Global model parameters
# TODO: or use .get_params()?
# TODO: load from optimization?
# TODO: should we make sure break_ties=True?
svc_kwargs = dict(linear=dict(penalty='l2',
                              loss='squared_hinge',
                              dual=False,
                              multi_class='ovr',
                              class_weight=None,
                              fit_intercept=True,
                              intercept_scaling=1.0,
                              tol=1.0E-3,
                              C=1.0),
                  kernel=dict(kernel='precomputed',
                              decision_function_shape='ovr',
                              class_weight=None,
                              break_ties=False,
                              tol=1.0E-3,
                              C=10.0))

pcovr_kwargs = dict(linear=dict(n_components=None, alpha=0.0, regularization=1.0E-12),
                    kernel=dict(n_components=None, alpha=0.0, regularization=1.0E-12))

In [13]:
# Slices for saving KSVC and KPCovR outputs
deem_train_slice = slice(n_iza_train, None)
deem_test_slice = slice(n_iza_test, None)
iza_train_slice = slice(0, n_iza_train)
iza_test_slice = slice(0, n_iza_test)

# Build kernels

In [14]:
# Load the kernels or compute if they don't exist
for cutoff in cutoffs:
    for kernel_type in ('linear', 'gaussian'):
        kernel_name = kernel_type.capitalize()
        
        work_dir = f'{model_dir}/{cutoff}/Kernel_Models/{kernel_name}/KSVC-KPCovR'
        
        if not os.path.exists(work_dir):
            os.makedirs(work_dir)
        
        # File to store kernels for re-use
        kernel_file = f'{work_dir}/structure_kernels.hdf5'
        kernel_parameter_file = f'{work_dir}/volumes_mae_parameters.json'

        if not os.path.exists(kernel_file):
            
            # SOAP files (atomwise, FPS'ed features)
            deem_file = f'{deem_dir}/{cutoff}/soaps.hdf5'
            iza_file = f'{iza_dir}/{cutoff}/soaps.hdf5'

            # Assemble the train and test set SOAPs from IZA and DEEM
            soaps_train, soaps_test = utils.load_soaps(deem_file, iza_file,
                                                       idxs_deem_train, idxs_deem_test,
                                                       idxs_iza_train, idxs_iza_test,
                                                       idxs_iza_delete=[RWY])

            # Compute kernels
            # This can be consolidated if doing linear KRR
            if kernel_type == 'gaussian':
                kernel_parameters = load_json(kernel_parameter_file)
                kernel_parameters.pop('sigma')
                kernel_parameters.pop('regularization')
            else:
                kernel_parameters = dict(kernel='linear', zeta=1)
            
            utils.compute_kernels(soaps_train, soaps_test, **kernel_parameters, kernel_file=kernel_file)

# Kernel models

In [11]:
deem_name = 'DEEM_10k'
iza_name = 'IZA_226onDEEM_10k'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

## KernelSVC

In [36]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    
    for kernel_type in tqdm(('linear', 'gaussian'), desc='Kernel', leave=False):
        kernel_name = kernel_type.capitalize()
        
        kernel_dir = f'{model_dir}/{cutoff}/Kernel_Models/{kernel_name}/KSVC-KPCovR'
        
        # File to store kernels for re-use
        kernel_file = f'{kernel_dir}/structure_kernels.hdf5'
        
        # Load kernels
        K_train, K_test, K_test_test = utils.load_kernels(kernel_file)

        # Center and scale kernels
        K_train, [K_test], [K_test_test] = \
            utils.preprocess_kernels(K_train, K_test=[K_test], K_test_test=[K_test_test],
                                     K_bridge=K_test)
        
        for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
            
            # Prepare outputs
            output_dir = f'Kernel_Models/{kernel_name}/KSVC-KPCovR/{n_cantons}-Class'
            
            if not os.path.exists(f'{deem_dir}/{cutoff}/{output_dir}'):
                os.makedirs(f'{deem_dir}/{cutoff}/{output_dir}')
                
            if not os.path.exists(f'{iza_dir}/{cutoff}/{output_dir}'):
                os.makedirs(f'{iza_dir}/{cutoff}/{output_dir}')
            
            svc_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
            svc_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
            
            svc_cantons_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_cantons.dat'
            svc_cantons_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_cantons.dat'
            
            parameter_dir = f'{kernel_dir}/{n_cantons}-Class'
            svc_parameter_file = f'{parameter_dir}/svc_parameters.json'

            # Run KSVC
            svc_parameters = svc_kwargs['kernel'].copy() #load_json(svc_parameter_file) ###
            df_train, df_test, predicted_cantons_train, predicted_cantons_test = \
                utils.do_svc(K_train, K_test, cantons_train[n_cantons], cantons_test[n_cantons], 
                             svc_type='kernel', **svc_parameters,
                             outputs=['decision_functions', 'predictions'])
            
            # Save IZA and DEEM KSVC decision functions
            utils.split_and_save(df_train, df_test,
                           idxs_deem_train, idxs_deem_test,
                           deem_train_slice, deem_test_slice,
                           output=svc_df_deem_file, output_format='%f',
                           hdf5_attrs=None)                           
            
            utils.split_and_save(df_train, df_test,
                                 idxs_iza_train, idxs_iza_test,
                                 iza_train_slice, iza_test_slice,
                                 output=svc_df_iza_file, output_format='%f',
                                 hdf5_attrs=None)
            
            # Save IZA and DEEM KSVC canton predictions
            utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                 idxs_deem_train, idxs_deem_test,
                                 deem_train_slice, deem_test_slice,
                                 output=svc_cantons_deem_file, output_format='%f',
                                 hdf5_attrs=None)
            
            utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                 idxs_iza_train, idxs_iza_test,
                                 iza_train_slice, iza_test_slice,
                                 output=svc_cantons_iza_file, output_format='%f',
                                 hdf5_attrs=None)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Kernel', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.82      0.20      0.33       113
           2       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.89      0.60      0.65      2363
weighted avg       0.95      0.96      0.95      2363

[[  23   90]
 [   5 2245]]
              precision    recall  f1-score   support

           1       0.25      0.14      0.18        14
           2       0.71      0.07      0.13        68
           3       0.25      0.03      0.06        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.54      0.31      0.34      2363
weighted avg       0.94      0.96      0.94      2363

[[   2    1    0   11]
 [   5    5    3   55]
 [   0    1    1   29]
 [   1    0    0 2249]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.90      0.31      0.46       113
           2       0.97      1.00      0.98      2250

    accuracy                           0.97      2363
   macro avg       0.93      0.65      0.72      2363
weighted avg       0.96      0.97      0.96      2363

[[  35   78]
 [   4 2246]]
              precision    recall  f1-score   support

           1       0.38      0.21      0.27        14
           2       0.77      0.15      0.25        68
           3       0.27      0.10      0.14        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.60      0.36      0.41      2363
weighted avg       0.95      0.96      0.94      2363

[[   3    2    0    9]
 [   4   10    6   48]
 [   0    1    3   27]
 [   1    0    2 2247]]


HBox(children=(FloatProgress(value=0.0, description='Kernel', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.93      0.70      0.80       113
           2       0.99      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.96      0.85      0.89      2363
weighted avg       0.98      0.98      0.98      2363

[[  79   34]
 [   6 2244]]
              precision    recall  f1-score   support

           1       0.36      0.36      0.36        14
           2       0.72      0.50      0.59        68
           3       0.31      0.13      0.18        31
           4       0.98      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.59      0.50      0.53      2363
weighted avg       0.96      0.97      0.96      2363

[[   5    4    0    5]
 [   6   34    8   20]
 [   0    8    4   19]
 [   3    1    1 2245]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.93      0.68      0.79       113
           2       0.98      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.96      0.84      0.89      2363
weighted avg       0.98      0.98      0.98      2363

[[  77   36]
 [   6 2244]]
              precision    recall  f1-score   support

           1       0.36      0.36      0.36        14
           2       0.74      0.57      0.64        68
           3       0.45      0.29      0.35        31
           4       0.99      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.63      0.55      0.59      2363
weighted avg       0.97      0.97      0.97      2363

[[   5    5    0    4]
 [   7   39    9   13]
 [   0    8    9   14]
 [   2    1    2 2245]]



## KRR check

In [46]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    
    for kernel_type in tqdm(('linear', 'gaussian'), desc='Kernel', leave=False):
        kernel_name = kernel_type.capitalize()
        
        kernel_dir = f'{model_dir}/{cutoff}/Kernel_Models/{kernel_name}/KSVC-KPCovR'
        
        # Load kernels
        kernel_file = f'{kernel_dir}/structure_kernels.hdf5'
        K_train, K_test, K_test_test = utils.load_kernels(kernel_file)

        # Center and scale kernels
        K_train, [K_test], [K_test_test] = \
            utils.preprocess_kernels(K_train, K_test=[K_test], K_test_test=[K_test_test], K_bridge=K_test)
        
        for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
            
            # Load decision functions
            input_dir = f'Kernel_Models/{kernel_name}/KSVC-KPCovR/{n_cantons}-Class'
            svc_df_deem_file = f'{deem_dir}/{cutoff}/{input_dir}/svc_structure_dfs.dat'
            svc_df_iza_file = f'{iza_dir}/{cutoff}/{input_dir}/svc_structure_dfs.dat'
            
            df_train, df_test = utils.load_data(svc_df_deem_file, svc_df_iza_file,
                                                idxs_deem_train, idxs_deem_test,
                                                idxs_iza_train, idxs_iza_test)
            
            # Center and scale the decision functions
            df_train, df_test, df_center, df_scale = \
                utils.preprocess_data(df_train, df_test)

            # Check that KRR can reproduce the decision functions
            utils.regression_check(K_train, K_test, df_train, df_test, regression_type='kernel')

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Kernel', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.402407674818527e-07
2.563285939083567e-07
[0.23737999 0.19256604 0.205351   0.15324587]
[0.25570135 0.20392276 0.21818637 0.21883868]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

9.53750867978601e-09
1.7538193820153714e-06
[0.00453011 0.0045754  0.00583132 0.00563494]
[1.2519002  1.1003435  1.3837016  2.18657991]


HBox(children=(FloatProgress(value=0.0, description='Kernel', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.6016311846371444e-07
1.6826738304952807e-07
[0.18725306 0.18849139 0.18977415 0.18832027]
[0.20814343 0.20774303 0.20894607 0.26139914]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

8.976657572570218e-13
3.5131696161718493e-07
[5.57619892e-11 1.30316961e-10 1.13797759e-10 5.75528065e-11]
[0.2379738  0.26253117 0.27975249 0.22307988]



In [None]:
# TODO: try running the above with ovo and see if the results are better like before

## KPCovR

In [30]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    
    for kernel_type in tqdm(('linear', 'gaussian'), desc='Kernel', leave=False):
        kernel_name = kernel_type.capitalize()
        
        kernel_dir = f'{model_dir}/{cutoff}/Kernel_Models/{kernel_name}/KSVC-KPCovR'
        
        # Load kernels
        kernel_file = f'{kernel_dir}/structure_kernels.hdf5'
        K_train, K_test, K_test_test = utils.load_kernels(kernel_file)

        # Center and scale kernels
        K_train, [K_test], [K_test_test] = \
            utils.preprocess_kernels(K_train, K_test=[K_test], K_test_test=[K_test_test], K_bridge=K_test)
        
        for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
            
            # Prepare outputs
            output_dir = f'Kernel_Models/{kernel_name}/KSVC-KPCovR/{n_cantons}-Class'
            
            svc_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
            svc_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
            
            pcovr_projection_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structures.hdf5' 
            pcovr_projection_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structures.hdf5'
            
            pcovr_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structure_dfs.dat'
            pcovr_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structure_dfs.dat'
            
            pcovr_cantons_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structure_cantons.dat'
            pcovr_cantons_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structure_cantons.dat'
            
            parameter_dir = f'{kernel_dir}/{n_cantons}-Class'
            pcovr_parameter_file = f'{parameter_dir}/pcovr_parameters.json'
            svc_parameter_file = f'{parameter_dir}/svc_parameters.json'
            
            df_train, df_test = utils.load_data(svc_df_deem_file, svc_df_iza_file,
                                                idxs_deem_train, idxs_deem_test,
                                                idxs_iza_train, idxs_iza_test)
            
            # Center and scale the decision functions
            df_train, df_test, df_center, df_scale = \
                utils.preprocess_data(df_train, df_test)
            
            # Run KPCovR
            pcovr_parameters = pcovr_kwargs['kernel'].copy() #load_json(pcovr_parameter_file) ###
            svc_parameters = svc_kwargs['kernel'].copy() #load_json(svc_parameter_file) ###
            T_train, T_test, dfp_train, dfp_test = \
                utils.do_pcovr(K_train, K_test, df_train, df_test, 
                               pcovr_type='kernel', **pcovr_parameters)
                      
            # Post process the KPCovR decision functions
            # (i.e., turn them back into canton predictions)
            predicted_cantons_train, predicted_cantons_test = \
                utils.postprocess_decision_functions(dfp_train, dfp_test, df_center, df_scale,
                                                     df_type=svc_parameters['decision_function_shape'],
                                                     n_classes=n_cantons)
            
            # Save IZA and DEEM KPCovR projections
            utils.split_and_save(T_train, T_test,
                                 idxs_deem_train, idxs_deem_test,
                                 deem_train_slice, deem_test_slice,
                                 output=pcovr_projection_deem_file, output_format='%f',
                                 hdf5_attrs=pcovr_parameters)
            
            utils.split_and_save(T_train, T_test,
                                 idxs_iza_train, idxs_iza_test,
                                 iza_train_slice, iza_test_slice,
                                 output=pcovr_projection_iza_file, output_format='%f',
                                 hdf5_attrs=pcovr_parameters)            
            
            # Save IZA and DEEM KPCovR decision functions
            utils.split_and_save(dfp_train, dfp_test,
                                 idxs_deem_train, idxs_deem_test,
                                 deem_train_slice, deem_test_slice,
                                 output=pcovr_df_deem_file, output_format='%f',
                                 hdf5_attrs=None)
            
            utils.split_and_save(dfp_train, dfp_test,
                                 idxs_iza_train, idxs_iza_test,
                                 iza_train_slice, iza_test_slice,
                                 output=pcovr_df_iza_file, output_format='%f', 
                                 hdf5_attrs=None)
            
            # Save IZA and DEEM KPCovR canton predictions
            utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                 idxs_deem_train, idxs_deem_test,
                                 deem_train_slice, deem_test_slice,
                                 output=pcovr_cantons_deem_file, output_format='%d',
                                 hdf5_attrs=None)
            
            utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                 idxs_iza_train, idxs_iza_test,
                                 iza_train_slice, iza_test_slice,
                                 output=pcovr_cantons_iza_file, output_format='%d',
                                 hdf5_attrs=None)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Kernel', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Kernel', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…




# Linear Models

In [56]:
# Linear model setup
n_species = 2
group_names = {'power': ['OO', 'OSi', 'SiSi', 
                         'OO+OSi', 'OO+SiSi', 'OSi+SiSi',
                         'OO+OSi+SiSi'], 
               'radial': ['O', 'Si', 'O+Si']}

In [34]:
deem_name = 'DEEM_10k'
iza_name = 'IZA_226'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

## LinearSVC

In [57]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    linear_dir = f'{model_dir}/{cutoff}/Linear_Models/LSVC-LPCovR'
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        
        soaps_train, soaps_test = utils.load_soaps(deem_file, iza_file,
                                                   idxs_deem_train, idxs_deem_test,
                                                   idxs_iza_train, idxs_iza_test,
                                                   idxs_iza_delete=[RWY],
                                                   train_test_concatenate=True)

        # Scale the SOAPs so they are of a 'usable' magnitude for the SVC
        soaps_train, soaps_test = utils.preprocess_soaps(soaps_train, soaps_test)
        
        n_features = soaps_train.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type,
                                                     combinations=True)

        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                                
                # Prepare outputs
                output_dir = f'Linear_Models/LSVC-LPCovR/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                
                if not os.path.exists(f'{deem_dir}/{cutoff}/{output_dir}'):
                    os.makedirs(f'{deem_dir}/{cutoff}/{output_dir}')
                    
                if not os.path.exists(f'{iza_dir}/{cutoff}/{output_dir}'):
                    os.makedirs(f'{iza_dir}/{cutoff}/{output_dir}')
                
                svc_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
                svc_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'

                svc_cantons_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_cantons.dat'
                svc_cantons_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_cantons.dat'
                
                parameter_dir = f'{linear_dir}/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                svc_parameter_file = f'{parameter_dir}/svc_parameters.json'

                # Run LSVC
                # TODO: save classification reports and confusion matrices?
                svc_parameters = svc_kwargs['linear'].copy() #load_json(svc_parameter_file) ###
                df_train, df_test, predicted_cantons_train, predicted_cantons_test = \
                    utils.do_svc(soaps_train[:, feature_idxs], soaps_test[:, feature_idxs], 
                                 cantons_train[n_cantons], cantons_test[n_cantons], 
                                 svc_type='linear', **svc_parameters,
                                 outputs=['decision_functions', 'predictions'])

                # Save IZA and DEEM LSVC decision functions
                utils.split_and_save(df_train, df_test,
                                     idxs_deem_train, idxs_deem_test,
                                     deem_train_slice, deem_test_slice,
                                     output=svc_df_deem_file, output_format='%f',
                                     hdf5_attrs=None)
                
                utils.split_and_save(df_train, df_test,
                                     idxs_iza_train, idxs_iza_test,
                                     iza_train_slice, iza_test_slice,
                                     output=svc_df_iza_file, output_format='%f',
                                     hdf5_attrs=None)

                # Save IZA and DEEM LSVC canton predictions
                utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                     idxs_deem_train, idxs_deem_test,
                                     deem_train_slice, deem_test_slice,
                                     output=svc_cantons_deem_file, output_format='%f',
                                     hdf5_attrs=None)
                
                utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                     idxs_iza_train, idxs_iza_test,
                                     iza_train_slice, iza_test_slice,
                                     output=svc_cantons_iza_file, output_format='%f',
                                     hdf5_attrs=None)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

OO


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.80      0.04      0.07       113
           2       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.88      0.52      0.52      2363
weighted avg       0.95      0.95      0.93      2363

[[   4  109]
 [   1 2249]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.00      0.00      0.00        68
           3       0.33      0.03      0.06        31
           4       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.32      0.26      0.26      2363
weighted avg       0.91      0.95      0.93      2363

[[   0    0    0   14]
 [   0    0    2   66]
 [   0    0    1   30]
 [   0    0    0 2250]]
OSi


  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.80      0.11      0.19       113
           2       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.88      0.55      0.58      2363
weighted avg       0.95      0.96      0.94      2363

[[  12  101]
 [   3 2247]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.33      0.01      0.03        68
           3       0.20      0.03      0.06        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.37      0.26      0.27      2363
weighted avg       0.92      0.95      0.93      2363

[[   0    1    0   13]
 [   0    1    4   63]
 [   0    1    1   29]
 [   0    0    0 2250]]
SiSi


  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.95      0.19      0.31       113
           2       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.96      0.59      0.65      2363
weighted avg       0.96      0.96      0.95      2363

[[  21   92]
 [   1 2249]]
              precision    recall  f1-score   support

           1       0.20      0.07      0.11        14
           2       0.62      0.07      0.13        68
           3       0.33      0.03      0.06        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.53      0.29      0.32      2363
weighted avg       0.94      0.96      0.94      2363

[[   1    2    0   11]
 [   4    5    2   57]
 [   0    1    1   29]
 [   0    0    0 2250]]
OO+OSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.78      0.12      0.21       113
           2       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.87      0.56      0.60      2363
weighted avg       0.95      0.96      0.94      2363

[[  14   99]
 [   4 2246]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.67      0.03      0.06        68
           3       0.33      0.10      0.15        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.49      0.28      0.30      2363
weighted avg       0.93      0.95      0.93      2363

[[   0    1    0   13]
 [   0    2    5   61]
 [   0    0    3   28]
 [   0    0    1 2249]]
OO+SiSi


  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.88      0.39      0.54       113
           2       0.97      1.00      0.98      2250

    accuracy                           0.97      2363
   macro avg       0.93      0.69      0.76      2363
weighted avg       0.97      0.97      0.96      2363

[[  44   69]
 [   6 2244]]
              precision    recall  f1-score   support

           1       0.29      0.14      0.19        14
           2       0.73      0.16      0.27        68
           3       0.29      0.06      0.11        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.57      0.34      0.39      2363
weighted avg       0.94      0.96      0.94      2363

[[   2    3    0    9]
 [   4   11    4   49]
 [   0    1    2   28]
 [   1    0    1 2248]]
OSi+SiSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.88      0.33      0.48       113
           2       0.97      1.00      0.98      2250

    accuracy                           0.97      2363
   macro avg       0.92      0.66      0.73      2363
weighted avg       0.96      0.97      0.96      2363

[[  37   76]
 [   5 2245]]
              precision    recall  f1-score   support

           1       0.50      0.14      0.22        14
           2       0.78      0.21      0.33        68
           3       0.29      0.06      0.11        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.63      0.35      0.41      2363
weighted avg       0.95      0.96      0.95      2363

[[   2    2    0   10]
 [   1   14    5   48]
 [   0    2    2   27]
 [   1    0    0 2249]]
OO+OSi+SiSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.87      0.42      0.56       113
           2       0.97      1.00      0.98      2250

    accuracy                           0.97      2363
   macro avg       0.92      0.71      0.77      2363
weighted avg       0.97      0.97      0.96      2363

[[  47   66]
 [   7 2243]]
              precision    recall  f1-score   support

           1       0.33      0.14      0.20        14
           2       0.70      0.21      0.32        68
           3       0.25      0.06      0.10        31
           4       0.97      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.56      0.35      0.40      2363
weighted avg       0.94      0.96      0.95      2363

[[   2    2    0   10]
 [   3   14    5   46]
 [   0    4    2   25]
 [   1    0    1 2248]]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

O


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.00      0.00      0.00       113
           2       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.48      0.50      0.49      2363
weighted avg       0.91      0.95      0.93      2363

[[   0  113]
 [   0 2250]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.00      0.00      0.00        68
           3       0.00      0.00      0.00        31
           4       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.24      0.25      0.24      2363
weighted avg       0.91      0.95      0.93      2363

[[   0    0    0   14]
 [   0    0    0   68]
 [   0    0    0   31]
 [   0    0    0 2250]]
Si


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.00      0.00      0.00       113
           2       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.48      0.50      0.49      2363
weighted avg       0.91      0.95      0.93      2363

[[   0  113]
 [   0 2250]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.00      0.00      0.00        68
           3       0.00      0.00      0.00        31
           4       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.24      0.25      0.24      2363
weighted avg       0.91      0.95      0.93      2363

[[   0    0    0   14]
 [   0    0    0   68]
 [   0    0    0   31]
 [   0    0    0 2250]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


O+Si


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.00      0.00      0.00       113
           2       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.48      0.50      0.49      2363
weighted avg       0.91      0.95      0.93      2363

[[   0  113]
 [   0 2250]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.00      0.00      0.00        68
           3       0.00      0.00      0.00        31
           4       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.24      0.25      0.24      2363
weighted avg       0.91      0.95      0.93      2363

[[   0    0    0   14]
 [   0    0    0   68]
 [   0    0    0   31]
 [   0    0    0 2250]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

OO


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.92      0.71      0.80       113
           2       0.99      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.95      0.85      0.90      2363
weighted avg       0.98      0.98      0.98      2363

[[  80   33]
 [   7 2243]]
              precision    recall  f1-score   support

           1       0.50      0.43      0.46        14
           2       0.75      0.49      0.59        68
           3       0.32      0.29      0.31        31
           4       0.98      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.64      0.55      0.59      2363
weighted avg       0.97      0.97      0.97      2363

[[   6    3    0    5]
 [   5   33   14   16]
 [   0    8    9   14]
 [   1    0    5 2244]]
OSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.90      0.75      0.82       113
           2       0.99      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.95      0.87      0.91      2363
weighted avg       0.98      0.98      0.98      2363

[[  85   28]
 [   9 2241]]
              precision    recall  f1-score   support

           1       0.33      0.36      0.34        14
           2       0.77      0.50      0.61        68
           3       0.48      0.42      0.45        31
           4       0.99      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.64      0.57      0.60      2363
weighted avg       0.97      0.97      0.97      2363

[[   5    3    0    6]
 [   6   34   11   17]
 [   1    6   13   11]
 [   3    1    3 2243]]
SiSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.94      0.58      0.72       113
           2       0.98      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.96      0.79      0.86      2363
weighted avg       0.98      0.98      0.98      2363

[[  66   47]
 [   4 2246]]
              precision    recall  f1-score   support

           1       0.44      0.29      0.35        14
           2       0.76      0.47      0.58        68
           3       0.46      0.19      0.27        31
           4       0.98      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.66      0.49      0.55      2363
weighted avg       0.96      0.97      0.96      2363

[[   4    3    0    7]
 [   3   32    6   27]
 [   0    7    6   18]
 [   2    0    1 2247]]
OO+OSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.90      0.75      0.82       113
           2       0.99      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.95      0.87      0.91      2363
weighted avg       0.98      0.98      0.98      2363

[[  85   28]
 [   9 2241]]
              precision    recall  f1-score   support

           1       0.33      0.36      0.34        14
           2       0.74      0.50      0.60        68
           3       0.37      0.32      0.34        31
           4       0.99      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.61      0.54      0.57      2363
weighted avg       0.97      0.97      0.97      2363

[[   5    4    0    5]
 [   7   34   13   14]
 [   1    7   10   13]
 [   2    1    4 2243]]
OO+SiSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.90      0.66      0.77       113
           2       0.98      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.94      0.83      0.88      2363
weighted avg       0.98      0.98      0.98      2363

[[  75   38]
 [   8 2242]]
              precision    recall  f1-score   support

           1       0.42      0.36      0.38        14
           2       0.77      0.49      0.59        68
           3       0.38      0.32      0.35        31
           4       0.98      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.64      0.54      0.58      2363
weighted avg       0.97      0.97      0.97      2363

[[   5    3    0    6]
 [   6   33   13   16]
 [   0    7   10   14]
 [   1    0    3 2246]]
OSi+SiSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.90      0.76      0.82       113
           2       0.99      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.94      0.88      0.91      2363
weighted avg       0.98      0.98      0.98      2363

[[  86   27]
 [  10 2240]]
              precision    recall  f1-score   support

           1       0.31      0.36      0.33        14
           2       0.77      0.53      0.63        68
           3       0.50      0.42      0.46        31
           4       0.99      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.64      0.58      0.60      2363
weighted avg       0.97      0.97      0.97      2363

[[   5    4    0    5]
 [   7   36   11   14]
 [   1    6   13   11]
 [   3    1    2 2244]]
OO+OSi+SiSi


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.91      0.74      0.82       113
           2       0.99      1.00      0.99      2250

    accuracy                           0.98      2363
   macro avg       0.95      0.87      0.91      2363
weighted avg       0.98      0.98      0.98      2363

[[  84   29]
 [   8 2242]]
              precision    recall  f1-score   support

           1       0.36      0.36      0.36        14
           2       0.75      0.49      0.59        68
           3       0.41      0.39      0.40        31
           4       0.99      1.00      0.99      2250

    accuracy                           0.97      2363
   macro avg       0.63      0.56      0.58      2363
weighted avg       0.97      0.97      0.97      2363

[[   5    3    0    6]
 [   7   33   13   15]
 [   1    7   12   11]
 [   1    1    4 2244]]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

O


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.83      0.22      0.35       113
           2       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.90      0.61      0.66      2363
weighted avg       0.96      0.96      0.95      2363

[[  25   88]
 [   5 2245]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.64      0.10      0.18        68
           3       0.50      0.03      0.06        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.52      0.28      0.30      2363
weighted avg       0.94      0.96      0.94      2363

[[   0    2    0   12]
 [   0    7    1   60]
 [   0    2    1   28]
 [   1    0    0 2249]]
Si


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.00      0.00      0.00       113
           2       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.48      0.50      0.49      2363
weighted avg       0.91      0.95      0.93      2363

[[   0  113]
 [   0 2250]]
              precision    recall  f1-score   support

           1       0.00      0.00      0.00        14
           2       0.00      0.00      0.00        68
           3       0.00      0.00      0.00        31
           4       0.95      1.00      0.98      2250

    accuracy                           0.95      2363
   macro avg       0.24      0.25      0.24      2363
weighted avg       0.91      0.95      0.93      2363

[[   0    0    0   14]
 [   0    0    0   68]
 [   0    0    0   31]
 [   0    0    0 2250]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


O+Si


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

              precision    recall  f1-score   support

           1       0.84      0.23      0.36       113
           2       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.90      0.61      0.67      2363
weighted avg       0.96      0.96      0.95      2363

[[  26   87]
 [   5 2245]]
              precision    recall  f1-score   support

           1       0.50      0.07      0.12        14
           2       0.67      0.09      0.16        68
           3       0.20      0.03      0.06        31
           4       0.96      1.00      0.98      2250

    accuracy                           0.96      2363
   macro avg       0.58      0.30      0.33      2363
weighted avg       0.94      0.96      0.94      2363

[[   1    1    0   12]
 [   0    6    4   58]
 [   0    2    1   28]
 [   1    0    0 2249]]



## LR check

In [67]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'

        soaps_train, soaps_test = utils.load_soaps(deem_file, iza_file,
                                                   idxs_deem_train, idxs_deem_test,
                                                   idxs_iza_train, idxs_iza_test,
                                                   idxs_iza_delete=[RWY],
                                                   train_test_concatenate=True)
        
        n_features = soaps_train.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type, 
                                                     combinations=True)
        
        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):
            
            x_train = soaps_train[:, feature_idxs]
            x_test = soaps_test[:, feature_idxs]

            # Preprocess the SOAPs like the decision functions
            # (i.e., center and scale) for the regression.
            x_train, x_test, x_center, x_scale = \
                utils.preprocess_data(x_train, x_test)
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # Load decision functions
                input_dir = f'Linear_Models/LSVC-LPCovR/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                svc_df_deem_file = f'{deem_dir}/{cutoff}/{input_dir}/svc_structure_dfs.dat'
                svc_df_iza_file = f'{iza_dir}/{cutoff}/{input_dir}/svc_structure_dfs.dat'

                df_train, df_test = utils.load_data(svc_df_deem_file, svc_df_iza_file,
                                                    idxs_deem_train, idxs_deem_test,
                                                    idxs_iza_train, idxs_iza_test)
                
                # Center and scale the decision functions
                df_train, df_test, df_center, df_scale = \
                    utils.preprocess_data(df_train, df_test)
                
                # Check that LR can reproduce the decision functions
                utils.regression_check(x_train, x_test, df_train, df_test, regression_type='linear')

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.0749010089037303e-06
3.0315473299287485e-06
[3.63892498e-07 3.97145911e-07 6.21325129e-07 5.35731276e-07]
[5.98772760e-07 8.42157903e-07 1.59321916e-06 1.51619337e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.2388988022683794e-06
2.4732538551147494e-06
[3.77787233e-07 3.82632520e-07 6.11368178e-07 6.09921324e-07]
[7.16662125e-07 7.63164674e-07 1.15149195e-06 1.22969027e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

9.238253136990646e-07
1.0976398878215434e-06
[3.28747386e-07 4.06701489e-07 9.31504384e-07 4.62960654e-07]
[3.41776585e-07 4.77889541e-07 1.09934635e-06 5.49176257e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.22460679799062e-06
7.319865274627528e-06
[5.07444110e-07 7.35669214e-07 7.73564136e-07 1.06058247e-06]
[1.56932947e-06 2.15473845e-06 3.41141888e-06 3.47096372e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.0661227947568726e-06
2.9707262768848435e-06
[3.13497376e-07 4.02394307e-07 5.95798619e-07 5.37496446e-07]
[8.64249236e-07 8.85993631e-07 1.66647437e-06 1.49423390e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.8984352942797284e-06
8.307526980538326e-06
[4.07208401e-07 8.61568641e-07 1.49049333e-06 1.38378658e-06]
[9.13502811e-07 2.45133130e-06 3.77538131e-06 3.98713889e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.0689722921622423e-06
9.448480451652057e-06
[6.44646981e-07 7.13733480e-07 1.11235193e-06 9.68614097e-07]
[2.30612901e-06 3.01172591e-06 5.74816206e-06 4.49033679e-06]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.024694752283358e-06
1.9879084004926056e-06
[1.15609247e-06 5.24506625e-07 5.59488524e-06 1.01234737e-06]
[1.18875682e-06 5.19566836e-07 5.59876834e-06 9.93954185e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

3.9402097333571434e-06
3.8872535333406995e-06
[2.05544565e-06 1.68491055e-06 2.36873350e-05 1.97010338e-06]
[2.04946843e-06 1.69045669e-06 2.39655500e-05 1.94362313e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.2215553125479915e-06
2.247651781983556e-06
[1.02689589e-06 5.50799852e-07 6.37523276e-06 1.11077775e-06]
[1.05030851e-06 5.50055203e-07 6.38953446e-06 1.12382529e-06]


HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.708602341800333e-07
2.2551931200933813e-07
[1.47090252e-07 5.67775605e-08 7.86535299e-08 9.07553532e-08]
[1.83302979e-07 6.66205913e-08 1.00018162e-07 1.16945564e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.1463225685860387e-06
1.4671853539861524e-06
[5.99061879e-07 6.57609126e-07 4.63051348e-07 5.52743935e-07]
[8.11864422e-07 8.77825851e-07 5.88426970e-07 7.08105608e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

8.541777460937013e-07
9.704272888066993e-07
[3.81579891e-07 4.17598382e-07 3.00549261e-07 4.25904384e-07]
[4.38390815e-07 4.51680871e-07 3.43563578e-07 4.76647083e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.3482344581074784e-06
2.1962328337158324e-06
[5.57240891e-07 6.50360054e-07 5.10530786e-07 6.60151018e-07]
[9.79883293e-07 1.12829095e-06 8.47947306e-07 1.06818369e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

9.322477076527293e-07
1.3424293686597926e-06
[5.28166900e-07 2.35175628e-07 3.85013344e-07 4.80342923e-07]
[7.23748288e-07 3.49264738e-07 5.52031476e-07 6.89909512e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.8278744981731676e-06
2.722550555833599e-06
[1.31135261e-06 1.25735269e-06 1.02164665e-06 9.25669634e-07]
[1.99788962e-06 1.93599905e-06 1.55547428e-06 1.37683803e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.921651111044156e-06
3.4893581600025316e-06
[1.22997719e-06 1.04590112e-06 8.46045106e-07 9.69585136e-07]
[2.28449790e-06 2.06159709e-06 1.55793994e-06 1.75741947e-06]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

3.553851952625699e-07
3.599934929469955e-07
[4.60249422e-07 2.42833622e-07 6.05090102e-07 1.76127927e-07]
[4.54021498e-07 2.42779982e-07 5.95550458e-07 1.78952114e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.5178059225795018e-06
1.5651067364030127e-06
[9.39259115e-07 9.48844572e-07 9.80167710e-07 7.69554547e-07]
[9.39077517e-07 9.31584504e-07 9.48550040e-07 7.53678811e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

3.6013261140750726e-07
3.6158475163623923e-07
[3.59404104e-07 2.34158969e-07 5.31608426e-07 1.78960498e-07]
[3.57532063e-07 2.32741821e-07 5.28003675e-07 1.79144505e-07]



## PCovR

In [82]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    linear_dir = f'{model_dir}/{cutoff}/Linear_Models/LSVC-LPCovR'
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'

        soaps_train, soaps_test = utils.load_soaps(deem_file, iza_file,
                                                   idxs_deem_train, idxs_deem_test,
                                                   idxs_iza_train, idxs_iza_test,
                                                   idxs_iza_delete=[RWY],
                                                   train_test_concatenate=True)
        
        # Scale the SOAPs so they are of a 'usable' magnitude for the SVC
        soaps_train, soaps_test = utils.preprocess_soaps(soaps_train, soaps_test)
        
        n_features = soaps_train.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type,
                                                     combinations=True)
        
        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):
            
            x_train = soaps_train[:, feature_idxs]
            x_test = soaps_test[:, feature_idxs]

            # Preprocess the SOAPs like the decision functions
            # (i.e., center and scale) for the regression.
            x_train, x_test, x_center, x_scale = \
                utils.preprocess_data(x_train, x_test)
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # Prepare outputs
                output_dir = f'Linear_Models/LSVC-LPCovR/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                
                svc_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
                svc_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
                
                pcovr_projection_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structures.hdf5'
                pcovr_projection_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structures.hdf5'
                
                pcovr_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structure_dfs.dat'
                pcovr_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structure_dfs.dat'
                
                pcovr_cantons_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structure_cantons.dat'
                pcovr_cantons_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structure_cantons.dat'
                
                parameter_dir = f'{linear_dir}/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                pcovr_parameter_file = f'{parameter_dir}/pcovr_parameters.json'
                
                df_train, df_test = utils.load_data(svc_df_deem_file, svc_df_iza_file,
                                                    idxs_deem_train, idxs_deem_test,
                                                    idxs_iza_train, idxs_iza_test)
            
                # Center and scale the decision functions
                df_train, df_test, df_center, df_scale = \
                    utils.preprocess_data(df_train, df_test)
                
                # Run PCovR
                pcovr_parameters = svc_kwargs['linear'].copy() #load_json(pcovr_parameter_file) ###
                T_train, T_test, dfp_train, dfp_test = \
                    utils.do_pcovr(x_train, x_test, df_train, df_test, pcovr_type='linear')

                # Post process the PCovR decision functions
                # (i.e., turn them back into canton predictions)
                predicted_cantons_train, predicted_cantons_test = \
                    utils.postprocess_decision_functions(dfp_train, dfp_test, df_center, df_scale,
                                                         df_type=svc_parameters['multi_class'],
                                                         n_classes=n_cantons)
                
                # Save IZA and DEEM PCovR projections
                utils.split_and_save(T_train, T_test,
                                     idxs_deem_train, idxs_deem_test,
                                     deem_train_slice, deem_test_slice,
                                     output=pcovr_projection_deem_file, output_format='%f',
                                     hdf5_attrs=pcovr_parameters)
                
                utils.split_and_save(T_train, T_test,
                                     idxs_iza_train, idxs_iza_test,
                                     iza_train_slice, iza_test_slice,
                                     output=pcovr_projection_iza_file, output_format='%f',
                                     hdf5_attrs=pcovr_parameters)

                # Save IZA and DEEM PCovR decision functions
                utils.split_and_save(dfp_train, dfp_test,
                                     idxs_deem_train, idxs_deem_test,
                                     deem_train_slice, deem_test_slice,
                                     output=pcovr_df_deem_file, output_format='%f',
                                     hdf5_attrs=None)
                
                utils.split_and_save(dfp_train, dfp_test,
                                     idxs_iza_train, idxs_iza_test,
                                     iza_train_slice, iza_test_slice,
                                     output=pcovr_df_iza_file, output_format='%f',
                                     hdf5_attrs=None)

                # Save IZA and DEEM PCovR canton predictions
                utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                     idxs_deem_train, idxs_deem_test,
                                     deem_train_slice, deem_test_slice,
                                     output=pcovr_cantons_deem_file, output_format='%d',
                                     hdf5_attrs=None)
                
                utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
                                     idxs_iza_train, idxs_iza_test,
                                     iza_train_slice, iza_test_slice,
                                     output=pcovr_cantons_iza_file, output_format='%d',
                                     hdf5_attrs=None)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…




# Logistic Regression