In [1]:
#https://medium.com/@pallawi.ds/semantic-segmentation-with-u-net-train-and-test-on-your-custom-data-in-keras-39e4f972ec89
#Keras loss functions
#https://github.com/maxvfischer/keras-image-segmentation-loss-functions
#https://github.com/maxvfischer/keras-image-segmentation-loss-functions/blob/master/losses/multiclass_losses.py#L107
import os
import sys
from pathlib import Path
utilsPath = os.path.join(Path(os.getcwd()).parent.absolute(),'utils')
if utilsPath not in sys.path:
    sys.path.append(utilsPath)
#print(sys.path)
from getTiles import image_gen
from getTiles import read_image
from lossFunctions import iou_coef, DiceLoss, weighted_categorical_crossentropy
import random
import glob
import warnings
import numpy as np
import pandas as pd
import matplotlib
import datetime

from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from keras.models import Model, load_model
from keras.layers import Input, Dense
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.regularizers import l2
from keras.layers.merge import concatenate
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from keras import backend as K
from tensorflow.keras.optimizers import Adam
#from tensorflow.keras.optimizers import SGD

import keras
import tensorflow as tf
import cv2
import rasterio
#import tensorflow.keras.backend as K

from typing import Callable, Union

from sklearn.metrics import confusion_matrix, accuracy_score, multilabel_confusion_matrix, classification_report

import matplotlib.pyplot as plt
from matplotlib import colors
from rasterio.plot import show


In [2]:
#Define training data and test data

seed = 42
random.seed = seed
np.random.seed = seed

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')

num_classes = 8

labelFilePath = '../inputs/labeledTilesValidYears/plot_{}_{}.tif'
bandFilePath = '../inputs/landsatTilesMultiYear/Landsat7_{}_{}.tif'

fileNames = glob.glob('../inputs/labeledTilesValidYears/*.tif')
fileNames.sort()

plotIDYears = [[int(x.split('_')[1]),int(x.split('_')[2].split('.')[0])] for x in fileNames]

plotIDS = np.unique([plotIDYears[i][0] for i in np.arange(len(plotIDYears))])
np.random.shuffle(plotIDS)


#Split into training and validation
trainingPlotIDs = plotIDS[:int(len(plotIDS)*0.7)]
validationPlotIDs = plotIDS[int(len(plotIDS)*0.7):]

trainingPlots = [x for x in plotIDYears if x[0] in trainingPlotIDs]
validationPlots = [x for x in plotIDYears if x[0] in validationPlotIDs]
print(len(trainingPlots), len(validationPlots))


484 209


In [3]:
#set parameters
n_classes=8
n_features = 4
image_size = 64
batch_size=64
bands = (1,5,6,7)


In [4]:
random.seed = seed
np.random.seed = seed

#Build U-Net model
inputs = Input((image_size, image_size, n_features))

w_decay = 0.0005
s = Lambda(lambda x: x) (inputs)

c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (s)
c1 = Dropout(0.1) (c1)
c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c1)
p1 = MaxPooling2D((2, 2)) (c1)

c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (p1)
c2 = Dropout(0.1) (c2)
c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c2)
p2 = MaxPooling2D((2, 2)) (c2)

c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (p2)
c3 = Dropout(0.2) (c3)
c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c3)
p3 = MaxPooling2D((2, 2)) (c3)

c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (p3)
c4 = Dropout(0.2) (c4)
c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c4)
p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (p4)
c5 = Dropout(0.3) (c5)
c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c5)

u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (u6)
c6 = Dropout(0.2) (c6)
c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c6)

u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (u7)
c7 = Dropout(0.2) (c7)
c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c7)

u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (u8)
c8 = Dropout(0.1) (c8)
c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c8)

u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (u9)
c9 = Dropout(0.1) (c9)
c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same',kernel_regularizer=l2(w_decay)) (c9)

outputs = Conv2D(n_classes, (1, 1), padding="same", activation="softmax") (c9)





2021-12-09 09:50:34.456413: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [6]:
#Model name to save
modelName = 'unet2'

#Define if we want to reinialize a new model
newModel = False


# Define weights for classes!
classnames = ['No Loss','Hard commodities', 'Forest products', 'Other disturbances',
                'Soft commodities', 'Settlements/infrastructure', 'Fires','Hansen Mistake']
class_weights = np.array([2, 80, 60, 100,
                          100, 90, 100,100]).astype(float)


