# Import all required libraries

In [None]:
# # deactivate GPU 
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
import os
import h5py as hdf5
import numpy as np
import pandas as pd
import pickle
import itertools
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
print('Using TensorFlow {:s} with {:d} GPUs'.format(tf.__version__,len(tf.config.experimental.list_physical_devices('GPU'))))

# Load files that contain the functions used later

In [None]:
import catalogues as cat
import scaling as sca
import plotting as plo
import observations as obs
import reinforcement as rl
import optimizers as opt

# Global Parameters that should be fixed

In [None]:
# Minimum value for the minmax-normalised features and labels
minvalue   = 0.01

#Fraction of data in the validation set
vali_ratio = 0.1

#Fraction of data in the test set
test_ratio = 0.1

In [None]:
#Define the features that are used and give minimum and maximum values for the scaling
#This makes the minmax scaling independent of the exact input data
#The last column indicates if a feature is first taken to log before scaling
halo_features_used = [
    ['Scale', 0.0, 1.0, False],
    ['Halo_mass', 10.0, 16.0, False],
    ['Halo_mass_peak', 10.0, 16.0, False],
#     ['Halo_radius',      0.01,  3.0,   True],
    ['Halo_growth_rate', 0.01, 1000000.0, True],
    ['Halo_growth_peak', 0.01, 1000000.0, True],
    ['Scale_peak_mass', 0.0, 1.0, False],
#     ['Scale_half_peak_mass',  0.0,   1.0,   False],
    ['Concentration', 0.01, 10000.0, True],
#     ['Halo_spin',        0.001, 1.0,   True],
#     ['Merger', 'categorical'],
    ['Main_galaxies', 'categorical'],
    ['Central', 'categorical'],
#     ['Satellite', 'categorical'],
#     ['Orphan', 'categorical']
 ]

#Define the labels that are used and give minimum and maximum values for the scaling
#This makes the minmax scaling independent of the exact input data
#The last column indicates if a label is first taken to log before scaling
galaxy_labels_used = [
    ['Stellar_mass',     6.0,   13.0, False],
    ['SFR',              1.e-6, 1.e4, True] 
]
# Add used features and labels together
columns_used = halo_features_used + galaxy_labels_used

# pure name tags as lists
halo_columns_active = [column_name[0] for column_name in halo_features_used] 

galaxy_columns_active = [column_name[0] for column_name in galaxy_labels_used] 

columns_active= halo_columns_active + galaxy_columns_active 

halo_columns_scaled = [s + '_scaled' for s in halo_columns_active] 

galaxy_columns_scaled = [s + '_scaled' for s in galaxy_columns_active]

columns_scaled = [s + '_scaled' for s in columns_active]


# needed to prepare the features and labels
halo_columns_active2=[
 'Scale',
 'Halo_mass_scaled',
 'Halo_mass_peak_scaled',
#  'Halo_radius_scaled',  
 'Halo_growth_rate_scaled',
 'Halo_growth_peak_scaled',
 'Scale_peak_mass_scaled',
#  'Scale_half_peak_mass_scaled',
 'Concentration_scaled',
#  'Halo_spin_scaled',
#  'Merger',
 'Main_galaxies',
 'Central',
#  'Satellite',
#  'Orphan',
 'X_pos',
 'Y_pos',
 'Z_pos',
 'Scale_scaled'
 ]

galaxy_columns_active2=['Stellar_mass_scaled',
 'SFR_scaled',
 'weights',
 'Scale']

In [None]:
H0 = 67.8100
Om0 = 0.308000
Lbox = 100
a_scales = np.array([0.08, 0.09, 0.1 , 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18,
       0.19, 0.2 , 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29,
       0.3 , 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4 ,
       0.41, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5 , 0.51,
       0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6 , 0.61, 0.62,
       0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7 , 0.71, 0.72, 0.73,
       0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8 , 0.81, 0.82, 0.83, 0.84,
       0.85, 0.86, 0.87, 0.88, 0.89, 0.9 , 0.91, 0.92, 0.93, 0.94, 0.95,
       0.96, 0.97, 0.98, 0.99, 1.  ])
#Set the stellar mass array used to compute statistics
dmstar = 0.4
mstar_bins = np.arange(7.0,12.4,dmstar)
file_redshift = np.array([0. , 0.1, 0.2, 0.5, 1. , 2. , 3. , 4. , 6. , 8])
#Set contraints for correlation functions
#Set the redshift used to compare correlation functions to data
wp_redshift = 0.1
#Set the first and last data sets used for the fit
wp_start    = 1
wp_stop     = 4

#Set scaling parameters for the observational uncertainty
obssigma0 = 0.08
obssigmaz = 0.06
zmax_sig  = 4.0
ssfrmin   = 1.0e-12
ssfrthre  = 0.5

In [None]:
#Choose data type
dtypetf = tf.float32

#Use automatic mixed precision scaling:
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
#This *should* automatically set the data type to tf.float16 while running the model but store normal in tf.float32

# Load, Scale and Prepare Data for RNN

In [None]:
# ### View full Pandas ####
# pd.set_option('display.max_rows', None) 
# pd.set_option('display.max_columns', None)

In [None]:
# get galaxies, offset and snapshots as panda dfs
gal0, off0, snap0 = cat.load_MergerTree_to_panda_df('tree.0.h5', 'MergerTrees/h5')
gal1, off1, snap1 = cat.load_MergerTree_to_panda_df('tree.1.h5', 'MergerTrees/h5')
gal2, off2, snap2 = cat.load_MergerTree_to_panda_df('tree.2.h5', 'MergerTrees/h5')
gal3, off3, snap3 = cat.load_MergerTree_to_panda_df('tree.3.h5', 'MergerTrees/h5')
gal4, off4, snap4 = cat.load_MergerTree_to_panda_df('tree.4.h5', 'MergerTrees/h5')
gal5, off5, snap5 = cat.load_MergerTree_to_panda_df('tree.5.h5', 'MergerTrees/h5')
gal6, off6, snap6 = cat.load_MergerTree_to_panda_df('tree.6.h5', 'MergerTrees/h5')
gal7, off7, snap7 = cat.load_MergerTree_to_panda_df('tree.7.h5', 'MergerTrees/h5')
gal_list = [gal0, gal1, gal2, gal3, gal4, gal5, gal6, gal7]

