In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np

# Plotting
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ML
from soap import extract_species_pair_groups

from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score
from sklearn.linear_model import Ridge

# Utilities
import h5py
import json
import itertools
from tqdm.notebook import tqdm
import project_utils as utils
from tools import load_json, save_json

# Import COSMO style toolkit
import cosmoplot.colorbars as cosmocbars
import cosmoplot.utils as cosmoutils
import cosmoplot.style as cosmostyle

cosmostyle.set_style('article')
colorList = cosmostyle.color_cycle

# Load train and test splits

In [5]:
# Load SOAP cutoffs
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)
    
cutoffs = soap_hyperparameters['interaction_cutoff']

In [6]:
# Load train and test set indices for Deem
deem_train_idxs = np.loadtxt('../Processed_Data/DEEM_330k/train.idxs', dtype=int)
deem_test_idxs = np.loadtxt('../Processed_Data/DEEM_330k/test.idxs', dtype=int)
n_deem = len(deem_train_idxs) + len(deem_test_idxs)

7750 2250


In [None]:
# Load train and test set indices for IZA
iza_train_idxs = np.loadtxt('../Processed_Data/IZA_226/train.idxs', dtype=int)
iza_test_idxs = np.loadtxt('../Processed_Data/IZA_226/test.idxs', dtype=int)
n_iza = len(iza_train_idxs) + len(iza_test_idxs)

In [9]:
# Load IZA cantons
iza_cantons = np.loadtxt('../Raw_Data/GULP/IZA_226/cantons.txt', usecols=1, dtype=int)

In [11]:
# Build set of "master" canton labels
cantons = {}

cantons[4] = np.concatenate((
    cantons_iza, 
    np.ones(n_deem, dtype=int) * 4
))

cantons[2] = np.concatenate((
    np.ones(n_iza, dtype=int),
    np.ones(n_deem, dtype=int) * 2
))

In [12]:
# Load dummy cantons for Deem
dummy_cantons = {}
dummy_cantons[2] = np.loadtxt('../Processed_Data/DEEM_330k/dummy_cantons_2-class.dat', dtype=int)
dummy_cantons[4] = np.loadtxt('../Processed_Data/DEEM_330k/dummy_cantons_4-class.dat', dtype=int)

In [None]:
# Concatenate IZA and Deem indices
train_idxs = np.concatenate((iza_train_idxs, deem_train_idxs + n_iza))
test_idxs = np.concatenate((iza_test_idxs, deem_test_idxs + n_iza))

# Model setup

In [13]:
model_dir = '../Processed_Data/Models'

In [14]:
class_names = {}
class_names[2] = ['IZA', 'DEEM']
class_names[4] = ['IZA1', 'IZA2', 'IZA3', 'DEEM']

In [21]:
# Global model parameters
# TODO: or use .get_params()?
# TODO: load from optimization?
# TODO: should we make sure break_ties=True?

# NOTE: with balance weights, the sum of the weights times class frequencies is one
# NOTE: LSVC decision functions are OVR (or Crammer-Singer) only; KSVC decision functions
# are always computed with OVO, but can be converted to OVR. Hence we use OVR with both
# LSVC and KSVC for consistency
# svc_kwargs = dict(
#     penalty='l2',
#     loss='squared_hinge',
#     dual=False,
#     multi_class='ovr',
#     #class_weight=None,
#     class_weight='balanced',
#     fit_intercept=True,
#     intercept_scaling=1.0,
#     tol=1.0E-3,
#     max_iter=5000,
#     C=0.01
# )

# TODO: load optimal parameters

# Linear SVC

In [17]:
# Linear model setup
n_species = 2
group_names = {'power': ['OO', 'OSi', 'SiSi', 
                         'OO+OSi', 'OO+SiSi', 'OSi+SiSi',
                         'OO+OSi+SiSi'], 
               'radial': ['O', 'Si', 'O+Si']}

In [18]:
deem_name = 'DEEM_330k'
iza_name = 'IZA_226'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

In [19]:
# TODO: SAVE THE SOAP SCALING!

