In [1]:
# Disable TensorFlow debugging info and warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 2: Info and warnings not displayed 

In [2]:
# Add massimal tools folder to path
import sys
sys.path.append("/massimal/python/tools")

In [3]:
import tensorflow as tf
import numpy as np
import skimage.io
import sklearn.decomposition
import matplotlib.pyplot as plt
import pathlib
#import tqdm
import pickle
import hyspec_cnn
import datetime

In [4]:
# Paths
base_dir = pathlib.Path('/massimal/data/Vega_Sola/Hyperspectral/20220823/Area')
train_tiles_path = base_dir / '3a_PCA_TrainValidationSplit/Training/PCA-Tiles/20220823_Vega_Sola_Train_Tiles'
val_tiles_path = base_dir / '3a_PCA_TrainValidationSplit/Validation/PCA-Tiles/20220823_Vega_Sola_Val_Tiles'

unet_model_save_dir = base_dir / 'M_UnetModels'
unet_model_save_dir.mkdir(exist_ok=True)
tensorboard_log_dir = base_dir / 'M_TensorBoardLogs'
tensorboard_log_dir.mkdir(exist_ok=True)

In [5]:
# Check if GPU is used
tf.config.get_visible_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [6]:
# Parameters
OUTPUT_CHANNELS = 8
BATCH_SIZE = 8
DEPTH = 4

In [7]:
# Load datasets 
train_dataset = tf.data.Dataset.load(str(train_tiles_path))
val_dataset = tf.data.Dataset.load(str(val_tiles_path))

In [8]:
# Get number of tiles in each dataset, and dataset shape
n_tiles_train = train_dataset.cardinality()
n_tiles_val = val_dataset.cardinality()
tile_nrows,tile_ncols,tile_nchannels = train_dataset.element_spec[0].shape.as_list()
print(f'Number of training tiles: {n_tiles_train}')
print(f'Number of validation tiles: {n_tiles_val}')
print(f'Tile data shape (PCA tiles): {(tile_nrows,tile_ncols,tile_nchannels)}')

Number of training tiles: 3266
Number of validation tiles: 711
Tile data shape (PCA tiles): (128, 128, 8)


In [9]:
# From https://www.tensorflow.org/tutorials/images/segmentation#optional_imbalanced_classes_and_class_weights
def add_sample_weights(image, label, name):
    class_weights = tf.constant([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) # Hard-coded for 7 classes
    class_weights = class_weights/tf.reduce_sum(class_weights)

    # Create an image of `sample_weights` by using the label at each pixel as an 
    # index into the `class weights` .
    sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

    return image, label, sample_weights

In [10]:
# Shuffle training dataset (tiles are originally ordered by image) and add sample weights
train_dataset = train_dataset.shuffle(buffer_size=n_tiles_train)
train_dataset = train_dataset.map(add_sample_weights)
val_dataset = val_dataset.map(add_sample_weights)

In [11]:
# Batch datasets
train_dataset_batch = train_dataset.batch(BATCH_SIZE)
val_dataset_batch = val_dataset.batch(BATCH_SIZE)

In [12]:
# Create the U-Net model
unet = hyspec_cnn.unet(input_channels=tile_nchannels,
                       output_channels=OUTPUT_CHANNELS,
                       first_layer_channels=32,
                       depth = DEPTH,
               )
unet.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, None, None,  0           []                               
                                 8)]                                                              
                                                                                                  
 augmentation (Sequential)      (None, None, None,   0           ['input_image[0][0]']            
                                8)                                                                
                                                                                                  
 initial_convolution (Conv2D)   (None, None, None,   2336        ['augmentation[0][0]']           
                                32)                                                           

In [13]:
# Print layers with sublayers
for layer in unet.layers:
    print('----')
    print(layer.name)
    if hasattr(layer,'layers'):
        for l in layer.layers:
            print('\t'+l.name)
       # print(layer.layers)

----
input_image
----
augmentation
	random_flip
