K.clear_session() is useful when you're creating multiple models in succession, such as during hyperparameter search or cross-validation. Each model you train adds nodes (potentially numbering in the thousands) to the graph. TensorFlow executes the entire graph whenever you (or Keras) call tf.Session.run() or tf.Tensor.eval(), so your models will become slower and slower to train, and you may also run out of memory

del will delete variable in python and since model is a variable, del model will delete it but the TF graph will have no changes (TF is your Keras backend). This said, K.clear_session() will destroy the current TF graph and creates a new one. Creating a new model seems to be an independent step, but don't forget the backend :

In [3]:
#%matplotlib widget

import sys
import os
sys.path.append('../../Utils')
from metrics import compute_metrics
from sklearn.model_selection import train_test_split
import sklearn.metrics
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pickle
import time
import gpflow
from gpflow.utilities import print_summary
from gpflow.config import default_float
import scipy.stats
import warnings

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

current_dir = os.getcwd()
data_dir = os.path.join(current_dir, '../../../Data/')

RNA_PROT_EMBED = data_dir+'ProcessedData/protein_embeddings/rna_protein_u64embeddings.pkl'
tf.keras.backend.set_floatx('float64')
#gpflow.config.set_default_float('float32')


In [4]:
device = 1
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(physical_devices[device], 'GPU')

print(f'TF eager exectution: {tf.executing_eagerly()}')
print(f'Using device {physical_devices[device]}')

TF eager exectution: True
Using device PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')


### Try with different activation functions, ReLu and log-sigmoid
* Manifold GP uses activations after every output, try omitting final activation as well.

In [5]:
class General_MLP(keras.Model):
    def __init__(self,hidden_nodes,input_shape,activation='relu',last=False):
        '''
        https://www.tensorflow.org/api_docs/python/tf/keras/Model
        hidden nodes (array like) - all the dimensions after input including output size
        input shape is optional (tuple), if not specified then network takes input shape as the shape of the first vector passed to it.
        activation (string), type of activation function to use, must be in keras activations
        last (Bool), whether or not to have activation on the final output layer
        ex:
        atlas_mlp = General_MLP([66,1],input_shape=(1,66))
        atlas_mlp.summary()
        '''
        super(General_MLP, self).__init__()
        self.mlp_layers = []
        for nodes in hidden_nodes[0:-1]: 
            self.mlp_layers.append(keras.layers.Dense(nodes, activation=activation))
            
        if last:
            self.mlp_layers.append(keras.layers.Dense(hidden_nodes[-1], activation=activation))
        else:
            self.mlp_layers.append(keras.layers.Dense(hidden_nodes[-1]))
            
        #Specific line is to cast the output to the gpflow default precision
        self.mlp_layers.append(tf.keras.layers.Lambda(lambda x: tf.cast(x, default_float())))
        self.out_size = hidden_nodes[-1]
        
        if type(input_shape) != type(None):
            self.in_size = input_shape[1]
            self.build(input_shape)
    
    #training flag if specific layers behave differently (ex: batch norm), for mlp no difference
    def call(self, inputs, training=True):
        for layer in self.mlp_layers:
            inputs = layer(inputs)
        return inputs

In [6]:
class nn_based_kernel(gpflow.kernels.Kernel):
    def __init__(self,base_kernel: gpflow.kernels.Kernel,nn_model):
        super(nn_based_kernel, self).__init__()
        assert(nn_model.built),"NN model is not built, input shape is not initialized"

        self.model = nn_model
        self.base_kernel = base_kernel
    
    def K(self,X,X2=None,presliced=False):
        """
        If you add a method in the child class with the same name as a function in the
        parent class, the inheritance of the parent method will be overridden.
        """
        transformed_X = self.model(X)
        transformed_X2 = self.model(X2) if X2 is not None else X2
        return self.base_kernel.K(transformed_X, transformed_X2, presliced)
    
    def K_diag(self, X_input,presliced=False):
        transformed_X = self.model(X_input)
        return self.base_kernel.K_diag(transformed_X, presliced)

In [7]:
with open(RNA_PROT_EMBED,'rb') as file:
    rna_prot_embed = pickle.load(file)