In [22]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    linear_dir = f'{model_dir}/{cutoff}/Linear_Models/SVC'
    if cutoff == 3.5: #####
        continue #####
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        if spectrum_type == 'radial': #####
            continue #####
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        
        # TODO: new SOAP loading with concatenation
        
        n_features = soaps_train.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type,
                                                     combinations=True)
        # Prepare loading of the DEEM 330k structures 
        all_deem_file = f'{all_deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        f = h5py.File(all_deem_file, 'r')
        deem_330k_dataset = f['0']

        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):
            
            if species_pairing != 'OO+OSi+SiSi': #####
                continue #####
                
            # Scale the SOAPs so they are of a 'usable' magnitude for the SVM
            # TODO: SOAP feature selection (before scaling!)
            # TODO: new SOAP preprocessing with pipeline
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                if n_cantons == 4: #####
                    continue #####
                                
                # Prepare outputs
                output_dir = f'Linear_Models/SVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                
                if not os.path.exists(f'{deem_dir}/{cutoff}/{output_dir}'):
                    os.makedirs(f'{deem_dir}/{cutoff}/{output_dir}')
                    
                if not os.path.exists(f'{all_deem_dir}/{cutoff}/{output_dir}'):
                    os.makedirs(f'{all_deem_dir}/{cutoff}/{output_dir}')
                    
                if not os.path.exists(f'{iza_dir}/{cutoff}/{output_dir}'):
                    os.makedirs(f'{iza_dir}/{cutoff}/{output_dir}')
                
                svc_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
                svc_df_all_deem_file = f'{all_deem_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
                svc_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'

                svc_cantons_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_cantons.dat'
                svc_cantons_all_deem_file = f'{all_deem_dir}/{cutoff}/{output_dir}/svc_structure_cantons.dat'
                svc_cantons_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_cantons.dat'
                                
                parameter_dir = f'{linear_dir}/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                svc_parameter_file = f'{parameter_dir}/svc_parameters.json'
                svc_model_file = f'{parameter_dir}/svc.json'
                
                svc_weights_file = f'{parameter_dir}/svc_weights.dat'

                # TODO: run SVC with pipeline and optimal parameters
                
                # Run LSVC
#                 svc_parameters = svc_kwargs['linear'].copy() #load_json(svc_parameter_file) ###
#                 df_train, df_test, predicted_cantons_train, predicted_cantons_test, weights = \
#                     utils.do_svc(soaps_train[:, feature_idxs], soaps_test[:, feature_idxs], 
#                                  cantons_train[n_cantons], cantons_test[n_cantons], 
#                                  svc_type='linear', **svc_parameters,
#                                  outputs=['decision_functions', 'predictions', 'weights'],
#                                  save_model=svc_model_file)

#                 # Save IZA and DEEM LSVC decision functions
#                 utils.split_and_save(df_train, df_test,
#                                      idxs_deem_train, idxs_deem_test,
#                                      deem_train_slice, deem_test_slice,
#                                      output=svc_df_deem_file, output_format='%f',
#                                      hdf5_attrs=None)
                
#                 utils.split_and_save(df_train, df_test,
#                                      idxs_iza_train, idxs_iza_test,
#                                      iza_train_slice, iza_test_slice,
#                                      output=svc_df_iza_file, output_format='%f',
#                                      hdf5_attrs=None)

#                 # Save IZA and DEEM LSVC canton predictions
#                 utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
#                                      idxs_deem_train, idxs_deem_test,
#                                      deem_train_slice, deem_test_slice,
#                                      output=svc_cantons_deem_file, output_format='%d',
#                                      hdf5_attrs=None)
                
#                 utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
#                                      idxs_iza_train, idxs_iza_test,
#                                      iza_train_slice, iza_test_slice,
#                                      output=svc_cantons_iza_file, output_format='%d',
#                                      hdf5_attrs=None)
                
#                 # Save weights
#                 # TODO: when constructing the real space density, 
#                 # do we also have to account for the intercept?
#                 np.savetxt(svc_weights_file, weights)
                
#                 # Load the SVC model
#                 svc_model_dict = load_json(svc_model_file, array_convert=True)
#                 svc_model = LinearSVC()
#                 svc_model.__dict__ = svc_model_dict
                
#                 # Read the DEEM_330k structures and compute decision functions
#                 # and canton predictions
                
