In [1]:
import skimage as sk
from skimage import io, util
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import math, sys, os
from residual_model_resdnet import *
from MMNet_TBPTT import *
import glob
from subprocess import call
from tqdm import *

In [2]:
def rescale_to_255f(img):
    dtype = img.dtype
    if dtype == np.uint16:
        img = img / 2**16 * 2**8
    elif dtype == np.uint8:
        img = img.astype(np.float32)
    return img

# Define HyperParameters

In [3]:
args_noise_estimation = True # wheter to estimate noise or not
args_init = True # wheter to initialize the input with bilinear
args_max_iter = 20 # maximum number of iterations
args_use_gpu = True 
block_size = (128, 128)
args_model = 'results_experimentation/msr_bilinear_noisy_iter20_bugfix_2ndrun/' # model path

# Load Model

In [4]:
model_params = torch.load(args_model+'model_best.pth')

model = ResNet_Den(BasicBlock, model_params[2], weightnorm=True)
mmnet = MMNet(model, max_iter=model_params[1])
for param in mmnet.parameters():
    param.requires_grad = False

mmnet.load_state_dict(model_params[0])
if args_use_gpu:
    mmnet = mmnet.cuda()

# Process Images

In [None]:
# Define folder with RAW images
img_folder = 'real_images_new/'
filepaths_img = glob.glob(img_folder+'*')

for img_path in tqdm(filepaths_img):
    try:
        print('Processing ', img_path)
        call(["dcraw","-d","-T","-6","-W",img_path])
        img_path = img_path.split(".")
        img_path[-1] = '.tiff'
        img = io.imread("".join(img_path))
        res = rescale_to_255f(img)
        if res.shape[0] % block_size[0] != 0:
            mod = block_size[0]- res.shape[0] % block_size[0]
            res = np.pad(res, ((0,mod),(0,0)), 'constant')

        if res.shape[1] % block_size[1] != 0:
            mod = block_size[1]- res.shape[1] % block_size[1]
            res = np.pad(res, ((0,0), (0,mod)), 'constant')
        blocks = util.view_as_blocks(res, block_shape=(128, 128))

        def process_patch(patch):
            with torch.no_grad():
                mmnet.eval()
                mosaic = torch.FloatTensor(patch).float()[None]
                r_mask = np.zeros(patch.shape)
                r_mask[0::2, 0::2] = 1

                g_mask = np.zeros(patch.shape)
                g_mask[::2, 1::2] = 1
                g_mask[1::2, ::2] = 1

                b_mask = np.zeros(patch.shape)
                b_mask[1::2, 1::2] = 1
                mask = np.zeros(patch.shape +(3,))
                mask[:, :, 0] = r_mask
                mask[:, :, 1] = g_mask
                mask[:, :, 2] = b_mask
                M = torch.FloatTensor(mask)[None]
                mosaic = mosaic[...,None]*M
                mosaic = mosaic.permute(0,3,1,2)
                M = M.permute(0,3,1,2)

                p = Demosaic(mosaic.float(), M.float())
                if args_use_gpu:
                    p.cuda_()

                xcur = mmnet.forward_all_iter(p, max_iter=args_max_iter, init=args_init, noise_estimation=args_noise_estimation)
                return xcur[0].cpu().data.permute(1,2,0).numpy()

        block_size = blocks.shape[-2:]
        num_blocks = blocks.shape[:2]
        original_size = (blocks.shape[0] * blocks.shape[2], blocks.shape[1] * blocks.shape[3])
        final_img = np.zeros((original_size[0], original_size[1],3), dtype=np.float32)
        for i in range(num_blocks[0]):
            for j in range(num_blocks[1]):
                patch_result = process_patch(blocks[i,j])
                final_img[i*block_size[0]:(i+1)*block_size[0], j*block_size[1]:(j+1)*block_size[1]] = patch_result
        final_img = final_img.astype(np.uint8)
        final_img = final_img[:img.shape[0],:img.shape[1]]
        call(["rm","".join(img_path)])
        img_path[-1] = '.jpeg'
        io.imsave("".join(img_path),final_img)
    except e as Exception:
        print(e)
plt.imshow(img)

  0%|          | 0/57 [00:00<?, ?it/s]

Processing  real_images_new/r3092a7e7t.NEF


  2%|▏         | 1/57 [00:48<45:31, 48.77s/it]

Processing  real_images_new/r3378d41et.NEF


  4%|▎         | 2/57 [01:51<51:09, 55.81s/it]

Processing  real_images_new/r3201c080t.NEF


  5%|▌         | 3/57 [02:54<52:18, 58.13s/it]

Processing  real_images_new/r3647a976t.NEF


  7%|▋         | 4/57 [03:43<49:15, 55.76s/it]

Processing  real_images_new/r3639e3f4t.NEF


  9%|▉         | 5/57 [04:31<47:07, 54.38s/it]

Processing  real_images_new/r3206d2cdt.NEF


 11%|█         | 6/57 [05:35<47:27, 55.84s/it]

Processing  real_images_new/r3910a007t.NEF


 12%|█▏        | 7/57 [06:38<47:25, 56.92s/it]

Processing  real_images_new/r3644d097t.NEF


 14%|█▍        | 8/57 [07:41<47:05, 57.66s/it]

Processing  real_images_new/r3333cbf6t.NEF


 16%|█▌        | 9/57 [08:29<45:18, 56.64s/it]

Processing  real_images_new/r3335cb09t.NEF


 18%|█▊        | 10/57 [09:32<44:52, 57.29s/it]

Processing  real_images_new/r3463f166t.NEF


 19%|█▉        | 11/57 [10:36<44:19, 57.82s/it]

Processing  real_images_new/r3972e5dbt.NEF


 21%|██        | 12/57 [11:39<43:41, 58.25s/it]

Processing  real_images_new/r3724e50ct.NEF


 23%|██▎       | 13/57 [12:41<42:58, 58.61s/it]

Processing  real_images_new/r3602d116t.NEF


 25%|██▍       | 14/57 [13:44<42:13, 58.92s/it]

Processing  real_images_new/r3197ab81t.NEF


 26%|██▋       | 15/57 [14:33<40:45, 58.24s/it]

Processing  real_images_new/r3615b30dt.NEF
