In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np

# Plotting
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ML
from soap import extract_species_pair_groups
from skcosmo.decomposition import PCovR

from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.compose import TransformedTargetRegressor
from sklearn.model_selection import GridSearchCV

# Utilities
import h5py
import json
import itertools
from copy import deepcopy
from tqdm.notebook import tqdm
import project_utils as utils
from tools import load_json, save_json
from tempfile import mkdtemp
from shutil import rmtree
import functools

# Import COSMO style toolkit
import cosmoplot.colorbars as cosmocbars
import cosmoplot.utils as cosmoutils
import cosmoplot.style as cosmostyle

cosmostyle.set_style('article')
colorList = cosmostyle.color_cycle

# Load train and test splits

In [3]:
# Load SOAP cutoffs
soap_hyperparameters = load_json('../Processed_Data/soap_hyperparameters.json')   
cutoffs = soap_hyperparameters['interaction_cutoff']

In [4]:
# Load train sets for IZA and Deem
iza_train_idxs = np.loadtxt('../Processed_Data/IZA_230/svm_train.idxs', dtype=int)
iza_sort_idxs = np.argsort(iza_train_idxs)
iza_unsort_idxs = np.argsort(iza_sort_idxs)
deem_train_idxs = np.loadtxt('../Processed_Data/DEEM_330k/svm_train.idxs', dtype=int)

In [5]:
# Load cantons for IZA and Deem
iza_cantons = np.loadtxt('../Raw_Data/IZA_230/cantons.dat', usecols=1, dtype=int)
deem_cantons_2 = np.loadtxt('../Processed_Data/DEEM_330k/Data/cantons_2-class.dat', dtype=int)
deem_cantons_4 = np.loadtxt('../Processed_Data/DEEM_330k/Data/cantons_4-class.dat', dtype=int)

In [6]:
# Build set of "master" canton labels
cantons = {}

cantons[4] = np.concatenate((
    iza_cantons[iza_train_idxs], 
    deem_cantons_4[deem_train_idxs]
))

cantons[2] = np.concatenate((
    np.ones(len(iza_train_idxs), dtype=int),
    deem_cantons_2[deem_train_idxs]
))

# Build set of class weights (by sample) for centering and scaling
class_weights = {n_cantons: utils.balanced_class_weights(cantons[n_cantons]) for n_cantons in (2, 4)}

# Model setup

In [7]:
model_dir = '../Processed_Data/Models'

deem_name = 'DEEM_330k'
iza_name = 'IZA_230'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

In [8]:
# CV splits
n_splits = 2

pcovr_parameters = dict(n_components=2, tol=1.0E-10)
y_scaler_parameters = dict(featurewise=False)

# Linear PCovR

In [9]:
# Linear model setup
n_species = 2
group_names = {
    'power': ['OO', 'OSi', 'SiSi', 'OO+OSi', 'OO+SiSi', 'OSi+SiSi', 'OO+OSi+SiSi'], 
    'radial': ['O', 'Si', 'O+Si']
}

## Optimize PCovR parameters

In [10]:
mixings = np.linspace(0.0, 1.0, 11)
alphas = np.logspace(-10, 0, 11)
parameter_grid=dict(
    pcovr__regressor__regressor__alpha=alphas, 
    pcovr__regressor__mixing=mixings
)