----
initial_convolution
----
downsamp_res_1/2
	conv2d
	batch_normalization
	leaky_re_lu
----
downsamp_res_1/4
	conv2d_1
	batch_normalization_1
	leaky_re_lu_1
----
downsamp_res_1/8
	conv2d_2
	batch_normalization_2
	leaky_re_lu_2
----
downsamp_res_1/16
	conv2d_3
	batch_normalization_3
	leaky_re_lu_3
----
upsamp_res_1/8
	conv2d_transpose
	batch_normalization_4
	re_lu
----
skipconnection_res_1/8
----
upsamp_res_1/4
	conv2d_transpose_1
	batch_normalization_5
	re_lu_1
----
skipconnection_res_1/4
----
upsamp_res_1/2
	conv2d_transpose_2
	batch_normalization_6
	re_lu_2
----
skipconnection_res_1/2
----
upsamp_res_1/1
	conv2d_transpose_3
	batch_normalization_7
	re_lu_3
----
skipconnection_res_1/1
----
classification


In [14]:
# Define callbacks
model_save_filename = str(unet_model_save_dir) + '/unet_model.depth' + str(DEPTH) +'.epoch{epoch:02d}-loss{val_loss:.6f}-acc{val_sparse_categorical_accuracy:.3f}.hdf5'
callbacks =[tf.keras.callbacks.ModelCheckpoint(filepath = model_save_filename,
                                               save_best_only=True,
                                               verbose = 1),
            tf.keras.callbacks.ReduceLROnPlateau(factor=0.2, verbose=1),
            tf.keras.callbacks.TensorBoard(log_dir= tensorboard_log_dir)]

In [15]:
model_save_filename

'/massimal/data/Vega_Sola/Hyperspectral/20220823/Area/M_UnetModels/unet_model.depth4.epoch{epoch:02d}-loss{val_loss:.6f}-acc{val_sparse_categorical_accuracy:.3f}.hdf5'

In [16]:
# Compile model
unet.compile(optimizer=tf.keras.optimizers.RMSprop(0.0001), 
             loss="sparse_categorical_crossentropy",
             weighted_metrics=['sparse_categorical_accuracy'], # Need weights to ignore background
             metrics = []) # Sparse because classes are numbered, not one-hot

In [17]:
# print(tensorboard_log_dir)
#%tensorboard --logdir /massimal/data/Vega_Sola/Hyperspectral/20220823/Area/M_TensorBoardLogs

In [18]:
# Fit model to dataset
history = unet.fit(train_dataset.batch(BATCH_SIZE),
                   epochs=100,
                   validation_data=val_dataset.batch(BATCH_SIZE),
                   callbacks=callbacks)

Epoch 1/100
Epoch 1: val_loss improved from inf to 0.08209, saving model to /massimal/data/Vega_Sola/Hyperspectral/20220823/Area/M_UnetModels/unet_model.depth4.epoch01-loss0.082087-acc0.587.hdf5
Epoch 2/100
Epoch 2: val_loss improved from 0.08209 to 0.07243, saving model to /massimal/data/Vega_Sola/Hyperspectral/20220823/Area/M_UnetModels/unet_model.depth4.epoch02-loss0.072435-acc0.661.hdf5
Epoch 3/100
Epoch 3: val_loss improved from 0.07243 to 0.06779, saving model to /massimal/data/Vega_Sola/Hyperspectral/20220823/Area/M_UnetModels/unet_model.depth4.epoch03-loss0.067786-acc0.652.hdf5
Epoch 4/100
Epoch 4: val_loss did not improve from 0.06779
Epoch 5/100
Epoch 5: val_loss did not improve from 0.06779
Epoch 6/100
Epoch 6: val_loss did not improve from 0.06779
Epoch 7/100
Epoch 7: val_loss improved from 0.06779 to 0.05611, saving model to /massimal/data/Vega_Sola/Hyperspectral/20220823/Area/M_UnetModels/unet_model.depth4.epoch07-loss0.056109-acc0.719.hdf5
Epoch 8/100
Epoch 8: val_loss d