#                 # We could load up the whole dataset beforehand and probably make this faster
#                 # since we wouldn't have to read the HDF5 for both the DFs and predictions,
#                 # but the single-read seems to be memory-prohibitive and this 
#                 # multiple-read construction appears tolerably fast
#                 # and also seems to use much less memory
#                 df_all_deem = svc_model.decision_function(deem_330k_dataset[:, feature_idxs]/soap_scale)
#                 predicted_cantons_all_deem = svc_model.predict(deem_330k_dataset[:, feature_idxs]/soap_scale)
                
#                 # Save DEEM 330k LSVC decision functions and canton predictions
#                 np.savetxt(svc_df_all_deem_file, df_all_deem)
#                 np.savetxt(svc_cantons_all_deem_file, predicted_cantons_all_deem, fmt='%d')
                
        f.close()

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…




## Check that a random split of DEEM can't be predicted

In [25]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    linear_dir = f'{model_dir}/{cutoff}/Linear_Models/SVC'
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        
        # TODO: new SOAP loading with concatenation
        
        n_features = soaps_train.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type,
                                                     combinations=True)

        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):
            
            # Scale the SOAPs so they are of a 'usable' magnitude for the SVC
            # TODO: SOAP feature selection (before scaling!)
            # TODO: new SOAP preprocessing with pipeline
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # Prepare outputs
                output_dir = f'Linear_Models/SVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'

                svc_cantons_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/dummy_svc_structure_cantons.dat'
                                
                parameter_dir = f'{linear_dir}/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                svc_parameter_file = f'{parameter_dir}/svc_parameters.json'
                
                svc_weights_file = f'{parameter_dir}/svc_weights.dat'
                
                # TODO: run SVC with pipeline

#                 # Run LSVC
#                 svc_parameters = svc_kwargs['linear'].copy() #load_json(svc_parameter_file) ###
#                 #svc_parameters['class_weight'] = None ###
#                 predicted_cantons_train, predicted_cantons_test = \
#                     utils.do_svc(soaps_train[:, feature_idxs], soaps_test[:, feature_idxs], 
#                                  dummy_cantons_deem_train[n_cantons], 
#                                  dummy_cantons_deem_test[n_cantons], 
#                                  svc_type='linear', **svc_parameters,
#                                  outputs=['predictions'])
                
#                 # Save DEEM LSVC canton predictions
#                 utils.split_and_save(predicted_cantons_train, predicted_cantons_test,
#                                      idxs_deem_train, idxs_deem_test,
#                                      slice(None), slice(None),
#                                      output=svc_cantons_deem_file, output_format='%d',
#                                      hdf5_attrs=None)
                
                print(accuracy_score(dummy_cantons_deem_train[n_cantons], predicted_cantons_train))
                print(confusion_matrix(dummy_cantons_deem_train[n_cantons], predicted_cantons_train))
                print(accuracy_score(dummy_cantons_deem_test[n_cantons], predicted_cantons_test))
                print(confusion_matrix(dummy_cantons_deem_test[n_cantons], predicted_cantons_test))

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.520774193548387
[[1858 2036]
 [1678 2178]]
0.5022222222222222
[[527 608]
 [512 603]]
0.2713548387096774
[[343 543 554 525]
 [285 592 547 543]
 [275 557 588 557]
 [260 519 482 580]]
