In [None]:
import os, glob, math
import numpy as np
import tensorflow as tf
from PIL import Image
import timeit

from myutils_tf import bwutils

In [None]:

def get_model(model_name, model_sig):
    base_path = os.path.join('model_dir', 'checkpoint')
    structure_path = os.path.join(base_path, model_name + '_model_structure.h5')
    ckpt_path = os.path.join(base_path, model_name + '_' + model_sig)
    print(structure_path, '\n', ckpt_path)


    # load model structure
    model = tf.keras.models.load_model(structure_path)

    # find latest weights and load
    ckpts = glob.glob(os.path.join(ckpt_path, '*.h5'))
    ckpts.sort()
    ckpt = ckpts[-1]
    model.load_weights(ckpt)

    print(ckpt)
    # model.summary()
    return model


In [None]:
model_name = 'bwunet'
model_sig = 'noise'

In [None]:
# get model
model = get_model(model_name, model_sig)
model.summary()

In [None]:
# cellsize
cfa_pattern = 'tetra'

In [None]:
# Path
PATH_VAL = '/Users/bw/Dataset/MIPI_demosaic_hybridevs/val/input'


cwd = os.getcwd()
print(cwd)
if '/content/drive/MyDrive' in cwd:
    PATH_VAL = '/content/drive/MyDrive/Datasets/MIPI_tetra_hybridenvs/valid/input'

    
# get file lists    
files = glob.glob(os.path.join(PATH_VAL, '*.npy'))
files.sort()

# utils for patternized input

In [None]:
pad_size = 32
patch_size = 128


# utils for patternized
utils = bwutils(input_type='nonshrink',
                cfa_pattern=cfa_pattern,
                patch_size=patch_size,
                crop_size=patch_size,
                input_max=1,
                use_unprocess=False,
                loss_type=['rgb'],
                loss_mode='2norm',
                loss_scale=1e4,
                cache_enable=False)


# Lets infer

In [None]:

# shape = np.load(files[0]).shape
# height, width, channels = np.load(files[0]).shape
# npatches_y, npatches_x = math.ceil(shape[0]/patch_size), math.ceil(shape[1]/patch_size)
# print(arr_pred.shape)
start = timeit.timeit()
infcnt = 0
for idx, file in enumerate(files):
    infcnt+=1
    arr = np.load(file)    # (0, 1023)
    arr = arr / (2**10 -1) # (0, 1)
    arr = arr * 2 -1       # (-1, 1)


    print('arr.shape', arr.shape)
    arr = np.pad(arr, ((pad_size, pad_size), (pad_size, pad_size)), 'symmetric')
    print('arr.shape', arr.shape)

    # break

    height, width = arr.shape
    npatches_y = math.ceil((height+2*pad_size) / (patch_size-2*pad_size))
    npatches_x = math.ceil((width +2*pad_size) / (patch_size-2*pad_size))


    # arr_pred = np.zeros_like(arr)
    # arr_pred = arr_pred[...,np.newaxis]
    arr_pred = np.zeros(arr.shape + (3,) )
    print(idx, file, arr.shape, arr_pred.shape)
    # exit()
    cnt=0
    tcnt= npatches_x*npatches_y
    

    for idx_y in range(npatches_y):
        for idx_x  in range(npatches_x):
            if(cnt%10==0):
                print(f'{cnt} / {tcnt}')
            cnt+=1
            sy = idx_y * (patch_size-2*pad_size)
            ey = sy + patch_size
            sx = idx_x * (patch_size-2*pad_size)
            ex = sx + patch_size

            if ey >= height:
                ey = height-1
                sy = height-patch_size-1

            if ex >= width:
                ex = width-1
                sx = width-patch_size-1

            arr_patch = arr[sy:ey, sx:ex]
            print("before:",np.amin(arr_patch), np.amax(arr_patch))
            arr_patch = utils.get_patternized_1ch_to_3ch_image(arr_patch)
            print("after :",np.amin(arr_patch), np.amax(arr_patch))
            
            # print(np.amin(arr_patch), np.amax(arr_patch) )
            # exit()
            # # pre-process # no gamma & bais for demosaic/remosaic
            # arr_patch = arr_patch**(1/2.2)

            # prediction
            pred = model.predict(arr_patch[np.newaxis,...])
            # print(pred.shape)

            # exit()

            # post-process
            arr_pred[sy+pad_size:ey-pad_size, sx+pad_size:ex-pad_size, :] = \
                        pred[0, pad_size:-pad_size, pad_size:-pad_size, :]
                        #  (pred[0, pad_size:-pad_size, pad_size:-pad_size, :]+1)/2 #  (-1, 1) -> (0, 1)
            print(np.amin(arr_patch), np.amax(arr_patch), np.amin(arr_pred), np.amax(arr_pred))
#             exit()
    
    # exit()

    # arr_pred.astype(np.uint8)
    arr_pred = arr_pred[pad_size:-pad_size, pad_size:-pad_size, :]
    arr_pred = (arr_pred+1) / 2 # normalized from (-1, 1) to (0,1)
    img_pred = Image.fromarray((arr_pred*255).astype(np.uint8))
    # name = os.path.join(PATH_PIXELSHIFT, f'inf_{model_name}_{model_sig}_%02d.png'%(idx+1))
#     name = os.path.join(PATH_VAL, f'inf_{model_name}_{model_sig}_%02d.png'%(idx+1))
    name = os.path.join(PATH_VAL,  '%04d.png'%(idx+1))
    img_pred.save(name)
    print(np.amin(img_pred), np.amax(img_pred), np.amin(arr_pred.astype(np.uint8)), np.amax(arr_pred.astype(np.uint8)))
#     break
end = timeit.timeit()
elapsed = end - start
print("elapsed time ", elapsed / infcnt)
