In [3]:
import os
import sys
sys.path.insert(0, '../utils')
sys.path.insert(0, '..')
import logging

logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)
logger = logging.getLogger('__name__')

import numpy as np
import tensorflow as tf
from tensorflow.keras.metrics import Mean
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam, SGD, Nadam

from modelsTF import *
from loss import *
%reload_ext autoreload
%autoreload 2

In [39]:
# import data
CLEAN_DATA_DIR = '/home/mark/DataBank/PROBA-V-CHKPT/trimmedPatchesDir'
band = 'NIR'
X_train = np.load(os.path.join(CLEAN_DATA_DIR, f'TRAINpatchesLR_{band}.npy'), allow_pickle=True)
y_train = np.load(os.path.join(CLEAN_DATA_DIR, f'TRAINpatchesHR_{band}.npy'), allow_pickle=True)

X = X_train.transpose((0, 3, 4, 2, 1))
y = y_train.transpose((0, 3, 4, 2, 1)).squeeze(3)
print(f'Input shape: {X.shape} --------> Output shape: {y.shape}')

Input shape: (52563, 32, 32, 9, 1) --------> Output shape: (52563, 96, 96, 1)


In [5]:
import tensorflow as tf
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU')]

In [6]:
model = WDSRConv3D(scale=3, numFilters=32, kernelSize=(3, 3, 3), numResBlocks=8,
                expRate=8, decayRate=0.8, numImgLR=9, patchSizeLR=32, isGrayScale=True)

In [7]:
optimizer = Nadam(learning_rate=5e-4)
checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     psnr=tf.Variable(1.0),
                                     optimizer=optimizer,
                                     model=model)
checkpointManager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                   directory='/home/mark/DataBank/PROBA-V-CHKPT/models',
                                                   max_to_keep=5)

In [45]:
def loadTrainDataAsTFDataSet(X, y, epochs, batchSize, bufferSize):
    return tf.data.Dataset.from_tensor_slices(
        (X, y, y.mask)).shuffle(bufferSize, reshuffle_each_iteration=True).repeat(epochs).batch(batchSize).prefetch(tf.data.experimental.AUTOTUNE)


def loadValDataAsTFDataSet(X, y, valSteps, batchSize, bufferSize):
    return tf.data.Dataset.from_tensor_slices(
        (X, y, y.mask)).shuffle(bufferSize).batch(batchSize).prefetch(tf.data.experimental.AUTOTUNE).take(valSteps)

In [40]:
X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.7, random_state=17)

In [41]:
test = loadTrainDataAsTFDataSet(X_train, y_train, 100, 100, 512)

In [42]:
for a, b, c in test.as_numpy_iterator():
    print(c.shape)
    break

(100, 96, 96, 1)


In [27]:
valData = [X_val, y_val]

# Initialize metrics
trainLoss = Mean(name='trainLoss')
trainPSNR = Mean(name='trainPSNR')
testLoss = Mean(name='testLoss')
testPSNR = Mean(name='testPSNR')

In [52]:
fitTrainData(model, optimizer, [trainLoss, trainPSNR, testLoss, testPSNR], shiftCompensatedL1Loss,
                 shiftCompensatedcPSNR,
                 X_train, y_train, 1024, 1000, 512, valData, 100,
                 checkpoint, checkpointManager,
                 '/home/mark/DataBank/PROBA-V-CHKPT/logs', '/home/mark/DataBank/PROBA-V-CHKPT/models', 1)

(15768, 32, 32, 9, 1) (96, 96, 1)


TypeError: in converted code:


    TypeError: tf__trainStep() missing 6 required positional arguments: 'mask', 'checkpoint', 'loss', 'metric', 'trainLoss', and 'trainPSNR'


