In [1]:
import sys

sys.path.append("./..")

In [2]:
import tensorflow as tf
import os
from amc_neural_networks.data_loading import (
    load_tfrecords,
    get_tf_batch,
    get_n_tfrecord_files,
)
from amc_neural_networks.model import create_model

In [3]:
dataset_root_path = 'path/to/tfrecord/dataset'
n_train_files = get_n_tfrecord_files(dataset_root_path, "train")
n_test_files = get_n_tfrecord_files(dataset_root_path, "test")

In [4]:
BATCH_SIZE = 32
train_ds = (
    load_tfrecords(dataset_root_path, split="train")
    .shuffle(n_train_files, reshuffle_each_iteration=True)
    .interleave(
        tf.data.TFRecordDataset,
        cycle_length=tf.data.AUTOTUNE,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    .batch(BATCH_SIZE)
    .map(get_tf_batch)
    .prefetch(tf.data.AUTOTUNE)
)

test_ds = (
    load_tfrecords(dataset_root_path, split="test")
    .interleave(
        tf.data.TFRecordDataset,
        cycle_length=tf.data.AUTOTUNE,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    .batch(BATCH_SIZE)
    .map(get_tf_batch)
    .prefetch(tf.data.AUTOTUNE)
)

eval_ds = (
    load_tfrecords(dataset_root_path, split="test")
    .interleave(
        tf.data.TFRecordDataset,
        cycle_length=tf.data.AUTOTUNE,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    .batch(BATCH_SIZE)
    .map(lambda x: get_tf_batch(x, with_snr=True))
    .prefetch(tf.data.AUTOTUNE)
)

In [5]:
model = create_model(
    use_se=True, use_act=True, residual=False, dilate=True, self_attention=False,
)
adam = tf.keras.optimizers.Adam(lr=1e-4, clipnorm=1)
model.compile(
    adam,
    loss=["sparse_categorical_crossentropy"],
    metrics=[
        "acc",
        tf.keras.metrics.SparseTopKCategoricalAccuracy(
            k=2, name="top_2_acc", dtype=None
        ),
    ],
)
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 1024, 2)]    0                                            
__________________________________________________________________________________________________
conv_1 (Conv1D)                 (None, 1024, 32)     480         input[0][0]                      
__________________________________________________________________________________________________
se_block_1_stat (GlobalAverageP (None, 32)           0           conv_1[0][0]                     
__________________________________________________________________________________________________
se_block_1_reshape (Reshape)    (None, 1, 32)        0           se_block_1_stat[0][0]            
______________________________________________________________________________________________

In [6]:
N_EPOCHS = 1
callbacks = [
    tf.keras.callbacks.TerminateOnNaN(),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="loss", factor=0.1, patience=12, min_lr=1e-7, mode="min"
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="loss", mode="min", patience=20, restore_best_weights=True
    ),
]

history = model.fit(
    train_ds, 
    steps_per_epoch=1, # remove to train on whole dataset
    validation_data=test_ds,
    validation_steps=1, # remove to train on whole dataset
    callbacks=callbacks, 
    epochs=N_EPOCHS, 
    verbose=2
)

1/1 - 3s - loss: 3.1778 - acc: 0.0625 - top_2_acc: 0.1250 - val_loss: 3.1775 - val_acc: 0.0625 - val_top_2_acc: 0.0938


If you would like to include logging, checkpoints, and TensorBoard, the following code can also be used for callbacks before `model.fit`.

```python

callbacks.extend(
    [
        tf.keras.callbacks.CSVLogger(os.path.join(save_dir, "training.log")),
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(save_dir, "model.hdf5"),
            monitor="val_loss",
            save_best_only=True,
            mode="min",
        ),
        tf.keras.callbacks.TensorBoard(log_dir=os.path.join(save_dir, "logs")),
    ]
)

```
