In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np
from numpy.random import default_rng

# Plotting
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ML
from soap import extract_species_pair_groups

from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV


# Utilities
import h5py
import json
import itertools
from tempfile import mkdtemp
from shutil import rmtree
from copy import deepcopy
from tqdm.notebook import tqdm
import project_utils as utils
from tools import load_json, save_json

# Import COSMO style toolkit
import cosmoplot.colorbars as cosmocbars
import cosmoplot.utils as cosmoutils
import cosmoplot.style as cosmostyle

cosmostyle.set_style('article')
colorList = cosmostyle.color_cycle

# Load train and test splits

In [7]:
# Load SOAP cutoffs
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)
    
cutoffs = soap_hyperparameters['interaction_cutoff']

In [8]:
# Load train set and CV indices for Deem
deem_train_idxs = np.loadtxt('../Processed_Data/DEEM_330k/train.idxs', dtype=int)
deem_cv_idxs = np.loadtxt('../Processed_Data/DEEM_330k/cv_2.idxs', dtype=int)

In [None]:
# Load train set and CV indices for IZA
iza_train_idxs = np.loadtxt('../Processed_Data/IZA_226/train.idxs', dtype=int)
iza_cv_idxs = np.loadtxt('../Processed_Data/IZA_226/cv_2.idxs', dtype=int)

In [11]:
# Load IZA cantons
iza_cantons = np.loadtxt('../Raw_Data/GULP/IZA_226/cantons.txt', usecols=1, dtype=int)

In [None]:
# Load DEEM cantons
deem_cantons_2 = np.loadtxt('../Processed_Data/DEEM_330k/Data/cantons_2-class.dat', dtype=int)
deem_cantons_4 = np.loadtxt('../Processed_Data/DEEM_330k/Data/cantons_4-class.dat', dtype=int)

In [13]:
# Build set of "master" canton labels
cantons = {}

cantons[4] = np.concatenate((
    iza_cantons[iza_train_idxs], 
    deem_cantons_4[deem_train_idxs]
))

cantons[2] = np.concatenate((
    np.ones(len(iza_train_idxs), dtype=int),
    deem_cantons_2[deem_train_idxs]
))

In [None]:
# Load dummy Deem cantons to test the "null" case
dummy_cantons = {}
dummy_cantons[2] = np.loadtxt('../Processed_Data/DEEM_330k/dummy_cantons_2-class.dat', dtype=int)
dummy_cantons[4] = np.loadtxt('../Processed_Data/DEEM_330k/dummy_cantons_4-class.dat', dtype=int)

In [None]:
# Concatenate IZA and Deem CV indices
cv_idxs = np.vstack((iza_cv_idxs, deem_cv_idxs + len(iza_train_idxs)))

# Model setup

In [27]:
model_dir = '../Processed_Data/Models'

deem_name = 'DEEM_330k'
iza_name = 'IZA_226'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

In [24]:
# Template model parameters
svc_parameters = dict(
    dual=False,
    multi_class='ovr',
    class_weight='balanced',
    tol=1.0E-3 # Consistent with kernel SVC
)

# Regularization parameters for cross-validation
C = np.logspace(-5, 5, 11)
parameter_grid = dict(svc__C=C)

# Linear SVC

In [26]:
# Linear model setup
n_species = 2
group_names = {'power': ['OO', 'OSi', 'SiSi', 
                         'OO+OSi', 'OO+SiSi', 'OSi+SiSi',
                         'OO+OSi+SiSi'], 
               'radial': ['O', 'Si', 'O+Si']}

## Optimize LinearSVC parameters