In [None]:
# whole_dataset as panda
whole_dataset = cat.get_whole_dataset(gal_list, columns_active, columns_used, columns_scaled, HGRP=True,HPMS=True,
                                      Merger=True, Main_galaxies=True, Type=True)

# features and labels as tensors
halos = tf.convert_to_tensor(np.array(whole_dataset[halo_columns_active])) # only feature columns
halos_scaled = tf.convert_to_tensor(np.array(whole_dataset[halo_columns_scaled])) # feature columns scaled
galaxies = tf.convert_to_tensor(np.array(whole_dataset[galaxy_columns_active])) # label columns

# add weights to each galaxy
whole_dataset['weights'] = cat.get_loss_weights(galaxies,galaxy_labels_used,dm=0.2,norm=10.0)

In [None]:
#whole_dataset.describe()

In [None]:
# # one hot encode galaxy types
# central, satellite, orphan = cat.get_galaxy_type(whole_dataset)
# #save as pkl 
# with open('pkl_Data/central.pkl', 'wb') as f:
#     pickle.dump(central, f)
# with open('pkl_Data/satellite.pkl', 'wb') as f:
#     pickle.dump(satellite, f)
# with open('pkl_Data/orphan.pkl', 'wb') as f:
#     pickle.dump(orphan, f)

## Calculate Input Data 

In [None]:
# split them again
gal0_scaled = whole_dataset[0:len(gal0)]
gal1_scaled = whole_dataset[len(gal0):len(gal0)+len(gal1)]
gal2_scaled = whole_dataset[len(gal0)+len(gal1):len(gal0)+len(gal1)+len(gal2)]
gal3_scaled = whole_dataset[len(gal0)+len(gal1)+len(gal2):len(gal0)+len(gal1)+len(gal2)+len(gal3)]
gal4_scaled = whole_dataset[len(gal0)+len(gal1)+len(gal2)+len(gal3):len(gal0)+len(gal1)+len(gal2)+len(gal3)+len(gal4)]
gal5_scaled = whole_dataset[len(gal0)+len(gal1)+len(gal2)+len(gal3)+len(gal4):len(gal0)+len(gal1)+len(gal2)+len(gal3)+len(gal4)+len(gal5)]
gal6_scaled = whole_dataset[len(gal0)+len(gal1)+len(gal2)+len(gal3)+len(gal4)+len(gal5):len(gal0)+len(gal1)+len(gal2)+len(gal3)+len(gal4)+len(gal5)+len(gal6)]
gal7_scaled = whole_dataset[len(gal0)+len(gal1)+len(gal2)+len(gal3)+len(gal4)+len(gal5)+len(gal6):len(gal0)+len(gal1)+len(gal2)+len(gal3)+len(gal4)+len(gal5)+len(gal6)+len(gal7)]

In [None]:
# Get list of all single trees
trees0 = cat.split_dataset_into_MergerTrees(gal0_scaled, off0)
trees1 = cat.split_dataset_into_MergerTrees(gal1_scaled, off1)
trees2 = cat.split_dataset_into_MergerTrees(gal2_scaled, off2)
trees3 = cat.split_dataset_into_MergerTrees(gal3_scaled, off3)
trees4 = cat.split_dataset_into_MergerTrees(gal4_scaled, off4)
trees5 = cat.split_dataset_into_MergerTrees(gal5_scaled, off5)
trees6 = cat.split_dataset_into_MergerTrees(gal6_scaled, off6)
trees7 = cat.split_dataset_into_MergerTrees(gal7_scaled, off7)

In [None]:
# #calculate main_galaxies(main branch galaxies) and merger(snapshot positions where mergers on main_branch and general mergers happen)
# merger_list0, main_merger_list0, main_galaxies_list0 = cat.get_main_galaxies_and_mergers(trees0, whole_dataset)
# merger_list1, main_merger_list1, main_galaxies_list1 = cat.get_main_galaxies_and_mergers(trees1, whole_dataset)
# merger_list2, main_merger_list2, main_galaxies_list2 = cat.get_main_galaxies_and_mergers(trees2, whole_dataset)
# merger_list3, main_merger_list3, main_galaxies_list3 = cat.get_main_galaxies_and_mergers(trees3, whole_dataset)
# merger_list4, main_merger_list4, main_galaxies_list4 = cat.get_main_galaxies_and_mergers(trees4, whole_dataset)
# merger_list5, main_merger_list5, main_galaxies_list5 = cat.get_main_galaxies_and_mergers(trees5, whole_dataset)
# merger_list6, main_merger_list6, main_galaxies_list6 = cat.get_main_galaxies_and_mergers(trees6, whole_dataset)
# merger_list7, main_merger_list7, main_galaxies_list7 = cat.get_main_galaxies_and_mergers(trees7, whole_dataset)
# combined_merger_list = merger_list0 + merger_list1 + merger_list2 + merger_list3 + merger_list4 + merger_list5 + merger_list6 + merger_list7
# combined_main_merger_list = main_merger_list0 + main_merger_list1 + main_merger_list2 + main_merger_list3 + main_merger_list4 + main_merger_list5 + main_merger_list6 + main_merger_list7
# combined_main_galaxies_list = main_galaxies_list0 + main_galaxies_list1 + main_galaxies_list2 + main_galaxies_list3 + main_galaxies_list4 + main_galaxies_list5 + main_galaxies_list6 + main_galaxies_list7

In [None]:
# # save as pkl 
# with open('pkl_Data/merger_new.pkl', 'wb') as f:
#     pickle.dump(combined_merger_list, f)
# # save as pkl 
# with open('pkl_Data/main_merger.pkl', 'wb') as f:
#     pickle.dump(combined_main_merger_list, f)
# # save as pkl 
# with open('pkl_Data/main_galaxies_new.pkl', 'wb') as f:
#     pickle.dump(combined_main_galaxies_list, f)

In [None]:
##### divide full merger trees into sub branches
# If HGRP/HPMS = True also calculates the halo growth rate peak/Scale_half_peak_mass
# HGRP and HPMS are added as columns to the panda dfs
Reduced_trees0 = cat.split_trees(trees0, HGRP=False, HPMS=False)
Reduced_trees1 = cat.split_trees(trees1, HGRP=False, HPMS=False)
Reduced_trees2 = cat.split_trees(trees2, HGRP=False, HPMS=False)
Reduced_trees3 = cat.split_trees(trees3, HGRP=False, HPMS=False)
Reduced_trees4 = cat.split_trees(trees4, HGRP=False, HPMS=False)
Reduced_trees5 = cat.split_trees(trees5, HGRP=False, HPMS=False)
Reduced_trees6 = cat.split_trees(trees6, HGRP=False, HPMS=False)
Reduced_trees7 = cat.split_trees(trees7, HGRP=False, HPMS=False)