In [8]:
#Code Parameters
#-----------------------------------------
cols_drop=['ProteinAUC']
MRNA_THRESH = 3
ZSCORE = True
BATCH = 32
SAVE = False
LOG_TRANS = True
#-----------------------------------------

data = rna_prot_embed['AM_04M_F0'].copy()
data.drop(columns='AvgChrs',inplace=True)
data = data[data['mRNA_TMM']>MRNA_THRESH]

if LOG_TRANS:
    #Log transform mRNA, protein, and protein length -> log-normal distributed
    data['mRNA_TMM'] = np.log2(data['mRNA_TMM']+1)
    data['ProteinAUC'] = np.log2(data['ProteinAUC']+1)
    data['ProteinLength'] = np.log2(data['ProteinLength']+1)

SEED = 10
train,test = train_test_split(data,test_size=0.2,random_state=SEED)
SEED = 42
test,val = train_test_split(test,test_size=0.5,random_state=SEED)
train.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mRNA_TMM,ProteinAUC,ProteinLength,0,1,2,3,4,5,6,...,54,55,56,57,58,59,60,61,62,63
Gene.names,Majority.protein.IDs,cell,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
Sept11,Q8C1B7,AM_04M_F0,7.945561,27.64407,8.754888,-0.111104,0.127562,-0.109677,-0.974325,-0.039085,-0.02864,0.120612,...,0.116717,0.022352,-0.087753,-0.003184,-0.031651,0.522554,-0.158445,-0.016598,0.383635,0.050379
Dido1,Q8C9B9,AM_04M_F0,6.389982,22.593388,11.140191,-0.059608,0.060762,-0.162357,-0.977233,-0.078824,-0.384336,-0.06512,...,0.227555,0.084422,0.169869,-0.154971,0.003582,0.260411,-0.20982,0.005686,0.1546,0.004119
Gpcpd1,Q8C0L9,AM_04M_F0,8.533405,23.408089,9.400879,-0.049922,0.094703,-0.123327,-0.959221,-0.022357,-0.226174,0.083336,...,0.129952,0.008575,-0.026459,-0.018394,-0.034741,0.239011,-0.096025,-0.05263,0.390176,-0.008997
Rab19,P35294,AM_04M_F0,4.558499,21.870202,7.768184,-0.111689,0.13359,-0.073235,-0.963093,0.005832,-0.131049,0.120637,...,0.12837,0.011257,-0.101153,0.017144,-0.06718,0.512086,-0.108738,-0.049255,0.499497,-0.009795
Cdk2,P97377,AM_04M_F0,5.064248,23.498134,8.438792,-0.072193,0.104111,-0.129663,-0.968208,-0.037051,-0.105913,0.078196,...,0.106188,0.012131,-0.059163,-0.022869,-0.027992,0.340807,-0.116832,-0.048478,0.384089,-0.002199


#### Try not transforming embeddings?

In [9]:
def zscore(train_df):
    assert isinstance(train_df,pd.DataFrame)
    means = train_df.mean(axis=0)
    stds = train_df.std(axis=0)
    zscored = (train_df-means)/stds
    return zscored, means, stds

if ZSCORE:
    print(f'Data is z-scored')
    train, train_mean, train_std = zscore(train) #zscore data
    val = (val-train_mean)/train_std #zscore validation data using mean and std from train set
    test = (test-train_mean)/train_std #zscore test data using mean and std from train set

Data is z-scored


In [10]:
x_train = train.drop(columns=cols_drop).values
y_train = train[['ProteinAUC']].values

x_val = val.drop(columns=cols_drop).values
y_val = val[['ProteinAUC']].values

x_test = test.drop(columns=cols_drop).values
y_test = test[['ProteinAUC']].values

print(f'train dataset size: {x_train.shape}')
print(f'validation dataset size: {x_val.shape}')
print(f'test dataset size: {x_test.shape}')

trn_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train))
trn_dataset = trn_dataset.shuffle(buffer_size=x_train.shape[0]).batch(BATCH) #I think default is 32

