In [146]:
import scipy
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Activation, Dense, Conv2D, MaxPool2D, Flatten, UpSampling3D, UpSampling2D, Conv3D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import DenseNet121
import matplotlib.pyplot as plt
from HSI2RGB import HSI2RGB
import os

In [147]:
def _loss_function(ground_truth, prediction):
    squared_difference = tf.square(ground_truth - prediction)
    return tf.reduce_mean(squared_difference, axis=-1)

model = Sequential([
    Input([64,64,3]),
    Conv2D(filters=31, kernel_size=(1,1), activation="relu"),
    Conv2D(filters=31, kernel_size=(3,3), activation="relu", padding="same"),
    Conv2D(filters=31, kernel_size=(3,3), activation="relu", padding="same"),
    Conv2D(filters=31, kernel_size=(3,3), activation="relu", padding="same"),  
    Conv2D(filters=31, kernel_size=(3,3), activation="relu", padding="same"),
    Conv2D(filters=31, kernel_size=(3,3), activation="relu", padding="same"),  
])

model.compile(optimizer=Adam(learning_rate=0.002), loss=_loss_function, metrics=["accuracy"])
model.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_16 (Conv2D)          (None, 64, 64, 31)        124       
                                                                 
 conv2d_17 (Conv2D)          (None, 64, 64, 31)        8680      
                                                                 
 conv2d_18 (Conv2D)          (None, 64, 64, 31)        8680      
                                                                 
 conv2d_19 (Conv2D)          (None, 64, 64, 31)        8680      
                                                                 
 conv2d_20 (Conv2D)          (None, 64, 64, 31)        8680      
                                                                 
 conv2d_21 (Conv2D)          (None, 64, 64, 31)        8680      
                                                                 
Total params: 43,524
Trainable params: 43,524
Non-trai

In [148]:
from os import path, listdir
import cv2
import h5py

def load_images(load_dir):
    rgb_paths = [path.join(load_dir, "rgb", image_path) for image_path in listdir(path.join(load_dir,"rgb"))]
    hsi_paths = [path.join(load_dir, "spectral", image_path) for image_path in listdir(path.join(load_dir, "spectral"))]
    
    rgb_images = []
    for rgb_path in rgb_paths:
        rgb_images.append(cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB) / 255)
    
    hsi_images = []
    for hsi_path in hsi_paths:
        with h5py.File(hsi_path, 'r') as hf:
            # swapping axis to match the hight x width x color form of the rgb images
            hsi_images.append(np.array(hf['cube']).swapaxes(0,2))
            bands = np.array(hf['bands'])

    print(f"Loaded {len(rgb_images)} RGB and {len(hsi_images)} HSI Images")
    return rgb_images, hsi_images, bands

def load_hsi(hsi_path):
    with h5py.File(hsi_path, 'r') as hf:
        return np.array(hf['cube']).swapaxes(0,2)

def split_images(rgb_images, hsi_images):
    split_rgb = []
    split_hsi = []
    for rgb, hsi in zip(rgb_images, hsi_images):
        height, width, _ = rgb.shape
        # make sure each image overlaps with at least 8 pixels (48+2*8=64)
        # on each side to avoid stitching problems
        y_split = int(np.floor(height / 48)) + 1
        x_split = int(np.floor(width / 48)) + 1
        
        # decrease the step width to space out the patches more evenly
        y_step = int(48 + np.floor((height - y_split * 48) / y_split))
        x_step = int(48 + np.floor((width - x_split * 48) / x_split))
    
        for y in range(y_split):
            for x in range(x_split):
                # the last patch always go to the edge -> can overlapp more than the other patches
                if x == x_split - 1 and y == y_split - 1:
                    split_rgb.append(rgb[height - 64:,width - 64:])
                    split_hsi.append(hsi[height - 64:,width - 64:])   

                elif x == x_split-1:
                    split_rgb.append(rgb[y*y_step:y*y_step+64,width-64:])
                    split_hsi.append(hsi[y*y_step:y*y_step+64,width-64:])

                elif y == y_split-1:
                    split_rgb.append(rgb[height-64:,x*x_step:x*x_step+64])
                    split_hsi.append(hsi[height-64:,x*x_step:x*x_step+64])
                else:
                    split_rgb.append(rgb[y*y_step:y*y_step+64,x*x_step:x*x_step+64])
                    split_hsi.append(hsi[y*y_step:y*y_step+64,x*x_step:x*x_step+64])
                    
    return split_rgb, split_hsi

def split_image(image):
    split = []
    height, width, _ = image.shape
    # make sure each image overlaps with at least 8 pixels (48+2*8=64)
    # on each side to avoid stitching problems
    y_split = int(np.floor(height / 48)) + 1
    x_split = int(np.floor(width / 48)) + 1

    # decrease the step width to space out the patches more evenly
    y_step = int(48 + np.floor((height - y_split * 48) / y_split))
    x_step = int(48 + np.floor((width - x_split * 48) / x_split))

    for y in range(y_split):
        for x in range(x_split):
            # the last patch always go to the edge -> can overlapp more than the other patches
            if x == x_split - 1 and y == y_split - 1:
                split.append(image[height - 64:,width - 64:])
            elif x == x_split-1:
                split.append(image[y*y_step:y*y_step+64,width-64:])
            elif y == y_split-1:
                split.append(image[height-64:,x*x_step:x*x_step+64])
            else:
                split.append(image[y*y_step:y*y_step+64,x*x_step:x*x_step+64])

    return split

