# Training Time

In [1]:
import glob
from pprint import pprint
import pickle

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from slice_generator import slice_generator

In [2]:
print("tensorflow version:",tf.__version__)

tensorflow version: 2.2.0-dlenv


## Model variables

In [3]:
frames = 1
channels = 1
pixels_x = 21
pixels_y = 21

## Load the model

In [4]:
model_dir = "../models/"
models_list = sorted(glob.glob(model_dir+"*.h5"))
pprint(models_list)

['../models/full_stack_1f_1c_21x_21y.h5',
 '../models/t_full_stack_1f_1c_21x_21y.h5']


In [5]:
# choose a model
file_index = 1
models_list[file_index]
model_name = 'full_stack_1f_1c_21x_21y' # todo: use regex to parse filename

In [6]:
model = tf.keras.models.load_model(models_list[file_index])
model.summary()

Model: "Full_stack"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
model_input (InputLayer)        [(None, 2, 1, 21, 21 0                                            
__________________________________________________________________________________________________
tf_op_layer_unstack_1 (TensorFl [(None, 1, 21, 21),  0           model_input[0][0]                
__________________________________________________________________________________________________
gaussian_noise_1 (GaussianNoise (None, 1, 21, 21)    0           tf_op_layer_unstack_1[0][0]      
__________________________________________________________________________________________________
convB1 (Conv2D)                 (None, 8, 9, 9)      208         tf_op_layer_unstack_1[0][1]      
_________________________________________________________________________________________

## Check model history

In [7]:
model.history

## Prepare for training

In [8]:
train_file_path = "../data/train"
valid_file_path = "../data/validate"
vars_           = ['t2m']
proc_type       = "convlstm"
# 3 years of training data = 
train_steps = 3 * 365 * 24 / frames
# 1 year of validation data = 
valid_steps = 1 * 365 * 24 / frames

## Train

In [10]:
epochs = 8

In [12]:
model.fit(
    slice_generator(img_dir=train_file_path, slice_size=frames, 
                    vars_=vars_, proc_type=proc_type,
                    pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                    ),
    steps_per_epoch = train_steps,
    epochs = epochs,
    verbose = 1,
    shuffle = False,
    initial_epoch = 0,
    validation_steps = valid_steps,
    validation_data = slice_generator(img_dir=valid_file_path, slice_size=frames,
                                      vars_=vars_, proc_type=proc_type,
                                      pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                                      ),
)

tf.keras.models.save_model(
        model = model,
        filepath = '../models/t_'+model_name+'.h5',
        overwrite=True,
        include_optimizer=True,
        save_format='tf',
        signatures=None
)


Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8


## get training and validation loss

In [21]:
model.history.history

{'loss': [1693.7786865234375,
  151.57391357421875,
  9.25869369506836,
  1.3110136985778809,
  1.2769964933395386,
  1.2693883180618286,
  1.2686151266098022,
  1.2681891918182373],
 'accuracy': [0.015561354346573353,
  0.015905631706118584,
  0.1809016466140747,
  0.15021562576293945,
  0.13372653722763062,
  0.1338932365179062,
  0.13390591740608215,
  0.1337084174156189],
 'mean_absolute_error': [150.7167205810547,
  59.357295989990234,
  4.1898345947265625,
  3.658730983734131,
  3.6664934158325195,
  3.6664674282073975,
  3.6671292781829834,
  3.6672863960266113],
 'val_loss': [56639.94921875,
  389165.9375,
  1.521070957183838,
  1.6959989070892334,
  1.4849776029586792,
  1.4700580835342407,
  1.4693400859832764,
  1.4689908027648926],
 'val_accuracy': [0.019188953563570976,
  0.023456186056137085,
  0.11307892948389053,
  0.10407153517007828,
  0.10852359235286713,
  0.11307892948389053,
  0.11307892948389053,
  0.11307892948389053],
 'val_mean_absolute_error': [104.8054275512

## Visually Inspect Prediction

In [None]:
slice_train = slice_generator(img_dir=train_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     )
slice_val = slice_generator(img_dir=valid_file_path, slice_size=frames, vars_=vars_,
                     proc_type=proc_type, pixels_x=pixels_x, pixels_y=pixels_y, debug=False
                     )

In [None]:
in_, out_ = next(slice_train)
in_.shape

In [None]:
test_pred= model.predict(in_, verbose=1)
test_pred