In [44]:
'''Confusion matrix'''

import tensorflow as tf
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from train import data_prep

x_train, y_train, x_valid, y_valid = data_prep()
n_classes = 9
y_train_categ = tf.keras.utils.to_categorical(y_train, num_classes=n_classes)
y_valid_categ = tf.keras.utils.to_categorical(y_valid, num_classes=n_classes)

model_path = './pretrained.hdf5'
model = tf.keras.models.load_model(model_path)
y_pred_categ = model.predict(x_valid)
y_pred = y_pred_categ.argmax(axis=1)


y_valid = [labels[y] for y in y_valid]
y_pred = [labels[y] for y in y_pred]

labels = {0:'Open Water', 1:'Developed', 2:'Barren Land',
                    3:'Forest', 4:'Scrub', 5:'Grassland/Crops', 6:'Pasture',
                    7:'Wetland', 8:'Null'
                    }
labels = list(labels.values())[:-1]

cm = confusion_matrix(y_valid, 
                          y_pred,
                          labels,
                          normalize='true') # normalized by rows

fig = plt.figure(figsize=(12,12))

ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title('Confusion matrix of the CNN')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')

for (i, j), z in np.ndenumerate(cm):
    ax.text(j, i, '{:0.2f}'.format(z), ha='center', va='center')

# plt.show()

plt.savefig('./figure/confusion_matrix.png')


In [45]:
'''True vs Predicted visualization'''

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib.patches as mpatches
import tensorflow as tf
from util import classify

arr = np.load('./L8_NLCD_extracted_dataset.npy')
n_test = 10
arr_test = arr[-n_test:,:,:,:] 
arr_nlcd = arr_test[:,:,:,8]
arr_l8 = arr_test[:,:,:,:8]

patch_size = 15
model_path = './pretrained.hdf5'
model = tf.keras.models.load_model(model_path)

arr_cls = classify(arr_l8, model, patch_size)

nlcd_dic = {}
nlcd_dic['simple'] = {0:8,11:0,12:0,
                    21:1,22:1,23:1,24:1,31:2,41:3,42:3,
                    43:3,51:4,52:4,71:5,72:5,73:5,
                    74:5,81:6,82:6,90:7,95:7}

def map_nlcd(arr3d, mode='simple'):
    for x0 in range(arr3d.shape[0]):
        for x1 in range(arr3d.shape[1]):
            for x2 in range(arr3d.shape[2]):
                arr3d[x0,x1,x2] = nlcd_dic[mode][arr3d[x0,x1,x2]]
    return arr3d     

labels = {0:'Open Water', 1:'Developed', 2:'Barren Land',
                    3:'Forest', 4:'Scrub', 5:'Grassland/Crops', 6:'Pasture',
                    7:'Wetland', 8:'Null'
                    }

c_hex = ['#3264aa', '#fa0000', '#8c96aa', '#0a8228', '#8caa14', 
        '#3cf014', '#d2f014', '#6ea0be', '#000000']
clist = [(50, 100, 170), (250, 0, 0), (140, 150, 170), (10, 130, 40), 
        (140, 170, 20), (60, 240, 20), (210, 240, 20), (110, 160, 190), (0,0,0)]

arr_nlcd = map_nlcd(arr_nlcd)

def get_arr_rgb(arr_nlcd, clist):
    '''arr_nlcd of size (N, X, Y)'''
    N, X, Y = arr_nlcd.shape
    arr_nlcd_rgb = np.zeros((N, X, Y, 3), dtype=int)
    for n in range(N):
        for x in range(X):
            for y in range(Y): 
                arr_nlcd_rgb[n,x,y,0], arr_nlcd_rgb[n,x,y,1], arr_nlcd_rgb[n,x,y,2] = clist[arr_nlcd[n,x,y]]

    return arr_nlcd_rgb            

arr_nlcd_rgb = get_arr_rgb(arr_nlcd, clist)
arr_cls_rgb = get_arr_rgb(arr_cls, clist)

# Plot
# create a patch (proxy artist) for every color 
patches = [mpatches.Patch(color=c_hex[i], label=labels[i]) \
                                     for i in range(len(c_hex)-1)]

plt.imshow(arr_nlcd_rgb[3,7:-7,7:-7,:])
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
plt.axis('off')
# plt.show()
plt.savefig('./figure/site_true.png')

plt.imshow(arr_cls_rgb[3,7:-7,7:-7,:])
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
plt.axis('off')
# plt.show()
plt.savefig('./figure/site_predicted.png')