model = keras.models.load_model('models/{}'.format(modelName), custom_objects = {"loss":weighted_categorical_crossentropy(class_weights),"iou_coef": iou_coef, "weighted_categorical_crossentropy": weighted_categorical_crossentropy(class_weights)})



In [None]:
def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

classnames = ['No\nloss','Hard\ncommodities', 'Forest\nproducts', 'Other\ndisturbances',
                'Soft\ncommodities', 'Settlements/\ninfrastructure', 'Fires','Hansen\nmistake']
classcolors = ['#ffffff','#FCABAB','#93D896','#C5E4FC','#FBFD38','#BABABA','#FC3B26','#3540c8']


modelPath = 'models/{}'.format(modelName+'2')
# #Load model
model = keras.models.load_model(modelPath, custom_objects = {"loss":weighted_categorical_crossentropy(class_weights),"iou_coef": iou_coef, "weighted_categorical_crossentropy": weighted_categorical_crossentropy(class_weights)})


predictionFilePath = '../outputs/{}/plot_{}_{}.tif'


random.seed = seed
np.random.seed = seed
ids = np.arange(len(validationPlots))#np.random.randint(0, high=len(validationPlots), size=len(validationPlots))#[1,2,3,5,7,9,10,200,300,200,40,20,150]


for idy in ids:
    validationID = validationPlots[idy]
    validationXRaw, label = read_image(validationID, labelFilePath, bandFilePath,bands=bands)
    validationX = np.zeros((1, image_size, image_size, n_features))
    validationX[0,:,:,:] = validationXRaw
    
    ynew = model.predict(validationX)
    ynew = ynew[0]
    label = label[0]
    #ynew = np.argmax(ynew, axis=-1)

#     ##Code block to filter to only loss
    ynew[:,:,0] = 0
    ynew = np.argmax(ynew,axis=-1)
    ynew = np.where(label==0, 0, ynew) 
    
    
    
    labelsrc = rasterio.open(labelFilePath.format(validationID[0],validationID[1]))
    profile = labelsrc.profile


    with rasterio.open(predictionFilePath.format(modelName+'2',validationID[0],validationID[1])
                       , 'w', **profile) as dst:
        dst.write(array.astype(rasterio.uint8), 1)
    
    
#     cmap = colors.ListedColormap(classcolors)
#     bounds=[-0.5,0.5,1.5,2.5,3.5,4.5,5.5,6.5,7.5]
#     norm = colors.BoundaryNorm(bounds, cmap.N)
    
#     fig1, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3,figsize=(12,6)) # two axes on figure
    
#     X = np.zeros((64,64,3))
#     X[:,:,0] = validationXRaw[:,:,1]
#     X[:,:,1] = validationXRaw[:,:,2]
#     X[:,:,2] = validationXRaw[:,:,3]
#     X = NormalizeData(X)

#     ax1.imshow(X, interpolation='none', aspect='auto')#vmin=0, vmax=num_classes,cmap='tab10')
#     ax1.set_title('Satellite for tile {} in {}'.format(validationID[0],validationID[1]))
#     ax1.set_aspect('equal', adjustable='box')
    
    
#     ax2.imshow(label,cmap=cmap, norm=norm)#,vmin=0, vmax=num_classes,cmap='tab10')
#     ax2.set_title('Label for tile {} in {}'.format(validationID[0],validationID[1]))
#     ax3.imshow(ynew,cmap=cmap, norm=norm)#vmin=0, vmax=num_classes,cmap='tab10')
#     ax3.set_title('Prediction for tile {} in {}'.format(validationID[0],validationID[1]))

#     cbar = plt.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap), ax=[ax1,ax2,ax3],orientation='horizontal')
#     cbar.set_ticks([0,1,2,3,4,5,6,7])

#     cbar.set_ticklabels(classnames)

    
#     plt.show()
    
#     fig1.savefig('/Users/kristine/Desktop/Viz/{}.png'.format(idy), dpi=800)
    
    
#     labelValues, labelCounts = np.unique(label, return_counts=True)
#     predValues, predCounts = np.unique(ynew, return_counts=True)
#     print('Counts for L: ',dict(zip(labelValues, labelCounts)))
#     print('Counts for P: ',dict(zip(predValues, predCounts)))
#     print('Accuracy :',accuracy_score(label.flatten(), ynew.flatten()))
#     matrix = confusion_matrix(label.flatten(), ynew.flatten())
#     print('Accuracy by class: ',matrix.diagonal()/matrix.sum(axis=1))
#     #print('Confusion matrix for label {}: '.format(idy))
#     #print(classification_report(label.flatten(), ynew.flatten()))
    