In [1]:
from typing import List
import tensorflow as tf
from tensorflow.keras.optimizers import Adam, SGD, Nadam

In [16]:
CROP_BORDER = 3
MAX_PIXEL_SHIFT = 2*CROP_BORDER

def shiftCompensatedcPSNR(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor) -> float:
    '''
    The maximum cPSNR of every possible pixel shift between the predicted HR image and the ground truth.
    This is how the ESA has been computing the submissions of the contestants.
    See details at the ff link: https://kelvins.esa.int/proba-v-super-resolution/scoring/
    '''
    N, H, W, C = tf.shape(patchHR)

    cropSizeHeight = H - MAX_PIXEL_SHIFT
    cropSizeWidth = W - MAX_PIXEL_SHIFT
    cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
    cachecPSNR = []

    # Iterate through all possible shift configurations
    for i in tf.range(MAX_PIXEL_SHIFT+1):
        for j in tf.range(MAX_PIXEL_SHIFT+1):
            stackcPSNR(i, j, patchHR, maskHR, cropPrediction, cachecPSNR)
    cachecPSNR = tf.stack(cachecPSNR)
    maxcPSNR = tf.reduce_max(cachecPSNR)
    return maxcPSNR



def shiftCompensatedL2Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor) -> float:
    '''
    The minimum L2 Loss of every possible pixel shift between the predicted HR image and the ground truth.
    This is how the ESA has been computing the submissions of the contestants.
    See details at the ff link: https://kelvins.esa.int/proba-v-super-resolution/scoring/
    '''
    N, H, W, C = tf.shape(patchHR)

    cropSizeHeight = H - MAX_PIXEL_SHIFT
    cropSizeWidth = W - MAX_PIXEL_SHIFT
    cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
    cacheLosses = []

    # Iterate through all possible shift configurations
    for i in tf.range(MAX_PIXEL_SHIFT+1):
        for j in tf.range(MAX_PIXEL_SHIFT+1):
            stackL2Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)
    cacheLosses = tf.stack(cacheLosses)
    minLoss = tf.reduce_min(cacheLosses)
    return minLoss

def shiftCompensatedL1Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor) -> float:
    '''
    The minimum L1 Loss of every possible pixel shift between the predicted HR image and the ground truth.
    This is how the ESA has been computing the submissions of the contestants.
    See details at the ff link: https://kelvins.esa.int/proba-v-super-resolution/scoring/
    '''
    theShape = tf.shape(patchHR)

    cropSizeHeight = 96 - MAX_PIXEL_SHIFT
    cropSizeWidth = 96 - MAX_PIXEL_SHIFT
    cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
    cacheLosses = []

    # Iterate through all possible shift configurations
    for i in range(MAX_PIXEL_SHIFT+1):
        for j in range(MAX_PIXEL_SHIFT+1):
            stackL1Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)
    cacheLosses = tf.stack(cacheLosses)
    minLoss = tf.reduce_min(cacheLosses)
    return minLoss


def shiftCompensatedL1Lossv2(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor) -> float:
    '''
    The minimum L1 Loss of every possible pixel shift between the predicted HR image and the ground truth.
    This is how the ESA has been computing the submissions of the contestants.
    See details at the ff link: https://kelvins.esa.int/proba-v-super-resolution/scoring/
    '''
    theShape = tf.shape(patchHR)

    cropSizeHeight = 96 - MAX_PIXEL_SHIFT
    cropSizeWidth = 96 - MAX_PIXEL_SHIFT
    cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
    cacheLosses = []

    # Iterate through all possible shift configurations
    for i in range(MAX_PIXEL_SHIFT+1):
        for j in range(MAX_PIXEL_SHIFT+1):
            theShape = tf.shape(cropPrediction)
            cropTrueImg = cropImage(patchHR, i, theShape[1], j, theShape[2])
            cropTrueMsk = cropImage(maskHR, i, theShape[1], j, theShape[2])
            cropPredMskd = cropPrediction * cropTrueMsk
            totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

            b = (1.0 / totalClearPixels) * tf.reduce_sum(tf.subtract(cropTrueMsk, cropPredMskd), axis=(1, 2, 3))
            b = tf.reshape(b, (theShape[0], 1, 1, theShape[3]))

    correctedCropPred = cropPrediction + b
    correctedCropPredMskd = correctedCropPred * cropTrueMsk

    L1Loss = computeL1Loss(totalClearPixels, cropTrueImg, correctedCropPredMskd)
    cacheLosses.append(L1Loss)
    cacheLosses = tf.stack(cacheLosses)
    minLoss = tf.reduce_min(cacheLosses)
    return minLoss


