In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np
from numpy.random import default_rng

# Plotting
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ML
from soap import extract_species_pair_groups

from sklearn.svm import LinearSVC, SVC
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.multiclass import OneVsRestClassifier

# Utilities
import h5py
import json
import itertools
from tempfile import mkdtemp
from shutil import rmtree
from copy import deepcopy
from tqdm.auto import tqdm
import project_utils as utils
from tools import load_json, save_json

# Import COSMO style toolkit
import cosmoplot.colorbars as cosmocbars
import cosmoplot.utils as cosmoutils
import cosmoplot.style as cosmostyle

cosmostyle.set_style('article')
colorList = cosmostyle.color_cycle

# Load train and test splits

In [3]:
# Load SOAP cutoffs
soap_hyperparameters = load_json('../Processed_Data/soap_hyperparameters.json')   
cutoffs = soap_hyperparameters['interaction_cutoff']

In [4]:
# Load train sets for IZA and Deem
iza_train_idxs = np.loadtxt('../Processed_Data/IZA_230/svm_train.idxs', dtype=int)
iza_sort_idxs = np.argsort(iza_train_idxs)
iza_unsort_idxs = np.argsort(iza_sort_idxs)
deem_train_idxs = np.loadtxt('../Processed_Data/DEEM_330k/svm_train.idxs', dtype=int)

In [5]:
# Load cantons for IZA and Deem
iza_cantons = np.loadtxt('../Raw_Data/IZA_230/cantons.dat', usecols=1, dtype=int)
deem_cantons_2 = np.loadtxt('../Processed_Data/DEEM_330k/Data/cantons_2-class.dat', dtype=int)
deem_cantons_4 = np.loadtxt('../Processed_Data/DEEM_330k/Data/cantons_4-class.dat', dtype=int)

In [6]:
# Build set of "master" canton labels
cantons = {}

cantons[4] = np.concatenate((
    iza_cantons[iza_train_idxs], 
    deem_cantons_4[deem_train_idxs]
))

cantons[2] = np.concatenate((
    np.ones(len(iza_train_idxs), dtype=int),
    deem_cantons_2[deem_train_idxs]
))

# Build set of class weights (by sample) for centering and scaling
class_weights = {n_cantons: utils.balanced_class_weights(cantons[n_cantons]) for n_cantons in (2, 4)}

In [7]:
# Load dummy Deem cantons to test the "null" case
dummy_cantons = {}
dummy_cantons[2] = np.loadtxt('../Processed_Data/DEEM_330k/Data/dummy_cantons_2-class.dat', dtype=int)
dummy_cantons[2] = dummy_cantons[2][deem_train_idxs]
dummy_cantons[4] = np.loadtxt('../Processed_Data/DEEM_330k/Data/dummy_cantons_4-class.dat', dtype=int)
dummy_cantons[4] = dummy_cantons[4][deem_train_idxs]

# Build set of dummy class weights (by sample) for centering and scaling
dummy_class_weights = {n_cantons: utils.balanced_class_weights(dummy_cantons[n_cantons]) for n_cantons in (2, 4)}

# Model setup

In [8]:
model_dir = '../Processed_Data/Models'

deem_name = 'DEEM_330k'
iza_name = 'IZA_230'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

In [9]:
# CV splits
n_splits = 2

# When using the OneVsRestClassifier, 
# n_classes binary problems are passed to SVC, 
# and the decision function shape doesn't have an impact
svc_parameters = dict(
    kernel='precomputed',
    decision_function_shape='ovo',
    class_weight='balanced',
    tol=1.0E-3,
    cache_size=1000
)

# Linear SVC

In [10]:
# Linear model setup
n_species = 2
group_names = {
    'power': ['OO', 'OSi', 'SiSi', 'OO+OSi', 'OO+SiSi', 'OSi+SiSi', 'OO+OSi+SiSi'], 
    'radial': ['O', 'Si', 'O+Si']
}

## Optimize LinearSVC parameters

