In [3]:
from typing import List
import tensorflow as tf

In [94]:
CROP_BORDER = 3
MAX_PIXEL_SHIFT = 2*CROP_BORDER

def shiftCompensatedcPSNR(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor) -> float:
    '''
    The maximum cPSNR of every possible pixel shift between the predicted HR image and the ground truth.
    This is how the ESA has been computing the submissions of the contestants.
    See details at the ff link: https://kelvins.esa.int/proba-v-super-resolution/scoring/
    '''
    N, H, W, C = tf.shape(patchHR)

    cropSizeHeight = H - MAX_PIXEL_SHIFT
    cropSizeWidth = W - MAX_PIXEL_SHIFT
    cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
    cachecPSNR = []

    # Iterate through all possible shift configurations
    for i in tf.range(MAX_PIXEL_SHIFT+1):
        for j in tf.range(MAX_PIXEL_SHIFT+1):
            stackcPSNR(i, j, patchHR, maskHR, cropPrediction, cachecPSNR)
    cachecPSNR = tf.stack(cachecPSNR)
    maxcPSNR = tf.reduce_max(cachecPSNR)
    return maxcPSNR



def shiftCompensatedL2Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor) -> float:
    '''
    The minimum L2 Loss of every possible pixel shift between the predicted HR image and the ground truth.
    This is how the ESA has been computing the submissions of the contestants.
    See details at the ff link: https://kelvins.esa.int/proba-v-super-resolution/scoring/
    '''
    N, H, W, C = tf.shape(patchHR)

    cropSizeHeight = H - MAX_PIXEL_SHIFT
    cropSizeWidth = W - MAX_PIXEL_SHIFT
    cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
    cacheLosses = []

    # Iterate through all possible shift configurations
    for i in tf.range(MAX_PIXEL_SHIFT+1):
        for j in tf.range(MAX_PIXEL_SHIFT+1):
            stackL2Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)
    cacheLosses = tf.stack(cacheLosses)
    minLoss = tf.reduce_min(cacheLosses)
    return minLoss


def shiftCompensatedL1Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor) -> float:
    '''
    The minimum L1 Loss of every possible pixel shift between the predicted HR image and the ground truth.
    This is how the ESA has been computing the submissions of the contestants.
    See details at the ff link: https://kelvins.esa.int/proba-v-super-resolution/scoring/
    '''
    N, H, W, C = tf.shape(patchHR)

    cropSizeHeight = H - MAX_PIXEL_SHIFT
    cropSizeWidth = W - MAX_PIXEL_SHIFT
    cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
    cacheLosses = []

    # Iterate through all possible shift configurations
    for i in tf.range(MAX_PIXEL_SHIFT+1):
        for j in tf.range(MAX_PIXEL_SHIFT+1):
            stackL1Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)
    cacheLosses = tf.stack(cacheLosses)
    minLoss = tf.reduce_min(cacheLosses)
    return minLoss


def stackL1Loss(i: int, j: int, patchHR: tf.Tensor, maskHR: tf.Tensor, cropPred: tf.Tensor, cache: List[float]):
    N, cropSizeHeight, cropSizeWidth, C = tf.shape(cropPred)
    cropTrueImg = cropImage(patchHR, i, cropSizeHeight, j, cropSizeWidth)
    cropTrueMsk = cropImage(maskHR, i, cropSizeHeight, j, cropSizeWidth)
    cropPredMskd = cropPred * cropTrueMsk
    totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

    b = computeBiasBrightness(totalClearPixels, cropTrueImg, cropPredMskd)

    correctedCropPred = cropPred + b
    correctedCropPredMskd = correctedCropPred * cropTrueMsk

    L1Loss = computeL1Loss(totalClearPixels, cropTrueImg, correctedCropPredMskd)
    cache.append(L1Loss)



def stackL2Loss(i: int, j: int, patchHR: tf.Tensor, maskHR: tf.Tensor, cropPred: tf.Tensor, cache: List[float]):
    N, cropSizeHeight, cropSizeWidth, C = tf.shape(cropPred)
    cropTrueImg = cropImage(patchHR, i, cropSizeHeight, j, cropSizeWidth)
    cropTrueMsk = cropImage(maskHR, i, cropSizeHeight, j, cropSizeWidth)
    cropPredMskd = cropPred * cropTrueMsk
    totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

    b = computeBiasBrightness(totalClearPixels, cropTrueImg, cropPredMskd)

    correctedCropPred = cropPred + b
    correctedCropPredMskd = correctedCropPred * cropTrueMsk

    L2Loss = computeL2Loss(totalClearPixels, cropTrueImg, correctedCropPredMskd)
    cache.append(L2Loss)