val_dataset = tf.data.Dataset.from_tensor_slices((x_val,y_val))
val_dataset = val_dataset.batch(BATCH)

train dataset size: (2820, 66)
validation dataset size: (353, 66)
test dataset size: (352, 66)


#### Gridsearch parameters

In [11]:
first_hidden = [0,32,64,96]
hidden_units = np.arange(16,72,16)
final_units = np.arange(4,18,4)
l2_penalty = np.geomspace(.1,1000,num=5)
activations = ['relu','sigmoid','tanh']
final_act = [False,True]

In [12]:
gridsearch_results = pd.DataFrame()
trial_name = dict()
counter = 0

#Try except control flow
class generalization_increase(Exception):
    pass

for act in activations:
    for f_act in final_act:
        for s_u in first_hidden:
            for h_u in hidden_units:
                for f_u in final_units:
                    for pen in l2_penalty:
                        trial_name[counter] = act+f',final_act:{f_act},{s_u},{h_u},{f_u},{pen}'

                        #Reset/delete all graphs in case it slows down
                        tf.keras.backend.clear_session() #destroys graph and all values

                        #Model Initialization
                        if s_u==0:
                            manifold_mlp = General_MLP([h_u,f_u],input_shape=(1,66),last=f_act,activation=act)
                        else:
                            manifold_mlp = General_MLP([s_u,h_u,f_u],input_shape=(1,66),last=f_act,activation=act)
                        base_kernel = gpflow.kernels.SquaredExponential(lengthscale=[1]*manifold_mlp.out_size) #Initialize ARD for lengthscale
                        k = nn_based_kernel(base_kernel,manifold_mlp)
                        model = gpflow.models.GPR(data=(x_train, y_train), kernel=k, mean_function=None)
                        model.likelihood.variance.assign(0.1)

                        #Parameter Tracking
                        neg_likelihood = list()
                        neg_likelihood.append(-model.log_marginal_likelihood().numpy())
                        lengthscales = model.kernel.base_kernel.lengthscale.numpy().copy()
                        kernel_variance = list()
                        kernel_variance.append(model.kernel.base_kernel.variance.numpy())
                        model_variance = list()
                        model_variance.append(model.likelihood.variance.numpy())
                        validation_mse = list()
                        validation_pearson = list()
                        regularizer = tf.keras.regularizers.l2(l=pen)

                        #Optimization function
                        opt = gpflow.optimizers.Scipy()
                        def objective_closure():
                            global lengthscales, x_val, y_val
                            neg_likelihood.append(-model.log_marginal_likelihood().numpy())
                            lengthscales = np.vstack((lengthscales,model.kernel.base_kernel.lengthscale.numpy().copy()))
                            kernel_variance.append(model.kernel.base_kernel.variance.numpy())
                            model_variance.append(model.likelihood.variance.numpy())

                            mean, var = model.predict_f(x_val)
                            validation_mse.append(sklearn.metrics.mean_squared_error(mean.numpy(),y_val))
                            validation_pearson.append(scipy.stats.pearsonr(mean.numpy().squeeze(),y_val.squeeze())[0])

                            weight_penalty = 0
                            for model_variables in model.trainable_variables:
                                if 'dense' in model_variables.name and 'bias' not in model_variables.name:
                                    weight_penalty+=regularizer(model_variables)

                            return - model.log_marginal_likelihood() + weight_penalty

                        def early_stopping(xi):
                            if len(validation_mse)>5 and validation_mse[-1]>validation_mse[-2]>validation_mse[-3]>validation_mse[-4]>validation_mse[-5]:
                                raise generalization_increase()

                        #Optimization
                        time_start = time.time()
                        try:
                            opt_logs = opt.minimize(objective_closure,
                                                    model.trainable_variables,
                                                    options=dict(maxiter=100),
                                                    callback = early_stopping)
                        except generalization_increase:
                            print(f'Finished trial {trial_name[counter]}, mse {validation_mse[-1]}, p_r {validation_pearson[-1]}')
                        print(f'Run time {time.time()-time_start}')

                        #Saving evaluation metrics
                        temp_dict = dict()
                        temp_dict['val_mse_final'] = validation_mse[-1]
                        temp_dict['pearson_final'] = validation_pearson[-1]
                        temp_dict['neg_likelihood'] = neg_likelihood.copy()
                        temp_dict['lengthscales'] = lengthscales.copy()
                        temp_dict['k_var'] = kernel_variance.copy()
                        temp_dict['m_var'] = model_variance.copy()
                        temp_dict['val_mse'] = validation_mse.copy()
                        temp_dict['val_pearson'] = validation_pearson.copy()
                        gridsearch_results = gridsearch_results.append(temp_dict.copy(),ignore_index=True)

                        counter+=1
                
