In [1]:
import os
os.environ["OMP_PROC_BIND"] = os.environ.get("OMP_PROC_BIND", "true")

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn as nn
import scipy as sp
from scipy.interpolate import RectBivariateSpline
import torch
import torch.nn.functional as F
import scipy.misc

#from pix_transform.pix_transform import PixTransform
#from baselines.baselines import bicubic
#from utils.utils import downsample,align_images
##from prox_tv import tvgen
#from utils.plots import plot_result

####  load dataset  #############################################################
data_path = "./data/depth_sample_images.npz"

dataset = np.load(data_path)
target_imgs = dataset["target_imgs"].squeeze()
guide_imgs =  dataset["guide_imgs"].squeeze()
dataset.close()

In [None]:
##Functions




class PixTransformNet(nn.Module):

    def __init__(self, channels_in=5, kernel_size = 1,weights_regularizer = None):
        super(PixTransformNet, self).__init__()

        self.channels_in = channels_in
        
        self.spatial_net = nn.Sequential(nn.Conv2d(2,32,(1,1),padding=0),
                                         nn.ReLU(),nn.Conv2d(32,2048,(kernel_size,kernel_size),padding=(kernel_size-1)//2))
        self.color_net = nn.Sequential(nn.Conv2d(channels_in-2,32,(1,1),padding=0),
                                       nn.ReLU(),nn.Conv2d(32,2048,(kernel_size,kernel_size),padding=(kernel_size-1)//2))
        self.head_net = nn.Sequential(nn.ReLU(),nn.Conv2d(2048, 32, (kernel_size,kernel_size),padding=(kernel_size-1)//2),
                                      nn.ReLU(),nn.Conv2d(32, 1, (1, 1),padding=0))

        if weights_regularizer is None:
            reg_spatial = 0.0001
            reg_color = 0.001
            reg_head = 0.0001
        else:
            reg_spatial = weights_regularizer[0]
            reg_color = weights_regularizer[1]
            reg_head = weights_regularizer[2]
        
        self.params_with_regularizer = []
        self.params_with_regularizer += [{'params':self.spatial_net.parameters(),'weight_decay':reg_spatial}]
        self.params_with_regularizer += [{'params':self.color_net.parameters(),'weight_decay':reg_color}]
        self.params_with_regularizer += [{'params':self.head_net.parameters(),'weight_decay':reg_head}]


    def forward(self, input):

        input_spatial = input[:,self.channels_in-2:,:,:]
        input_color = input[:,0:self.channels_in-2,:,:]

        merged_features = self.spatial_net(input_spatial) + self.color_net(input_color)
        
        return self.head_net(merged_features)

In [None]:



def bicubic(source_img, scaling_factor):
    source_img_size = source_img.shape[0]
    x_or_y = np.array(list(range(0, int(source_img_size)))).astype(float)
    int_img = RectBivariateSpline(x_or_y, x_or_y, source_img)
    x_or_y_up = np.array(list(range(0, source_img_size * scaling_factor))).astype(float) / scaling_factor - 0.5

    x_grid, y_grid = np.meshgrid(x_or_y_up, x_or_y_up, indexing="ij")
    return int_img.ev(x_grid, y_grid)

In [None]:


def downsample(image, scaling_factor):

    image = F.avg_pool2d(torch.from_numpy(image).unsqueeze(0).unsqueeze(0).double(), scaling_factor)
    image = image.squeeze().numpy()

    return image


def align_images(img_s,img_t,limits=(-1.,1.),steps=25):

    image_size = img_t.shape[0]

    maximum_limit = int(np.ceil(np.max(np.abs(np.array(limits)))))
    mask = np.zeros_like(img_s)
    mask[maximum_limit:-maximum_limit,maximum_limit:-maximum_limit] = 1.

    x_or_y = np.array(list(range(0, int(image_size)))).astype(float)
    img_t_shifter = scipy.interpolate.RectBivariateSpline(x_or_y, x_or_y, img_t)

    delta = np.linspace(limits[0],limits[1],steps)
    mse_best = 1e9
    x_best = 0
    y_best = 0
    for i in range(0,steps):
        for j in range(0,steps):

            x_grid, y_grid = np.meshgrid(x_or_y + delta[i], x_or_y + delta[j], indexing="ij")
            img_t_shifted = img_t_shifter.ev(x_grid, y_grid)

            mse = np.mean((mask*(img_t_shifted-img_s))**2)
            if mse < mse_best:
                mse_best = mse
                x_best = delta[i]
                y_best = delta[j]

    x_grid, y_grid = np.meshgrid(x_or_y + x_best, x_or_y + y_best, indexing="ij")
    img_t_shifted = img_t_shifter.ev(x_grid, y_grid)

    img_t_shifted = img_t_shifted[maximum_limit:-maximum_limit,maximum_limit:-maximum_limit]
    img_s = img_s[maximum_limit:-maximum_limit,maximum_limit:-maximum_limit]

    return img_s,img_t_shifted


In [None]:



def plot_result(guide_img, input_img_nearest, output_img, bicubic_img, label_img=None, data_type="rgb", fig_size=(16, 4)):
    cmap = "Spectral"

    if len(guide_img.shape) > 2:

        guide_img = np.rollaxis(guide_img, 0, 3)

        if data_type == "sat":
            guide_img = (guide_img[:, :, [2, 1, 0]])

        elif data_type == "rgb":
            guide_img = (guide_img[:, :, [0, 1, 2]])

        else:
            guide_img = np.mean(guide_img, axis=2)

    guide_min = np.percentile(guide_img, 0.05, axis=(0, 1), keepdims=True)  # guide_img.min(axis=(1,2),keepdims=True)
    guide_max = np.percentile(guide_img, 99.95, axis=(0, 1), keepdims=True)  # guide_img.max(axis=(1,2),keepdims=True)
    guide_img = (guide_img - guide_min) / (guide_max - guide_min)
    guide_img = np.clip(guide_img, 0, 1)

    if label_img is not None:
        vmin = np.min(label_img)
        vmax = np.max(label_img)

        f, axarr = plt.subplots(1, 5, figsize=fig_size)

        if len(guide_img.shape) > 2:
            axarr[0].imshow(guide_img)
        else:
            axarr[0].imshow(guide_img, cmap="gray")

        axarr[1].imshow(input_img_nearest, vmin=vmin, vmax=vmax, cmap=cmap)

        axarr[2].imshow(label_img, vmin=vmin, vmax=vmax, cmap=cmap)

        axarr[3].imshow(output_img, vmin=vmin, vmax=vmax, cmap=cmap)

        axarr[4].imshow(bicubic_img, vmin=vmin, vmax=vmax, cmap=cmap)

        titles = ['Guide', 'Source', 'Target',
                  'Predicted Target (MSE {:.3f})'.format(np.mean((label_img - output_img) ** 2)), 'Bicubic (MSE {:.3f})'.format(np.mean((label_img - bicubic_img) ** 2))]
    else:
        vmin = np.min(input_img_nearest)
        vmax = np.max(input_img_nearest)

        f, axarr = plt.subplots(1, 4, figsize=fig_size)
        if len(guide_img.shape) > 2:
            axarr[0].imshow(guide_img)
        else:
            axarr[0].imshow(guide_img, cmap="gray")

        axarr[1].imshow(input_img_nearest, vmin=vmin, vmax=vmax, cmap=cmap)

        axarr[2].imshow(output_img, vmin=vmin, vmax=vmax, cmap=cmap)

        axarr[3].imshow(bicubic_img, vmin=vmin, vmax=vmax, cmap=cmap)

        titles = ['Guide', 'Source', 'Predicted Target', 'Bicubic (MSE {:.3f})'.format(np.mean((label_img - bicubic_img) ** 2))]

    for i, ax in enumerate(axarr):
        ax.set_axis_off()
        ax.set_title(titles[i])

    plt.tight_layout()
    return f, axarr

In [None]:
####  define parameters  ########################################################
params = {'img_idxs' : [], # idx images to process, if empty then all of them
            
          'scaling': 8,
          'greyscale': False, # Turn image into grey-scale
          'channels': -1,
          
          'spatial_features_input': True,
          'weights_regularizer': [0.0001, 0.001, 0.0001], # spatial color head
          'loss': 'l1',
 
          'optim': 'adam',
          'lr': 0.001,
                  
          'batch_size': 32,
          'iteration': 1024*32*32//32,
                  
          'logstep': 64,
          
          'final_TGV' : False, # Total Generalized Variation in post-processing
          'align': False, # Move image around for evaluation in case guide image and target image are not perfectly aligned
          'delta_PBP': 1, # Delta for percentage of bad pixels 
         }

In [None]:
if len(params['img_idxs'])==0:
    idxs = np.array(range(0,target_imgs.shape[0]))
else:
    idxs = params['img_idxs']

In [None]:
for n_image,idx in enumerate(idxs):
    
    print("####### image {}/{} - image idx {} ########".format(n_image+1,len(idxs),idx))
    
    guide_img = guide_imgs[idx]
    target_img = target_imgs[idx]
    source_img = downsample(target_img,params['scaling'])

    bicubic_target_img = bicubic(source_img=source_img, scaling_factor=params['scaling'])
    
    predicted_target_img = PixTransform(guide_img=guide_img,source_img=source_img,params=params,target_img=target_img)
    


    if params['final_TGV'] :
        print("applying TGV...")
        predicted_target_img = tvgen(predicted_target_img,[0.1, 0.1],[1, 2],[1, 1])
        
    if params['align'] :
        print("aligning...")
        target_img,predicted_target_img = align_images(target_img,predicted_target_img)

    
    f, ax = plot_result(guide_img,source_img,predicted_target_img,bicubic_target_img,target_img)
    plt.show()
    
    if target_img is not None:
        # compute metrics and plot results
        MSE = np.mean((predicted_target_img - target_img) ** 2)
        MAE = np.mean(np.abs(predicted_target_img - target_img))
        PBP = np.mean(np.abs(predicted_target_img - target_img) > params["delta_PBP"])

        print("MSE: {:.3f}  ---  MAE: {:.3f}  ---  PBP: {:.3f}".format(MSE,MAE,PBP))
        print("\n\n")