# Hyperparam tuning 

In [1]:
import os
import glob
import json

import numpy as np

#import tensorflow as tf
#from tensorflow import keras
#from tensorflow.keras import losses, optimizers
#from tensorflow.keras.models import Sequential
#from tensorflow.keras.layers import ConvLSTM2D, BatchNormalization

from kerastuner import HyperModel

from sclouds.helpers  import (path_input, path_convlstm_results, get_lon_array, get_lat_array)
from sclouds.io.utils import (dataset_to_numpy_grid_keras, get_xarray_dataset_for_period, 
                              train_test_split_keras)

# Custom properties made for keras.
from sclouds.ml.ConvLSTM.hyper_convlstm import HyperConvLSTM
from sclouds.ml.ConvLSTM.utils import r2_keras, keras_custom_loss_function

# Packages from keras tuner.
from kerastuner.tuners import RandomSearch
from kerastuner import HyperParameters

# Hyper parameters:

In [2]:
hp = HyperParameters()
hypermodel = HyperConvLSTM(num_hidden_layers = 2, seq_length= 4)

tuner = RandomSearch(
        hypermodel,
        objective='mean_squared_error',
        max_trials=10,
        allow_new_entries = True, 
        directory=path_convlstm_results,
        project_name='test_hyperparameters')

INFO:tensorflow:Reloading Oracle from /home/hanna/lagrings/results/convlstm/test_hyperparameters/oracle.json


In [3]:
tuner.search_space_summary()

# Load data

In [4]:
data = get_xarray_dataset_for_period(start = '2012-01-01', stop = '2012-01-31')
X_train, y_train, X_test, y_test = train_test_split_keras(data, seq_length = 4, val_split=0.2)
#X, y = dataset_to_numpy_grid_keras(data, bias = False)

Num files 5


In [5]:
X_train, y_train, X_test, y_test = train_test_split_keras(data, seq_length = 4, val_split=0.2)

# Fit hyperparam models

In [None]:
tuner.search(X_train, y_train,
             epochs=2,
             validation_data=(X_test, y_test))

Train on 148 samples, validate on 38 samples
Epoch 1/2


In [21]:
models = tuner.get_best_models(num_models=2)

In [22]:
models

[<tensorflow.python.keras.engine.sequential.Sequential at 0x7fd3d41ab7b8>,
 <tensorflow.python.keras.engine.sequential.Sequential at 0x7fd3d41ab438>]