In [None]:
#load Reduced_trees from pkl_files 
with open('pkl_Data/Red_trees/Reduced_trees0.pkl', 'rb') as f:
    Reduced_trees0 = pickle.load(f)
    
with open('pkl_Data/Red_trees/Reduced_trees1.pkl', 'rb') as f:
    Reduced_trees1 = pickle.load(f)

with open('pkl_Data/Red_trees/Reduced_trees2.pkl', 'rb') as f:
    Reduced_trees2 = pickle.load(f)
    
with open('pkl_Data/Red_trees/Reduced_trees3.pkl', 'rb') as f:
    Reduced_trees3 = pickle.load(f)
    
with open('pkl_Data/Red_trees/Reduced_trees4.pkl', 'rb') as f:
    Reduced_trees4 = pickle.load(f)
    
with open('pkl_Data/Red_trees/Reduced_trees5.pkl', 'rb') as f:
    Reduced_trees5 = pickle.load(f)

with open('pkl_Data/Red_trees/Reduced_trees6.pkl', 'rb') as f:
    Reduced_trees6 = pickle.load(f)
    
with open('pkl_Data/Red_trees/Reduced_trees7.pkl', 'rb') as f:
    Reduced_trees7 = pickle.load(f)

In [None]:
Reduced_trees_combined = list(itertools.chain(Reduced_trees0, Reduced_trees1, Reduced_trees2, 
                                              Reduced_trees3, Reduced_trees4, Reduced_trees5, 
                                              Reduced_trees6, Reduced_trees7))

In [None]:
with open('pkl_Data/Red_trees/Reduced_trees_combined.pkl', 'rb') as f:
    Reduced_trees_combined = pickle.load(f)

In [None]:
# #calculate HPMS
# HPMS = get_HPMS(Reduced_trees_combined, whole_dataset)
# # save as pkl 
# with open('pkl_Data/HPMS.pkl', 'wb') as f:
#     pickle.dump(HPMS, f)

In [None]:
# #calculate HGRP
# HGRP = cat.get_HGRP(Reduced_trees_combined, whole_dataset)
# # save as pkl 
# with open('pkl_Data/HGRP.pkl', 'wb') as f:
#     pickle.dump(HGRP, f)

In [None]:
# zero padding and converts pandas to tensors
X,pos = cat.prepare_features(Reduced_trees_combined, halo_columns_active2, minvalue=0.00)

In [None]:
# # save as pkl 
# with open('pkl_Data/X_10_3.pkl', 'wb') as f:
#     pickle.dump(X, f)

In [None]:
# # save as pkl 
# with open('pkl_Data/pos.pkl', 'wb') as f:
#     pickle.dump(pos, f)

In [None]:
y,w =  cat.prepare_labels(Reduced_trees_combined, galaxy_columns_active2, minvalue=0.00)

In [None]:
# # save as pkl 
# with open('pkl_Data/y.pkl', 'wb') as f:
#     pickle.dump(y, f)

In [None]:
# # save as pkl 
# with open('pkl_Data/w.pkl', 'wb') as f:
#     pickle.dump(w, f)

## Load from pkl

In [None]:
with open('pkl_Data/Data/X_9.pkl', 'rb') as f:
    X = pickle.load(f)
    
with open('pkl_Data/Data/pos.pkl', 'rb') as f:
    pos = pickle.load(f)

with open('pkl_Data/Data/y.pkl', 'rb') as f:
    y = pickle.load(f)
    
with open('pkl_Data/Data/w.pkl', 'rb') as f:
    w = pickle.load(f)

## Split Data

In [None]:
X_train,X_vali,X_test,y_train,y_vali,y_test,w_train,w_vali,w_test,index = sca.split_data(X, y, np.array(w), vali_ratio=0.1, test_ratio=0.1)

In [None]:
# ### Remove Full View ####
# pd.reset_option('display.max_rows') # resets pandas options to default value

# Set RNN

### Either, load an existing Keras model...

In [None]:
#Either, load an existing Keras model...

model    = tf.keras.models.load_model('Master/final_models/Master_model_RNN.h5')
histfile = hdf5.File('Master/final_models/Master_model_RNN_hist.h5','r')
weights = model.get_weights()
config = model.get_config()
training = {'loss': np.array(histfile['loss']).tolist(), 'val_loss': np.array(histfile['val_loss']).tolist()} # dictionary with loss and val_loss
histfile.close()

### ... Fine Tune Model

In [None]:
import talos as ta

In [None]:
params = {
    'hidden_layers_1': [2],
   'hidden_layers_2': [2],
    'neurons_1': [2,4,8,16,32],
   'neurons_2': [2,4,8,16,32],
    'activation': ['tanh'],
    'optimizer': ['adam'],  
    'loss': ['mean_squared_error'],
    'kernel_initializer': ['lecun_uniform'],#'glorot_uniform',
    'batch_size': [100],
    'epochs':[50]
}

In [None]:
def create_model_fine_tune(X_train, y_train, w_train, X_vali, y_vali):
    patience = 40
    early_stopping_cb     = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=patience,restore_best_weights=True)
    checkpoint_cb         = tf.keras.callbacks.ModelCheckpoint('checkpoint_talos.01.h5',save_best_only=True)
    model = tf.keras.Sequential()
    for i in range(params['hidden_layers_1']):
        model.add(layers.GRU(units=params['neurons_1'], activation=params['activation'], return_sequences = True, kernel_initializer=params['kernel_initializer']))
    for i in range(params['hidden_layers_2']):
        model.add(layers.GRU(units=params['neurons_2'], activation=params['activation'], return_sequences = True, kernel_initializer=params['kernel_initializer']))
    model.add(layers.TimeDistributed(layers.Dense(2, kernel_initializer=params['kernel_initializer'])))
    model.compile(loss=params['loss'], optimizer=params['optimizer'])
    out = model.fit(X_train, y_train,
                    epochs=params['epochs'],
                    batch_size=params['batch_size'],
                    validation_data=(X_vali,y_vali),
                    callbacks=[early_stopping_cb,
                               checkpoint_cb]
                       )
    return out, model

