## Train the Model

Train WaveNet-LSTM model with proceeded data.

In [19]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Date    : Feb-16-22 00:28
# @Author  : Kelley Kan HUANG (kan.huang@connect.ust.hk)
# @RefLink : https://www.kaggle.com/wimwim/wavenet-lstm

import imp
import os
import argparse
from datetime import datetime
import tensorflow as tf
from model.wavenet_lstm import WaveNet_LSTM
from datasets.load_data import load_data
from datasets.window_sequences import WindowSequences
import pandas as pd
from scipy.signal import convolve
from scipy.signal import hilbert
from scipy.signal import hann
import scipy.signal as sg
# import seaborn as sns
# from sklearn.metrics import mean_absolute_error
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")


pd.options.display.precision = 15
%matplotlib inline

In [20]:
exper_name = "lstm"
MODEL_TYPE = "LSTM"

model_gpu = 0
train_gpu = 0

model_device = "/device:GPU:" + str(model_gpu)
train_device = "/device:GPU:" + str(train_gpu)

# 超参数
model_type = "WaveNet_LSTM"
window_size = 7200 # seconds
# 7200 acoustic -> 1s time_to_failure

stride = 10 # 100

batch_size = 32 # 64

validation_split = 0.2

epochs = 20

In [27]:
# Load data
BASE_PATH = "E:\\DeepLearningData\\LANL-Earthquake-Prediction"

train_acoustic_data = load_data(path=BASE_PATH, name="train_acoustic_data")

train_time_to_failure = load_data(
    path=BASE_PATH, name="train_time_to_failure")

window_sequence_train = WindowSequences(
    train_acoustic_data, train_time_to_failure, window_size=window_size, stride=stride, batch_size=batch_size, shuffle=True, seed=42, validation_split=validation_split, subset="training")
window_sequence_val = WindowSequences(
    train_acoustic_data, train_time_to_failure, window_size=window_size, stride=stride, batch_size=batch_size, shuffle=True, seed=42, validation_split=validation_split, subset="validation")

# Config paths
prefix = os.path.join(
    "~", "Documents", "DeepLearningData", "Earthquake")
date_time = datetime.now().strftime("%Y%m%d-%H%M%S")
subfix = os.path.join(model_type,
                      "_".join(["stride", str(stride)]), date_time)  # date_time at last

# ckpts 和 logs 分开
log_dir = os.path.expanduser(os.path.join(prefix, "logs", subfix))
ckpt_dir = os.path.expanduser(os.path.join(prefix, "ckpts", subfix))
os.makedirs(log_dir)
os.makedirs(ckpt_dir)


load time used: 0.0023492000000260305 seconds.
load time used: 0.0017505000000710425 seconds.


## Prepare model

WaveNet-LSTM model

In [23]:
def lr_schedule(epoch):
    lr = 1e-4  # base learning rate
    if epoch >= 20:
        lr *= 0.1  # # reduced by 0.1 when finish training for 40 epochs
    return lr


In [28]:
# Define callbacks
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint, CSVLogger, TensorBoard

ckpt_filename = "%s-epoch-{epoch:03d}-mse-{mse:.4f}.h5" % model_type
ckpt_path = os.path.join(ckpt_dir, ckpt_filename)
checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path, monitor="mse", verbose=1)

lr_scheduler = LearningRateScheduler(lr_schedule)
csv_logger = CSVLogger(os.path.join(log_dir, "training.log.csv"), append=True)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir, histogram_freq=1, update_freq="batch")
# 不需要 earlystop
callbacks = [csv_logger, lr_scheduler,
            checkpoint_callback, tensorboard_callback]

loss = tf.keras.losses.MeanSquaredError()  # "mse"
metrics = [  # "mae", "mse", "mape", "msle"
    tf.keras.metrics.MeanAbsoluteError(name="mae"),
    tf.keras.metrics.MeanSquaredError(name="mse"),
    tf.keras.metrics.MeanAbsolutePercentageError(name="mape"),
    tf.keras.metrics.MeanSquaredLogarithmicError(name="msle")
    ]

# Prepare model
from tensorflow.keras.optimizers import Adam, SGD
with tf.device(model_device):
    if model_type == "WaveNet_LSTM":
        model = WaveNet_LSTM(input_shape=(window_size, 1))

    model.compile(
        Adam(clipvalue=1.0, lr=lr_schedule(0)),
        loss=loss,
        metrics=metrics
    )

In [29]:
with tf.device(train_device):
    model.fit(
        x=window_sequence_train,
        validation_data=window_sequence_val,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )


  ...
    to  
  ['...']
  ...
    to  
  ['...']
Train for 1572846 steps, validate for 393212 steps
Epoch 1/100
    411/1572846 [..............................] - ETA: 153:10:14 - loss: 14.4190 - mae: 3.0227 - mse: 14.3683 - mape: 388.2914 - msle: 0.5022
Epoch 00001: saving model to C:\Users\kellyhwong\Documents\DeepLearningData\Earthquake\ckpts\WaveNet_LSTM\stride_10\20220309-235601\WaveNet_LSTM-epoch-001-mse-14.3683.h5
    411/1572846 [..............................] - ETA: 153:34:14 - loss: 14.4190 - mae: 3.0227 - mse: 14.3683 - mape: 388.2914 - msle: 0.5022

KeyboardInterrupt: 