In [11]:
# Regularization parameters for cross-validation
C = np.logspace(-4, 4, 9)
parameter_grid = dict(svc__estimator__C=C)

In [12]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    work_dir = f'{model_dir}/{cutoff}/LSVC'
    
    os.makedirs(work_dir, exist_ok=True)
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_soaps = utils.load_hdf5(iza_file, indices=iza_train_idxs[iza_sort_idxs])
        iza_soaps = iza_soaps[iza_unsort_idxs]

        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        deem_soaps = utils.load_hdf5(deem_file, indices=deem_train_idxs)
        
        soaps = np.vstack((iza_soaps, deem_soaps))
        
        n_features = soaps.shape[1]
        feature_groups = extract_species_pair_groups(
            n_features, n_species, 
            spectrum_type=spectrum_type,
            combinations=True
        )
        
        for species_pairing, feature_idxs in zip(
            tqdm(group_names[spectrum_type], desc='Species', leave=False),
            feature_groups
        ):

            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # IZA + Deem classification
                pipeline = Pipeline(
                    [
                        ('norm_scaler', utils.StandardNormScaler()),
                        ('kernel_constructor', utils.KernelConstructor()),
                        ('svc', OneVsRestClassifier(SVC(**svc_parameters)))
                    ],
                )
                
                gscv = GridSearchCV(
                    pipeline, parameter_grid,
                    scoring=[
                        'accuracy', 'balanced_accuracy',
                    ],
                    cv=StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0),
                    refit=False, return_train_score=True, error_score='raise', n_jobs=4
                )
                fit_params = {'norm_scaler__sample_weight': class_weights[n_cantons]}
                gscv.fit(soaps[:, feature_idxs], cantons[n_cantons], **fit_params)

                # Prepare outputs
                output_dir = f'{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                os.makedirs(f'{work_dir}/{output_dir}', exist_ok=True)
                save_json(gscv.cv_results_, f'{work_dir}/{output_dir}/cv_results.json', array_convert=True)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…




## Check the cross-validated parameters

In [17]:
#### IZA + DEEM classification
for cutoff in cutoffs:
    work_dir = f'{model_dir}/{cutoff}/LSVC'
    
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        
        for group_name in group_names[spectrum_type]:
            for n_cantons in (2, 4):
                result_dir = f'{n_cantons}-Class/{spectrum_name}/{group_name}'
                cv_results = load_json(f'{work_dir}/{result_dir}/cv_results.json')
                print(f'-----Optimal Parameters for {cutoff} {spectrum_type} {group_name} {n_cantons} -----')
                
                for score in ('accuracy', 'balanced_accuracy'):
                    idx = np.argmin(cv_results[f'rank_test_{score}'])
                    opt_parameters = utils.get_optimal_parameters(cv_results, score, **svc_parameters)
                    print(f'{score} =', cv_results[f'mean_test_{score}'][idx])
                    print(opt_parameters)
                    print('')
                    
                    save_json(opt_parameters, f'{work_dir}/{result_dir}/svc_parameters_{score}.json')

-----Optimal Parameters for 3.5 power OO 2 -----
accuracy = 0.9485809164762218
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 10000.0}

balanced_accuracy = 0.9359014594146899
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 10.0}

-----Optimal Parameters for 3.5 power OO 4 -----
accuracy = 0.9471965205830745
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 10000.0}

balanced_accuracy = 0.5111435836617872
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 10.0}

-----Optimal Parameters for 3.5 power OSi 2 -----
accuracy = 0.969840534676836
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 10000.0}

balanc


balanced_accuracy = 0.5147306350586356
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 0.001}

-----Optimal Parameters for 6.0 radial Si 2 -----
accuracy = 0.8680905360848354
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 1000.0}

balanced_accuracy = 0.8769258248140857
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 1000.0}

-----Optimal Parameters for 6.0 radial Si 4 -----
accuracy = 0.8745180926363507
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 10000.0}

balanced_accuracy = 0.5264365436151943
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 0.1}

