In [2]:
!git clone https://github.com/bemc22/JR2net

Cloning into 'JR2net'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 116 (delta 17), reused 8 (delta 8), pack-reused 96[K
Receiving objects: 100% (116/116), 84.67 MiB | 34.93 MiB/s, done.
Resolving deltas: 100% (51/51), done.


In [3]:
%cd "JR2net"

/content/JR2net


In [4]:
import os

import tensorflow as tf
import numpy as np
import scipy.io as sio

from jr2net.utils import dd_cassi , coded2DTO3D
from jr2net.metrics import SAM

In [6]:
! ls

codes  dataset.py	     demo_train.ipynb  LICENSE	  weights
data   demo_inference.ipynb  jr2net	       README.md


In [7]:
RGB = [27, 17 , 4]
BATCH_SIZE = 1
split = 0.9
INPUT_SHAPE = (512, 512, 31)
data_path = r"data/test"

In [8]:
H = sio.loadmat('./codes/H_T=0.3.mat')['H'].astype(np.float32)
H = H[None, ..., None]
H = tf.cast(H, dtype=tf.float32)
H = coded2DTO3D(H)
print(H.shape)

(1, 512, 512, 31)


In [9]:
from jr2net.models import JR2net
from jr2net.metrics import psnr

unrolled_model = 'jr2net'
STAGES = 7
TRANSFER = False
factors = [1, 1, 1/2, 1/2, 1/4, 1/8]
PRIOR_FACTOR=2


main_model = JR2net(input_size=(512, 512,31),num_stages=STAGES, training=False, factors=factors, prior_factor=PRIOR_FACTOR)
model = main_model.unrolled


model_params = {
    'loss' : ['mse', 'mse'],
    'optimizer' : tf.keras.optimizers.Adam(learning_rate=2e-5, amsgrad=False),
    'metrics' : [psnr]
}

model.compile(**model_params,run_eagerly=False)
unrolled_weights = f'./weights/jr2net_kaist.h5'
model.load_weights(unrolled_weights)

In [10]:
import time

imgs_names = os.listdir(data_path)
imgs_names = [name for name in imgs_names if name.endswith('.mat')]


for name in imgs_names:

    path = os.path.join(data_path, name)
    x = sio.loadmat(path)['img']
    x = tf.constant(x, dtype=tf.float32)

    y =  dd_cassi([x, H])

    t = time.time()
    x_est = model.predict((y,H))
    duration = time.time() - t
    print("inference time:", round(duration,3))

    psnr = tf.image.psnr(x, x_est, max_val=1)
    ssim = tf.image.ssim(x, x_est, 1)
    sam = SAM(x, x_est)

    psnr = tf.reduce_mean(psnr).numpy()
    ssim = tf.reduce_mean(ssim).numpy()
    sam = tf.reduce_mean(sam).numpy()

    psnr = round(psnr, 3)
    ssim = round(ssim, 3)
    sam = round(sam, 3)

    print(name, "PSNR:", psnr, "SSIM:", ssim, "SAM:", sam)

inference time: 11.54
Image_29.mat PSNR: 42.189 SSIM: 0.985 SAM: 0.125
inference time: 0.361
Image_30.mat PSNR: 41.695 SSIM: 0.987 SAM: 0.058
inference time: 0.382
Image_28.mat PSNR: 40.356 SSIM: 0.979 SAM: 0.132