def stackL1Loss(i: int, j: int, patchHR: tf.Tensor, maskHR: tf.Tensor, cropPred: tf.Tensor, cache: List[float]):
    #N, cropSizeHeight, cropSizeWidth, C = tf.shape(cropPred)
    theShape = tf.shape(cropPred)
    cropTrueImg = cropImage(patchHR, i, theShape[1], j, theShape[2])
    cropTrueMsk = cropImage(maskHR, i, theShape[1], j, theShape[2])
    cropPredMskd = cropPred * cropTrueMsk
    totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

    b = computeBiasBrightness(totalClearPixels, cropTrueImg, cropPredMskd)

    correctedCropPred = cropPred + b
    correctedCropPredMskd = correctedCropPred * cropTrueMsk

    L1Loss = computeL1Loss(totalClearPixels, cropTrueImg, correctedCropPredMskd)
    cache.append(L1Loss)



def stackL2Loss(i: int, j: int, patchHR: tf.Tensor, maskHR: tf.Tensor, cropPred: tf.Tensor, cache: List[float]):
    N, cropSizeHeight, cropSizeWidth, C = tf.shape(cropPred)
    cropTrueImg = cropImage(patchHR, i, cropSizeHeight, j, cropSizeWidth)
    cropTrueMsk = cropImage(maskHR, i, cropSizeHeight, j, cropSizeWidth)
    cropPredMskd = cropPred * cropTrueMsk
    totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

    b = computeBiasBrightness(totalClearPixels, cropTrueImg, cropPredMskd)

    correctedCropPred = cropPred + b
    correctedCropPredMskd = correctedCropPred * cropTrueMsk

    L2Loss = computeL2Loss(totalClearPixels, cropTrueImg, correctedCropPredMskd)
    cache.append(L2Loss)


def stackcPSNR(i: int, j: int, patchHR: tf.Tensor, maskHR: tf.Tensor, cropPred: tf.Tensor, cache: List[float]):
    N, cropSizeHeight, cropSizeWidth, C = tf.shape(cropPred)
    cropTrueImg = cropImage(patchHR, i, cropSizeHeight, j, cropSizeWidth)
    cropTrueMsk = cropImage(maskHR, i, cropSizeHeight, j, cropSizeWidth)
    cropPredMskd = cropPred * cropTrueMsk
    totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

    b = computeBiasBrightness(totalClearPixels, cropTrueImg, cropPredMskd)

    correctedCropPred = cropPred + b
    correctedCropPredMskd = correctedCropPred * cropTrueMsk

    cPSNR = computecPSNR(totalClearPixels, cropTrueImg, correctedCropPredMskd)
    cache.append(cPSNR)


def computeL1Loss(totalClearPixels, HR, correctedSR):
    loss = (1.0 / totalClearPixels) * tf.reduce_sum(tf.abs(tf.subtract(HR, correctedSR)), axis=(1, 2))
    return loss


def computeL2Loss(totalClearPixels, HR, correctedSR):
    loss = (1.0 / totalClearPixels) * tf.reduce_sum(tf.square(tf.subtract(HR, correctedSR)), axis=(1, 2))
    return loss


def computecPSNR(totalClearPixels, HR, correctedSR):
    loss = (1.0 / totalClearPixels) * tf.reduce_sum(tf.square(tf.subtract(HR, correctedSR)), axis=(1, 2))
    cPSNR = -tf.math.log(loss) / tf.math.log(tf.constant(10, dtype=tf.float32))
    return cPSNR


