In [20]:
from abc import ABC, abstractmethod
import numpy as np
import math
from scipy.sparse import triu

In [21]:
class PYToolsSubset:
    @staticmethod
    # Reshape H x W x C matrix into HW x C
    def VectorizeMatrix(mat):
        vec = mat.reshape((mat.shape[0]*mat.shape[1], mat.shape[2]))
        return vec
        
    @staticmethod
    # Reshape HW x C matrix into H x W x C
    def UnvectorizeMatrix(vec, rows, cols):
        mat = vec.reshape((rows, cols, vec.shape[1]))
        return mat
        
    @staticmethod
    # Fast replacement of SUB2IND (Linear index from multiple subscripts)
    # in dimension 2; no check is performed.
    def Sub2ind_fast2(siz, i, j):
        ind = i + (j-1)*siz[0]
        return ind
        
    @staticmethod
    # Fast replacement of SUB2IND (Linear index from multiple subscripts)
    # in dimension 3; no check is performed.
    def Sub2ind_fast3(siz, i, j, k):
        ind = i + (j-1)*siz[0] + (k-1)*siz[0]*siz[1]
        return ind

In [22]:
class LocalTransfer:
    def __init__(self, imgInput, imgMatch, imgTarget):
        assert imgInput is not None, "imgInput must not be empty"

        self.imgSize = imgInput.shape

        # Vectorize all images
        self.imgInput_3R = np.reshape(imgInput, [-1, np.prod(self.imgSize)])
        self.imgMatch_3R = np.reshape(imgMatch, [-1, np.prod(self.imgSize)])
        self.imgTarget_3R = np.reshape(imgTarget, [-1, np.prod(self.imgSize)])

        self.linearIndicesInWindow_NP = None
        self.sparseLinearIndices_pattern_row = None

    @staticmethod
    def fastApplyTransforms(imgInput, localTransform_array3CR):
        C = localTransform_array3CR.shape[1]
        R = localTransform_array3CR.shape[2]
        assert (imgInput.shape[0] * imgInput.shape[1]) == R

        imgInput_31R = np.reshape(np.transpose(np.reshape(imgInput, (R, 3))), (3, 1, R))
        if C == 4:
            imgInput_C1R = np.vstack((imgInput_31R, np.ones((1, 1, R))))
        else:
            imgInput_C1R = imgInput_31R

        imgOutput_31R = localTransform_array3CR @ imgInput_C1R
        imgOutput = np.transpose(np.reshape(np.transpose(imgOutput_31R), (3, R)), (1, 0)).reshape(imgInput.shape)

        return imgOutput

    @staticmethod
    def ApplyPrecomputedTransform(imgInput, localTransform_array3CR, patchWidth):
        if patchWidth == 1:
            imgOutput = fastApplyTransforms(imgInput, localTransform_array3CR) 
        else:
            localTransferRec = LocalTransfer_affineModel(imgInput, None, None)
            localTransferRec.gatherSquarePatches(patchWidth)
            
            # If cached transform is diagonal or linear, extend it to affine
            if localTransform_array3CR.shape[1] < 4:
                localTransform_array3CR[:,3,:] = 0
            
            # Get the linear indices of the patch centers
            patchCenters_linearIndex_1P = \
                localTransferRec.linearIndicesInWindow_NP[(patchWidth**2)//2,:]
            
            # Find the per-patch transform using the patch centers    
            Ak_perPatch_34P = localTransform_array3CR[:,:,patchCenters_linearIndex_1P]
            
            # Apply the transforms on the input image
            localTransferRec.estimateOutput_givenLocalTransforms(Ak_perPatch_34P)
            
            imgOutput = localTransferRec.imgOutput_3R.T.reshape(localTransferRec.imgSize)
        
        return imgOutput

    @abstractmethod
    def computeGlobalTransform(self):
        pass

    @abstractmethod
    def initializeLocalTransforms(self):
        pass

    @abstractmethod
    def estimateLocalTransforms_givenOutput(self):
        pass

    @abstractmethod
    def buildClosedFormMatrices(self):
        pass
    
    @abstractmethod
    def estimateInvBk_perPatch(self):
        pass

    @abstractmethod
    def estimateOutput_givenLocalTransforms(self):
        pass

    def computeError(self):
        # IMPORTANT: This method assumes that the input and match images
        # are identical. It computes the error between the transformed input
        # (or transformed match) and the target image.

        # Compute the per-pixel squared error
        perPixelError_1R = np.sum((self.imgTarget_3R - self.imgOutput_3R) ** 2, axis=1)

        perPixelError_1R = np.sqrt(perPixelError_1R)

        # Sum the error over all pixels of every patch; store the error at the
        # center pixel.
        N = self.linearIndicesInWindow_NP.shape[0]
        P = self.linearIndicesInWindow_NP.shape[1]
        centerPixel_linearIndex_1P = self.linearIndicesInWindow_NP[math.ceil(N / 2) - 1, :]
        perPixelError_NP = np.reshape(perPixelError_1R[self.linearIndicesInWindow_NP], (N, P))
        perPatchError_1R = np.full_like(perPixelError_1R, np.nan)
        perPatchError_1R[centerPixel_linearIndex_1P] = np.sum(perPixelError_NP, axis=0)

        return perPixelError_1R, perPatchError_1R

    def gatherSquarePatches(self, patchWidth):
        halfK = patchWidth // 2
        imageWidth = self.imgSize[1]
        imageHeight = self.imgSize[0]

        # Get the coordinate of the center pixel of all patches in the images,
        # except near the borders
        centerPixel_x = np.tile(np.arange(1 + halfK, imageWidth - halfK), (imageHeight - 2 * halfK, 1))
        centerPixel_y = np.tile(np.arange(1 + halfK, imageHeight - halfK)[:, np.newaxis], (1, imageWidth - 2 * halfK))
        centerPixel_linearIndices = np.ravel_multi_index((centerPixel_y, centerPixel_x), (imageHeight, imageWidth))

        # Get the linear index offsets of all neighbors in a window
        neighbors_offset_col = np.tile(np.arange(-halfK, halfK + 1), (patchWidth, 1))
        neighbors_offset_row = neighbors_offset_col.T
        neighbors_offset_linearIndices = neighbors_offset_col * imageHeight + neighbors_offset_row

        # Get the linear index of each neighbor in each patch
        N = neighbors_offset_linearIndices.size
        patchesLinearIndices_NP = centerPixel_linearIndices.reshape(-1, 1) + neighbors_offset_linearIndices.ravel()

        # Discard patches that contain at least one NaN pixel
        badPatchIndices = findBadPatchIndices(self.imgInput_3R, patchesLinearIndices_NP)
        if self.imgMatch_3R is not None:
            badPatchIndices = np.concatenate((badPatchIndices, findBadPatchIndices(self.imgMatch_3R, patchesLinearIndices_NP)))
        if self.imgTarget_3R is not None:
            badPatchIndices = np.concatenate((badPatchIndices, findBadPatchIndices(self.imgTarget_3R, patchesLinearIndices_NP)))
        patchesLinearIndices_NP = np.delete(patchesLinearIndices_NP, badPatchIndices, axis=1)

        self.linearIndicesInWindow_NP = patchesLinearIndices_NP
        self.sparseLinearIndices_pattern_row = np.tile(np.arange(N), (N, 1))  # used for computing neighborhoods in the adjacency matrices later


    def findBadPatchIndices(img_3R, patchesLinearIndices_NP):
        N, P = patchesLinearIndices_NP.shape

        pixelValues_3NP = img_3R[:, patchesLinearIndices_NP.reshape(-1)]

        nbNaNValuesPerPatch_1P = np.sum(np.isnan(pixelValues_3NP.reshape(3, N, P)), axis=0)

        return np.nonzero(nbNaNValuesPerPatch_1P > 0)[0]

    def perPixelLocalTransforms(self, Ak_perPatch_3CP):
        N = self.linearIndicesInWindow_NP.shape[0]  # number of pixels per patch
        R = self.imgInput_3R.shape[1]  # number of pixels in image
        C = Ak_perPatch_3CP.shape[1]

        # Create C images, where each pixel shows some part of the local
        # transform estimated for the patch centered on this pixel
        localTransform_array3CR = np.empty((3, C, R))
        localTransform_array3CR[:] = np.nan
        patchCenterLinearIndex_1P = self.linearIndicesInWindow_NP[N // 2, :]
        localTransform_array3CR[:, :, patchCenterLinearIndex_1P] = Ak_perPatch_3CP

        imgMask_HW = None
        if (nargout > 1):  # FIND OUT nargout!!!!
            isTransformNaN_1R = np.sum(np.isnan(localTransform_array3CR.reshape(3*C, R)), axis=0) > 0
            imgMask_HW = np.reshape(~isTransformNaN_1R, self.imgSize)

        return localTransform_array3CR, imgMask_HW

    def perPixelTransformError(self, localTransform_array3CR):
        N = obj.linearIndicesInWindow_NP.shape[0]  # number of pixels per patch
        P = obj.linearIndicesInWindow_NP.shape[1]  # number of patches
        R = obj.imgInput_3R.shape[1]  # total number of pixels
        C = localTransform_array3CR.shape[1]

        # Gather matrix vk(M) and vk(T) for each patch
        if C == 3:
            M_CNP = np.reshape(obj.imgMatch_3R[:, obj.linearIndicesInWindow_NP.flatten()],
                            (3, *obj.linearIndicesInWindow_NP.shape))
        elif C == 4:
            M_CNP = np.reshape(obj.imgMatch_4R[:, obj.linearIndicesInWindow_NP.flatten()],
                            (4, *obj.linearIndicesInWindow_NP.shape))
        T_3NP = np.reshape(obj.imgTarget_3R[:, obj.linearIndicesInWindow_NP.flatten()],
                        (3, *obj.linearIndicesInWindow_NP.shape))

        # Get linear index of the center of each pixel
        centerPixel_linearIndex_1P = obj.linearIndicesInWindow_NP[np.ceil(N / 2).astype(int) - 1, :]
        Ak_perPatch_3CP = localTransform_array3CR[:, :, centerPixel_linearIndex_1P]

        # Apply the local transforms on each neighborhood
        transformedM_3NP = np.matmul(Ak_perPatch_3CP, M_CNP)

        # Sum the error over all pixels of every patch; store the error at the
        # center pixel.
        perTransformError_1P = np.sum(np.reshape((T_3NP - transformedM_3NP) ** 2, (-1, P)), axis=0)
        perTransformError_1R = np.full((1, R), np.nan)
        perTransformError_1R[:, centerPixel_linearIndex_1P] = perTransformError_1P

        return perTransformError_1R

    def transfer_closedForm(self):
        # Parameters
        epsilon = 1
        gamma = 0.01

        # Compute global transform, and globally transformed image
        globalTransform_3C = obj.computeGlobalTransform()

        # Construct system matrix and right side
        M_sparseRR, u_3R, invBk_perPatch_CCP = obj.buildClosedFormMatrices(globalTransform_3C, epsilon, gamma)

        # Make sure matrix M is symmetric
        # (sometimes there are some inaccuracies, and M-M'=1e-14 in some cells)
        # by replacing the lower triangle by the transpose of the upper triangle
        M_sparseRR = triu(M_sparseRR) + triu(M_sparseRR, 1).transpose()

        # Solve linear system
        obj.imgOutput_3R = np.linalg.solve(M_sparseRR, u_3R)

        # Output transformed image
        imgOutput = obj.imgOutput_3R.transpose().reshape(obj.imgSize)

        # Estimate the best-matching local transforms
        Ak_perPatch_3CP = obj.estimateLocalTransforms_givenOutput(invBk_perPatch_CCP, globalTransform_3C, epsilon, gamma)

        # Output the per-pixel local transform and corresponding binary mask
        localTransform_array3CR, imgMask_HW = obj.perPixelLocalTransforms(Ak_perPatch_3CP)

        return imgOutput, localTransform_array3CR, imgMask_HW

    def transfer_iterative(self, nbIterations):
        # Parameters
        epsilon = 1
        gamma = 0.01

        # Initialize lists that will contain per-iteration results
        savePerIterationResults = (nargout > 3)
        if savePerIterationResults:
            imgOutput_perIteration = []
            localTransform_array3CR_perIteration = []

        # Compute global transform, and globally transformed image
        globalTransform_3C = obj.computeGlobalTransform()

        # Precompute Bk^-1 for each patch
        invBk_perPatch_CCP = obj.estimateInvBk_perPatch(epsilon, gamma)

        for it in range(1, nbIterations + 1):

            if (it == 1):
                # Initialize local transforms
                Ak_perPatch_3CP = obj.initializeLocalTransforms(
                    globalTransform_3C, epsilon, gamma)
            else:
                # Re-estimate the local transforms, given the new output
                Ak_perPatch_3CP = obj.estimateLocalTransforms_givenOutput(
                    invBk_perPatch_CCP, globalTransform_3C, epsilon, gamma)

            # Apply transform on each patch to estimate the output image
            obj.estimateOutput_givenLocalTransforms(Ak_perPatch_3CP)

            if (savePerIterationResults or (it == nbIterations)):

                # Output transformed image
                imgOutput = obj.imgOutput_3R.T.reshape(obj.imgSize)

                # Output the per-pixel local transform and corresponding binary mask
                localTransform_array3CR, imgMask_HW = obj.perPixelLocalTransforms(
                    Ak_perPatch_3CP)

                # Also save the per-iteration results
                if savePerIterationResults:
                    imgOutput_perIteration.append(imgOutput)
                    localTransform_array3CR_perIteration.append(localTransform_array3CR)

        return (imgOutput, localTransform_array3CR, imgMask_HW, imgOutput_perIteration, localTransform_array3CR_perIteration)
