In [1]:
import os
import sys
from pathlib import Path

#Import utility functions
utilsPath = os.path.join(Path(os.getcwd()).parent.absolute(),'utils')
if utilsPath not in sys.path:
    sys.path.append(utilsPath)

from modelFunctions import iou_coef, DiceLoss, weighted_categorical_crossentropy, prepUnetModel, compileUnetModel
from getTiles import image_gen
from getTiles import read_image

#Import other modules
import random
import glob
import warnings
import numpy as np
import pandas as pd
import matplotlib
import datetime

from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from keras.models import Model, load_model
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam

import keras
import tensorflow as tf
import cv2
import rasterio

from typing import Callable, Union

from sklearn.metrics import confusion_matrix, accuracy_score, multilabel_confusion_matrix, classification_report

import matplotlib.pyplot as plt
from matplotlib import colors
from rasterio.plot import show


In [3]:
#Load data

#Define seed
seed = 30
random.seed = seed
np.random.seed = seed

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')

#Define file path for label and bands
trainingFolderPath = '../inputs/trainingPlots'
validationFolderPath = '../inputs/validationPlots'
bandFilePath = '../inputs/landsatTilesPostYear'
labelFileFormat = 'plot_{}_{}.tif'
labelFileFormat = 'Landsat7_{}_{}.tif'


#Get list of [[plodID, treeCoverLossYear]] for training set and validation set
trainingFileNames = glob.glob(os.path.join(trainingFolderPath,'*.tif'))
trainingFileNames.sort()
trainingPlots = [[int(x.split('_')[1]),int(x.split('_')[2].split('.')[0])] for x in trainingFileNames]

validationFileNames = glob.glob(os.path.join(validationFolderPath,'*.tif'))
validationFileNames.sort()
validationPlots = [[int(x.split('_')[1]),int(x.split('_')[2].split('.')[0])] for x in validationFileNames]

print('Num training tiles: ',len(trainingPlots), ' Num validation tiles: ',len(validationPlots))
print('First 5 training tiles: ',trainingPlots[:5], ' First 5 validation tiles: ',validationPlots[:5])


Num training tiles:  1248  Num validation tiles:  511
First 5 training tiles:  [[103, 2001], [103, 2002], [103, 2003], [103, 2004], [103, 2005]]  First 5 validation tiles:  [[107, 2001], [107, 2002], [107, 2003], [107, 2004], [107, 2006]]


In [4]:
#set parameters
n_classes=8
n_features = 4
image_size = 64
batch_size=64
bands = 'all'  #Format for listing bands: (1,5,6,7)


In [7]:
#Model name to save
modelName = 'unet'
runNum = 0


#Define if we want to reinialize a new model
newModel = True


# Define weights for classes!
classnames = ['No Loss','Hard commodities', 'Forest products', 'Other disturbances',
                'Soft commodities', 'Settlements/infrastructure', 'Fires','Hansen Mistake']
class_weights = np.array([1, 10, 10, 30,
                          30, 30, 30,30]).astype(float)


#Create new model
if newModel==True:
    model = compileUnetModel(n_classes, image_size, n_features, class_weights, learning_rate=0.001, w_decay=0.0005)

# #Load model
else:
    model = keras.models.load_model('models/{}'.format(modelName+str(runNum)), custom_objects = {"loss":weighted_categorical_crossentropy(class_weights),"iou_coef": iou_coef, "weighted_categorical_crossentropy": weighted_categorical_crossentropy(class_weights)})

##print to see model structure
#model.summary()


In [8]:
#load seed
random.seed = seed
np.random.seed = seed


#Scheduler for decreating learning rate as epochs go on
epochCutoff = 5
def scheduler(epoch):
    if epoch < epochCutoff:
        return 0.001
    else:
        return 0.001 * tf.math.exp(0.1 * (epochCutoff - epoch))

#Callbacks for training
lr_schedule = tf.keras.callbacks.LearningRateScheduler(scheduler)
earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('model-{}-1.h5'.format(modelName), verbose=1, save_best_only=True)
log_dir = 'reports/tensorboard/{}/'.format(modelName) + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard = TensorBoard(
    log_dir=log_dir,
    write_graph=True,
    write_images=True,
    update_freq='epoch'
)

#Define generator for training and validation data
train_sequence_generator = image_gen(trainingPlots,trainingFolderPath,bandFilePath,
        n_classes,
        n_features,
        batch_size,
        rotation_range=90,
        horizontal_flip=True,
        vertical_flip=True,bands=bands)


validation_sequence_generator = image_gen(validationPlots,validationFolderPath,bandFilePath,
        n_classes,
        n_features,
        batch_size=batch_size,
        rotation_range=0,
        horizontal_flip=False,
        vertical_flip=False,bands=bands)

#Fit the model!
results = model.fit(train_sequence_generator, 
                    validation_data=validation_sequence_generator,validation_steps=5, 
                    epochs=5,steps_per_epoch=100,
                    callbacks=[lr_schedule,earlystopper, checkpointer,tensorboard],
                    verbose=1)