In [11]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    work_dir = f'{model_dir}/{cutoff}/LPCovR'
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_soaps = utils.load_hdf5(iza_file, indices=iza_train_idxs[iza_sort_idxs])
        iza_soaps = iza_soaps[iza_unsort_idxs]
        
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        deem_soaps = utils.load_hdf5(deem_file, indices=deem_train_idxs)
        
        soaps = np.vstack((iza_soaps, deem_soaps))
        
        n_features = soaps.shape[1]
        feature_groups = extract_species_pair_groups(
            n_features, n_species, 
            spectrum_type=spectrum_type,
            combinations=True
        )
        
        for species_pairing, feature_idxs in zip(
            tqdm(group_names[spectrum_type], desc='Species', leave=False),
            feature_groups
        ):
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # Prepare inputs and outputs
                df_dir = f'LSVC/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                
                # Load decision functions
                iza_dfs = np.loadtxt(f'{iza_dir}/{cutoff}/{df_dir}/svc_structure_dfs.dat')
                iza_dfs = iza_dfs[iza_train_idxs]
                
                deem_dfs = np.loadtxt(f'{deem_dir}/{cutoff}/{df_dir}/svc_structure_dfs.dat')
                deem_dfs = deem_dfs[deem_train_idxs]
                
                dfs = np.concatenate((iza_dfs, deem_dfs))
                
                # Instead of using sample weights to account
                # for class imbalance in the below pipeline
                # (since this---at least to me--- seems
                # a bit tricky to do for PCovR),
                # we instead replicate the training samples.
                # This is done through ReplicatedStratifiedKFold,
                # which builds a stratified k-fold division
                # and then replicates the minority class examples
                # to *approximately* match the majority class
                # by replicating the training indices.
                # The test set indices are not replicated;
                # instead, the class-averaged scores are computed
                # using class-balanced scoring
                
                # Sub-pipeline for target data
                y_pipeline = Pipeline(
                    [
                        ('drop_features', utils.ColumnTransformerInverse([('drop_features', 'passthrough', [0])])),
                        ('norm_scaler', utils.StandardNormScaler(**y_scaler_parameters))
                    ]
                )
                
                # Model pipeline
                pipeline = Pipeline(
                    [
                        ('norm_scaler', utils.StandardNormScaler()),
                        ('pcovr', TransformedTargetRegressor(
                            regressor=PCovR(
                                **pcovr_parameters,
                                regressor=Ridge(fit_intercept=False, normalize=False)
                            ),
                            transformer=y_pipeline,
                            check_inverse=False
                        ))
                    ],
                )
                
                # IZA + DEEM classification
                gscv = GridSearchCV(
                    pipeline, parameter_grid,
                    scoring=functools.partial(utils.class_balanced_pcovr_score, class_col=-1),
                    cv=utils.ReplicatedStratifiedKFold(
                        n_splits=n_splits, stratify_col=-1, shuffle=True, random_state=0
                    ),
                    refit=False, return_train_score=True, error_score='raise'
                )
                gscv.fit(soaps[:, feature_idxs], np.column_stack((dfs, cantons[n_cantons])))
                
                # Prepare outputs
                output_dir = f'{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                os.makedirs(f'{work_dir}/{output_dir}', exist_ok=True)
                
                save_json(gscv.cv_results_, f'{work_dir}/{output_dir}/cv_results.json', array_convert=True)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…




## Check the cross-validated parameters

In [None]:
# TODO: also extract optimal for mixing = 0.0 and mixing = 1.0
# TODO: make alpha plot for a given regularization to check

In [33]:
# IZA + DEEM classification
for cutoff in cutoffs:
    work_dir = f'{model_dir}/{cutoff}/LPCovR'
    
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        
        for group_name in group_names[spectrum_type]:
            for n_cantons in (2, 4):
                result_dir = f'{n_cantons}-Class/{spectrum_name}/{group_name}'
                cv_results = load_json(f'{work_dir}/{result_dir}/cv_results.json')
                print(f'-----Optimal Parameters for {cutoff} {spectrum_type} {group_name} {n_cantons} -----')

                for score in ['score']:
                    idx = np.argmin(cv_results[f'rank_test_{score}'])
                    opt_parameters = utils.get_optimal_parameters(cv_results, score, **pcovr_parameters)
                    print(f'{score} =', cv_results[f'mean_test_{score}'][idx])
                    print(opt_parameters)
                    print('')
                    
                    save_json(opt_parameters, f'{work_dir}/{result_dir}/pcovr_parameters_{score}.json')

-----Optimal Parameters for 3.5 power OO 2 -----
score = -0.12581572867093654
{'n_components': 2, 'tol': 1e-10, 'mixing': 0.5, 'alpha': 1e-10}

-----Optimal Parameters for 3.5 power OO 4 -----
score = -0.10032391506037673
{'n_components': 2, 'tol': 1e-10, 'mixing': 0.5, 'alpha': 1e-10}

-----Optimal Parameters for 3.5 power OSi 2 -----
score = -0.1659369941175184
{'n_components': 2, 'tol': 1e-10, 'mixing': 0.4, 'alpha': 1e-09}

-----Optimal Parameters for 3.5 power OSi 4 -----
score = -0.13750220299203936
{'n_components': 2, 'tol': 1e-10, 'mixing': 0.4, 'alpha': 1e-10}

-----Optimal Parameters for 3.5 power SiSi 2 -----
score = -0.3276354066165555
{'n_components': 2, 'tol': 1e-10, 'mixing': 0.5, 'alpha': 0.01}

-----Optimal Parameters for 3.5 power SiSi 4 -----
score = -0.3215325511924047
{'n_components': 2, 'tol': 1e-10, 'mixing': 0.2, 'alpha': 1.0}

-----Optimal Parameters for 3.5 power OO+OSi 2 -----
score = -0.17803679306156894
{'n_components': 2, 'tol': 1e-10, 'mixing': 0.4, 'alph