In [1]:
import argparse
import time
from typing import Tuple

import h5py
import numpy as np
import tensorflow as tf

from models.model_zoo import unet_tf2

In [2]:
# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

params = dict()
params['train_db01_path'] = '/Users/sallesd/Projects/f3_ffn_db/train_db01.hdf'
params['test_db01_path'] = '/Users/sallesd/Projects/f3_ffn_db/valid_db01.hdf'
params['train_db02_path'] = '/Users/sallesd/Projects/f3_ffn_db/train_db02.hdf'
params['test_db02_path'] = '/Users/sallesd/Projects/f3_ffn_db/valid_db02.hdf'
params['model_path'] = '/Users/sallesd/Projects/f3_ffn_model'

In [3]:
class MeanIoU(tf.keras.metrics.Metric):
    def __init__(self, name='mean_iou', **kwargs):
        super(MeanIoU, self).__init__(name=name, **kwargs)
        self.tf_mean_iou = tf.keras.metrics.MeanIoU(num_classes=2)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)

        self.tf_mean_iou.update_state(y_true, y_pred)

    def result(self):
        return self.tf_mean_iou.result()

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.tf_mean_iou.reset_states()

In [4]:
def training_loop(train_dataset: tf.data.Dataset,
                  val_dataset: tf.data.Dataset,
                  model: tf.keras.Model,
                  model_path: str,
                  optimizer: tf.keras.optimizers.Optimizer,
                  loss_fn: tf.keras.losses.Loss,
                  metrics: Tuple[dict, dict],
                  epochs: int):
    train_metrics = metrics[0]
    val_metrics = metrics[1]

    # Iterate over epochs.
    best = 0.0
    for epoch in range(epochs):
        print('Start of epoch %d' % (epoch,))
        epoch_loss = 0
        # Iterate over the batches of the dataset.
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                logits = model(x_batch_train)
                pixel_loss = loss_fn(y_batch_train, logits)
                w = y_batch_train[:, :, :, -1] * 9
                w = w + y_batch_train[:, :, :, 0]
                pixel_loss *= w
                loss_value = tf.reduce_sum(pixel_loss) * (
                        1. / y_batch_train.shape[0] / y_batch_train.shape[1] / y_batch_train.shape[2])

            grads = tape.gradient(loss_value, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))

            epoch_loss += loss_value

            # Update training metric.
            for metric in train_metrics.values():
                metric(y_batch_train, logits)

            # Log every 2 batches.
            if step % 2 == 0:
                print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
                print('Seen so far: %s samples' % ((step + 1) * x_batch_train.shape[0]))

        print(f'Epoch loss: {epoch_loss / (step + 1)}')
        # Display metrics at the end of each epoch.
        for metric_name, metric in train_metrics.items():
            print(f"| {metric_name}: {metric.result()} ", end="", flush=True)
            metric.reset_states()
        print("|")

        # Run a validation loop at the end of each epoch.
        for x_batch_val, y_batch_val in val_dataset:
            val_logits = model(x_batch_val)
            for metric in val_metrics.values():
                metric(y_batch_val, val_logits)

        metric_dict = {}
        for metric_name, metric in val_metrics.items():
            metric_dict[metric_name] = metric.result().numpy()
            print(f'| {metric_name}: {metric_dict[metric_name]} ', end='', flush=True)
            metric.reset_states()
        print("|")
        # print('Validation acc: %s' % (float(val_acc),))

        if best < metric_dict[metric_name]:
            model.save(f'{model_path}_best.h5')

    return model

In [5]:
start = time.time()

with h5py.File(params['train_db01_path'], 'r') as train_h5_file:
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (np.array(train_h5_file.get('features')),
         np.array(train_h5_file.get('label')))
    )
    train_dataset = train_dataset.shuffle(buffer_size=10192).batch(batch_size=2048)

with h5py.File(params['test_db01_path'], 'r') as test_h5_file:
    test_dataset = tf.data.Dataset.from_tensor_slices(
        (np.array(test_h5_file.get('features')),
         np.array(test_h5_file.get('label')))
    )
    test_dataset = test_dataset.batch(batch_size=512)

input_shape = (64, 64, 2)
num_classes_output = 2

In [None]:
train_metrics = {
    'train_mean_iou': MeanIoU(),
    'train_cat_acc': tf.keras.metrics.CategoricalAccuracy()
}
val_metrics = {
    'val_mean_iou': MeanIoU(),
    'val_cat_acc': tf.keras.metrics.CategoricalAccuracy()
}

model = training_loop(
    train_dataset=train_dataset,
    val_dataset=test_dataset,
    model=unet_tf2(input_shape=input_shape, output_channels=num_classes_output),
    model_path=params['model_path'],
    optimizer=tf.keras.optimizers.Adam(lr=0.001),
    loss_fn=tf.keras.losses.CategoricalCrossentropy(
        reduction=tf.keras.losses.Reduction.NONE),
    metrics=(train_metrics, val_metrics),
    epochs=5)
model.save(f"{params['model_path']}.h5")

print(f"Total training time: {time.time() - start}")