gridsearch_results.rename(trial_name,inplace=True)

Finished trial relu,final_act:False,0,16,4,0.1, mse 0.6166949723087195, p_r 0.6858132774934939
Run time 13.574345827102661
Finished trial relu,final_act:False,0,16,4,1.0, mse 0.6839287456281923, p_r 0.6480175775641853
Run time 14.196211338043213
Finished trial relu,final_act:False,0,16,4,10.0, mse 0.6019086630020496, p_r 0.6955685113773331
Run time 39.43218207359314
Finished trial relu,final_act:False,0,16,4,100.0, mse 0.504985211473711, p_r 0.7562333827564004
Run time 23.150994062423706
Finished trial relu,final_act:False,0,16,4,1000.0, mse 0.5135323443519485, p_r 0.7522074941408202
Run time 59.66253900527954
Finished trial relu,final_act:False,0,16,8,0.1, mse 0.7107788308577221, p_r 0.6341704225913564
Run time 15.078193664550781
Finished trial relu,final_act:False,0,16,8,1.0, mse 0.7105044154053084, p_r 0.6421200122605712
Run time 19.190613985061646
Finished trial relu,final_act:False,0,16,8,10.0, mse 0.6022086278683378, p_r 0.6943834511272167
Run time 15.079932689666748
Finished tri



Run time 28.31288480758667
Finished trial relu,final_act:False,64,32,4,0.1, mse 0.7302707416535366, p_r 0.6168678247238157
Run time 8.967565536499023
Finished trial relu,final_act:False,64,32,4,1.0, mse 0.6740087694558943, p_r 0.6540847099714561
Run time 8.957745790481567
Finished trial relu,final_act:False,64,32,4,10.0, mse 0.65043000667008, p_r 0.6727897725117746
Run time 21.346724033355713
Finished trial relu,final_act:False,64,32,4,100.0, mse 0.546767757114431, p_r 0.7281448153624006
Run time 48.983150482177734
Finished trial relu,final_act:False,64,32,4,1000.0, mse 0.5320956416284645, p_r 0.7392397741466256
Run time 56.46200156211853
Finished trial relu,final_act:False,64,32,8,0.1, mse 0.7658756239827453, p_r 0.6185973170552537
Run time 17.2268168926239
Finished trial relu,final_act:False,64,32,8,1.0, mse 0.7732592891945859, p_r 0.6037132460180031
Run time 10.340738773345947
Finished trial relu,final_act:False,64,32,8,10.0, mse 0.6985112623746756, p_r 0.645093125050678
Run time 10



Finished trial relu,final_act:True,32,16,4,1000.0, mse 1.1649317982031597, p_r nan
Run time 6.199992656707764
Finished trial relu,final_act:True,32,16,8,0.1, mse 0.6821016034955396, p_r 0.652352706565203
Run time 16.582798957824707
Finished trial relu,final_act:True,32,16,8,1.0, mse 0.6965438709668563, p_r 0.6431930713646732
Run time 15.969040632247925
Finished trial relu,final_act:True,32,16,8,10.0, mse 0.5805397410249007, p_r 0.7091987511893191
Run time 19.343472242355347
Finished trial relu,final_act:True,32,16,8,100.0, mse 0.5138199309612393, p_r 0.7484413925474076
Run time 28.925281047821045
Finished trial relu,final_act:True,32,16,8,1000.0, mse 0.5491741545583921, p_r 0.7337509890129071
Run time 27.547584056854248
Finished trial relu,final_act:True,32,16,12,0.1, mse 0.7224552694472269, p_r 0.6247494788065191
Run time 9.656793594360352
Finished trial relu,final_act:True,32,16,12,1.0, mse 0.5966537597641396, p_r 0.6978434189494227
Run time 11.725154161453247
Finished trial relu,fin



