### Setup

In [9]:
from importlib import reload

import numpy as np
import tensorflow as tf

from hot_dust import preprocess, model

In [37]:
# "reload" to get changes in preprocess.py without restarting the kernel
reload(preprocess)
reload(model)
from hot_dust.preprocess import prepare_training_data, split_training_data
from hot_dust.model import to_tensorflow, compile, pretraining

### Parameters

In [38]:
epochs = 500  # Max 500
batch_size = 64
buffer_size = 10 * batch_size

ds = prepare_training_data()
train, validate, test = to_tensorflow(split_training_data(ds))

train = train.shuffle(buffer_size).batch(batch_size)
validate = validate.batch(batch_size).cache()

In [12]:
layer = pretraining(dataset=test)
network = compile(normalization=layer)
network.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 11)]              0         
                                                                 
 normalization_3 (Normaliza  (None, 11)                23        
 tion)                                                           
                                                                 
 dense_3 (Dense)             (None, 8)                 96        
                                                                 
 dense_4 (Dense)             (None, 1)                 9         
                                                                 
Total params: 128 (516.00 Byte)
Trainable params: 105 (420.00 Byte)
Non-trainable params: 23 (96.00 Byte)
_________________________________________________________________


### Fitting

In [15]:
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=20,
    verbose=1,
)

fit = network.fit(
    train,
    epochs=epochs,
    validation_data=validate,
    verbose=2,
    callbacks=[early_stopping],
)

network.save("data/network")
fit = {"epoch": fit.epoch, **fit.history}
np.savez("data/fit.npz", fit)

Epoch 1/500
3122/3122 - 8s - loss: 0.1971 - val_loss: 0.1959 - 8s/epoch - 3ms/step
Epoch 2/500
3122/3122 - 7s - loss: 0.1969 - val_loss: 0.1959 - 7s/epoch - 2ms/step
Epoch 3/500
3122/3122 - 7s - loss: 0.1969 - val_loss: 0.1958 - 7s/epoch - 2ms/step
Epoch 4/500
3122/3122 - 7s - loss: 0.1969 - val_loss: 0.1957 - 7s/epoch - 2ms/step
Epoch 5/500
3122/3122 - 7s - loss: 0.1970 - val_loss: 0.1960 - 7s/epoch - 2ms/step
Epoch 6/500
3122/3122 - 7s - loss: 0.1969 - val_loss: 0.1960 - 7s/epoch - 2ms/step
Epoch 6: early stopping
INFO:tensorflow:Assets written to: data/network/assets


INFO:tensorflow:Assets written to: data/network/assets


### Training and Validation Losses

In [16]:
model.plot_loss(fit)