def stackcPSNR(i: int, j: int, patchHR: tf.Tensor, maskHR: tf.Tensor, cropPred: tf.Tensor, cache: List[float]):
    N, cropSizeHeight, cropSizeWidth, C = tf.shape(cropPred)
    cropTrueImg = cropImage(patchHR, i, cropSizeHeight, j, cropSizeWidth)
    cropTrueMsk = cropImage(maskHR, i, cropSizeHeight, j, cropSizeWidth)
    cropPredMskd = cropPred * cropTrueMsk
    totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

    b = computeBiasBrightness(totalClearPixels, cropTrueImg, cropPredMskd)

    correctedCropPred = cropPred + b
    correctedCropPredMskd = correctedCropPred * cropTrueMsk

    cPSNR = computecPSNR(totalClearPixels, cropTrueImg, correctedCropPredMskd)
    cache.append(cPSNR)


def computeL1Loss(totalClearPixels, HR, correctedSR):
    loss = (1.0 / totalClearPixels) * tf.reduce_sum(tf.abs(tf.subtract(HR, correctedSR)), axis=(1, 2))
    return loss


def computeL2Loss(totalClearPixels, HR, correctedSR):
    loss = (1.0 / totalClearPixels) * tf.reduce_sum(tf.square(tf.subtract(HR, correctedSR)), axis=(1, 2))
    return loss


def computecPSNR(totalClearPixels, HR, correctedSR):
    loss = (1.0 / totalClearPixels) * tf.reduce_sum(tf.square(tf.subtract(HR, correctedSR)), axis=(1, 2))
    cPSNR = -tf.math.log(loss) / tf.math.log(tf.constant(10, dtype=tf.float32))
    return cPSNR


def computeBiasBrightness(totalClearPixels, HR, SR):
    N, H, W, C = tf.shape(HR)
    b = (1.0 / totalClearPixels) * tf.reduce_sum(tf.subtract(HR, SR), axis=(1, 2, 3))
    b = tf.reshape(b, (N, 1, 1, C))
    return b


def cropImage(imgBatch: tf.Tensor, startIdxH: int, lengthHeight: int,
              startIdxW: int, lengthWidth: int) -> tf.Tensor:
    return tf.cast(imgBatch[:, startIdxH: startIdxH + lengthHeight, startIdxW: startIdxW + lengthWidth, :], tf.float32)


In [5]:
import numpy as np

In [83]:
patchHR = np.random.randint(0, 1400, (10, 96, 96, 1))
maskHR = np.random.randint(0, 2, (10, 96, 96, 1)).astype(np.bool)
predPatchHR = np.random.randint(0, 1400, (10, 96, 96, 1))
patchHR = tf.convert_to_tensor(patchHR, dtype=tf.float32)
maskHR = tf.convert_to_tensor(maskHR, dtype=tf.bool)
predPatchHR = tf.convert_to_tensor(predPatchHR, dtype=tf.float32)

In [92]:


N, H, W, C = tf.shape(patchHR)
tf.print(H, W, C)
cropSizeHeight = H - MAX_PIXEL_SHIFT
cropSizeWidth = W - MAX_PIXEL_SHIFT
cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
cacheLosses = []

for i in range(MAX_PIXEL_SHIFT+1):
    for j in range(MAX_PIXEL_SHIFT+1):

        N, cropSizeHeight, cropSizeWidth, C = tf.shape(cropPrediction)
        cropTrueImg = cropImage(patchHR, i, cropSizeHeight, j, cropSizeWidth)
        cropTrueMsk = cropImage(maskHR, i, cropSizeHeight, j, cropSizeWidth)
        cropPredMskd = cropPrediction * cropTrueMsk
        totalClearPixels = tf.reduce_sum(cropTrueMsk, axis=(1, 2, 3))

        N, H, W, C = tf.shape(cropTrueImg)
        b = (1.0 / totalClearPixels) * tf.reduce_sum(tf.subtract(cropTrueImg, cropPredMskd), axis=(1, 2, 3))
        tf.print(tf.shape(b))
        b = tf.reshape(b, (N, 1, 1, C))
        tf.print(tf.shape(b))



        correctedCropPred = cropPrediction + b
        correctedCropPredMskd = correctedCropPred * cropTrueMsk

        L1Loss = computeL1Loss(totalClearPixels, cropTrueImg, correctedCropPredMskd)
        cacheLosses.append(L1Loss)

96 96 1
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]
[10]
[10 1 1 1]


In [95]:
stackL1Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)

In [64]:
H, W, C = tf.shape(patchHR)

cropSizeHeight = H - MAX_PIXEL_SHIFT
cropSizeWidth = W - MAX_PIXEL_SHIFT
cropPrediction = cropImage(predPatchHR, CROP_BORDER, cropSizeHeight, CROP_BORDER, cropSizeWidth)
cacheLosses = []

for i in range(MAX_PIXEL_SHIFT+1):
    for j in range(MAX_PIXEL_SHIFT+1):
        stackL1Loss(i, j, patchHR, maskHR, cropPrediction, cacheLosses)

In [71]:
tf.print(shiftCompensatedL2Loss(patchHR, maskHR, predPatchHR))

1418917.38


In [72]:
tf.print(shiftCompensatedcPSNR(patchHR, maskHR, predPatchHR))

-6.15195656


In [73]:
tf.print(shiftCompensatedL1Loss(patchHR, maskHR, predPatchHR))

1421.98535