Run time 31.01314902305603
Finished trial sigmoid,final_act:False,64,48,8,0.1, mse 0.5342364606234514, p_r 0.7345755657413022
Run time 17.913865089416504
Finished trial sigmoid,final_act:False,64,48,8,1.0, mse 0.5408653381271108, p_r 0.7306435918215346
Run time 24.81158685684204
Finished trial sigmoid,final_act:False,64,48,8,10.0, mse 0.5233128134193135, p_r 0.7428584528513347
Run time 39.967865228652954
Finished trial sigmoid,final_act:False,64,48,8,100.0, mse 0.5603058147796537, p_r 0.7281180816530717
Run time 22.835940837860107
Finished trial sigmoid,final_act:False,64,48,8,1000.0, mse 1.164931588650019, p_r 0.274942523412127
Run time 4.823401689529419
Finished trial sigmoid,final_act:False,64,48,12,0.1, mse 0.5392429605540041, p_r 0.7316770473858767
Run time 12.423718690872192
Finished trial sigmoid,final_act:False,64,48,12,1.0, mse 0.539548319311829, p_r 0.7320834998350283
Run time 20.881171464920044
Finished trial sigmoid,final_act:False,64,48,12,10.0, mse 0.5349144510800384, p_r



Finished trial sigmoid,final_act:False,96,32,12,1000.0, mse 1.1649317982031606, p_r nan
Run time 9.664430141448975
Finished trial sigmoid,final_act:False,96,32,16,0.1, mse 0.5303784352747237, p_r 0.7363917387929255
Run time 18.658405780792236
Finished trial sigmoid,final_act:False,96,32,16,1.0, mse 0.5971040751778018, p_r 0.7012660130906496
Run time 31.19925093650818
Finished trial sigmoid,final_act:False,96,32,16,10.0, mse 0.5183968319828478, p_r 0.7467891327379393
Run time 47.652719020843506
Finished trial sigmoid,final_act:False,96,32,16,100.0, mse 0.5443185167412905, p_r 0.7345452699340485
Run time 41.70586824417114
Finished trial sigmoid,final_act:False,96,32,16,1000.0, mse 1.1649316381722543, p_r 0.3263986986168656
Run time 4.151305675506592
Finished trial sigmoid,final_act:False,96,48,4,0.1, mse 0.5546257409602384, p_r 0.722912166613462
Run time 27.753209352493286
Finished trial sigmoid,final_act:False,96,48,4,1.0, mse 0.5251944370564352, p_r 0.7413851433345134
Run time 14.61242



Run time 20.015158891677856
Finished trial sigmoid,final_act:True,64,48,8,0.1, mse 0.5506237568442175, p_r 0.7256498111179195
Run time 23.455157995224
Finished trial sigmoid,final_act:True,64,48,8,1.0, mse 0.5890796851416976, p_r 0.7017231980700077
Run time 24.127560138702393
Finished trial sigmoid,final_act:True,64,48,8,10.0, mse 0.5214510971398767, p_r 0.7444505451223247
Run time 38.63401770591736
Finished trial sigmoid,final_act:True,64,48,8,100.0, mse 1.1649202216856096, p_r 0.2762036348213599
Run time 7.58053731918335
Finished trial sigmoid,final_act:True,64,48,8,1000.0, mse 1.1649315305280676, p_r 0.38391666446570616
Run time 4.138140678405762
Finished trial sigmoid,final_act:True,64,48,12,0.1, mse 0.5345716861982434, p_r 0.7349305570538979
Run time 17.326372623443604
Finished trial sigmoid,final_act:True,64,48,12,1.0, mse 0.5312348012672274, p_r 0.7389439644992818
Run time 11.162145614624023
Finished trial sigmoid,final_act:True,64,48,12,10.0, mse 0.5222900888427489, p_r 0.74464