In [None]:
scan_object = ta.Scan(x=X_train, y=y_train, x_val=X_vali, y_val=y_vali, params=params, model=create_model_fine_tune, 
                      experiment_name="test", random_method='quantum', seed=42)

In [None]:
best_model = scan_object.best_model(metric='val_loss', asc=True)
#best_model.summary()

In [None]:
scan_object.data #0.00104

In [None]:
# # save as pkl 
# with open('scan_object_1.pkl', 'wb') as f:
#     pickle.dump(scan_object, f)

###  ... or create new model

In [None]:
k_init0 = tf.keras.initializers.lecun_normal(seed=123450)
k_init1 = tf.keras.initializers.lecun_normal(seed=123451)
k_init2 = tf.keras.initializers.lecun_normal(seed=123452)
k_init3 = tf.keras.initializers.lecun_normal(seed=123453)
k_init4 = tf.keras.initializers.lecun_normal(seed=123454)
k_init5 = tf.keras.initializers.lecun_normal(seed=123455)
k_init6 = tf.keras.initializers.lecun_normal(seed=123456)
k_init7 = tf.keras.initializers.lecun_normal(seed=123457)
k_init8 = tf.keras.initializers.lecun_normal(seed=123458)
k_init9 = tf.keras.initializers.lecun_normal(seed=123459)
k_init10 = tf.keras.initializers.lecun_normal(seed=123460)

def create_model():

# variante 1
    model = tf.keras.Sequential()
    model.add(layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init1, input_shape=(93,9)))
    model.add(layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init2))
    model.add(layers.GRU(units=8, return_sequences = True, kernel_initializer=k_init3))
    model.add(layers.GRU(units=8, return_sequences = True, kernel_initializer=k_init4))
#     model.add(layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init5))
#     model.add(layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init9))
#     model.add(layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init7))
#     model.add(layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init8))
    model.add(layers.TimeDistributed(layers.Dense(2, kernel_initializer=k_init6)))

# variante 2
#     inputs1 = layers.Input(shape=(93,9))   
#     hidden1  = layers.GRU(units=32, return_sequences = True, kernel_initializer=k_init1)(inputs1)
#     hidden2  = layers.GRU(units=32, return_sequences = True, kernel_initializer=k_init2)(hidden1)
#     branch11  = layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init3)(hidden2)
#     branch21  = layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init4)(hidden2)
#     branch12 = layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init5)(branch11)
#     branch22 = layers.GRU(units=16, return_sequences = True, kernel_initializer=k_init6)(branch21)
#     output1   = layers.TimeDistributed(layers.Dense(1, kernel_initializer=k_init9))(branch12)
#     output2   = layers.TimeDistributed(layers.Dense(1, kernel_initializer=k_init10))(branch22)
#     concat   = tf.keras.layers.Concatenate()([output1, output2])
#     model    = tf.keras.models.Model(inputs=inputs1, outputs=concat)

    return model


model = create_model()
model.summary()

# Train RNN

In [None]:
#Set the seeds to get as much reproducibility as possible
np.random.seed(43)
tf.random.set_seed(42)

#Define the maximum number of epochs, the patience, and the batch size
epochs = 500
patience = 100
batch_size = 85

#Compile the model using the (weighted) mean absolute error
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) #, clipnorm=0.001
model.compile(loss='mse',optimizer=optimizer,  sample_weight_mode='temporal') 

#Define the checkpoints to store the models, early stopping, and plotting
checkpoint_cb         = tf.keras.callbacks.ModelCheckpoint('checkpoint.01.h5',save_best_only=True)
early_stopping_cb     = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=patience,restore_best_weights=True)

history = model.fit(
    X_train,
    y_train,
    sample_weight=w_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=(
        X_vali,
        y_vali,
        w_vali
    ),
    callbacks=[
        checkpoint_cb,
        early_stopping_cb,
    ]
)

#Store the training and validation loss histories in arrays
training = history.history

#Save fitting history
hf = hdf5.File('test_hist.h5', 'w')
for item in history.history:
    data = np.array(history.history[item])
    hf.create_dataset(item, data=data)
hf.close()
model.save('test.h5')

# Evaluate RNN


In [None]:
model    = tf.keras.models.load_model('Master/final_models/Master_model_RNN.h5')

In [None]:
model.summary()

In [None]:
#get best val_loss
np.min(training['val_loss'])

In [None]:
#Get the validation and training predictions 
y_vali_pred  = model.predict(X_vali, batch_size=10000)  
# y_pred = model.predict(X, batch_size=10000)
# y_test_pred  = model.predict(X_test, batch_size=10000)  

In [None]:
# Get the X_vali, v_vali and v_vali_pred without zero-padding values
X_vali_original, y_vali_original, y_vali_pred_original = cat.get_data_without_zeropadding(X_vali, y_vali, y_vali_pred,halo_features_used)
# X_original, y_original, y_pred_original = cat.get_data_without_zeropadding(X, y, y_pred, halo_features_used)
# X_test_original, y_test_original, y_test_pred_original = cat.get_data_without_zeropadding(X_test, y_test, y_test_pred,halo_features_used)

In [None]:
plo.plot_history2(training, fs=20, lw=3, ymin=0.0018, ymax=0.03)
plt.savefig('Master_training.png', dpi=100)

In [None]:
plo.compare_scaled_input_prediction_RNN(y_vali_original, y_vali_pred_original, file='Master_Compare.png')

In [None]:
plo.compare_input_prediction_RNN(y_vali_original, y_vali_pred_original,
                             y_test_original, y_test_pred_original,
                             galaxy_labels_used,lw=3,fs=30,
                             axis1=[7.0,12.3,7.0,12.3],
                             axis2=[-6.5,3.5,-6.5,3.5],
                             file='Master_Compare2.png')

In [None]:
fig = plt.figure(figsize=(18.0,6.0))

axis = [7.0,12.5,-5.9,2.9]
#plo.plot_main_sequence_panel_RNN(fig,halos,gal,0.0,0.5,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=1,axis=axis,barposition=[0.96,0.15,0.03,0.80],modelname='Emerge')

plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_original,0.0,0.5,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=1,axis=axis,barposition=[0.96,0.15,0.03,0.80],modelname='Emerge')
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_original,0.5,1.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=2,axis=axis)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_original,1.0,2.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=3,axis=axis)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_original,2.0,4.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=4,axis=axis)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_original,4.0,8.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=5,axis=axis)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_pred_original,0.0,0.5,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=6,axis=axis,modelname='RNN + RL',showredshift=False)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_pred_original,0.5,1.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=7,axis=axis,showredshift=False)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_pred_original,1.0,2.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=8,axis=axis,showredshift=False)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_pred_original,2.0,4.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=9,axis=axis,showredshift=False)
plo.plot_main_sequence_panel_RNN(fig,X_vali_original,y_vali_pred_original,4.0,8.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=10,axis=axis,showredshift=False)

plt.subplots_adjust(left=0.08, right=0.89, bottom=0.14, top=0.99, hspace=0.0, wspace=0.0)

plt.savefig('Main_Sequence_RNN_16x2+8x2.png', dpi=100)

plt.show()

In [None]:
fig = plt.figure(figsize=(18.0,6.0))

axis = [10.5,15.2,7.1,12.8]

plo.plot_shmr_panel_RNN(fig,X_original,y_original,0.0,0.5,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=1,axis=axis,barposition=[0.96,0.15,0.03,0.80],modelname='Emerge')
plo.plot_shmr_panel_RNN(fig,X_original,y_original,0.5,1.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=2,axis=axis)
plo.plot_shmr_panel_RNN(fig,X_original,y_original,1.0,2.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=3,axis=axis)
plo.plot_shmr_panel_RNN(fig,X_original,y_original,2.0,4.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=4,axis=axis)
plo.plot_shmr_panel_RNN(fig,X_original,y_original,4.0,8.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=5,axis=axis)
plo.plot_shmr_panel_RNN(fig,X_original,y_pred_original,0.0,0.5,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=6,axis=axis,modelname='RNN',showredshift=False)
plo.plot_shmr_panel_RNN(fig,X_original,y_pred_original,0.5,1.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=7,axis=axis,showredshift=False)
plo.plot_shmr_panel_RNN(fig,X_original,y_pred_original,1.0,2.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=8,axis=axis,showredshift=False)
plo.plot_shmr_panel_RNN(fig,X_original,y_pred_original,2.0,4.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=9,axis=axis,showredshift=False)
plo.plot_shmr_panel_RNN(fig,X_original,y_pred_original,4.0,8.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=10,axis=axis,showredshift=False)

plt.subplots_adjust(left=0.08, right=0.89, bottom=0.14, top=0.99, hspace=0.0, wspace=0.0)

plt.savefig('Shmr_RNN_RNN_16x2+8x2.png', dpi=100)

plt.show()

In [None]:
models_hist =['F_2_hist.h5','F_3_hist.h5','F_4_hist.h5','F_5_hist.h5','F_6_hist.h5',
              'F_7_hist.h5','F_8_hist.h5','F_9_hist.h5','F_10_hist.h5','F_11_hist.h5']
        
models =['F_2.h5','F_3.h5','F_4.h5','F_5.h5','F_6.h5',
         'F_7.h5','F_8.h5','F_9.h5','F_10.h5','F_11.h5']

In [None]:
models_hist =['4_hist.h5','2x8_hist.h5','2x8+2x4_hist.h5',
              '2x16+2x8_hist.h5','4x16+4x8_hist.h5','2x32_hist.h5',
              '64_hist.h5','2x32+2x16_hist.h5','4x32_hist.h5',
              '4x32+4x16_hist.h5']
       
models =['4.h5','2x8.h5','2x8+2x4.h5',
         '2x16+2x8.h5','4x16+4x8.h5','2x32.h5',
         '64.h5','2x32+2x16.h5','4x32.h5',
         '4x32+4x16.h5']

In [None]:
plo.plot_compare_models2(models_hist, models, ymin=0.004, ymax=0.06,fs=22,lw=3)

## Calculate mstar_integrated

### Prepare trees

In [None]:
model    = tf.keras.models.load_model('Master/final_models/Master_model_RNN_RL_3rd.h5')

In [None]:
y_pred = model.predict(X, batch_size=10000)

In [None]:
unscaled_labels, unscaled_predictions = cat.unscale_timeserieses(y,y_pred, galaxy_labels_used)

In [None]:
# # save as pkl 
# with open('pkl_Data/del1.pkl', 'wb') as f:
#     pickle.dump(unscaled_predictions, f)

In [None]:
with open('pkl_Data/del1.pkl', 'rb') as f:
    unscaled_predictions = pickle.load(f)

In [None]:
Reduced_trees_combined2 = cat.add_unscaled_predictions(Reduced_trees_combined, unscaled_predictions)

In [None]:
# # save as pkl 
# with open('pkl_Data/del2.pkl', 'wb') as f:
#     pickle.dump(Reduced_trees_combined2, f)

In [None]:
with open('pkl_Data/del2.pkl', 'rb') as f:
    Reduced_trees_combined2 = pickle.load(f)

In [None]:
Full_trees, indices = cat.create_full_merger_tree(Reduced_trees_combined2)

In [None]:
# # save as pkl 
# with open('Master/Preprocess/Full_trees_model_RNN_RL_3rd.pkl', 'wb') as f:
#     pickle.dump(Full_trees, f)

In [None]:
with open('Master/Preprocess/Full_trees_model_RNN_RL_3rd.pkl', 'rb') as f:
    Full_trees = pickle.load(f)

In [None]:
Full_trees_reduced = []
for tree in Full_trees:
    tree2 = tree[['Scale','ID','Up_ID','Desc_ID','Main_ID',
                  'Coprog_ID','Leaf_ID','Num_prog','Stellar_mass','SFR','mstar_pred','sfr_pred']]
    tree2 = tree2.sort_values(by=['Scale'])
    Full_trees_reduced.append(tree2)

In [None]:
# # save as pkl 
# with open('Master/Preprocess/Full_trees_reduced_model_RNN_RL_3rd.pkl', 'wb') as f:
#     pickle.dump(Full_trees_reduced, f)

In [None]:
with open('Master/Preprocess/Full_trees_reduced_model_RNN_RL_3rd.pkl', 'rb') as f:
    Full_trees_reduced = pickle.load(f)

### Calculation