def computeBiasBrightness(totalClearPixels, HR, SR):
    theShape = tf.shape(HR)
    b = (1.0 / totalClearPixels) * tf.reduce_sum(tf.subtract(HR, SR), axis=(1, 2, 3))
    b = tf.reshape(b, (theShape[0], 1, 1, theShape[3]))
    return b


def cropImage(imgBatch: tf.Tensor, startIdxH: int, lengthHeight: int,
              startIdxW: int, lengthWidth: int) -> tf.Tensor:
    return tf.cast(imgBatch[:, startIdxH: startIdxH + lengthHeight, startIdxW: startIdxW + lengthWidth, :], tf.float32)


In [3]:
import numpy as np

In [4]:
patchLR = np.random.randint(0, 1400, (10, 32, 32, 9, 1))
patchHR = np.random.randint(0, 1400, (10, 96, 96, 1))
maskHR = np.random.randint(0, 2, (10, 96, 96, 1)).astype(np.bool)
predPatchHR = np.random.randint(0, 1400, (10, 96, 96, 1))
patchLR = tf.convert_to_tensor(patchLR, dtype=tf.float32)
patchHR = tf.convert_to_tensor(patchHR, dtype=tf.float32)
maskHR = tf.convert_to_tensor(maskHR, dtype=tf.bool)
predPatchHR = tf.convert_to_tensor(predPatchHR, dtype=tf.float32)

In [14]:
optimizer = Nadam(learning_rate=5e-4)

@tf.function
def trainStep(patchLR, patchHR, maskHR, model):
    with tf.GradientTape() as tape:

        predPatchHR = model(patchLR, training=True)
        loss = shiftCompensatedL1Lossv2(patchHR, maskHR, predPatchHR)  # Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [95]:
stackL1Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)

In [64]:
H, W, C = tf.shape(patchHR)

cropSizeHeight = H - MAX_PIXEL_SHIFT
cropSizeWidth = W - MAX_PIXEL_SHIFT
cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
cacheLosses = []

for i in range(MAX_PIXEL_SHIFT+1):
    for j in range(MAX_PIXEL_SHIFT+1):
        stackL1Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)

In [6]:
tf.print(shiftCompensatedL2Loss(patchHR, maskHR, predPatchHR))

1388238.5


In [7]:
tf.print(shiftCompensatedcPSNR(patchHR, maskHR, predPatchHR))

-6.14246416


In [8]:
tf.print(shiftCompensatedL1Loss(patchHR, maskHR, predPatchHR))

1395.79224


In [11]:
model = WDSRConv3D(scale=3, numFilters=32, kernelSize=(3, 3, 3), numResBlocks=8,
                expRate=8, decayRate=0.8, numImgLR=9, patchSizeLR=32, isGrayScale=True)

In [17]:
trainStep(patchLR, patchHR, maskHR, model)

NameError: in converted code:

    <ipython-input-10-4cbf4aa26191>:8 trainStep  *
        loss = shiftCompensatedL1Lossv2(patchHR, maskHR, predPatchHR)  # Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor)
    <ipython-input-16-9dc6a4e0f780>:99 shiftCompensatedL1Lossv2  *
        cache.append(L1Loss)

    NameError: name 'cache' is not defined


In [6]:
import tensorflow as tf
from tensorflow_addons.layers import WeightNormalization
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv3D, Conv2D, Lambda, Add, Reshape