In [134]:
iterations = [len(trial) for trial in gridsearch_results['val_mse']]
gridsearch_results['iterations'] = iterations

In [163]:
lowest_mse_likelihood = [trial[-5] for trial in gridsearch_results['neg_likelihood']]
gridsearch_results['lowest_mse_likelihood'] = final_neg_likelihood

In [164]:
lowest_mse_pearsonr = [trial[-5] for trial in gridsearch_results['val_pearson']]
gridsearch_results['lowest_mse_pearsonr'] = lowest_mse_pearsonr

In [165]:
lowest_mse = [trial[-5] for trial in gridsearch_results['val_mse']]
gridsearch_results['lowest_val_mse'] = lowest_mse

### discard ones with iterations over 100, some strange stopping for those

In [176]:
filtered_iterations_gridsearch = gridsearch_results.copy()

In [177]:
filtered_iterations_gridsearch.drop(filtered_iterations_gridsearch[filtered_iterations_gridsearch['iterations']>100].index,inplace=True)

In [178]:
filtered_iterations_gridsearch.sort_values(by=['lowest_mse_pearsonr'],ascending=False).head()

Unnamed: 0,k_var,lengthscales,m_var,neg_likelihood,val_mse,val_pearson,iterations,lowest_val_mse,lowest_mse_likelihood,lowest_mse_pearsonr
"relu,final_act:False,0,48,16,100.0","[1.0, 1.0, 1.0498190724747998, 1.0612336491676...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.127...","[4216.132716423424, 4216.132716423424, 3481.66...","[0.9516255408999622, 0.7873573410702047, 0.747...","[0.4619055730894084, 0.5778495977365445, 0.606...",29,0.48978,2808.774511,0.76381
"sigmoid,final_act:False,32,64,12,1.0","[1.0, 1.0, 1.0082747346344174, 1.0151552580194...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.163...","[8858.865553668025, 8858.865553668025, 5505.26...","[0.7568362895954688, 0.6339327360682601, 0.581...","[0.5971005280053571, 0.6846413943243643, 0.711...",24,0.490268,2849.135465,0.762651
"sigmoid,final_act:True,0,16,16,0.1","[1.0, 1.0, 1.0146208517557171, 1.0212426173215...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.198...","[9291.785340594339, 9291.785340594339, 4450.47...","[0.7547625191459439, 0.6163478771517387, 0.590...","[0.6066438268065506, 0.6922343021298762, 0.707...",28,0.492845,2843.869754,0.761145
"sigmoid,final_act:True,0,64,4,10.0","[1.0, 1.0, 1.0017912342807402, 1.0034884529579...","[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [...","[0.1000000014901161, 0.1000000014901161, 0.195...","[12449.594169842989, 12449.594169842989, 5554....","[1.0646332555980946, 0.8432369725164829, 0.736...","[0.30357008144346964, 0.5299106996127467, 0.60...",93,0.493854,2859.483194,0.760447
"sigmoid,final_act:False,32,64,8,0.1","[1.0, 1.0, 1.0039918501311471, 1.0064973323232...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1....","[0.1000000014901161, 0.1000000014901161, 0.160...","[10555.152368866155, 10555.152368866155, 5833....","[0.925029530335916, 0.7326509615625881, 0.6463...","[0.45556599238135154, 0.6132866675332149, 0.66...",46,0.489929,2688.21297,0.760023


In [179]:
filtered_iterations_gridsearch.sort_values(by=['lowest_val_mse']).head()