In [None]:
calculate_mstar = cat.calculate_mstar(Full_trees_reduced, a_scales, exsitu = True) 

In [None]:
# # save as pkl 
# with open('Master/Preprocess/exsitu_all_RL.pkl', 'wb') as f:
#     pickle.dump(calculate_mstar, f)

In [None]:
with open('Master/Preprocess/exsitu_all_RL.pkl', 'rb') as f:
    calculate_mstar = pickle.load(f)

In [None]:
calculate_mstar_only_insitu  = cat.calculate_mstar(Full_trees_reduced, a_scales, exsitu = False) 

In [None]:
# # save as pkl 
# with open('Master/Preprocess/insitu_all_RL.pkl', 'wb') as f:
#     pickle.dump(calculate_mstar_only_insitu, f)

In [None]:
with open('Master/Preprocess/insitu_all_RL.pkl', 'rb') as f:
    calculate_mstar_only_insitu = pickle.load(f)

In [None]:
exsitu =  [item for sublist in calculate_mstar for item in sublist]

In [None]:
insitu =  [item for sublist in calculate_mstar_only_insitu for item in sublist]

### Graph

In [None]:
mstar_label = []
for i in Full_trees_reduced:
    mstar_label.append(i['Stellar_mass'].tolist())

In [None]:
mstar_label2 = [item for sublist in mstar_label for item in sublist]

In [None]:
mstar_pred = []
for i in Full_trees_reduced:
    mstar_pred.append(i['mstar_pred'].tolist())

In [None]:
mstar_pred2 = [item for sublist in mstar_pred for item in sublist]

In [None]:
compare_calculated_mstar(exsitu,insitu, mstar_pred2, mstar_label2,compare='sum', lw=3,fs=25,file='calculated_mstar_combined_mean2.png')

## Compare Main_branches

### Calculations

In [None]:
with open('Master/Preprocess/Full_trees_model_RNN_RL_3rd.pkl', 'rb') as f:
    Full_trees = pickle.load(f)

In [None]:
trees, main_branches, indices_list, len_list = cat.get_example_main_branches(Full_trees, number_trees=10, len_trees=100)
# low = 100
# high = 1000

In [None]:
Full_trees_reduced = []
for tree in trees:
    tree2 = tree[['Scale','ID','Up_ID','Desc_ID','Main_ID',
                  'Coprog_ID','Leaf_ID','Num_prog','Stellar_mass','SFR','mstar_pred','sfr_pred','Main_galaxies']]
    tree2 = tree2.sort_values(by=['Scale'])
    Full_trees_reduced.append(tree2)

In [None]:
calculate_mstar = cat.calculate_mstar(Full_trees_reduced, a_scales, indices_list, exsitu = True)

In [None]:
dfs = []
for counter, tree in enumerate(Full_trees_reduced):
    tree['mstar_integrated'] = calculate_mstar[counter]
    tree2 = tree[tree['Main_galaxies'] == 1.0]
    tree3 = tree2[['Scale','mstar_pred','Stellar_mass','mstar_integrated']]
    dfs.append(tree3)

In [None]:
# save as pkl 
with open('Master/Preprocess/example10_RNN_RL_low.pkl', 'wb') as f:
    pickle.dump(dfs, f)

### Graph

In [None]:
with open('Master/Preprocess/example10_RNN_low.pkl', 'rb') as f:
    example10_RNN = pickle.load(f)

with open('Master/Preprocess/example10_RNN_RL_low.pkl', 'rb') as f:
    example10_RNN_RL = pickle.load(f)

with open('Master/Preprocess/example10_NN_low.pkl', 'rb') as f:
    example10_NN = pickle.load(f)
    
with open('Master/Preprocess/example10_NN_low_HaloNet.pkl', 'rb') as f:
    example10_NN_RL = pickle.load(f)

In [None]:
plo.plot_main_branches(example10_RNN, example10_RNN_RL, example10_NN, example10_NN_RL, file='low.png')

In [None]:
example10_RNN[0]

In [None]:
example10_RNN_RL[0]

## Plot Baryon Efficiency

In [None]:
df = cat.calculate_baryon_df(tf.keras.models.load_model('Master/final_models/Master_model_RNN_RL_3rd.h5'), X, y, halo_features_used, galaxy_labels_used)

In [None]:
# only stellar mass as color coding
plo.plot_baryon_efficiency(df, file='test_bayron_compare.png')

In [None]:
compare_list = {
  'Stellar_mass': None,
  'SFR': 'log',
#  'Redshift': None,
#  'Scale': None,
  'Halo_mass': None,
  'Halo_peak_mass': None,
  'Halo_Growth_rate': 'log',
  'Halo_growth_peak': 'log',
#  'Scale_peak_mass': None,
  'Concentration': 'log',
#  'Main': None,
#  'Central': None
}

In [None]:
plo.plot_baryon_efficiency_2(df, compare_list=compare_list, file='test_bayron_compare1.png')

# Load observed statistical data

In [None]:
universe    = obs.load_statistics_file('statistics.h5') #It contains ['CSFRD', 'Chi2', 'Clustering', 'FQ', 'Model_Parameters', 'SMF', 'SSFR']
smf         = obs.average_smf_in_z_bins(universe,file_redshift,mstar_bins)
fq          = obs.average_fq_in_z_bins(universe,file_redshift,mstar_bins)
csfrd       = obs.average_csfrd_in_z_bins(universe,file_redshift)
ssfr        = obs.average_ssfr_in_z_bins(universe,file_redshift,mstar_bins)
wp, wp_mass = obs.get_clustering_data(universe,wp_start,wp_stop) #wp.shape=(n_attributs,n_sets,n_attr_entries) #wp_mass.shape=(n_sets, n_attr(min/max mass))

#Set the minimum and maximum radius for the correlation functions according to the observed values
rmin        = np.min(wp[0][wp[0]>0.0])
rmax        = np.max(wp[0][wp[0]>0.0]) -20
nrbin       = wp.shape[2]

# Load positions and set bin edges

In [None]:
positions          = cat.load_positions_RNN(whole_dataset)
mstar_bin_edges    = obs.get_bin_edges(mstar_bins)
redshift_bin_edges = obs.get_bin_edges(file_redshift)

