In [1]:
#@title Download NTIRE202 validation dataset (toy example)
!gdown --id  1IK2emmr39ribjE7zeIGW221u8na8_ar2
!mkdir /content/data
!unzip NTIRE2020_Validation_Spectral -d /content/data

Downloading...
From: https://drive.google.com/uc?id=1IK2emmr39ribjE7zeIGW221u8na8_ar2
To: /content/NTIRE2020_Validation_Spectral.zip
100% 573M/573M [00:03<00:00, 159MB/s]
Archive:  NTIRE2020_Validation_Spectral.zip
  inflating: /content/data/ARAD_HS_0451.mat  
  inflating: /content/data/ARAD_HS_0453.mat  
  inflating: /content/data/ARAD_HS_0455.mat  
  inflating: /content/data/ARAD_HS_0456.mat  
  inflating: /content/data/ARAD_HS_0457.mat  
  inflating: /content/data/ARAD_HS_0459.mat  
  inflating: /content/data/ARAD_HS_0462.mat  
  inflating: /content/data/ARAD_HS_0463.mat  
  inflating: /content/data/ARAD_HS_0464.mat  
  inflating: /content/data/ARAD_HS_0465.mat  


In [2]:
#@title Clone JR2net Github Repository
!git clone https://github.com/bemc22/JR2net
%cd JR2net

Cloning into 'JR2net'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (60/60), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 60 (delta 25), reused 31 (delta 10), pack-reused 0[K
Unpacking objects: 100% (60/60), done.
/content/JR2net


In [3]:
import os
import tensorflow as tf
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt

from dataset import get_csi_pipeline, get_val_csi
from jr2net.utils import dd_cassi , dd_icassi, coded2DTO3D

In [4]:
RGB = [27, 17 , 4] # RGB channels
BATCH_SIZE = 3
PATCHES = 9       # The final batch size is BATCH_SIZE * PATCHES
split = 0.9
INPUT_SHAPE = (104, 104, 31)
DATASE_SIZE = (482, 512, 31)

FOLDER_PATH = "/content"

data_path = os.path.join(FOLDER_PATH, 'data')
test_path = data_path

DATA__PARAMS = dict(
    input_size = INPUT_SHAPE,
    batch_size = BATCH_SIZE,
    origin_size = DATASE_SIZE,
)


train_ds = get_csi_pipeline(data_path=data_path, buffer_size=10, **DATA__PARAMS, factor=PATCHES)
# val_ds = get_csi_pipeline(data_path=test_path, buffer_size=None, **DATA__PARAMS)
val_ds = get_val_csi(data_path)

In [5]:
from gc import callbacks
from jr2net.models import JR2net
from jr2net.metrics import prior_loss, psnr
from jr2net.utils import ClearMemory
from tensorflow.keras.callbacks import ModelCheckpoint, Callback, ReduceLROnPlateau
from tensorflow.keras import backend as k

STAGES = 7
FACTORS = [1, 1, 1/2, 1/2, 1/4, 1/8]
EPOCHS = 10

MODEL_PARAMS = {
    'loss': ['mse', 'mse'],
    'optimizer': tf.keras.optimizers.Adam(learning_rate=1e-4, amsgrad=False),
    'metrics': [psnr]
}


UNROLLED_PARAMS = {
    'input_size' : (None, None, INPUT_SHAPE[-1]),
    'num_stages': STAGES,
    'factors': FACTORS,
    'training': True,
}

unrolled_weights = "unrolled_weights.h5"
callbacks = [ModelCheckpoint(unrolled_weights,
                             monitor='recons_psnr',
                             save_best_only=True,
                             save_weights_only=True,
                             mode='max'),
             ClearMemory(),
             ReduceLROnPlateau(monitor="recons_psnr",
                               factor=0.9,
                               patience=100,
                               verbose=0,
                               mode="max")]

main_model = JR2net(**UNROLLED_PARAMS)
model = main_model.unrolled
model.compile(**MODEL_PARAMS)
model.summary()


Model: "JR2net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 spectral_image (InputLayer)    [(None, None, None,  0           []                               
                                 31)]                                                             
                                                                                                  
 coded_aperture (InputLayer)    [(None, None, None,  0           []                               
                                 31)]                                                             
                                                                                                  
 forward_cassi (Lambda)         (None, None, None,   0           ['spectral_image[0][0]',         
                                1)                                'coded_aperture[0][0]',    

In [6]:
model.fit(train_ds, epochs=EPOCHS, validation_data=None, callbacks=None)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7fd7600d6ed0>