In [None]:
EPOCHS = 100
BATCH_SIZE = 32
LEARNING_RATE = 2e-4

BETA_1 = 0.9
BETA_2 = 0.99

POOL_SIZE = BATCH_SIZE * 5
RESTORE_CHECKPOINT = False
SAVE_GAN_WEIGHTS = True
SAVE_EMA_MODEL = True


In [1]:
import tensorflow as tf
import time
import os
from glob import glob
from utils.train import no_gan_inner_step, AverageModelWeights

from keras.optimizers.optimizer_experimental.adamw import AdamW

from utils.data import load_function, feed_data, PoolData
from utils.data import feed_props_1, feed_props_2
from utils.data import usm_sharpener


2022-05-21 22:32:49.552663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/khoa/anaconda3/envs/nine/lib/python3.9/site-packages/cv2/../../lib64:
2022-05-21 22:32:49.552716: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-05-21 22:32:49.552736: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (zen): /proc/driver/nvidia/version does not exist
2022-05-21 22:32:49.552975: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
optimizer = AdamW(learning_rate=LEARNING_RATE, beta_1=BETA_1, beta_2=BETA_2)


In [None]:
from networks.models import RRDBNet

no_gan_model = RRDBNet()
no_gan_model.build((None, 256, 256, 3))

ema_no_gan_model = RRDBNet()
ema_no_gan_model.build((None, 256, 256, 3))

ema_api = AverageModelWeights(ema_no_gan_model, no_gan_model.get_weights())


In [None]:
data_path = os.path.abspath("./DIV2K_train_HR/*.png")
train_images_paths = sorted(glob(data_path))

train_dataset = tf.data.Dataset.from_tensor_slices((train_images_paths))
train_dataset = train_dataset.shuffle(len(train_images_paths))
train_dataset = train_dataset.map(
    load_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_generator = train_dataset.batch(BATCH_SIZE)


In [None]:
pool_train_data = PoolData(POOL_SIZE, BATCH_SIZE)

checkpoint_dir = './training_checkpoints'
ema_checkpoint_dir = './ema_training_checkpoints'

checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(model=no_gan_model, optimizer=optimizer)

ema_checkpoint_prefix = os.path.join(ema_checkpoint_dir, "ema_ckpt")
ema_checkpoint = tf.train.Checkpoint(model=ema_no_gan_model)


In [None]:
if RESTORE_CHECKPOINT:
    print("loading training checkpoints: ")
    print(tf.train.latest_checkpoint(checkpoint_dir))
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

    print("loading EMA training checkpoints: ")
    print(tf.train.latest_checkpoint(ema_checkpoint_dir))
    ema_checkpoint.restore(tf.train.latest_checkpoint(ema_checkpoint_dir))


In [None]:
@tf.function
def train_step(gt_images, lq_images):
    return no_gan_inner_step(gt_images, lq_images, no_gan_model, optimizer)


In [None]:
epochs = EPOCHS
start_epoch = 0

train_steps = int(len(train_images_paths) // BATCH_SIZE)

pool_train_data = PoolData(POOL_SIZE, BATCH_SIZE)

train_loss_metric = tf.keras.metrics.Mean()
loss_results = []


In [None]:
def train(epochs):
    print("Start Training")
    for epoch in range(start_epoch, epochs):
        train_loss_metric.reset_states()
        epoch_time = time.time()
        batch_time = time.time()
        step = 0

        epoch_count = f"0{epoch + 1}/{epochs}" if epoch < 9 else f"{epoch + 1}/{epochs}"

        for img, first_kernel, second_kernel, sinc_kernel in train_generator:
            gt_img, lq_img = feed_data(img, first_kernel, second_kernel, sinc_kernel, [
                                       feed_props_1, feed_props_2])
            gt_img, lq_img = pool_train_data.get_pool_data(gt_img, lq_img)
            gt_img = usm_sharpener.sharp(gt_img)
            loss = train_step(gt_img, lq_img)

            print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
                  '| Loss:', f"{loss:.5f}", "| Step Time:", f"{time.time() - batch_time:.2f}", end='')

            train_loss_metric.update_state(loss)
            loss = train_loss_metric.result().numpy()
            step += 1

            loss_results.append(loss)

            batch_time = time.time()

        checkpoint.save(file_prefix=checkpoint_prefix)
        ema_api.compute_ema_weights(no_gan_model)
        ema_checkpoint.save(file_prefix=ema_checkpoint_prefix)

        print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
              '| Loss:', f"{loss:.5f}", "| Epoch Time:", f"{time.time() - epoch_time:.2f}")


In [None]:
train(epochs)

if SAVE_NORMAL_MODEL:
    no_gan_model.save_weights('./checkpoint_weights/last_weights')

if SAVE_EMA_MODEL:
    ema_no_gan_model.save("./no_gan_ema_model")