def WDSRConv3D(scale: int, numFilters: int, kernelSize: tuple,
               numResBlocks: int, expRate: int, decayRate: float,
               numImgLR: int, patchSizeLR: int, isGrayScale: bool) -> Model:
    # Define inputs
    imgLRIn = Input(shape=(patchSizeLR, patchSizeLR, numImgLR, 1)) if isGrayScale \
        else Input(shape=(patchSizeLR, patchSizeLR, numImgLR, 3))

    # Get mean of instance mean patch and over all mean pixel value
    imgLR = Lambda(lambda x: tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]], mode='REFLECT'))(imgLRIn)
    meanImgLR = Lambda(lambda x: tf.reduce_mean(x, axis=3))(imgLR)
    allMean = Lambda(lambda x: tf.reduce_mean(x))(imgLR)
    allStdDev = Lambda(lambda x: tf.math.reduce_std(x))(imgLR)

    # Normalize Instance
    imgLR = Lambda(lambda x: tf.math.divide(tf.math.subtract(x, allMean), allStdDev))(imgLR)
    meanImgLR = Lambda(lambda x: tf.math.divide(tf.math.subtract(x, allMean), allStdDev))(meanImgLR)

    # ImgResBlocks | Main Path
    main = WDSRNetMainPath(imgLR, numFilters, kernelSize,
                           numResBlocks, patchSizeLR, numImgLR,
                           scale, expRate, decayRate)

    # MeanResBlocks | Residual Path
    residual = WDSRNetResidualPath(meanImgLR, kernelSize[:-1], scale)

    # Fuse Main and Residual Patch
    out = Add()([main, residual])

    # Denormalize Instance
    out = Lambda(lambda x: tf.math.add(tf.math.multiply(x, allStdDev), allMean))(out)

    return Model(imgLRIn, out, name='WDSRConv3D')


def WDSRNetResidualPath(meanImgLR: tf.Tensor, kernelSize: tuple, scale: int):
    x = Lambda(lambda x: tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT'))(meanImgLR)
    x = weightNormedConv2D(outChannels=scale*scale, kernelSize=kernelSize, padding='valid', activation='relu')(x)
    x = weightNormedConv2D(outChannels=scale*scale, kernelSize=kernelSize, padding='valid')(x)
    x = Lambda(lambda x: tf.nn.depth_to_space(x, scale))(x)
    return x


def WDSRNetMainPath(imgLR: tf.Tensor, numFilters: int, kernelSize: tuple,
                    numResBlocks: int, patchSizeLR: int, numImgLR: int,
                    scale: int, expRate: int, decayRate: int):
    x = weightNormedConv3D(numFilters, kernelSize, 'same', activation='relu')(imgLR)
    for _ in range(numResBlocks):
        x = ResConv3D(x, numFilters, expRate, decayRate, kernelSize)

    x = ConvReduceAndUpscale(x, numImgLR, scale, numFilters, kernelSize)
    x = Reshape((patchSizeLR, patchSizeLR, scale*scale))(x)
    x = Lambda(lambda x: tf.nn.depth_to_space(x, scale))(x)
    return x


def ConvReduceAndUpscale(x: tf.Tensor, numImgLR: int, scale: int, numFilters: int, kernelSize: tuple):
    # Conv Reducer
    for _ in range(numImgLR//scale):
        x = Lambda(lambda x: tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]], mode='reflect'))(x)
        x = weightNormedConv3D(numFilters, kernelSize, padding='valid', activation='relu')(x)
    # Upscale block
    x = weightNormedConv3D(outChannels=scale*scale, kernelSize=kernelSize, padding='valid')(x)
    return x


def ResConv3D(xIn: tf.Tensor, numFilters: int, expRate: int, decayRate: float, kernelSize: int):
    # Expansion Conv3d | Same padding
    x = weightNormedConv3D(outChannels=numFilters*expRate, kernelSize=1, padding='same', activation='relu')(xIn)
    # Decay Conv3d | Same padding
    x = weightNormedConv3D(outChannels=int(numFilters*decayRate), kernelSize=1, padding='same')(x)
    # Norm Conv3D | Same padding
    x = weightNormedConv3D(outChannels=numFilters, kernelSize=kernelSize, padding='same')(x)
    # Add input and result
    out = Add()([x, xIn])
    return out


def weightNormedConv3D(outChannels: int, kernelSize: int, padding: str, activation=None):
    return WeightNormalization(Conv3D(outChannels, kernelSize, padding=padding, activation=activation),
                               data_init=False)


def weightNormedConv2D(outChannels: int, kernelSize: int, padding: str, activation=None):
    return WeightNormalization(Conv2D(outChannels, kernelSize, padding=padding, activation=activation),
                               data_init=False)