In [30]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    work_dir = f'{model_dir}/{cutoff}/Linear_Models/SVC'
    
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_soaps = utils.load_hdf5(iza_file, indices=iza_train_idxs)

        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        deem_soaps = utils.load_hdf5(deem_file, indices=deem_train_idxs)
        
        soaps = np.vstack((iza_soaps, deem_soaps))
        
        n_features = soaps.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type,
                                                     combinations=True)
        
        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):

            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):

                cache_dir = mkdtemp()
                
                # We don't center the SOAPs for the SVC since we need
                # the weights to reflect the (uncentered) SOAPs
                # for the real space conversion.
                # The weights and decision functions (once accounting for the intercept)
                # will be very close between the centered and uncentered SOAPs
                # even if scaled afterwards, so centering doesn't appear
                # to have a large affect on the final predictions.
                
                # The NormScaler should be cached so that
                # it doesn't need to be re-fit when the input SOAPs
                # and parameters are the same
                pipeline = Pipeline(
                    [
                        ('norm_scaler', utils.NormScaler(with_mean=False)),
                        ('svc', LinearSVC(**svc_parameters))
                    ],
                    memory=cache_dir
                )
                
                # IZA + DEEM classification
                gscv = GridSearchCV(
                    pipeline, parameter_grid,
                    scoring=[
                        'accuracy', 'balanced_accuracy',
                        'roc_auc_ovr', 'roc_auc_ovr_weighted'
                    ],
                    cv=utils.cv_generator(cv_idxs),
                    refit=False, return_train_score=True, error_score='raise'
                )
                gscv.fit(soaps[:, feature_idxs], cantons[n_cantons])
                
                # Dummy DEEM classification
                dummy_gscv = GridSearchCV(
                    pipeline, parameter_grid,
                    scoring=[
                        'accuracy', 'balanced_accuracy',
                        'roc_auc_ovr', 'roc_auc_ovr_weighted'
                    ],
                    cv=utils.cv_generator(deem_cv_idxs),
                    refit=False, return_train_score=True, error_score='raise'
                )
                dummy_gscv.fit(deem_soaps[:, feature_idxs], dummy_cantons[n_cantons])
                
                # Prepare outputs
                output_dir = f'{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                os.makedirs(f'{work_dir}/{output_dir}', exist_ok=True)
                
                save_json(gscv.cv_results_, f'{work_dir}/{output_dir}/cv_results.json')
                save_json(dummy_gscv.cv_results_, f'{work_dir}/{output_dir}/dummy_cv_results.json')
                rmtree(cache_dir)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='C', max=11.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

HBox(children=(FloatProgress(value=0.0, description='n', max=5.0, style=ProgressStyle(description_width='initi…

0.01



## Check the cross-validated parameters

In [None]:
# IZA + DEEM classification
for cutoff in cutoffs:
    work_dir = f'{model_dir}/{cutoff}/Linear_Models/SVC'
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        for group_name in group_names[spectrum_type]:
            for n_cantons in (2, 4):
                result_dir = f'{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                cv_results = load_json(f'){work_dir}/{output_dir}/cv_results.json')
                print(f'-----Optimal Parameters for {cutoff} {spectrum_type} {group_name} {n_cantons} -----')
                for score in ('accuracy', 'balanced_accuracy', 'roc_auc_ovr', 'roc_auc_ovr_weighted'):
                    idx = np.argmin(cv_results[f'rank_test_{score}'])
                    opt_parameters = utils.get_optimal_parameters(cv_results, score, **svm_parameters)
                    print(f'{score} =', cv_results[f'mean_test_{score}'][idx])
                    print(opt_parameters)
                    print('')
                    
                    save_json(opt_parameters, f'{work_dir}/{output_dir}/svc_parameters_{score}.json')

In [None]:
# Dummy DEEM model
for cutoff in cutoffs:
    work_dir = f'{model_dir}/{cutoff}/Linear_Models/SVC'
    for spectrum_type in ('power', 'radial'):
        spectrum_name = spectrum_type.capitalize()
        for group_name in group_names[spectrum_type]:
            for n_cantons in (2, 4):
                result_dir = f'{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                cv_results = load_json(f'){work_dir}/{output_dir}/dummy_cv_results.json')
                print(f'-----Optimal Parameters for {cutoff} {spectrum_type} {group_name} {n_cantons} -----')
                for score in ('accuracy', 'balanced_accuracy', 'roc_auc_ovr', 'roc_auc_ovr_weighted'):
                    idx = np.argmin(cv_results[f'rank_test_{score}'])
                    opt_parameters = utils.get_optimal_parameters(cv_results, score, **svm_parameters)
                    print(f'{score} =', cv_results[f'mean_test_{score}'][idx])
                    print(opt_parameters)
                    print('')
                    
                    save_json(opt_parameters, f'{work_dir}/{output_dir}/dummy_svc_parameters_{score}.json')