# Training Time

In [1]:
import glob
from pprint import pprint

import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.keras.utils.data_utils import Sequence

from slice_generator import slice_generator

In [2]:
print("tensorflow version:",tf.__version__)

tensorflow version: 2.2.0-dlenv


## Model variables

In [3]:
frames = 1
channels = 1
pixels_x = 21
pixels_y = 21

## Load the model

In [4]:
model_dir = "../models/"
models_list = sorted(glob.glob(model_dir+"*.h5"))
pprint(models_list)

['../models/convlstm_1f_1c_21x_21y.h5',
 '../models/convlstm_6f_1c_21x_21y.h5',
 '../models/encoder_convlstm_1f_1c_21x_21y.h5',
 '../models/full_stack_1f_1c_21x_21y.h5',
 '../models/t_convlstm_1f_1c_21x_21y.h5',
 '../models/t_encoder_convlstm_1f_1c_21x_21y.h5']


In [5]:
# choose a model
file_index = 3
models_list[file_index]
model_name = 'full_stack_1f_1c_21x_21y' # todo: use regex to parse filename

In [6]:
model = tf.keras.models.load_model(models_list[file_index])
model.summary()

Model: "Full_stack"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
model_input (InputLayer)        [(None, 2, 1, 21, 21 0                                            
__________________________________________________________________________________________________
tf_op_layer_unstack_30 (TensorF [(None, 1, 21, 21),  0           model_input[0][0]                
__________________________________________________________________________________________________
gaussian_noise_30 (GaussianNois (None, 1, 21, 21)    0           tf_op_layer_unstack_30[0][0]     
__________________________________________________________________________________________________
convB1 (Conv2D)                 (None, 8, 9, 9)      208         tf_op_layer_unstack_30[0][1]     
_________________________________________________________________________________________

## Prepare for training

In [7]:
train_file_path = "../data/train"
valid_file_path = "../data/validate"
vars_           = ['t2m']
proc_type       = "convlstm"
# 3 years of training data = 
train_steps = 3 * 365 * 24 / frames
# 1 year of validation data = 
valid_steps = 1 * 365 * 24 / frames

In [8]:
slice_train = slice_generator(img_dir=train_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     )
slice_val = slice_generator(img_dir=valid_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     )

## Train

In [9]:
epochs = 1

In [10]:
history = model.fit(
    slice_generator(img_dir=train_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     ),
    steps_per_epoch = train_steps,
    epochs = epochs,
    verbose = 1,
    shuffle = False,
    initial_epoch = 0,
    validation_steps = valid_steps,
    validation_data = slice_generator(img_dir=valid_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     ),
    )

 1561/26280 [>.............................] - ETA: 33:15 - loss: 255.0488 - accuracy: 0.0459 - mean_absolute_error: 154.7438

KeyboardInterrupt: 

## Save trained model

In [None]:
history.history['loss']

In [None]:
tf.keras.models.save_model(
    model = model,
    filepath = '../models/t_'+model_name+'.h5',
    overwrite=True,
    include_optimizer=True,
    save_format='tf',
    signatures=None
)

## Visually Inspect Prediction

In [None]:
in_, out_ = next(slice_train)
out_.shape

In [None]:
# test concating two arrays
concated_inputs = np.concatenate((out_, out_), axis=1, out=None)
concated_inputs.shape

In [None]:
unstacked = tf.unstack(concated_inputs, num=None, axis=1, name='unstack')
print(len(unstacked))
print(unstacked[0].shape)
print(unstacked[1].shape)
#. must add another dim with expand_dims

### Test prediction

In [None]:
model.predict(in_[0], verbose=1)