In [None]:
# Set a dictionary for values that will be passed to the statistics functions
psodict = {
    'model': None,
    'modeltype_RNN': True,
    'X_RNN_input': X,
    'X': halos_scaled,
    'halos': halos,
    'galaxies': galaxies,
    'positions': positions,
    'pos_RNN_input': pos,
    'galaxy_labels_used': galaxy_labels_used,
    'halo_features_used': halo_features_used,
    'mstar_bin_edges': mstar_bin_edges,
    'redshift_bin_edges': redshift_bin_edges,
    'obssigma0': obssigma0,
    'obssigmaz': obssigmaz,
    'zmax_sig': zmax_sig,
    'ssfrthre': ssfrthre,
    'H0': H0,
    'Om0': Om0,
    'Lbox': Lbox,
    'dmstar': dmstar,
    'ssfrmin': ssfrmin,
    'wp_redshift': wp_redshift,
    'wp_mass': wp_mass,
    'rmin': rmin,
    'rmax': rmax,
    'nrbin': nrbin,
    'smf': smf,
    'fq': fq,
    'ssfr': ssfr,
    'csfrd': csfrd,
    'wp': wp
}

# Compute statistics for the galaxy catalogue

In [None]:
smf_em,smf_sig_em,fq_em,fq_sig_em,ssfr_em,ssfr_sig_em,csfrd_em,csfrd_sig_em,wp_em \
    = rl.compute_statistics(psodict=psodict)
rl.get_chi2(psodict=psodict,printchi=True)

In [None]:
chi2_list_em = np.round([1046.2301317161969, 496.9890074660465,267.76771598281334,40.44518624345288, 202.77466745905875, 38.253554564825336],1)

# Fit the parameters with Reinforcement Learning 

## Start with Model

In [None]:
#Load a RNN model to start with
model            = tf.keras.models.load_model('Master/final_models/Master_model_RNN_RL_2nd.h5')
best_parameters  = rl.get_weights(model, psodict)
psodict['model'] = model

## Particle Swarms

### Particles Swarm

In [None]:
#Create the swarm and define some fitting parameters
swarm = opt.PSOSwarm(n_particles=50, start_position=best_parameters,init_pos=0.05,seed=44,w=0.9,w_min=0.5,c1=0.9,c2=0.9)

#Train the swarm
swarm.train(rl.pso_loss,psodict=psodict, n_iterations=40, hist_file='pso.history.03.h5', gbest_file='pso.gbest.03.h5', gstop=0.5, n_loss=30)

#Get the best parameters from the swarm
best_pso_parameters = swarm.gbest_position

### Multi Particles Swarm

In [None]:
### n_particles must be >= 2 ###
Multi_swarm = opt.MSO(n_swarms=6, n_particles=8, start_position=best_parameters, 
                  mso_type='neutral', q=1.05, init_pos=0.003, pDeath = 0.005, pSwap = 0.005, w_max=0.9, 
                  w_min=0.5, c1_max=1.0, c1_min=0.8, c2_max=1.0, c2_min=0.8,
                  c3_max=0.45, c3_min=0.9, q_desc=False, RC=1, RP=10) #mso_types: charged, atomic, neutral
                                                         # q_desc implements a exp decay of the charge to q=1                                                      

Multi_swarm.train_MSO(rl.pso_loss, psodict, n_iterations=50, phase=True, gbest_file='Master_mso_2ndrun.h5') 
# if phase == True, c3 will be set to zero for half of the iterations
best_mso_parameters = Multi_swarm.gbest_position

# 1st run:Multi_swarm = opt.MSO(n_swarms=6, n_particles=8, start_position=best_parameters, 
#                   mso_type='atomic', q=1.05, init_pos=0.005, pDeath = 0.005, pSwap = 0.01, w_max=0.9, 
#                   w_min=0.4, c1_max=1.0, c1_min=0.8, c2_max=1.0, c2_min=0.8,
#                   c3_max=0.45, c3_min=0.9, q_desc=False, RC=1, RP=10) #mso_types: charged, atomic, neutral
#                                                          # q_desc implements a exp decay of the charge to q=1  
#Multi_swarm.train_MSO(rl.pso_loss, psodict, n_iterations=50, gbest_file='RL_RNN/X_9_mso.h5')


In [None]:
gbestfile = hdf5.File('Master/final_models/Master_model_RNN_RL_3rd_hist.h5','r')
filekeys  = [key for key in gbestfile.keys()]
best_mso_parameters = np.array(gbestfile['Best_position']).flatten()
best_mso_loss       = np.array(gbestfile['Best_loss'])
swarm_best_positions = np.array(gbestfile['Swarm_best_positions']) 
swarm_best_losses = np.array(gbestfile['Swarm_best_losses']) 
swarm_history = np.array(gbestfile['Swarm_history'])

gbestfile.close()

In [None]:
plo.plot_mso_history2(file1='Master/final_models/Master_model_RNN_RL_hist.h5', 
                 file2='Master/final_models/Master_model_RNN_RL_2nd_hist.h5',
                 file3='Master/final_models/Master_model_RNN_RL_3rd_hist.h5',
                 savefile='mso_history.png')

## Simulated Annealing

In [None]:
opt2 = opt.simulated_annealing(best_mso_parameters)

opt2.train(rl.pso_loss, psodict, scale_max=0.00005, scale_min=0.00005, maxsteps=1000, T0 = 0.1, debug=True)

## Brute Force Approach

In [None]:
asd = opt.optimizer(best_mso_parameters)
asd.train(rl.pso_loss, psodict, pos=2, delta=1, scale=0.1, gbest_file='optimizer.h5')
# pos = starting pos in parameter space, delta =number of changed parameters at one iteration, 
# scale =  range parameters are changed

## Save Model

In [None]:
#Either save the model...
rl.set_weights(model,best_mso_parameters,psodict)

model.save('Master/final_models/Master_model_RNN_RL_3rd.h5')

# #...or repeat the fitting (continue at previous cell afterwards)
# best_parameters = Multi_swarm.gbest_position

# Load a PSO history and best parameters from a file

In [None]:
# #Load PSO history from file
# historyfile = hdf5.File('pso/RNN_Merger_2x140+60(b=50)_pso.history.03.h5','r')
# filekeys    = [key for key in historyfile.keys()]
# history_pos = np.array(historyfile['Positions'])
# history_vel = np.array(historyfile['Velocities'])
# historyfile.close()

