In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np

# Plotting
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# ML
from regression import PCovR
from soap import extract_species_pair_groups

from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.compose import TransformedTargetRegressor

# Utilities
import h5py
import json
import itertools
from tqdm.notebook import tqdm
import project_utils as utils
from tools import load_json, save_json
from tempfile import mkdtemp
from shutil import rmtree

# Import COSMO style toolkit
import cosmoplot.colorbars as cosmocbars
import cosmoplot.utils as cosmoutils
import cosmoplot.style as cosmostyle

cosmostyle.set_style('article')
colorList = cosmostyle.color_cycle

In [3]:
sys.path.append('/scratch/helfrech/Sync/GDrive/Projects/KPCovR/kernel-tutorials')
sys.path.append('/scratch/helfrech/Sync/GDrive/Projects/KPCovR/KernelPCovR/analysis/scripts')
# from utilities.sklearn_covr.kpcovr import KernelPCovR as KPCovR2
# from utilities.sklearn_covr.pcovr import PCovR as PCovR2
from helpers import l_regr, l_kpcovr

# Load train and test splits

In [5]:
# Load SOAP cutoffs
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)
    
cutoffs = soap_hyperparameters['interaction_cutoff']

In [6]:
# Load train and test set indices for Deem
deem_train_idxs = np.loadtxt('../Processed_Data/DEEM_330k/train.idxs', dtype=int)
deem_test_idxs = np.loadtxt('../Processed_Data/DEEM_330k/test.idxs', dtype=int)
n_deem = len(deem_train_idxs) + len(deem_test_idxs)

7750 2250


In [None]:
# Load train and test set indices for IZA
iza_train_idxs = np.loadtxt('../Processed_Data/IZA_226/train.idxs', dtype=int)
iza_test_idxs = np.loadtxt('../Processed_Data/IZA_226/test.idxs', dtype=int)
n_iza = len(iza_train_idxs) + len(iza_test_idxs)

In [9]:
# Load IZA cantons
iza_cantons = np.loadtxt('../Raw_Data/GULP/IZA_226/cantons.txt', usecols=1, dtype=int)

In [11]:
# Build set of "master" canton labels
cantons = {}

cantons[4] = np.concatenate((
    cantons_iza, 
    np.ones(n_deem, dtype=int) * 4
))

cantons[2] = np.concatenate((
    np.ones(n_iza, dtype=int),
    np.ones(n_deem, dtype=int) * 2
))

In [12]:
# Load dummy cantons for Deem
dummy_cantons = {}
dummy_cantons[2] = np.loadtxt('../Processed_Data/DEEM_330k/dummy_cantons_2-class.dat', dtype=int)
dummy_cantons[4] = np.loadtxt('../Processed_Data/DEEM_330k/dummy_cantons_4-class.dat', dtype=int)

In [None]:
# Concatenate IZA and Deem indices
train_idxs = np.concatenate((iza_train_idxs, deem_train_idxs + n_iza))
test_idxs = np.concatenate((iza_test_idxs, deem_test_idxs + n_iza))

# Model setup

In [13]:
model_dir = '../Processed_Data/Models'

In [14]:
class_names = {}
class_names[2] = ['IZA', 'DEEM']
class_names[4] = ['IZA1', 'IZA2', 'IZA3', 'DEEM']

In [21]:
# Global model parameters
# pcovr_kwargs = dict(n_components=None, alpha=0.0, regularization=1.0E-12),

# TODO: load optimal parameters

# Linear PCovR

In [17]:
# Linear model setup
n_species = 2
group_names = {'power': ['OO', 'OSi', 'SiSi', 
                         'OO+OSi', 'OO+SiSi', 'OSi+SiSi',
                         'OO+OSi+SiSi'], 
               'radial': ['O', 'Si', 'O+Si']}

In [18]:
deem_name = 'DEEM_330k'
iza_name = 'IZA_226'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

In [18]:
batch_size = 100000

In [19]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    linear_dir = f'{model_dir}/{cutoff}/Linear_Models/PCovR'
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'

        # TODO: new SOAP loading with concatenation
        # want train soaps and all iza soaps and all deem soaps
        
        # Prepare loading of the DEEM 330k structures
        f = h5py.File(deem_file, 'r')
        deem_330k_dataset = f['0']
        
        n_features = soaps_train.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type,
                                                     combinations=True)
        
        # Prepare batches for PCovR
        n_samples_330k = deem_330k_dataset.len()
        n_batches = n_samples_330k // batch_size
        if n_samples_330k % batch_size > 0:
            n_batches += 1
        
        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                # TODO: decision function loading and concatenation
                
                # Prepare outputs
                output_dir = f'Linear_Models/PCovR/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                
                svc_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
                svc_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/svc_structure_dfs.dat'
                
                pcovr_projection_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structures.hdf5'
                pcovr_projection_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structures.hdf5'
                
                pcovr_df_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structure_dfs.dat'
                pcovr_df_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structure_dfs.dat'
                
                pcovr_cantons_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pcovr_structure_cantons.dat'
                pcovr_cantons_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pcovr_structure_cantons.dat'
                
                parameter_dir = f'{linear_dir}/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                pcovr_parameter_file = f'{parameter_dir}/pcovr_parameters.json'
                pcovr_model_file = f'{parameter_dir}/pcovr.json'
                
                cache_dir = mkdtemp()
                
                pipeline = Pipeline(
                    [
                        ('norm_scaler', utils.NormScaler()),
                        ('pcovr', TransformedTargetRegressor(
                            regressor=PCovR(**pcovr_parameters),
                            tranfsormer=utils.NormScaler(featurewise=True)
                        ))
                    ],
                    memory=cache_dir
                )
                
                pipeline.fit(soaps_train[:, feature_idxs], dfs_train)
                
                # Read the IZA structures and compute decision functions
                # and canton predictions
                pipeline.transform(iza_soaps[:, feature_idxs])
                pipeline.predict(iza_soaps[:, feature_idxs])
                
                # TODO: save iza
                # TODO: save the model
                
                # Prepare output arrays for batch processing
                T_deem = np.zeros((n_samples_330k, pcovr_model.Pxt.shape[1]))
                df_deem = np.zeros((n_samples_330k, pcovr_model.Pty.shape[1]))
                
                # Read the DEEM_330k structures and compute decision functions
                # and canton predictions in batches
                for i in tqdm(range(0, n_batches), desc='Batch', leave=False):
                    batch_slice = slice(i * batch_size, (i + 1) * batch_size)
                    
                    deem_330k_batch = deem_330k_dataset[batch_slice, feature_idxs]                    
                    T_deem[batch_slice] = pipeline.transform(deem_330k_batch)
                    df_deem[batch_slice] = pipeline.predict(deem_330k_batch)
                  
                # Squeeze the decision functions (required for df_to_class in the 2-class case)
                # and uncenter/unscale
                predicted_cantons_all_deem = utils.df_to_class(
                    df_all_deem, df_type=svc_parameters['multi_class'], n_classes=n_cantons
                )
                
                # Save DEEM 330k PCovR projections, decision functions, and canton predictions
                utils.save_hdf5(pcovr_projection_deem_file, T_deem, attrs=pcovr_parameters)
                np.savetxt(pcovr_df_deem_file, df_all_deem, fmt='%f')
                np.savetxt(pcovr_cantons_deem_file, predicted_cantons_all_deem, fmt='%d')
                
                rmtree(cache_dir)

HBox(children=(FloatProgress(value=0.0, description='Cutoff', max=2.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Spectrum', max=2.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Species', max=7.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Species', max=3.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Classes', max=2.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=4.0, style=ProgressStyle(description_width='i…