Epoch 1/5
  9/100 [=>............................] - ETA: 2:00 - loss: 4.5871 - iou_coef: 0.9235

KeyboardInterrupt: 

In [None]:
#Save model
model.save('models/{}'.format(modelName+str(runNum+1)))


In [None]:
#Tensorboard
#--logdir must be hard coded to match modelName
log_dir_short = 'reports/tensorboard/{}'.format(modelName)
%reload_ext tensorboard 
%tensorboard --logdir reports/tensorboard/unet


#http://localhost:6010/

In [None]:
#Plot some results!

#Normalize band data
def NormalizeData(data):
    return data/0.15
    #return (data + 0.2)/(0.3-0.2)
    #return (data - np.min(data)) / (np.max(data) - np.min(data))

    
classnames = ['No\nloss','Hard\ncommodities', 'Forest\nproducts', 'Other\ndisturbances',
                'Soft\ncommodities', 'Urbanization', 'Fires','Hansen\nmistake']
classcolors = ['#ffffff','#FCABAB','#93D896','#C5E4FC','#FBFD38','#BABABA','#FC3B26','#3540c8']

#Load model
model = keras.models.load_model('models/{}'.format(modelName+str(runNum+1)), custom_objects = {"loss":weighted_categorical_crossentropy(class_weights),"iou_coef": iou_coef, "weighted_categorical_crossentropy": weighted_categorical_crossentropy(class_weights)})

maskToTCL = False

##If you want to pick out a select number of plots
random.seed = seed
np.random.seed = seed
ids = np.random.choice(len(validationPlots), 30, replace=False)

# #Otherwise load all validation plots
# ids = np.arange(len(validationPlots))

for idy in ids:
    validationID = validationPlots[idy]
    validationXRaw, label = read_image(validationID, labelFilePath, bandFilePath,bands=bands)
    validationX = np.zeros((1, image_size, image_size, n_features))
    validationX[0,:,:,:] = validationXRaw
    
    label = label[0]
    ynew = model.predict(tf.convert_to_tensor(validationX, np.float32))
    ynew = ynew[0]
    

    #Code block to filter to only loss
    if maskToTCL == True:
        ynew[:,:,0] = 0
        ynew = np.argmax(ynew,axis=-1)
        ynew = np.where(label==0, 0, ynew)
        
    #Otherwise just show prediction
    else:
        ynew = np.argmax(ynew, axis=-1)

    #Get colormap
    cmap = colors.ListedColormap(classcolors)
    bounds=[-0.5,0.5,1.5,2.5,3.5,4.5,5.5,6.5,7.5]
    norm = colors.BoundaryNorm(bounds, cmap.N)
    
    fig1, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3,figsize=(12,6)) # two axes on figure
    
    #Load satellite data
    X = np.zeros((64,64,3))
    X[:,:,0] = validationXRaw[:,:,1]
    X[:,:,1] = validationXRaw[:,:,2]
    X[:,:,2] = validationXRaw[:,:,3]
    X = NormalizeData(X)

    #Plot
    ax1.imshow(X, interpolation='none', aspect='auto')#vmin=0, vmax=num_classes,cmap='tab10')
    ax1.set_title('Satellite for tile {} in {}'.format(validationID[0],validationID[1]))
    ax1.set_aspect('equal', adjustable='box')
    
    ax2.imshow(label,cmap=cmap, norm=norm)#,vmin=0, vmax=num_classes,cmap='tab10')
    ax2.set_title('Label for tile {} in {}'.format(validationID[0],validationID[1]))
    ax2.set_aspect('equal', adjustable='box')
    
    ax3.imshow(ynew,cmap=cmap, norm=norm)#vmin=0, vmax=num_classes,cmap='tab10')
    ax3.set_title('Prediction for tile {} in {}'.format(validationID[0],validationID[1]))
    ax3.set_aspect('equal', adjustable='box')

    cbar = plt.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap), ax=[ax1,ax2,ax3],orientation='horizontal')
    cbar.set_ticks([0,1,2,3,4,5,6,7])
    cbar.set_ticklabels(classnames)

    plt.show()
    
    #Save figure
    #fig1.savefig('../viz/{}.png'.format(validationID), dpi=800)
    
#     #Get statistics of image
#     labelValues, labelCounts = np.unique(label, return_counts=True)
#     predValues, predCounts = np.unique(ynew, return_counts=True)
#     print('Counts for L: ',dict(zip(labelValues, labelCounts)))
#     print('Counts for P: ',dict(zip(predValues, predCounts)))
#     print('Accuracy :',accuracy_score(label.flatten(), ynew.flatten()))
#     matrix = confusion_matrix(label.flatten(), ynew.flatten())
#     print('Accuracy by class: ',matrix.diagonal()/matrix.sum(axis=1))
#     print('Confusion matrix for label {}: '.format(idy))
#     print(classification_report(label.flatten(), ynew.flatten()))
    
    