In [None]:
import skimage as sk
from skimage import io, util
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import math, sys, os
from residual_model_resdnet import *
from MMNet_TBPTT import *
import glob
from subprocess import call, check_output
from tqdm import *
from subprocess import Popen, PIPE

In [None]:
def rescale_to_255f(img):
    dtype = img.dtype
    if dtype == np.uint16:
        img = img / 2**16 * 2**8
    elif dtype == np.uint8:
        img = img.astype(np.float32)
    return img

def check_pattern(img_path):
    p = Popen(['dcraw','-i','-v',img_path], stdin=PIPE, stdout=PIPE, stderr=PIPE)
    output, err = p.communicate()
    output = str(output)
    cfa = output.split('Filter pattern:')[1][1:5]
    if 'RGGB' in cfa or 'Fujifilm' in output:
        rollx , rolly = 0, 0
    elif 'GBRG' in cfa:
        rollx , rolly = -1, 0
    elif 'GRBG' in cfa:
        rollx , rolly = 0, -1
    else:
        raise NotImplementedError
    return rollx , rolly


def linrgb_to_srgb(img):
    """ Convert linRGB color space to sRGB 
        https://en.wikipedia.org/wiki/SRGB
    """
    assert img.dtype in [np.float32, np.float64] 
    img = img.copy()
    mask = img < 0.0031308
    img[~mask] = (img[~mask]**(1/2.4))*(1.055) - 0.055
    img[mask] = img[mask] * 12.92
    return img


# Define HyperParameters

In [None]:
args_noise_estimation = True # wheter to estimate noise or not
args_init = True # wheter to initialize the input with bilinear
args_use_gpu = True 
args_block_size = (512, 512)
args_model = 'pretrained_models/bayer_noisy/' # model path
# Define folder with RAW images
args_img_folder = '/home/datasets/raise/' # folder of RAW images
args_output_folder = 'output/' # save results to folder
args_type = '.png' # image type to save as
if 'xtrans' in args_model:
    args_pattern = 'xtrans'
else:
    args_pattern = 'RGGB'

# Load Model

In [None]:
model_params = torch.load(args_model+'model_best.pth')

model = ResNet_Den(BasicBlock, model_params[2], weightnorm=True)
mmnet = MMNet(model, max_iter=model_params[1])
for param in mmnet.parameters():
    param.requires_grad = False

mmnet.load_state_dict(model_params[0])
if args_use_gpu:
    mmnet = mmnet.cuda()
    

# Process Images

In [None]:
if not os.path.exists(args_output_folder):
    os.makedirs(args_output_folder)
    
filepaths_img = glob.glob(args_img_folder+'*')
filepaths_img.sort()
filepaths_img = np.random.choice(filepaths_img, 50, replace=False)
cnt = 0
for img_path in tqdm(filepaths_img):
    try:
        if cnt > 50: 
            break
        else:
            cnt += 1
        print('Processing ', img_path)
        call(["dcraw","-j","-d","-T","-4","-w", "+M", img_path])
        # convert to RGGB CFA 
        rollx, rolly = check_pattern(img_path)
        img_path = img_path.split(".")
        img_path[-1] = '.tiff'
        img_path = "".join(img_path)
        img = io.imread(img_path)

        img = np.roll(img, rollx,rolly)
        res = rescale_to_255f(img)
        
        # pad according to block size
        if res.shape[0] % args_block_size[0] != 0:
            mod = args_block_size[0]- res.shape[0] % args_block_size[0]
            res = np.pad(res, ((0,mod),(0,0)), 'constant')

        if res.shape[1] % args_block_size[1] != 0:
            mod = args_block_size[1]- res.shape[1] % args_block_size[1]
            res = np.pad(res, ((0,0), (0,mod)), 'constant')
        blocks = util.view_as_blocks(res, block_shape=args_block_size)

        def process_patch(patch):
            with torch.no_grad():
                mmnet.eval()
                mosaic = torch.FloatTensor(patch).float()[None]
                # padding in order to ensure no boundary artifacts
                mosaic = F.pad(mosaic[:,None],(8,8,8,8),'reflect')[:,0]
                shape = mosaic[0].shape
                mask = utils.generate_mask(shape, pattern=args_pattern)
                M = torch.FloatTensor(mask)[None]
                mosaic = mosaic[...,None]*M

                mosaic = mosaic.permute(0,3,1,2)
                M = M.permute(0,3,1,2)

                p = Demosaic(mosaic.float(), M.float())
                if args_use_gpu:
                    p.cuda_()
                xcur = mmnet.forward_all_iter(p, max_iter=mmnet.max_iter, init=args_init, noise_estimation=args_noise_estimation)

                return xcur[0].cpu().data.permute(1,2,0).numpy()[8:-8,8:-8]

        # demosaick image
        block_size = blocks.shape[-2:]
        num_blocks = blocks.shape[:2]
        original_size = (blocks.shape[0] * blocks.shape[2], blocks.shape[1] * blocks.shape[3])
        final_img = np.zeros((original_size[0], original_size[1],3), dtype=np.float32)
        for i in range(num_blocks[0]):
            for j in range(num_blocks[1]):
                patch_result = process_patch(blocks[i,j])
                final_img[i*block_size[0]:(i+1)*block_size[0], j*block_size[1]:(j+1)*block_size[1]] = patch_result
        
        final_img = final_img[:img.shape[0],:img.shape[1]]
        final_img = np.roll(final_img, -rollx, -rolly)
        # remove intermmidiate .tiff image
        call(["rm", img_path])
        img_path = img_path.replace(args_img_folder, args_output_folder)
        # save the linRGB image
        io.imsave(img_path.replace('.tiff', args_type),final_img.astype(np.uint8))
        # save the sRGB image
        srgb = linrgb_to_srgb(final_img/255)
        io.imsave(img_path.replace('.tiff', '_srgb'+args_type),srgb.clip(0,1))
    except Exception as e:
        print(e)