def merge_image(splits, size):
    height, width, depth = size
    
    y_split = int(np.floor(height / 48)) + 1
    x_split = int(np.floor(width / 48)) + 1

    y_step = int(48 + np.floor((height - y_split * 48) / y_split))
    x_step = int(48 + np.floor((width - x_split * 48) / x_split))
    
    image = np.empty((height, width, depth))
    splits = iter(splits)
    
    for y in range(y_split):
        for x in range(x_split):
            # the last patch always go to the edge -> can overlapp more than the other patches
            if x == x_split - 1 and y == y_split - 1:
                image[height - 64:,width - 64:] = next(splits)
                
            elif x == x_split-1:
                image[y*y_step:y*y_step+64,width-64:] = next(splits)
            elif y == y_split-1:
                image[height-64:,x*x_step:x*x_step+64] = next(splits)
            else:
                image[y*y_step:y*y_step+64,x*x_step:x*x_step+64] = next(splits)
                
    return image
    

In [149]:
rgb, hsi, bands = load_images("./data/train")
split_rgb, split_hsi = split_images(rgb, hsi)
rgb_stack = np.stack(split_rgb)
hsi_stack = np.stack(split_hsi)
print(hsi_stack.shape)
print(rgb_stack.shape)

Loaded 50 RGB and 50 HSI Images
(6050, 64, 64, 31)
(6050, 64, 64, 3)


In [151]:
if not os.path.isdir("./checkpoints"):
    mkdir("./checkpoints")

ts = datetime.now().strftime("%Y%M%d_%H%M%S")
checkpoint_filepath = "checkpoints/checkpoint.{epoch:02d}-VLOSS_{val_loss:.4f}-VACC_{val_accuracy:.4f}.hdf5"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_filepath,
                                                               save_weights_only = True,
                                                               monitor = "val_accuracy",
                                                               mode = "max",
                                                               save_best_only = False,
                                                               verbose = 1)


model.fit(x=rgb_stack, y=hsi_stack, validation_split=0.1, batch_size=50, epochs=10, shuffle=True, verbose=2,callbacks=[model_checkpoint_callback])

Epoch 1/10

Epoch 1: saving model to checkpoints/checkpoint.01-0.2959.hdf5
109/109 - 75s - loss: 0.0226 - accuracy: 0.4814 - val_loss: 0.0379 - val_accuracy: 0.2959 - 75s/epoch - 688ms/step
Epoch 2/10

Epoch 2: saving model to checkpoints/checkpoint.02-0.3418.hdf5
109/109 - 79s - loss: 0.0225 - accuracy: 0.4718 - val_loss: 0.0408 - val_accuracy: 0.3418 - 79s/epoch - 726ms/step
Epoch 3/10

Epoch 3: saving model to checkpoints/checkpoint.03-0.3213.hdf5
109/109 - 87s - loss: 0.0224 - accuracy: 0.5083 - val_loss: 0.0402 - val_accuracy: 0.3213 - 87s/epoch - 801ms/step
Epoch 4/10

Epoch 4: saving model to checkpoints/checkpoint.04-0.3079.hdf5
109/109 - 71s - loss: 0.0221 - accuracy: 0.5130 - val_loss: 0.0422 - val_accuracy: 0.3079 - 71s/epoch - 648ms/step
Epoch 5/10

Epoch 6: saving model to checkpoints/checkpoint.06-0.3333.hdf5
109/109 - 80s - loss: 0.0218 - accuracy: 0.5249 - val_loss: 0.0421 - val_accuracy: 0.3333 - 80s/epoch - 738ms/step
Epoch 7/10

Epoch 7: saving model to checkpoints/c

<keras.callbacks.History at 0x7f61a6859420>

In [None]:
example_image = cv2.cvtColor(cv2.imread("data/train/rgb/ARAD_1K_0901.jpg"), cv2.COLOR_BGR2RGB)
example_image_hsi = load_hsi("data/train/spectral/ARAD_1K_0901.mat")

In [None]:
def process_image(image):
    image_splits = split_image(image)
    height, width, depth = image.shape
    reconstructed_splits = []
    for split in image_splits:
        reconstructed_split = model.predict(split.reshape(1,64,64,3), verbose=False)
        reconstructed_splits.append(reconstructed_split)
    reconstructed = merge_image(reconstructed_splits, (height, width, 31))

    return reconstructed

test_reconstructed_image = process_image(example_image / 255)

In [None]:
figure, axis = plt.subplots(1,3, figsize=(12,36)) 

axis[0].imshow(np.stack((test_reconstructed_image[:,:,25],
                     test_reconstructed_image[:,:,15],
                     test_reconstructed_image[:,:,5]),
                    axis=-1))
axis[1].imshow(np.stack((example_image_hsi[:,:,25],
                     example_image_hsi[:,:,15],
                     example_image_hsi[:,:,5]),
                    axis=-1))
axis[2].imshow(example_image)

np.median(test_reconstructed_image[:,:,:])
np.median(test_reconstructed_image[:,:,:])