0.2408888888888889
[[ 92 170 168 165]
 [ 80 139 182 160]
 [ 82 143 152 148]
 [ 95 165 150 159]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.535483870967742
[[2005 1889]
 [1711 2145]]
0.5017777777777778
[[525 610]
 [511 604]]
0.2663225806451613
[[246 560 545 614]
 [243 556 546 622]
 [241 538 577 621]
 [223 430 503 685]]
0.2457777777777778
[[ 72 178 166 179]
 [ 74 133 172 182]
 [ 65 143 166 151]
 [ 69 143 175 182]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5163870967741936
[[1863 2031]
 [1717 2139]]
0.49866666666666665
[[519 616]
 [512 603]]
0.26670967741935486
[[424 600 398 543]
 [380 614 449 524]
 [416 580 475 506]
 [369 545 373 554]]
0.24
[[124 193 132 146]
 [107 151 141 162]
 [130 154 108 133]
 [106 181 125 157]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5508387096774193
[[2048 1846]
 [1635 2221]]
0.49333333333333335
[[509 626]
 [514 601]]
0.28735483870967743
[[203 620 489 653]
 [151 702 466 648]
 [165 596 560 656]
 [135 503 441 762]]
0.24666666666666667
[[ 51 180 159 205]
 [ 46 163 164 188]
 [ 53 143 139 190]
 [ 56 159 152 202]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5369032258064517
[[2002 1892]
 [1697 2159]]
0.5057777777777778
[[536 599]
 [513 602]]
0.2718709677419355
[[332 641 372 620]
 [290 681 399 597]
 [329 606 424 618]
 [273 545 353 670]]
0.23777777777777778
[[ 95 191 118 191]
 [ 92 159 126 184]
 [102 151 102 170]
 [ 88 189 113 179]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5463225806451613
[[2093 1801]
 [1715 2141]]
0.48977777777777776
[[544 591]
 [557 558]]
0.27625806451612905
[[314 555 506 590]
 [284 599 516 568]
 [278 537 590 572]
 [268 467 468 638]]
0.24844444444444444
[[ 96 165 166 168]
 [ 89 143 163 166]
 [ 84 130 147 164]
 [ 78 163 155 173]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5603870967741935
[[2131 1763]
 [1644 2212]]
0.5
[[541 594]
 [531 584]]
0.28761290322580646
[[241 504 611 609]
 [208 543 611 605]
 [210 467 721 579]
 [191 392 534 724]]
0.25155555555555553
[[ 75 143 182 195]
 [ 62 134 193 172]
 [ 60 119 169 177]
 [ 62 150 169 188]]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5060645161290322
[[1913 1981]
 [1847 2009]]
0.4822222222222222
[[531 604]
 [561 554]]
0.25690322580645164
[[672 349 294 650]
 [660 359 256 692]
 [644 351 296 686]
 [575 310 292 664]]
0.2653333333333333
[[200 103 100 192]
 [175 108  91 187]
 [164  95  87 179]
 [191 116  60 202]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5019354838709678
[[1939 1955]
 [1905 1951]]
0.5008888888888889
[[558 577]
 [546 569]]
0.2603870967741935
[[798 269 243 655]
 [762 272 242 691]
 [750 274 260 693]
 [664 255 234 688]]
0.26311111111111113
[[253  70  73 199]
 [200  80  82 199]
 [218  70  64 173]
 [210  88  76 195]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5033548387096775
[[1909 1985]
 [1864 1992]]
0.488
[[533 602]
 [550 565]]
0.25793548387096776
[[920  67 297 681]
 [906  63 275 723]
 [874  57 315 731]
 [779  63 298 701]]
0.27111111111111114
[[281  30  88 196]
 [247  19  90 205]
 [233  16  94 182]
 [266  22  65 216]]


HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5676129032258065
[[2226 1668]
 [1683 2173]]
0.496
[[556 579]
 [555 560]]
0.35096774193548386
[[703 447 398 417]
 [454 682 371 460]
 [439 437 638 463]
 [388 404 352 697]]
0.2351111111111111
[[141 145 146 163]
 [148 134 132 147]
 [128 137 107 153]
 [144 143 135 147]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5940645161290322
[[2316 1578]
 [1568 2288]]
0.49777777777777776
[[554 581]
 [549 566]]
0.3664516129032258
[[684 413 448 420]
 [378 699 445 445]
 [395 399 778 405]
 [351 377 434 679]]
0.23155555555555554
[[125 141 192 137]
 [137 129 142 153]
 [127 137 129 132]
 [130 147 154 138]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5619354838709677
[[2215 1679]
 [1716 2140]]
0.5048888888888889
[[576 559]
 [555 560]]
0.32387096774193547
[[587 408 477 493]
 [430 565 451 521]
 [407 394 685 491]
 [368 370 430 673]]
0.23866666666666667
[[123 148 176 148]
 [138 129 152 142]
 [125 114 133 153]
 [137 138 142 152]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.6099354838709677
[[2386 1508]
 [1515 2341]]
0.5115555555555555
[[575 560]
 [539 576]]
0.39083870967741935
[[824 365 367 409]
 [395 729 373 470]
 [416 406 691 464]
 [374 350 332 785]]
0.22755555555555557
[[143 137 151 164]
 [159 122 127 153]
 [136 127 107 155]
 [143 156 130 140]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5944516129032258
[[2330 1564]
 [1579 2277]]
0.5151111111111111
[[570 565]
 [526 589]]
0.3781935483870968
[[648 457 436 424]
 [378 745 379 465]
 [381 380 756 460]
 [328 386 345 782]]
0.23777777777777778
[[122 152 159 162]
 [113 136 149 163]
 [109 135 127 154]
 [131 157 131 150]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.6068387096774194
[[2375 1519]
 [1528 2328]]
0.5048888888888889
[[561 574]
 [540 575]]
0.38167741935483873
[[648 397 474 446]
 [367 699 443 458]
 [357 364 856 400]
 [325 334 427 755]]
0.24
[[124 139 184 148]
 [117 122 153 169]
 [105 130 148 142]
 [126 141 156 146]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.6268387096774194
[[2458 1436]
 [1456 2400]]
0.5022222222222222
[[557 578]
 [542 573]]
0.40825806451612906
[[714 407 419 425]
 [353 835 345 434]
 [384 397 779 417]
 [316 368 321 836]]
0.24044444444444443
[[123 156 157 159]
 [120 144 135 162]
 [115 135 123 152]
 [124 169 125 151]]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5073548387096775
[[2114 1780]
 [2038 1818]]
0.5044444444444445
[[637 498]
 [617 498]]
0.2603870967741935
[[438 460 422 645]
 [438 468 411 650]
 [443 429 453 652]
 [396 400 386 659]]
0.25066666666666665
[[143 144 123 185]
 [124 112 127 198]
 [111 117 115 182]
 [139 129 107 194]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5077419354838709
[[2141 1753]
 [2062 1794]]
0.5146666666666667
[[656 479]
 [613 502]]
0.26464516129032256
[[396 577 366 626]
 [356 602 401 608]
 [381 576 417 603]
 [364 493 348 636]]
0.25466666666666665
[[124 190 120 161]
 [107 157 113 184]
 [110 147  99 169]
 [110 177  89 193]]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

0.5092903225806452
[[2130 1764]
 [2039 1817]]
0.4968888888888889
[[630 505]
 [627 488]]
0.2709677419354839
[[510 479 339 637]
 [426 517 367 657]
 [459 479 365 674]
 [407 432 294 708]]
0.24977777777777777
[[141 160  94 200]
 [119 135 104 203]
 [133 124  82 186]
 [140 148  77 204]]



## LR check

In [19]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'

        # TODO: new SOAP loading with concatenation
        
        n_features = soaps_train.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type, 
                                                     combinations=True)
        
        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):

            # Preprocess the SOAPs like the decision functions
            # (i.e., center and scale) for the regression.
            # Use the TransformedTargetRegressor in a pipeline
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # Load decision functions
                input_dir = f'Linear_Models/SVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                svc_df_deem_file = f'{deem_dir}/{cutoff}/{input_dir}/svc_structure_dfs.dat'
                svc_df_iza_file = f'{iza_dir}/{cutoff}/{input_dir}/svc_structure_dfs.dat'

                df_train, df_test = utils.load_data(svc_df_deem_file, svc_df_iza_file,
                                                    idxs_deem_train, idxs_deem_test,
                                                    idxs_iza_train, idxs_iza_test)
                
                # Preprocess the decision functions in a pipeline
                # Predict with Ridge using a default regularization, no need to optimize

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

5.951420625549286e-07
1.6733860364302191e-06
[1.27004889e-07 1.59816071e-07 3.71299635e-07 4.15208233e-07]
[3.51954034e-07 4.60954017e-07 9.21982539e-07 1.21917692e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.1336435050418972e-06
2.2915594072670793e-06
[2.78726764e-07 3.50718283e-07 6.09254550e-07 5.52089460e-07]
[6.43106398e-07 6.41818354e-07 1.27383805e-06 1.11330573e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.401468187425319e-06
1.7834418977920897e-06
[3.16946628e-07 4.13547584e-07 1.70760246e-06 4.12898734e-07]
[3.77048792e-07 4.92051099e-07 2.19583235e-06 4.94628453e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.8200318111673028e-06
6.147678787394855e-06
[6.11355423e-07 6.00204629e-07 8.65233702e-07 8.65623998e-07]
[2.26922278e-06 1.92278123e-06 4.01612505e-06 2.51576042e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.3546632619055442e-06
2.766828475847418e-06
[2.53812100e-07 4.11765710e-07 8.53094381e-07 5.09669025e-07]
[8.12143717e-07 1.01092561e-06 1.79882459e-06 1.34374931e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.8142097907520355e-06
7.87111612056096e-06
[4.45139887e-07 8.72749767e-07 1.71392377e-06 1.35282434e-06]
[8.96218541e-07 2.39137770e-06 4.67103222e-06 3.90593901e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.280919189269994e-06
9.151368023117194e-06
[6.76746694e-07 7.91320770e-07 1.36701317e-06 9.78198207e-07]
[2.77727105e-06 3.51527258e-06 6.83335531e-06 4.48541045e-06]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

3.188688474869591e-07
3.2420831956504324e-07
[7.21500483e-08 1.03701209e-07 2.35601964e-07 4.60912257e-07]
[7.43509734e-08 1.01960558e-07 2.33813763e-07 4.61855085e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

6.861719458648793e-07
6.901253115705306e-07
[1.55463263e-07 1.64971051e-07 7.94888951e-07 8.58593158e-07]
[1.57981609e-07 1.66589641e-07 7.91007705e-07 8.51238058e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

3.3822295997953675e-07
3.3982591693237734e-07
[7.89963008e-08 1.10155668e-07 2.50457663e-07 5.00631340e-07]
[8.00302750e-08 1.12015077e-07 2.50111968e-07 4.93415687e-07]


HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.3707687359337926e-07
1.8159440438692538e-07
[1.31017245e-07 5.69018463e-08 6.58603384e-08 8.83658489e-08]
[1.64767187e-07 7.15867866e-08 8.57823719e-08 1.14529477e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

9.48241554775321e-07
1.2372174579251195e-06
[6.59764682e-07 6.31439479e-07 4.13242724e-07 5.03456366e-07]
[8.89059368e-07 8.44358691e-07 5.24529597e-07 6.51523325e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

6.531282921848146e-07
7.41362259711033e-07
[3.75519407e-07 4.94054474e-07 5.04454885e-07 3.79284177e-07]
[4.25662055e-07 5.43048494e-07 5.78109698e-07 4.20722640e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.294283570839404e-06
2.171071564886283e-06
[6.11375079e-07 5.99981848e-07 4.13029556e-07 6.00053990e-07]
[1.08267622e-06 1.04600222e-06 6.61170632e-07 9.93250372e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

7.195733018029833e-07
1.0717216968617537e-06
[5.42399177e-07 2.22065230e-07 3.95905535e-07 3.97158579e-07]
[7.47123281e-07 3.30861892e-07 5.56808904e-07 5.77168240e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.830049615923203e-06
2.7324137449766664e-06
[1.35509822e-06 1.22292678e-06 9.31614467e-07 8.92945925e-07]
[2.06130941e-06 1.89368518e-06 1.42331693e-06 1.33446588e-06]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

1.861334650866247e-06
3.3297284218006098e-06
[1.14042985e-06 8.98084707e-07 7.32659623e-07 8.90690077e-07]
[2.11802873e-06 1.77845048e-06 1.35734849e-06 1.62056273e-06]


HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.5441362296163235e-07
2.5001778686504963e-07
[1.03283577e-07 8.66931567e-08 1.66168109e-07 1.50184284e-07]
[1.03222293e-07 8.61848068e-08 1.63436131e-07 1.48061277e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

3.3018848518137207e-07
3.3124577381153986e-07
[1.79514578e-07 1.45543356e-07 2.05262148e-07 3.83007699e-07]
[1.75973643e-07 1.45824762e-07 2.05914799e-07 3.88154357e-07]


HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

2.4769552778583826e-07
2.46530652628551e-07
[9.31849887e-08 8.72353987e-08 1.45886103e-07 1.49308888e-07]
[9.21033395e-08 8.95295424e-08 1.45161181e-07 1.47399920e-07]