-----Optimal Parameters for 6.0 radial O+Si 2

# Linear SVC for dummy models

We do the dummy models separately since they tend to have convergence issues (since we are after all trying to learn random classes), and so may require some special treatment and so we can keep track of nonconvergence

In [14]:
# Regularization parameters for cross-validation
dummy_C = np.logspace(-4, 4, 9) # Starting at 10^{+/-5} we start not being able to converge within millions of steps
dummy_parameter_grid = dict(svc__estimator__C=dummy_C)

In [15]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    work_dir = f'{model_dir}/{cutoff}/LSVC'
    
    os.makedirs(work_dir, exist_ok=True)
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        deem_soaps = utils.load_hdf5(deem_file, indices=deem_train_idxs)
                
        n_features = soaps.shape[1]
        feature_groups = extract_species_pair_groups(
            n_features, n_species, 
            spectrum_type=spectrum_type,
            combinations=True
        )
        
        for species_pairing, feature_idxs in zip(
            tqdm(group_names[spectrum_type], desc='Species', leave=False),
            feature_groups
        ):

            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # Dummy DEEM classification
                dummy_pipeline = Pipeline(
                    [
                        ('norm_scaler', utils.StandardNormScaler()),
                        ('kernel_constructor', utils.KernelConstructor()),
                        ('svc', OneVsRestClassifier(SVC(**svc_parameters)))
                    ],
                )
                
                dummy_gscv = GridSearchCV(
                    dummy_pipeline, dummy_parameter_grid,
                    scoring=[
                        'accuracy', 'balanced_accuracy',
                    ],
                    cv=StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0),
                    refit=False, return_train_score=True, error_score='raise', n_jobs=4
                )
                dummy_fit_params = {'norm_scaler__sample_weight': dummy_class_weights[n_cantons]}
                dummy_gscv.fit(deem_soaps[:, feature_idxs], dummy_cantons[n_cantons], **dummy_fit_params)

                # Prepare outputs
                output_dir = f'{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                os.makedirs(f'{work_dir}/{output_dir}', exist_ok=True)
                save_json(dummy_gscv.cv_results_, f'{work_dir}/{output_dir}/dummy_cv_results.json', array_convert=True)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…




## Check the cross-validated parameters

In [16]:
# Dummy DEEM model
for cutoff in cutoffs:
    work_dir = f'{model_dir}/{cutoff}/LSVC'
    
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        
        for group_name in group_names[spectrum_type]:
            for n_cantons in (2, 4):
                result_dir = f'{n_cantons}-Class/{spectrum_name}/{group_name}'
                cv_results = load_json(f'{work_dir}/{result_dir}/dummy_cv_results.json')
                print(f'-----Optimal Parameters for {cutoff} {spectrum_type} {group_name} {n_cantons} -----')
                
                for score in ('accuracy', 'balanced_accuracy'):
                    idx = np.argmin(cv_results[f'rank_test_{score}'])
                    opt_parameters = utils.get_optimal_parameters(cv_results, score, **svc_parameters)
                    print(f'{score} =', cv_results[f'mean_test_{score}'][idx])
                    print(opt_parameters)
                    print('')
                    
                    save_json(opt_parameters, f'{work_dir}/{result_dir}/dummy_svc_parameters_{score}.json')

-----Optimal Parameters for 3.5 power OO 2 -----
accuracy = 0.5041502300460092
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 0.01}

balanced_accuracy = 0.5036274704639421
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 0.01}

-----Optimal Parameters for 3.5 power OO 4 -----
accuracy = 0.2520254450890178
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 0.0001}

balanced_accuracy = 0.2503032171514569
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 10000.0}

-----Optimal Parameters for 3.5 power OSi 2 -----
accuracy = 0.5054509301860373
{'kernel': 'precomputed', 'decision_function_shape': 'ovo', 'class_weight': 'balanced', 'tol': 0.001, 'cache_size': 1000, 'C': 100.0}

balanced