In [None]:
import warnings
warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf
import numpy as np
import dianna
import onnx
from onnx_tf.backend import prepare
import matplotlib.pyplot as plt
from pathlib import Path
from dianna import visualization
from keras import utils
import onnx
import onnxruntime
from scipy.special import softmax

def run_model(data, name):
    fname           = str(Path('models', name))
    sess            = onnxruntime.InferenceSession(fname)
    input_name      = sess.get_inputs()[0].name
    output_name     = sess.get_outputs()[0].name

    onnx_input      = {input_name: data}
    pred_onnx       = sess.run([output_name], onnx_input)

    return softmax(pred_onnx[0], axis=1)

DATA_PATH           = Path('data', 'shapes.npz')
MODEL_NAME          = 'geometric_shapes_model.onnx'
MODEL_PATH          = Path('models', MODEL_NAME)
# load dataset
data                = np.load(DATA_PATH)
# load testing data and the related labels
X_test              = data['X_test'].astype(np.float32).reshape([-1, 1, 64, 64])
y_test              = data['y_test']


# Load saved onnx model
onnx_model          = onnx.load(MODEL_PATH)
output_node         = prepare(onnx_model, gen_tensor_dict=True).outputs[0]
# pred_onnx           = run_model(X_test, MODEL_NAME)
# pred_ids            = pred_onnx.argmax(axis=1)
class_name          = ['circle', 'triangle']

In [None]:
import pickle
with open(Path("data", "test_rotation.pk"), "rb") as f:
    data = pickle.load(f)

In [None]:
plt.imshow(data['triangles'][99]['image'])

In [None]:
imgs = np.array([data['triangles'][j]['image'] for j in range(data['triangles'].shape[0])]) / 255.0
imgs = imgs.astype(np.float32)

In [None]:
r_arr = []
rotated_onnx           = run_model(imgs[:, None, ...], MODEL_NAME)

In [None]:
rotated_ids            = rotated_onnx.argmax(axis=1)
for i_instance in range(len(imgs)):
    # select instance for testing
    test_sample     = imgs[i_instance].copy().astype(np.float32).reshape(1, 64, 64)
    # model predictions with added batch axis to test sample
    predictions     = prepare(onnx_model).run(test_sample[None, ...])[f'{output_node}']
    pred_class      = class_name[np.argmax(predictions)]
    print("The predicted class is:", pred_class, "No.", i_instance)
    relevances      = dianna.explain_image(MODEL_PATH, test_sample,
                                                  method="LIME", labels=[rotated_ids[i_instance]],
                                                  top_label = 1,
                                                  num_features = 10,
                                                  num_samples = 1000,
                                                #   nsamples=1000,
                                                #   n_masks=1000, feature_res=2, p_keep=0.7,
                                                #   axis_labels=('channels','height','width')
                                                )

    class_idx       = rotated_ids[i_instance]
    # fig, ax         = plt.subplots(1,3)
    # ax[0].imshow(relevances[0],cmap='jet')
    # ax[1].imshow(utils.img_to_array(test_sample[0])/255.,cmap='gray')
    # ax[2].imshow(utils.img_to_array(test_sample[0]) / 255., cmap='gray')
    # ax[2].imshow(relevances[0], cmap='jet', alpha=0.4)
    # plt.title(str(1)+'_'+str(rotated_onnx[i_instance,rotated_ids[i_instance]]))
    # plt.show()
    r_arr.append(relevances)
    # if i_instance > 2:
    #     break


In [None]:
# tested with python3
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def GetRotateMatrixWithCenter(x, y, angle):
    # https://math.stackexchange.com/questions/2093314
    move_matrix = np.array(
        [
            [1, 0, x], 
            [0, 1, y], 
            [0, 0, 1]
        ])
    rotation_matrix = np.array(
        [
            [np.cos(angle), -np.sin(angle), 0], 
            [np.sin(angle),  np.cos(angle), 0], 
            [0,                       0,                      1]
        ])
    back_matrix = np.array(
        [
            [1, 0, -x], 
            [0, 1, -y], 
            [0, 0, 1]
        ])

    r = np.dot(move_matrix, rotation_matrix)
    return np.dot(r, back_matrix)

def Apply_Matrix_To_Image(matrix_to_apply, image_map):
    #takes an image and matrices and applies it.  
    x_min = 0
    y_min = 0
    x_max = image_map.shape[0]
    y_max = image_map.shape[1] 

    new_image_map = np.zeros((x_max, y_max), dtype=int) + img.max()

    for y_counter in range(0, y_max):
        for x_counter in range(0, x_max):
            curr_pixel = [x_counter,y_counter,1]

            curr_pixel = np.dot(matrix_to_apply, curr_pixel)

            # print(curr_pixel)

            if curr_pixel[0] > x_max - 1 or curr_pixel[1] > y_max - 1 or x_min > curr_pixel[0] or y_min > curr_pixel[1]:
                next
            else:
                new_image_map[x_counter][y_counter] = image_map[int(curr_pixel[0])][int(curr_pixel[1])] 

    return new_image_map


# convert image to grayscale
img = r_arr[0][0]

image_width = img.shape[0]
image_height = img.shape[1] 

# plt.subplot(1,2,1)
# plt.title('Origin heatmap')
# plt.imshow(img, cmap='gray', vmin=0, vmax=1)

# plt.subplot(1,2,2)
# plt.title('Transformed image')


alpha = 0

rotated_imgs = []
for ind, alp in enumerate(range(len(imgs)-1)):
    rotation_angle = 0 + alp
    # alpha = alpha + 1 # increate 1 degree
    # if alpha > 120:
    #     break
    rotation_angle = np.deg2rad(rotation_angle) # degree to radian
    
    rotation_matrix = GetRotateMatrixWithCenter(image_width / 2, image_height / 2, rotation_angle)

    rotated = Apply_Matrix_To_Image(rotation_matrix, img)

    rotated_imgs.append(rotated)

    # plt.imshow(rotated, cmap='gray', vmin=0, vmax=1)
    plt.subplot(1,2,1)
    plt.title('Heatmap of the image (clockwise rotated '  +str(ind) +' deg)', fontsize=6)
    plt.imshow(r_arr[ind][0], cmap='gray', vmin=0, vmax=1)

    plt.subplot(1,2,2)
    plt.title('Heatmap of the image (clockwise rotated '  +str(ind+1) +' deg)', fontsize=6)
    plt.imshow(r_arr[ind+1][0], cmap='gray')
    plt.pause(0.001)
# plt.show()