# TODO: convert model saving and loading to save_json and load_json

In [1]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np

# ML
from sklearn.decomposition import PCA

# Utilities
import h5py
import json
from tqdm.notebook import tqdm
import project_utils as utils
from selection import FPS
from tools import load_json, save_json

# Initial setup

In [2]:
# Load train and test sets
deem_10k_idxs = np.loadtxt('../Processed_Data/DEEM_330k/deem_10k.idxs', dtype=int)

In [3]:
# Load SOAP cutoffs
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)
    
cutoffs = soap_hyperparameters['interaction_cutoff']

# Linear PCA

## IZA on DEEM 10k

In [4]:
# Load IZA cantons
cantons_iza = np.loadtxt('../Raw_Data/GULP/IZA_226/cantons.txt', usecols=1, dtype=int)
n_iza = len(cantons_iza)

In [5]:
deem_name = 'DEEM_330k'
iza_name = 'IZA_226'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

In [6]:
batch_size = 10000
n_components = 3

In [11]:
for cutoff in tqdm(cutoffs):
    
    # Set PCA parameters
    pca_parameters = dict(n_components=n_components)
    
    # Set data directory
    model_dir = f'../Processed_Data/Models/{cutoff}/Linear_Models/PCA'
    output_dir = 'Linear_Models/PCA'
    
    # Prepare output files and directories
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    
    if not os.path.exists(f'{deem_dir}/{cutoff}/{output_dir}'):
        os.makedirs(f'{deem_dir}/{cutoff}/{output_dir}')
        
    if not os.path.exists(f'{iza_dir}/{cutoff}/{output_dir}'):
        os.makedirs(f'{iza_dir}/{cutoff}/{output_dir}')
    
    pca_model_file = f'{model_dir}/pca.json'
    
    # Set working directory
    pca_deem_file = f'{deem_dir}/{cutoff}/{output_dir}/pca_structures.hdf5'
    pca_iza_file = f'{iza_dir}/{cutoff}/{output_dir}/pca_structures.hdf5'

    # SOAP files
    deem_file = f'{deem_dir}/{cutoff}/soaps_power_full_avg_nonorm.hdf5'
    iza_file = f'{iza_dir}/{cutoff}/soaps_power_full_avg_nonorm.hdf5'
    
    # Prepare loading of the DEEM 330k structures 
    f = h5py.File(deem_file, 'r')
    deem_330k = f['0']
    deem_10k = deem_330k[deem_10k_idxs, :]
    
    # Prepare batches for PCA on the 330k
    n_samples_330k = deem_330k.len()
    n_batches = n_samples_330k // batch_size
    if n_samples_330k % batch_size > 0:
        n_batches += 1
        
    # Load IZA structures
    # TODO: change this once IZA SOAPs are saved concatenated
    iza_226 = utils.load_structures_from_hdf5(iza_file, datasets=None, concatenate=True)
        
    # Initialize PCA for structures
    pca_structures = PCA(**pca_parameters)
    
    # Fit the PCA
    pca_structures.fit(deem_10k)
    
    # Save the PCA model
    save_json(pca_structures.__dict__, pca_model_file, array_convert=True)
    
    # Compute IZA PCA projections
    T_iza_226 = pca_structures.transform(iza_226)
    
    utils.save_hdf5(pca_iza_file, T_iza_226, attrs=pca_structures.get_params())
            
    # Load PCA model (if needed)
    # pca_dict = load_json(pca_model_file, array_convert=True)
    # pca = PCA()
    # pca.__dict__ = pca_dict
        
    # Transform the data and save
    # Prepare output arrays for batch processing
    T_deem_330k = np.zeros((n_samples_330k, pca_structures.n_components_))

    # Read the DEEM_330k structures and compute decision functions
    # and canton predictions in batches
    for i in tqdm(range(0, n_batches), desc='Batch', leave=False):
        batch_slice = slice(i * batch_size, (i + 1) * batch_size)
        deem_330k_batch = deem_330k[batch_slice, :]
        T_deem_330k[batch_slice] = pca_structures.transform(deem_330k_batch)
    
    f.close()
    utils.save_hdf5(pca_deem_file, T_deem_330k, attrs=pca_parameters)

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Batch', max=34.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Batch', max=34.0, style=ProgressStyle(description_width='…