Unnamed: 0,k_var,lengthscales,m_var,neg_likelihood,val_mse,val_pearson,iterations,lowest_val_mse,lowest_mse_likelihood,lowest_mse_pearsonr
"relu,final_act:False,0,48,16,100.0","[1.0, 1.0, 1.0498190724747998, 1.0612336491676...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.127...","[4216.132716423424, 4216.132716423424, 3481.66...","[0.9516255408999622, 0.7873573410702047, 0.747...","[0.4619055730894084, 0.5778495977365445, 0.606...",29,0.48978,2808.774511,0.76381
"sigmoid,final_act:False,32,64,8,0.1","[1.0, 1.0, 1.0039918501311471, 1.0064973323232...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1....","[0.1000000014901161, 0.1000000014901161, 0.160...","[10555.152368866155, 10555.152368866155, 5833....","[0.925029530335916, 0.7326509615625881, 0.6463...","[0.45556599238135154, 0.6132866675332149, 0.66...",46,0.489929,2688.21297,0.760023
"sigmoid,final_act:False,32,64,12,1.0","[1.0, 1.0, 1.0082747346344174, 1.0151552580194...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.163...","[8858.865553668025, 8858.865553668025, 5505.26...","[0.7568362895954688, 0.6339327360682601, 0.581...","[0.5971005280053571, 0.6846413943243643, 0.711...",24,0.490268,2849.135465,0.762651
"sigmoid,final_act:True,0,16,16,0.1","[1.0, 1.0, 1.0146208517557171, 1.0212426173215...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.198...","[9291.785340594339, 9291.785340594339, 4450.47...","[0.7547625191459439, 0.6163478771517387, 0.590...","[0.6066438268065506, 0.6922343021298762, 0.707...",28,0.492845,2843.869754,0.761145
"sigmoid,final_act:True,0,64,4,10.0","[1.0, 1.0, 1.0017912342807402, 1.0034884529579...","[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [...","[0.1000000014901161, 0.1000000014901161, 0.195...","[12449.594169842989, 12449.594169842989, 5554....","[1.0646332555980946, 0.8432369725164829, 0.736...","[0.30357008144346964, 0.5299106996127467, 0.60...",93,0.493854,2859.483194,0.760447


In [184]:
filtered_iterations_gridsearch.sort_values(by=['lowest_mse_likelihood']).head()

Unnamed: 0,k_var,lengthscales,m_var,neg_likelihood,val_mse,val_pearson,iterations,lowest_val_mse,lowest_mse_likelihood,lowest_mse_pearsonr
"relu,final_act:False,96,64,12,10.0","[1.0, 1.0, 1.0278935570688073, 1.0359535680650...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.144...","[6587.93682720463, 6587.93682720463, 3710.1414...","[0.9875982849861638, 0.8152656648951754, 0.800...","[0.42734963060431874, 0.5604609367761814, 0.56...",31,0.736901,970.694891,0.648507
"relu,final_act:False,96,64,12,0.1","[1.0, 1.0, 1.0247690403520007, 1.0335938847899...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.146...","[6781.246499996365, 6781.246499996365, 3789.91...","[0.9164304991495723, 0.7602404429648901, 0.755...","[0.4743959039812231, 0.5961769223676308, 0.600...",23,0.754097,1426.337219,0.615027
"relu,final_act:True,0,64,12,10.0","[1.0, 1.0, 1.02199751748821, 1.026816239134334...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.158...","[6845.562039753739, 6845.562039753739, 3754.75...","[0.929840371308307, 0.7855034046109215, 0.7765...","[0.4664039925218152, 0.574246984409797, 0.5810...",41,0.725944,1582.949257,0.644927
"relu,final_act:False,96,16,16,10.0","[1.0, 1.0, 1.0348157364154458, 1.0406218883941...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.146...","[6147.626273004989, 6147.626273004989, 3637.75...","[0.8050706630089826, 0.7959798943236375, 0.795...","[0.5590798689858864, 0.5706481207205008, 0.571...",28,0.769649,1626.689313,0.613273
"relu,final_act:False,64,64,12,10.0","[1.0, 1.0, 1.0300411764361885, 1.0385206714414...","[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,...","[0.1000000014901161, 0.1000000014901161, 0.156...","[6880.971046953677, 6880.971046953677, 3772.81...","[0.8336834120255814, 0.7450423722484417, 0.739...","[0.5411108697578249, 0.6063039239022563, 0.611...",25,0.748228,1640.934579,0.629537


In [182]:
with open('grid_search.pickle', 'wb') as handle:
    pickle.dump(gridsearch_results, handle, protocol=pickle.HIGHEST_PROTOCOL)