Epoch 22/100
Epoch 22: val_loss did not improve from 0.05416
Epoch 23/100
Epoch 23: val_loss improved from 0.05416 to 0.04984, saving model to /massimal/data/Vega_Sola/Hyperspectral/20220823/Area/M_UnetModels/unet_model.depth4.epoch23-loss0.049841-acc0.764.hdf5
Epoch 24/100
Epoch 24: val_loss did not improve from 0.04984
Epoch 25/100
Epoch 25: val_loss did not improve from 0.04984
Epoch 26/100
Epoch 26: val_loss did not improve from 0.04984
Epoch 27/100
Epoch 27: val_loss did not improve from 0.04984
Epoch 28/100
Epoch 28: val_loss did not improve from 0.04984
Epoch 29/100
Epoch 29: val_loss did not improve from 0.04984
Epoch 30/100
Epoch 30: val_loss did not improve from 0.04984
Epoch 31/100
Epoch 31: val_loss did not improve from 0.04984
Epoch 32/100
Epoch 32: val_loss did not improve from 0.04984
Epoch 33/100
Epoch 33: val_loss did not improve from 0.04984

Epoch 33: ReduceLROnPlateau reducing learning rate to 1.9999999494757503e-05.
Epoch 34/100
Epoch 34: val_loss improved from 0.0

Epoch 44/100
Epoch 44: val_loss did not improve from 0.04980
Epoch 45/100
Epoch 45: val_loss did not improve from 0.04980
Epoch 46/100
Epoch 46: val_loss did not improve from 0.04980
Epoch 47/100
Epoch 47: val_loss did not improve from 0.04980
Epoch 48/100
Epoch 48: val_loss did not improve from 0.04980
Epoch 49/100
Epoch 49: val_loss did not improve from 0.04980
Epoch 50/100
Epoch 50: val_loss did not improve from 0.04980
Epoch 51/100
Epoch 51: val_loss did not improve from 0.04980
Epoch 52/100
Epoch 52: val_loss did not improve from 0.04980
Epoch 53/100
Epoch 53: val_loss did not improve from 0.04980

Epoch 53: ReduceLROnPlateau reducing learning rate to 7.999999979801942e-07.
Epoch 54/100
Epoch 54: val_loss did not improve from 0.04980
Epoch 55/100
Epoch 55: val_loss did not improve from 0.04980
Epoch 56/100
Epoch 56: val_loss did not improve from 0.04980
Epoch 57/100
Epoch 57: val_loss did not improve from 0.04980
Epoch 58/100
Epoch 58: val_loss did not improve from 0.04980
Epoch 5

Epoch 67/100
Epoch 67: val_loss did not improve from 0.04980
Epoch 68/100
Epoch 68: val_loss did not improve from 0.04980
Epoch 69/100
Epoch 69: val_loss did not improve from 0.04980
Epoch 70/100
Epoch 70: val_loss did not improve from 0.04980
Epoch 71/100
Epoch 71: val_loss did not improve from 0.04980
Epoch 72/100
Epoch 72: val_loss did not improve from 0.04980
Epoch 73/100
Epoch 73: val_loss did not improve from 0.04980

Epoch 73: ReduceLROnPlateau reducing learning rate to 3.199999980552093e-08.
Epoch 74/100
Epoch 74: val_loss did not improve from 0.04980
Epoch 75/100
Epoch 75: val_loss did not improve from 0.04980
Epoch 76/100
Epoch 76: val_loss did not improve from 0.04980
Epoch 77/100
Epoch 77: val_loss did not improve from 0.04980
Epoch 78/100
Epoch 78: val_loss did not improve from 0.04980
Epoch 79/100
 51/409 [==>...........................] - ETA: 1:24 - loss: 0.0227 - sparse_categorical_accuracy: 0.8994

KeyboardInterrupt: 