# Training Time

In [1]:
import glob
from pprint import pprint

import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from slice_generator import slice_generator

In [2]:
print("tensorflow version:",tf.__version__)

tensorflow version: 2.2.0-dlenv


## Model variables

In [3]:
frames = 1
channels = 1
pixels_x = 21
pixels_y = 21

## Load the model

In [7]:
model_dir = "../models/"
models_list = sorted(glob.glob(model_dir+"*.h5"))
pprint(models_list)

['../models/convlstm_1f_1c_21x_21y.h5',
 '../models/convlstm_6f_1c_21x_21y.h5',
 '../models/encoder_convlstm_1f_1c_21x_21y.h5',
 '../models/t_convlstm_1f_1c_21x_21y.h5']


In [8]:
# choose a model
file_index = 2
models_list[file_index]

'../models/encoder_convlstm_1f_1c_21x_21y.h5'

In [9]:
model = tf.keras.models.load_model(models_list[file_index])
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input1 (InputLayer)     [(None, 1, 1, 21, 21 0                                            
__________________________________________________________________________________________________
encoder1 (ConvLSTM2D)           [(None, 1, 1, 21, 21 204         encoder_input1[0][0]             
__________________________________________________________________________________________________
decoder_input1 (InputLayer)     [(None, 1, 1, 21, 21 0                                            
__________________________________________________________________________________________________
decoder1 (ConvLSTM2D)           [(None, 1, 1, 21, 21 204         encoder_input1[0][0]             
                                                                 encoder1[0][1]               

## Prepare for training

In [10]:
train_file_path = "../data/train"
valid_file_path = "../data/validate"
vars_           = ['t2m']
proc_type       = "conv_lstm"
# 3 years of training data = 
train_steps = 3 * 365 * 24 / frames
# 1 year of validation data = 
valid_steps = 1 * 365 * 24 / frames

In [11]:
slice_train = slice_generator(img_dir=train_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     )
slice_val = slice_generator(img_dir=valid_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     )

## Train

In [12]:
epochs = 5

In [13]:
history = model.fit(
    slice_train,
    steps_per_epoch = train_steps,
    epochs = epochs,
    verbose = 1,
    shuffle = False,
    initial_epoch = 0,
    validation_steps = valid_steps,
    validation_data = slice_val,
    )

Epoch 1/5
 1702/26280 [>.............................] - ETA: 6:17 - loss: 48.7383 - mean_absolute_error: nan

KeyboardInterrupt: 

## Save trained model

In [None]:
tf.keras.models.save_model(
    model = model,
    filepath = '../models/t_'+model_name+'.h5',
    overwrite=True,
    include_optimizer=True,
    save_format='tf',
    signatures=None
)