#Load best PSO position and value from file
gbestfile = hdf5.File('Master/final_models/Master_model_RNN_RL_3rd_hist.h5','r')
filekeys  = [key for key in gbestfile.keys()]
best_pso_parameters = np.array(gbestfile['Best_position']).flatten()
best_pso_loss       = np.array(gbestfile['Best_loss'])
gbestfile.close()

#Select a model and write the best parameters to it
model           = tf.keras.models.load_model('Master/final_models/Master_model_RNN.h5')
rl.set_weights(model,best_pso_parameters, psodict)

# Plot the global statistics

In [None]:
model    = tf.keras.models.load_model('Master/final_models/Master_model_RNN_RL_3rd.h5')

In [None]:
#First get the predictions
y_pred  = model.predict(X, batch_size=10000)

In [None]:
# remove zero padded values and unscale the features and targets
gal_pred, halos, positions, gal  = cat.data_without_zeropadding_RL(X, y, y_pred, galaxy_labels_used, halo_features_used, pos)

In [None]:
#Feed the predictions to the dictionary
psodict['positions']= positions
psodict['galaxies'] = gal_pred
psodict['halos'] = halos

In [None]:
#Compute the statistics for this model
smf_mod,smf_sig_mod,fq_mod,fq_sig_mod,ssfr_mod,ssfr_sig_mod,csfrd_mod,csfrd_sig_mod,wp_mod \
    = rl.compute_statistics(psodict=psodict)

rl.get_chi2(psodict=psodict,printchi=True)

In [None]:
chi2_list_mod = np.round([655.7301888784817, 265.6890717148436, 188.89909819609446, 32.03477032426668, 131.38075114897023, 37.72649749430668],1)

In [None]:
smf_obs_plot = smf.copy()
smf_obs_plot[0,0,-2] = smf_obs_plot[0,0,-2] - 0.2
plo.plot_smf(universe,file_redshift,mstar_bins,smf_obs_plot,smf_mod, smf_em, smf_huge=None,nx=5,ny=2,plotfile='GalaxyNet_RL_SMF.png')

In [None]:
redshifts_fq = np.array([file_redshift[2],file_redshift[3],file_redshift[4],file_redshift[5],file_redshift[6]])
fq_plot      = fq[:,[2,3,4,5,6],:]
fq_mod_plot  = fq_mod[[2,3,4,5,6],:]
fq_em_plot   = fq_em[[2,3,4,5,6],:]

plo.plot_fq(universe,redshifts_fq,mstar_bins,fq_plot,fq_mod_plot,fq_em_plot,nx=5,ny=1,plotfile='GalaxyNet_RL_FQ.png')

In [None]:
mstar_ssfr = np.array([mstar_bins[4],mstar_bins[6],mstar_bins[9],mstar_bins[11]])
ssfr_plot     = ssfr[:,:,[4,6,9,11]]
ssfr_mod_plot = ssfr_mod[:,[4,6,9,11]]
ssfr_em_plot  = ssfr_em[:,[4,6,9,11]]

plo.plot_ssfr(universe,file_redshift,mstar_ssfr,ssfr_plot,ssfr_mod_plot,ssfr_em_plot,nx=4,ny=1,plotfile='GalaxyNet_RL_SSFR.png')

In [None]:
plo.plot_csfrd(universe,file_redshift,csfrd,csfrd_mod, csfrd_em,plotfile='GalaxyNet_RL_CSFRD.png')

In [None]:
plo.plot_wp(universe,nx=4,ny=1,rad=wp[0,1],model=wp_mod,compare=wp_em,plotfile='GalaxyNet_RL_WP.png')

In [None]:
plo.compare_chi2(chi2_list_em, chi2_list_mod)

# Plot the Main Sequence and the SHMR

In [None]:
fig = plt.figure(figsize=(18.0,6.0))

axis = [7.0,12.5,-5.9,2.9]
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal,0.0,0.5,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=1,axis=axis,barposition=[0.96,0.15,0.03,0.80],modelname='Emerge', Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal,0.5,1.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=2,axis=axis, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal,1.0,2.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=3,axis=axis, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal,2.0,4.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=4,axis=axis, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal,4.0,8.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=5,axis=axis, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal_pred,0.0,0.5,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=6,axis=axis,modelname='RNN + RL',showredshift=False, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal_pred,0.5,1.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=7,axis=axis,showredshift=False, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal_pred,1.0,2.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=8,axis=axis,showredshift=False, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal_pred,2.0,4.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=9,axis=axis,showredshift=False, Unscale=False)
plo.plot_main_sequence_panel_RNN_RL(fig,halos,gal_pred,4.0,8.0,0,halo_features_used,galaxy_labels_used,H0=H0,Om0=Om0,plot_obs=True,nxpanel=5,nypanel=2,ipanel=10,axis=axis,showredshift=False, Unscale=False)

plt.subplots_adjust(left=0.08, right=0.89, bottom=0.14, top=0.99, hspace=0.0, wspace=0.0)

plt.savefig('Main_Sequence_RNN_16x2+8x2_RL.png', dpi=100)

plt.show()

In [None]:
fig = plt.figure(figsize=(18.0,6.0))

axis = [10.5,15.2,7.1,12.8]

plo.plot_shmr_panel_RNN_RL(fig,halos,gal,0.0,0.5,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=1,axis=axis,barposition=[0.96,0.15,0.03,0.80],modelname='Emerge', Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal,0.5,1.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=2,axis=axis, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal,1.0,2.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=3,axis=axis, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal,2.0,4.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=4,axis=axis, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal,4.0,8.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=5,axis=axis, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal_pred,0.0,0.5,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=6,axis=axis,modelname='RNN + RL',showredshift=False, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal_pred,0.5,1.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=7,axis=axis,showredshift=False, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal_pred,1.0,2.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=8,axis=axis,showredshift=False, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal_pred,2.0,4.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=9,axis=axis,showredshift=False, Unscale=False)
plo.plot_shmr_panel_RNN_RL(fig,halos,gal_pred,4.0,8.0,0,halo_features_used,galaxy_labels_used,nxpanel=5,nypanel=2,ipanel=10,axis=axis,showredshift=False, Unscale=False)

plt.subplots_adjust(left=0.08, right=0.89, bottom=0.14, top=0.99, hspace=0.0, wspace=0.0)

plt.savefig('RNN_merger_128+64shmr_RL.png', dpi=100)

plt.show()