In [41]:
import os, sys
sys.path.append(os.getcwd())
import onnx
import cv2
import torch
import torch.onnx 
import os
from importlib import reload  # Python 3.4+
import onnxruntime as ort
import numpy as np

In [42]:
def psnr(x,y):
    bands = x.shape[2]
    x = np.reshape(x, [-1, bands])
    y = np.reshape(y, [-1, bands])
    msr = np.mean((x-y)**2, 0)
    maxval = np.max(y, 0)**2
    return np.mean(10*np.log10(maxval/msr))

def sam(x, y):
    num = np.sum(np.multiply(x, y), 2)
    den = np.sqrt(np.multiply(np.sum(x**2, 2), np.sum(y**2, 2)))
    sam = np.sum(np.degrees(np.arccos(num / den))) / (x.shape[1]*x.shape[0])
    return sam

def rmse(x, y):
    aux = (np.sum((y-x)**2, (0,1))) / (x.shape[0]*x.shape[1])
    r = np.sqrt(np.mean(aux))
    return r
def awgn(x, snr):
    x = torch.from_numpy(x).float()  
    snr = 10 ** (snr / 10.0)
    xpower = torch.sum(x ** 2) / x.numel()
    npower = torch.sqrt(xpower / snr)
    noisy_x = x + torch.randn(x.shape) * npower 
    return noisy_x.cpu().detach().numpy()

In [43]:
##########################
#import encoder (SNR training strategy)
##########################
encoder = ort.InferenceSession('rtcs_ENCODER_SNR.onnx')
##########################
# input shape (full HSI) #
##########################
input_data = np.load('data/masked_CM/f090819t01p00r07rdn_b_ort_img_256.npy').astype(np.float32)
#print(input_data.shape)
##############################################################################
#Percepted Stripe Input 
##############################################################################
input_data = input_data[0:128,0:4,:]
input_data = np.transpose(input_data, (2,0,1))
input_data = input_data[np.newaxis,:,:,:]
#print(input_data.shape)
##############################################################################
input_name = encoder.get_inputs()[0].name
output_name = encoder.get_outputs()[0].name
result = encoder.run([output_name], {input_name: input_data})
encoded_result = np.array(result[0])
encoded_result_shape = encoded_result.shape
print("Encoded_shape:", encoded_result_shape)


Encoded_shape: (1, 27, 32, 1)


In [48]:
##########################
# import decoder (SNR training strategy)
##########################
decoder = ort.InferenceSession('rtcs_DECODER_SNR.onnx')
#input_data = np.random.randn(1, 27, 32, 1).astype(np.float32)  
input_name = decoder.get_inputs()[0].name
output_name = decoder.get_outputs()[0].name

result = decoder.run([output_name], {input_name: encoded_result})
decoded_result = np.array(result[0])
original_shape = decoded_result.shape
print("Decoded_shape:", original_shape)


Decoded_shape: (1, 172, 128, 4)


In [52]:
##########################
# Encoder-Decoder (Full model) (SNR training strategy)
##########################
model = ort.InferenceSession('rtcs_ENCODER_DECODER_SNR.onnx')
input_data = np.load('data/masked_CM/f090819t01p00r07rdn_b_ort_img_256.npy').astype(np.float32)
#input_data = np.random.randn(1, 172, 128, 4).astype(np.float32)  
input_data = input_data[0:128,0:4,:]
input_data = np.transpose(input_data, (2,0,1))
input_data = input_data[np.newaxis,:,:,:]
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
result = model.run([output_name], {input_name: input_data})

np_result = np.array(result[0])
original_shape = np_result.shape
print("Encoded_Decoded_output_shape:", original_shape)
print('SAM:',sam(input_data, np_result), 'RMSE:',(rmse(input_data, np_result)), 'PSNR:',psnr(input_data, np_result))

Encoded_Decoded_output_shape: (1, 172, 128, 4)
SAM: 1.9902306490166242 RMSE: 15.967152 PSNR: 47.819965


In [49]:
##########################
#import encoder (Mask training strategy)
##########################
encoder = ort.InferenceSession('rtcs_ENCODER_RM.onnx')
##########################
# input shape (full HSI) #
##########################
input_data = np.load('data/masked_CM/f090819t01p00r07rdn_b_ort_img_256.npy').astype(np.float32)
#print(input_data.shape)
##############################################################################
#Percepted Stripe Input 
##############################################################################
input_data = input_data[0:128,0:4,:]
input_data = np.transpose(input_data, (2,0,1))
input_data = input_data[np.newaxis,:,:,:]
#print(input_data.shape) 
##############################################################################
input_name = encoder.get_inputs()[0].name
output_name = encoder.get_outputs()[0].name
result = encoder.run([output_name], {input_name: input_data})
encoded_result = np.array(result[0])
encoded_result_shape = encoded_result.shape
print("Encoded_shape:", encoded_result_shape)


Encoded_shape: (1, 27, 32, 1)


In [50]:
##########################
# import decoder (Mask training strategy)
##########################
decoder = ort.InferenceSession('rtcs_DECODER_RM.onnx')
#input_data = np.random.randn(1, 27, 32, 1).astype(np.float32)  
input_name = decoder.get_inputs()[0].name
output_name = decoder.get_outputs()[0].name
result = decoder.run([output_name], {input_name: encoded_result})

decoded_result = np.array(result[0])
original_shape = decoded_result.shape
print("Decoded_shape:", original_shape)

Decoded_shape: (1, 172, 128, 4)


In [None]:
##########################
# Encoder-Decoder (Mask training strategy)
##########################
model = ort.InferenceSession('rtcs_ENCODER_DECODER_RM.onnx')
input_data = np.load('data/masked_CM/f090819t01p00r07rdn_b_ort_img_256.npy').astype(np.float32)
#input_data = np.random.randn(1, 172, 128, 4).astype(np.float32)  
input_data = input_data[0:128,0:4,:]
input_data = np.transpose(input_data, (2,0,1))
input_data = input_data[np.newaxis,:,:,:]
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[0].name
result = model.run([output_name], {input_name: input_data})

np_result = np.array(result[0])
original_shape = np_result.shape
print("Encoded_Decoded_output_shape:", original_shape)
print('SAM:',sam(input_data, np_result), 'RMSE:',(rmse(input_data, np_result)), 'PSNR:',psnr(input_data, np_result))

Encoded_Decoded_output_shape: (1, 172, 128, 4)
SAM: 2.2291919796965844 RMSE: 18.649158 PSNR: 46.45769
