# Model

In [None]:
!pip3 install -q open_speech; pip3 freeze | grep open-speech

try:
    from google.colab import auth
    auth.authenticate_user()

except Exception as e: print(e)

In [None]:
import sys
sys.path.append("..")

import tensorflow as tf
import tensorflow.keras as keras
import util

print("Using tensorflow:", tf.__version__)

## Model parameters

In [None]:
model_name = "model"
tpu_node = "" # set to "node-X", "colab" or `None`

root_path = "gs://open-speech-train"
model_path = root_path + "/" + model_name

print("model_path:", model_path)
if tpu_node: print("tpu_node:", tpu_node)

## Distribution strategy

In [None]:
# according to [Use TPUs](https://www.tensorflow.org/guide/tpu):
# `The TPU initialization code has to be at the beginning of your program.`

if tpu_node is not None:
    if tpu_node == "colab":
        import os
        tpu_addr = os.environ["COLAB_TPU_ADDR"]
    else:
        content = tf.io.read_file(root_path + "/tpus.json").numpy()
        import io, json
        with io.BytesIO(content) as file: tpu_addr = json.load(file).get(tpu_node)
else: tpu_addr = None

strategy = util.create_strategy(tpu_addr=tpu_addr)

## Training parameters

In [None]:
learn_rate = 0.0001

util.init_data()
print("alphabet:", util.alphabet)

## Model definition

In [None]:
def create_model():
    input = keras.layers.Input(shape=util.get_input_shape())
    output = input

    # ...

    output = keras.layers.Conv1D(filters=util.num_chars, kernel_size=1)(output) 
    # NB: softmax is applied inside ctc_loss() and ctc_decode()

    return keras.Model(inputs=input, outputs=output, name=model_name)

def get_optimizer():
    return keras.optimizers.Adam(learning_rate=learn_rate)

def get_loss():
    return util.ctc_loss

def get_metrics():
    return [ util.edit_distance ] if tpu_node is None else None # not supported on TPU

## Create model

In [None]:
import re

ckpt_path = model_path + "/checkpoints"
ckpt_templ = ckpt_path + "/epoch-{epoch:04d}.ckpt"
ckpt_regex = re.compile("epoch-([0-9]+)\.ckpt")

logs_path = model_path + "/logs"

print("checkpoints:", ckpt_path)
print("logs:", logs_path)

ckpt_latest = tf.train.latest_checkpoint(ckpt_path)
print("Latest checkpoint:", ckpt_latest)

# restore weights from a specific checkpoint (set to "" to ignore all checkpoints)
#ckpt_latest = ckpt_path + "/epoch-0000.ckpt"
print("Using checkpoint:", ckpt_latest)

In [None]:
with strategy.scope():
    model = create_model()
    model.compile(optimizer=get_optimizer(), loss=get_loss(), metrics=get_metrics())
    model.summary()

    if ckpt_latest:
        print("Loading weights:", ckpt_latest)
        model.load_weights(ckpt_latest)

print("Loading datasets")
train_data = util.get_train_dataset(prefetch=util.AUTOTUNE)
valid_data = util.get_valid_dataset(prefetch=util.AUTOTUNE)

## Train the model

In [None]:
init_epoch = int(ckpt_regex.search(ckpt_latest).group(1)) if ckpt_latest else 0
print("init_epoch:", init_epoch)

def update_init_epoch(epoch, logs):
    global init_epoch
    init_epoch = epoch + 1

In [None]:
hist = model.fit(x=train_data, validation_data=valid_data,
    initial_epoch=init_epoch,
    epochs=100,
    callbacks=[
        keras.callbacks.ModelCheckpoint(
            filepath=ckpt_templ,
            save_weights_only=True,
        ),
        keras.callbacks.TensorBoard(log_dir=logs_path),
        keras.callbacks.LambdaCallback(on_epoch_end=update_init_epoch),
    ]
)