In [55]:
def fitTrainData(model: tf.keras.Model, optimizer: tf.keras.optimizers,
                 metrics: List[tf.keras.metrics.Mean],
                 lossFunc,
                 PSNRFunc,
                 X: np.ma.array, y: np.ma.array,
                 batchSize: int, epochs: int, bufferSize: int,
                 valData: List[np.ma.array], valSteps: int,
                 checkpoint: tf.train.Checkpoint, checkpointManager: tf.train.CheckpointManager,
                 logDir: str, ckptDir: str, saveBestOnly: bool):

    trainSet = loadTrainDataAsTFDataSet(X, y, epochs, batchSize, bufferSize)
    valSet = loadValDataAsTFDataSet(valData[0], valData[1], valSteps, batchSize, bufferSize)

    # Logger
    w = tf.summary.create_file_writer(logDir)

    dataSetLength = len(X)
    totalSteps = tf.cast(dataSetLength/batchSize, tf.int64)
    globalStep = tf.cast(checkpoint.step, tf.int64)
    step = globalStep % totalSteps
    epoch = 0

    # Metrics
    trainLoss, trainPSNR, testLoss, testPSNR = metrics

    with w.as_default():
        for x_batch_train, y_batch_train, y_mask_batch_train in trainSet:
            if (totalSteps - step) == 0:
                epoch += 1
                step = globalStep % totalSteps
                logger.info('Start of epoch %d' % (epoch))
                # Reset metrics
                trainLoss.reset_states()
                trainPSNR.reset_states()
                testLoss.reset_states()
                testPSNR.reset_states()

            step += 1
            globalStep += 1
            trainStep(x_batch_train, y_batch_train, y_mask_batch_train, checkpoint,
                      lossFunc, PSNRFunc, trainLoss, trainPSNR)
            checkpoint.step.assign_add(1)

            t = f"step {step}/{int(totalSteps)}, loss: {trainLoss.result():.3f}, psnr: {trainPSNR.result():.3f}"
            logger.info(t)

            tf.summary.scalar('Train PSNR', trainPSNR.result(), step=globalStep)

            tf.summary.scalar('Train loss', trainLoss.result(), step=globalStep)

            if step != 0 and (step % opt.evalTestStep) == 0:
                # Reset states for test
                testLoss.reset_states()
                testPSNR.reset_states()
                for x_batch_val, y_batch_val, y_mask_batch_val in valSet:
                    testStep(x_batch_val, y_batch_val, y_mask_batch_val, checkpoint,
                             lossFunc, PSNRFunc, testLoss, testPSNR)
                tf.summary.scalar(
                    'Test loss', testLoss.result(), step=globalStep)
                tf.summary.scalar(
                    'Test PSNR', testPSNR.result(), step=globalStep)
                t = f"Validation results... val_loss: {testLoss.result():.3f}, val_psnr: {testPSNR.result():.3f}"
                logger.info(t)
                w.flush()

                if saveBestOnly and (testPSNR.result() <= checkpoint.psnr):
                    continue

                checkpoint.psnr = testPSNR.result()
                checkpointManager.save()


@tf.function
def trainStep(patchLR, patchHR, maskHR, checkpoint, loss, metric, trainLoss, trainPSNR):
    with tf.GradientTape() as tape:

        predPatchHR = checkpoint.model(patchLR, training=True)
        loss = loss(patchHR, maskHR, predPatchHR)  # Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor)

    gradients = tape.gradient(loss, checkpoint.model.trainable_variables)
    checkpoint.optimizer.apply_gradients(zip(gradients, checkpoint.model.trainable_variables))

    metric = metric(patchHR, maskHR, predPatchHR)
    trainLoss(loss)
    trainPSNR(metric)


@tf.function
def testStep(patchLR, patchHR, maskHR, checkpoint, loss, metric, testLoss, testPSNR):
    sr = checkpoint.model(patchLR, training=False)
    loss = loss(patchHR, maskHR, predPatchHR)
    metric = metric(patchHR, maskHR, predPatchHR)

    testLoss(loss)
    testPSNR(metric)