In [3]:
from abc import ABC, abstractmethod
import numpy as np
import math
from scipy.sparse import triu
from scipy.sparse import coo_matrix
from skimage.io import imread
from skimage.transform import resize
import cv2
import matplotlib.pyplot as plt

In [4]:
class PYToolsSubset:
    @staticmethod
    # Reshape H x W x C matrix into HW x C
    def VectorizeMatrix(mat):
        vec = mat.reshape((mat.shape[0]*mat.shape[1], mat.shape[2]))
        return vec
        
    @staticmethod
    # Reshape HW x C matrix into H x W x C
    def UnvectorizeMatrix(vec, rows, cols):
        mat = vec.reshape((rows, cols, vec.shape[1]))
        return mat
        
    @staticmethod
    # Fast replacement of SUB2IND (Linear index from multiple subscripts)
    # in dimension 2; no check is performed.
    def Sub2ind_fast2(siz, i, j):
        ind = i + (j-1)*siz[0]
        return ind
        
    @staticmethod
    # Fast replacement of SUB2IND (Linear index from multiple subscripts)
    # in dimension 3; no check is performed.
    def Sub2ind_fast3(siz, i, j, k):
        ind = i + (j-1)*siz[0] + (k-1)*siz[0]*siz[1]
        return ind

In [5]:
class LocalTransfer:
    def __init__(self, imgInput, imgMatch, imgTarget):
        assert imgInput is not None, "imgInput must not be empty"

        self.imgSize = imgInput.shape

        # Vectorize all images
        # self.imgInput_3R = np.reshape(imgInput, [-1, np.prod(self.imgSize)])
        # self.imgMatch_3R = np.reshape(imgMatch, [-1, np.prod(self.imgSize)])
        # self.imgTarget_3R = np.reshape(imgTarget, [-1, np.prod(self.imgSize)])
        self.imgInput_3R = PYToolsSubset.VectorizeMatrix(imgInput)
        self.imgMatch_3R = PYToolsSubset.VectorizeMatrix(imgMatch)
        self.imgTarget_3R = PYToolsSubset.VectorizeMatrix(imgTarget)

        self.linearIndicesInWindow_NP = None
        self.sparseLinearIndices_pattern_row = None

    @staticmethod
    def fastApplyTransforms(imgInput, localTransform_array3CR):
        C = localTransform_array3CR.shape[1]
        R = localTransform_array3CR.shape[2]
        assert (imgInput.shape[0] * imgInput.shape[1]) == R

        imgInput_31R = np.reshape(np.transpose(np.reshape(imgInput, (R, 3))), (3, 1, R))
        if C == 4:
            imgInput_C1R = np.vstack((imgInput_31R, np.ones((1, 1, R))))
        else:
            imgInput_C1R = imgInput_31R

        imgOutput_31R = localTransform_array3CR @ imgInput_C1R
        imgOutput = np.transpose(np.reshape(np.transpose(imgOutput_31R), (3, R)), (1, 0)).reshape(imgInput.shape)

        return imgOutput

    @staticmethod
    def ApplyPrecomputedTransform(imgInput, localTransform_array3CR, patchWidth):
        if patchWidth == 1:
            imgOutput = fastApplyTransforms(imgInput, localTransform_array3CR) 
        else:
            localTransferRec = LocalTransfer_affineModel(imgInput, None, None)
            localTransferRec.gatherSquarePatches(patchWidth)
            
            # If cached transform is diagonal or linear, extend it to affine
            if localTransform_array3CR.shape[1] < 4:
                localTransform_array3CR[:,3,:] = 0
            
            # Get the linear indices of the patch centers
            patchCenters_linearIndex_1P = \
                localTransferRec.linearIndicesInWindow_NP[(patchWidth**2)//2,:]
            
            # Find the per-patch transform using the patch centers    
            Ak_perPatch_34P = localTransform_array3CR[:,:,patchCenters_linearIndex_1P]
            
            # Apply the transforms on the input image
            localTransferRec.estimateOutput_givenLocalTransforms(Ak_perPatch_34P)
            
            imgOutput = localTransferRec.imgOutput_3R.T.reshape(localTransferRec.imgSize)
        
        return imgOutput

    @abstractmethod
    def computeGlobalTransform(self):
        pass

    @abstractmethod
    def initializeLocalTransforms(self):
        pass

    @abstractmethod
    def estimateLocalTransforms_givenOutput(self):
        pass

    @abstractmethod
    def buildClosedFormMatrices(self):
        pass
    
    @abstractmethod
    def estimateInvBk_perPatch(self):
        pass

    @abstractmethod
    def estimateOutput_givenLocalTransforms(self):
        pass

    def computeError(self):
        # IMPORTANT: This method assumes that the input and match images
        # are identical. It computes the error between the transformed input
        # (or transformed match) and the target image.

        # Compute the per-pixel squared error
        perPixelError_1R = np.sum((self.imgTarget_3R - self.imgOutput_3R) ** 2, axis=1)

        perPixelError_1R = np.sqrt(perPixelError_1R)

        # Sum the error over all pixels of every patch; store the error at the
        # center pixel.
        N = self.linearIndicesInWindow_NP.shape[0]
        P = self.linearIndicesInWindow_NP.shape[1]
        centerPixel_linearIndex_1P = self.linearIndicesInWindow_NP[math.ceil(N / 2) - 1, :]
        perPixelError_NP = np.reshape(perPixelError_1R[self.linearIndicesInWindow_NP], (N, P))
        perPatchError_1R = np.full_like(perPixelError_1R, np.nan)
        perPatchError_1R[centerPixel_linearIndex_1P] = np.sum(perPixelError_NP, axis=0)

        return perPixelError_1R, perPatchError_1R

    def gatherSquarePatches(self, patchWidth):
        halfK = patchWidth // 2
        imageWidth = self.imgSize[1]
        imageHeight = self.imgSize[0]

        # Get the coordinate of the center pixel of all patches in the images,
        # except near the borders
        centerPixel_x = np.tile(np.arange(1 + halfK, imageWidth - halfK), (imageHeight - 2 * halfK, 1))
        centerPixel_y = np.tile(np.arange(1 + halfK, imageHeight - halfK)[:, np.newaxis], (1, imageWidth - 2 * halfK))
        # centerPixel_linearIndices = np.ravel_multi_index((centerPixel_y, centerPixel_x), (imageHeight, imageWidth))
        centerPixel_linearIndices = PYToolsSubset.Sub2ind_fast2([imageHeight, imageWidth], centerPixel_y, centerPixel_x)

        # Get the linear index offsets of all neighbors in a window
        neighbors_offset_col = np.tile(np.arange(-halfK, halfK + 1), (patchWidth, 1))
        neighbors_offset_row = neighbors_offset_col.T
        neighbors_offset_linearIndices = neighbors_offset_col * imageHeight + neighbors_offset_row

        # Get the linear index of each neighbor in each patch
        N = neighbors_offset_linearIndices.size
        patchesLinearIndices_NP = centerPixel_linearIndices.reshape(-1, 1) + neighbors_offset_linearIndices.ravel()

        # Discard patches that contain at least one NaN pixel
        badPatchIndices = findBadPatchIndices(self.imgInput_3R, patchesLinearIndices_NP)
        if self.imgMatch_3R is not None:
            badPatchIndices = np.concatenate((badPatchIndices, findBadPatchIndices(self.imgMatch_3R, patchesLinearIndices_NP)))
        if self.imgTarget_3R is not None:
            badPatchIndices = np.concatenate((badPatchIndices, findBadPatchIndices(self.imgTarget_3R, patchesLinearIndices_NP)))
        patchesLinearIndices_NP = np.delete(patchesLinearIndices_NP, badPatchIndices, axis=1)

        self.linearIndicesInWindow_NP = patchesLinearIndices_NP
        self.sparseLinearIndices_pattern_row = np.tile(np.arange(N), (N, 1))  # used for computing neighborhoods in the adjacency matrices later


    def findBadPatchIndices(img_3R, patchesLinearIndices_NP):
        N, P = patchesLinearIndices_NP.shape

        pixelValues_3NP = img_3R[:, patchesLinearIndices_NP.reshape(-1)]

        nbNaNValuesPerPatch_1P = np.sum(np.isnan(pixelValues_3NP.reshape(3, N, P)), axis=0)

        return np.nonzero(nbNaNValuesPerPatch_1P > 0)[0]

    def perPixelLocalTransforms(self, Ak_perPatch_3CP):
        N = self.linearIndicesInWindow_NP.shape[0]  # number of pixels per patch
        R = self.imgInput_3R.shape[1]  # number of pixels in image
        C = Ak_perPatch_3CP.shape[1]

        # Create C images, where each pixel shows some part of the local
        # transform estimated for the patch centered on this pixel
        localTransform_array3CR = np.empty((3, C, R))
        localTransform_array3CR[:] = np.nan
        patchCenterLinearIndex_1P = self.linearIndicesInWindow_NP[N // 2, :]
        localTransform_array3CR[:, :, patchCenterLinearIndex_1P] = Ak_perPatch_3CP

        imgMask_HW = None
        if (nargout > 1):  # FIND OUT nargout!!!!
            isTransformNaN_1R = np.sum(np.isnan(localTransform_array3CR.reshape(3*C, R)), axis=0) > 0
            imgMask_HW = np.reshape(~isTransformNaN_1R, self.imgSize)

        return localTransform_array3CR, imgMask_HW

    def perPixelTransformError(self, localTransform_array3CR):
        N = self.linearIndicesInWindow_NP.shape[0]  # number of pixels per patch
        P = self.linearIndicesInWindow_NP.shape[1]  # number of patches
        R = self.imgInput_3R.shape[1]  # total number of pixels
        C = localTransform_array3CR.shape[1]

        # Gather matrix vk(M) and vk(T) for each patch
        if C == 3:
            M_CNP = np.reshape(self.imgMatch_3R[:, self.linearIndicesInWindow_NP.flatten()],
                            (3, *self.linearIndicesInWindow_NP.shape))
        elif C == 4:
            M_CNP = np.reshape(self.imgMatch_4R[:, self.linearIndicesInWindow_NP.flatten()],
                            (4, *self.linearIndicesInWindow_NP.shape))
        T_3NP = np.reshape(self.imgTarget_3R[:, self.linearIndicesInWindow_NP.flatten()],
                        (3, *self.linearIndicesInWindow_NP.shape))

        # Get linear index of the center of each pixel
        centerPixel_linearIndex_1P = self.linearIndicesInWindow_NP[np.ceil(N / 2).astype(int) - 1, :]
        Ak_perPatch_3CP = localTransform_array3CR[:, :, centerPixel_linearIndex_1P]

        # Apply the local transforms on each neighborhood
        transformedM_3NP = np.matmul(Ak_perPatch_3CP, M_CNP)

        # Sum the error over all pixels of every patch; store the error at the
        # center pixel.
        perTransformError_1P = np.sum(np.reshape((T_3NP - transformedM_3NP) ** 2, (-1, P)), axis=0)
        perTransformError_1R = np.full((1, R), np.nan)
        perTransformError_1R[:, centerPixel_linearIndex_1P] = perTransformError_1P

        return perTransformError_1R

    def transfer_closedForm(self):
        # Parameters
        epsilon = 1
        gamma = 0.01

        # Compute global transform, and globally transformed image
        globalTransform_3C = self.computeGlobalTransform()

        # Construct system matrix and right side
        M_sparseRR, u_3R, invBk_perPatch_CCP = self.buildClosedFormMatrices(globalTransform_3C, epsilon, gamma)

        # Make sure matrix M is symmetric
        # (sometimes there are some inaccuracies, and M-M'=1e-14 in some cells)
        # by replacing the lower triangle by the transpose of the upper triangle
        M_sparseRR = triu(M_sparseRR) + triu(M_sparseRR, 1).transpose()

        # Solve linear system
        self.imgOutput_3R = np.linalg.solve(M_sparseRR, u_3R)

        # Output transformed image
        imgOutput = self.imgOutput_3R.transpose().reshape(self.imgSize)

        # Estimate the best-matching local transforms
        Ak_perPatch_3CP = self.estimateLocalTransforms_givenOutput(invBk_perPatch_CCP, globalTransform_3C, epsilon, gamma)

        # Output the per-pixel local transform and corresponding binary mask
        localTransform_array3CR, imgMask_HW = self.perPixelLocalTransforms(Ak_perPatch_3CP)

        return imgOutput, localTransform_array3CR, imgMask_HW

    def transfer_iterative(self, nbIterations):
        # Parameters
        epsilon = 1
        gamma = 0.01

        # Initialize lists that will contain per-iteration results
        savePerIterationResults = (nargout > 3)
        if savePerIterationResults:
            imgOutput_perIteration = []
            localTransform_array3CR_perIteration = []

        # Compute global transform, and globally transformed image
        globalTransform_3C = self.computeGlobalTransform()

        # Precompute Bk^-1 for each patch
        invBk_perPatch_CCP = self.estimateInvBk_perPatch(epsilon, gamma)

        for it in range(1, nbIterations + 1):

            if (it == 1):
                # Initialize local transforms
                Ak_perPatch_3CP = self.initializeLocalTransforms(
                    globalTransform_3C, epsilon, gamma)
            else:
                # Re-estimate the local transforms, given the new output
                Ak_perPatch_3CP = self.estimateLocalTransforms_givenOutput(
                    invBk_perPatch_CCP, globalTransform_3C, epsilon, gamma)

            # Apply transform on each patch to estimate the output image
            self.estimateOutput_givenLocalTransforms(Ak_perPatch_3CP)

            if (savePerIterationResults or (it == nbIterations)):

                # Output transformed image
                imgOutput = self.imgOutput_3R.T.reshape(self.imgSize)

                # Output the per-pixel local transform and corresponding binary mask
                localTransform_array3CR, imgMask_HW = self.perPixelLocalTransforms(
                    Ak_perPatch_3CP)

                # Also save the per-iteration results
                if savePerIterationResults:
                    imgOutput_perIteration.append(imgOutput)
                    localTransform_array3CR_perIteration.append(localTransform_array3CR)

        return (imgOutput, localTransform_array3CR, imgMask_HW, imgOutput_perIteration, localTransform_array3CR_perIteration)


In [6]:
class LocalTransfer_affineModel(LocalTransfer):
    def __init__(self, imgInput, imgMatch, imgTarget):
        super().__init__(imgInput, imgMatch, imgTarget)

        R = self.imgSize[0] * self.imgSize[1]
        self.imgInput_4R = np.vstack((self.imgInput_3R, np.ones((1, R))))
        self.imgMatch_4R = np.vstack((self.imgMatch_3R, np.ones((1, R))))
        self.imgTarget_4R = np.vstack((self.imgTarget_3R, np.ones((1, R))))

    def buildClosedFormMatrices(self, G_34, epsilon, gamma):
        N = self.linearIndicesInWindow_NP.shape[0]  # number of pixels per patch
        P = self.linearIndicesInWindow_NP.shape[1]  # number of patches
        R = self.imgInput_3R.shape[1]  # total number of pixels

        # Gather matrices vk_bar(I), vk_bar(M), vk(T), for each patch
        Iban_4NP = self.imgInput_4R[:, self.linearIndicesInWindow_NP.flatten()].reshape(4, N, P)
        Mar_4NP = self.imgMatch_4R[:, self.linearIndicesInWindow_NP.flatten()].reshape(4, N, P)
        T_3NP = self.imgTarget_3R[:, self.linearIndicesInWindow_NP.flatten()].reshape(3, N, P)

        # Compute inverse matrix Bk at each patch
        invBk_perPatch_44P = self.estimateInvBk_perPatch(epsilon, gamma)
        Bk_44P = np.linalg.inv(invBk_perPatch_44P)

        # We obtain the matrix values for each patch by:
        # sparseMatrix_vals_NN = eye(N) - Ibar' * Bk * Ibar;
        Bk_Iban_4NP = np.matmul(Bk_44P, Iban_4NP)
        Iban_t_Bk_Iban_NNP = np.matmul(Iban_4NP.transpose(0, 2, 1), Bk_Iban_4NP)
        sparseMatrix_vals_NNP = np.eye(N) - Iban_t_Bk_Iban_NNP

        # We obtain the right-side vector elements for each patch by:
        # rightSide_vals_3N = (epsilon * T*Mbar' + gamma*G) * Bk * Ibar;
        T_Mar_t_34P = np.matmul(T_3NP, Mar_4NP.transpose(0, 2, 1))
        rightSide_vals_3NP = np.matmul((epsilon * T_Mar_t_34P + gamma * G_34), Bk_Iban_4NP)

        # Store the sparse matrix triplets
        sparseMatrix_rows_NNP = self.linearIndicesInWindow_NP[self.sparseLinearIndices_pattern_row, :, :]
        sparseMatrix_cols_NNP = self.linearIndicesInWindow_NP[self.sparseLinearIndices_pattern_row[:, np.newaxis], :, :]
        sparseMatrix_rows_NP = np.repeat(sparseMatrix_rows_NNP, N, axis=1)
        sparseMatrix_cols_NP = np.tile(sparseMatrix_cols_NNP, (N, 1, 1))
        sparseMatrix_vals_NP = sparseMatrix_vals_NNP.transpose(2, 0, 1).flatten()

        # Assemble the final sparse matrix
        M_sparseRR = coo_matrix((sparseMatrix_vals_NP, (sparseMatrix_rows_NP.flatten(), sparseMatrix_cols_NP.flatten())),
                                shape=(R, R)).tocsr()

        # Assemble the right side
        # TODO: replace this with an accumarray
        u_3R = np.zeros((3, R))
        for k in range(P):
            u_3R[:, self.linearIndicesInWindow_NP[:, k]] += rightSide_vals_3NP[:, :, k]

        return M_sparseRR, u_3R, invBk_perPatch_44P

    def computeGlobalTransform(self):
        R = self.imgMatch_3R.shape[1]

        nonNaNPixels = np.unique(self.linearIndicesInWindow_NP)
        M_4R = self.imgMatch_4R[:, nonNaNPixels]
        T_3R = self.imgTarget_3R[:, nonNaNPixels]

        G_34 = np.dot(T_3R, np.linalg.pinv(M_4R))

        if (nargout > 1):
            imgTransformed_3R = np.dot(G_34, np.vstack([self.imgInput_3R, np.ones((1, R))]))
            return (G_34, imgTransformed_3R)

        return G_34

    def estimateInvBk_perPatch(self, epsilon, gamma):
        # Gather matrices vk_bar(I), vk_bar(M), vk(T), for each patch
        Ibar_4NP = self.imgInput_4R[:, self.linearIndicesInWindow_NP].reshape((4, -1))
        Mbar_4NP = self.imgMatch_4R[:, self.linearIndicesInWindow_NP].reshape((4, -1))

        # We obtain (Bk)^-1 at each patch by vectorizing:
        # invBk_44 = Ibar*Ibar' + epsilon*(Mbar*Mbar') + gamma*eye(4);
        Ibar_Ibar_t_44P = np.dot(Ibar_4NP, Ibar_4NP.T)
        Mbar_Mbar_t_44P = np.dot(Mbar_4NP, Mbar_4NP.T)
        invBk_perPatch_44P = Ibar_Ibar_t_44P + epsilon * Mbar_Mbar_t_44P + gamma * np.eye(4)

        return invBk_perPatch_44P

    def estimateLocalTransforms_givenOutput(self, invBk_perPatch_44P, G_34, epsilon, gamma):
        P = self.linearIndicesInWindow_NP.shape[1]  # number of patches

        # Gather matrices vk_bar(I), vk_bar(M), vk(T), vk(O), for each patch
        Ibar_4NP = self.imgInput_4R[:, self.linearIndicesInWindow_NP].reshape((4, -1))
        Mbar_4NP = self.imgMatch_4R[:, self.linearIndicesInWindow_NP].reshape((4, -1))
        T_3NP = self.imgTarget_3R[:, self.linearIndicesInWindow_NP].reshape((3, -1))
        O_3NP = self.imgOutput_3R[:, self.linearIndicesInWindow_NP].reshape((3, -1))

        # Compute inverse matrix Bk at each patch
        Bk_44P = np.zeros_like(invBk_perPatch_44P)
        for k in range(P):
            Bk_44P[:, :, k] = np.linalg.inv(invBk_perPatch_44P[:, :, k])

        # We obtain Ak at each patch by vectorizing:
        # Ak_34 = (O*Ibar' + epsilon*T*Mbar' + gamma*G) * Bk;
        O_Ibar_t_34P = np.dot(O_3NP, Ibar_4NP.T)
        T_Mbar_t_34P = np.dot(T_3NP, Mbar_4NP.T)
        leftSide_34P = gamma * G_34 + O_Ibar_t_34P + epsilon * T_Mbar_t_34P
        Ak_perPatch_34P = np.dot(leftSide_34P, Bk_44P)

        return Ak_perPatch_34P

    def initializeLocalTransforms(self, G_34, epsilon, gamma):
        P = self.linearIndicesInWindow_NP.shape[1]

        # Gather matrices vk_bar(M), vk(T), for each patch
        Mbar_4NP = self.imgMatch_4R[:, self.linearIndicesInWindow_NP].reshape((4, -1))
        T_3NP = self.imgTarget_3R[:, self.linearIndicesInWindow_NP].reshape((3, -1))

        # We obtain Ak at each patch by vectorizing:
        # Ak_34 = (epsilon * (T*Mbar') + gamma*G) * ...
        #   inv(epsilon * (Mbar*Mbar') + gamma*eye(4));
        Mbar_Mbar_t_44P = np.dot(Mbar_4NP, Mbar_4NP.T)
        T_Mbar_t_34P = np.dot(T_3NP, Mbar_4NP.T)
        leftSide_34P = gamma * G_34 + epsilon * T_Mbar_t_34P
        invRightSide_44P = gamma * np.eye(4) + epsilon * Mbar_Mbar_t_44P

        # Inverse the right side for each patch
        rightSide_44P = np.zeros_like(invRightSide_44P)
        for k in range(P):
            rightSide_44P[:, :, k] = np.linalg.inv(invRightSide_44P[:, :, k])

        Ak_perPatch_34P = np.dot(leftSide_34P, rightSide_44P)

        return Ak_perPatch_34P



In [7]:
# Set method parameters
scale_factor = 0.5
use_model = 'affine'
use_closed_form = 1
nb_iterations = 3
patch_width = 5

# Load example input data and create synthetic match/target images
imgA = resize(plt.imread('imageA.jpg'), (0, 0), scale_factor, anti_aliasing=True)
imgB = resize(plt.imread('imageB.jpg'), (0, 0), scale_factor, anti_aliasing=True)

imgA = np.clip(imgA, 0, 1)
imgB = np.clip(imgB, 0, 1)

imgInput = imgA
imgMatch = imgA

imgTarget = np.multiply(imgA, np.reshape([0.9, 0.8, 0.7], (1, 1, 3)))
imgTarget[:, :imgTarget.shape[1]//2, :] = -0.3 + np.multiply(imgTarget[:, :imgTarget.shape[1]//2, :], np.reshape([1.5, 1.3, 1], (1, 1, 3)))

# Apply the color transfer
if use_model == 'affine':
    localTransfer = LocalTransfer_affineModel(imgInput, imgMatch, imgTarget)
elif use_model == 'linear':
    localTransfer = LocalTransfer_linearModel(imgInput, imgMatch, imgTarget)
elif use_model == 'diagonal':
    localTransfer = LocalTransfer_diagonalModel(imgInput, imgMatch, imgTarget)

# Gather patches
localTransfer.gatherSquarePatches(patch_width)

# Apply transfer
if use_closed_form:
    imgOutput, localTransform_array3CR, imgMask_HW = localTransfer.transfer_closedForm()
else:
    imgOutput, localTransform_array3CR, imgMask_HW = localTransfer.transfer_iterative(nb_iterations)

# Compute scaled error image
imgDiff = 10 * np.abs(imgTarget - imgOutput)

# Display the results
fig, axs = plt.subplots(2, 3, figsize=(10, 8))
fig.suptitle('Input and results')
axs[0, 0].imshow(imgMatch)
axs[0, 0].set_title('Match image')
axs[0, 1].imshow(imgInput)
axs[0, 1].set_title('Input image')
axs[0, 2].imshow(imgTarget)
axs[0, 2].set_title('Target image')
axs[1, 0].imshow(imgOutput)
axs[1, 0].set_title('Output image with local transforms')
axs[1, 1].imshow(imgDiff)
axs[1, 1].set_title('Error')
axs[1, 2].imshow(imgMask_HW)
axs[1, 2].set_title('Mask')
plt.show()

  factors = np.divide(input_shape, output_shape)


OverflowError: cannot convert float infinity to integer