# Test Harness Jupyter Notebook for [thesis name here]

## *About*
TODO

# Geoscience and Remote Sensing Society (GRSS) Data Fusion Contest (DFC) 2018 Univerisity of Houston Dataset

## *LiDAR*
- **Lidar sensor**: Optech Titan MW (14SEN/CON340) w/ integrated camera (This is a multispectral lidar sensor)
- **Ground Sampling Distance (GSD)**: 0.5 meters (for point clouds, multispectral intensity, DSM, and DEM)
- **Lidar MS Channel #1 (C1)**: 1550nm wavelength (near infrared)
- **Lidar MS Channel #2 (C2)**: 1064nm wavelength (near infrared)
- **Lidar MS Channel #3 (C3)**: 532nm wavelength (green)
- **DSM_C12**: a first surface model (DSM) generated from first returns detected on Titan channels 1 and 2. The elevations of the first returns were interpolated to a 50cm (0.5m) grid using Kriging with a search radius of 5 meters.
- **DEM_C123_3msr**: a bare-earth digital elevation model (DEM) generated from returns classified as ground from all three Titan channels. The elevations of the first returns were interpolated to a 50cm (0.5m) grid using Kriging with a search radius of 3 meters. This model represents terrain with data voids within the footprints of buildings and other manmade structures.
- **DEM_C123_TLI**: a bare-earth DEM generated from returns classified as ground from all three Titan channels. The elevations of the first returns were interpolated to a 50cm (0.5m) grid using a *triangulation with linear interpolation* algorithm. This model represents terrain where the voids of footprints of buildings and other manmade structures have been filled by the algorithm.
- **DEM+B_C123**: a hybrid DEM that combines returns that were classified as coming from buildings and the ground detected in all three Titan channels. The elevations of the returns were interpolated to a 50cm (0.5m) grid using Kriging with a search radius of 5 meters.

## *Hyperspectral*
- **Hyperspectral Image (HSI) sensor**: ITRES CASI 1500
- **Ground Sampling Distance (GSD)**: 1 meter
- **Number of spectral bands**: 48
- **Spectral range of data**: 380 - 1050nm



# --- TO-DO List ---
- Create thresholding filter for lidar intensity raster and DSM data
    - In the contest paper "Those pixel values that are greater than a threshold T are replaced with the minimum value in the data. We set T as 1*e*4 and 1*e*10 for LiDAR intensity raster data and DSM data, respectively"
- Create normalized DSM (NDSM) to feed to algorithm (NDSM = DSM - DEM)
- For all data (Hyperspectral, MS LiDAR, NDSM) normalize each feature dimension into range of \[0,1\]
- Contest winner conducted image partitioning, partitioning the main image into 40 sub-images with a size of 1202x300
    - During the test phase, there was no need to restore the gradient of the network anymore so the full test image is partitioned into 15 sub-images with a size of 2404 x 600
- Image blurring (smoothing) w/ Gaussian filter?
- In EDA, do histograms of Hyperspectral and lidar values, make it interactive to do over all classes as well as overall

# 1) Install/import required libraries and set up environment

In [None]:
#@title Install GDAL, Rasterio, and Spectral libraries
# Install GDAL
#!apt install gdal-bin python-gdal python3-gdal

# Install Rasterio
#!pip install rasterio

# Install Spectral (do I need git clone?)
#!pip install spectral
#!git clone https://github.com/spectralpython/spectral.git

In [None]:
#@title Import libraries to notebook
### Standard Python Libraries ###
import argparse
import collections
import copy
from copy import deepcopy
import datetime
import gc
import math
from operator import truediv
import os
from pathlib import Path
import time
import traceback

### Data Manipulation Libaries ###
import numpy as np
import pandas as pd

### Other Library Imports ###
import cpuinfo

### Band Selection Libraries ###
from __future__ import division, print_function, absolute_import
import cvxpy as cvx
from munkres import Munkres
import nimfa
from numpy import linalg
from scipy.ndimage import gaussian_filter, median_filter
from scipy.special import expit
from scipy.sparse.linalg import svds
from skfeature.function.sparse_learning_based import NDFS
from skfeature.function.similarity_based import lap_score, SPEC
from skfeature.utility import construct_W
from skfeature.utility.sparse_learning import feature_ranking
from sklearn import cluster
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.cluster import SpectralClustering, KMeans
from sklearn.decomposition import PCA, FastICA
from sklearn.linear_model import orthogonal_mp_gram, OrthogonalMatchingPursuit
from sklearn.metrics import accuracy_score, pairwise_distances
from sklearn.model_selection import cross_val_score, cross_val_predict, StratifiedKFold, train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
from sklearn.preprocessing import maxabs_scale, normalize
from sklearn.svm import SVC

### Hyperspectral Libaries ###
import rasterio
from rasterio.enums import Resampling
from rasterio.windows import Window
import spectral

### Data Visualization Libraries ###
import matplotlib.pyplot as plt
import seaborn as sns

# import Jupyter NB widgets
import ipywidgets as widgets
from IPython.display import display

### Machine Learning Libraries ###
import scipy.io as sio
from sklearn import metrics, preprocessing
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import regularizers
import tensorflow.keras.callbacks as kcallbacks
from tensorflow.keras.callbacks import (
    EarlyStopping,
    LearningRateScheduler,
    ModelCheckpoint,
)
from tensorflow.keras.layers import (
    Activation,
    Add,
    AveragePooling2D,
    AveragePooling3D,
    BatchNormalization,
    Concatenate,
    Conv1D,
    Conv2D,
    Conv3D,
    Dense,
    Dropout,
    Flatten,
    GlobalAveragePooling2D,
    GlobalAveragePooling3D,
    Input,
    MaxPooling2D,
    MaxPooling3D,
    Reshape,
    Resizing,
)
from tensorflow.keras.models import (
    Model,
    Sequential,
)
from tensorflow.keras.optimizers import (
    Adadelta,
    Adagrad,
    Adam,
    Adamax,
    Ftrl,
    Nadam,
    RMSprop,
    SGD,
)

from tensorflow.keras.utils import (
    Sequence,
    to_categorical,
)


In [None]:
#@title Environment Setup
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 2) Dataset Setup Functions

## 2.1) Dataset Setup Utility Functions

In [None]:
#@title Pad Image
def pad_img(img, pad_width, ignore_dims=[]):
    """
    """

    padding = [(pad_width,) if dim not in ignore_dims else (0,) for dim in range(img.ndim)]
    img = np.pad(img, padding, mode='constant')

    return img

In [None]:
#@title Threshold Image
def threshold_image(img, threshold):
    """
    """
    img = img[img > threshold] = img.min()
    return img

In [None]:
#@title Normalize Image
def normalize_image(img):
    """
    """
    img = img.astype(float, copy=False)
    img -= img.min()
    img /= img.max()

    return img

In [None]:
#@title Ground Truth Sampling
def sample_gt(gt, train_size, mode='random'):
    """Extract a fixed percentage of samples from an array of labels.

    Args:
        gt: a 2D array of int labels
        percentage: [0, 1] float
    Returns:
        train_gt, test_gt: 2D arrays of int labels

    """
    indices = np.nonzero(gt)
    X = list(zip(*indices)) # x,y features
    y = gt[indices].ravel() # classes
    train_gt = np.zeros_like(gt)
    test_gt = np.zeros_like(gt)
    if train_size > 1:
       train_size = int(train_size)

    if mode == 'random':
       train_indices, test_indices = train_test_split(X, train_size=train_size, stratify=y)
       train_indices = [list(t) for t in zip(*train_indices)]
       test_indices = [list(t) for t in zip(*test_indices)]
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]
    elif mode == 'fixed':
       print(f'Sampling {mode} with train size = {train_size}')
       train_indices, test_indices = [], []
       for c in np.unique(gt):
           if c == 0:
              continue
           indices = np.nonzero(gt == c)
           X = list(zip(*indices)) # x,y features

           train, test = train_test_split(X, train_size=train_size)
           train_indices += train
           test_indices += test
       train_indices = tuple([list(t) for t in zip(*train_indices)])
       test_indices = tuple([list(t) for t in zip(*test_indices)])
       train_gt[train_indices] = gt[train_indices]
       test_gt[test_indices] = gt[test_indices]

    elif mode == 'disjoint':
        train_gt = np.copy(gt)
        test_gt = np.copy(gt)
        for c in np.unique(gt):
            mask = gt == c
            for x in range(gt.shape[0]):
                first_half_count = np.count_nonzero(mask[:x, :])
                second_half_count = np.count_nonzero(mask[x:, :])
                try:
                    ratio = first_half_count / (first_half_count + second_half_count)
                    if ratio > 0.9 * train_size:
                        break
                except ZeroDivisionError:
                    continue
            mask[:x, :] = 0
            train_gt[mask] = 0

        test_gt[train_gt > 0] = 0
    else:
        raise ValueError(f'{mode} sampling is not implemented yet.')
    return train_gt, test_gt

In [None]:
#@title Get Valid Ground Truth Indicies
def get_valid_gt_indices(gt, ignored_labels=[]):

    mask = np.ones_like(gt)
    for label in ignored_labels:
        mask[gt == label] = 0

    x_pos, y_pos = np.nonzero(mask)
    indices = np.array([(x, y) for x, y in zip(x_pos, y_pos)])

    return indices

In [None]:
#@title Histogram Equalization
def histogram_equalization(img):
    """
    """

    # def _equalize_layer(layer):
    #     """
    #     Internal function to equalize a single image channel slice.
    #     """
    #     height, width = layer.shape
    #     h, bin = np.histogram(layer.flatten(), 256, [0, 256])

    #     cdf = np.cumsum(h)

    #     cdf_m = np.ma.masked_equal(cdf,0)
    #     cdf_m = (cdf_m - cdf_m.min())*255/(cdf_m.max()-cdf_m.min())
    #     cdf_final = np.ma.filled(cdf_m,0).astype('uint8')

    # if img.ndim == 2:
    #     return cv2.equalizeHist(img)
    # else:
    #     for channel in range(img.shape[-1]):
    #         img[:,:, channel] = cv2.equalizeHist(img[:,:,channel])
    #     return img
        #return cv2.equalizeHist(img)

    return img

In [None]:
#@title Image Filtering
def filter_image(img, filter_type, filter_size=None, normalize=False):
    """
    """

    failure = False

    if img.ndim == 2:
        height, width = img.shape
        channels = 1
    else:
        height, width, channels = img.shape

    if filter_type == 'median':
        if filter_size is None: filter_size = (3, 3, channels)
        img = median_filter(img, size=filter_size)
    elif filter_type == 'gaussian':
        img = gaussian_filter(img, 1)
    else:
        print(f'Bad filter argument {filter_type}! Image left alone...')
        failure = True

    if not failure and normalize:
        img=normalize_image(img)

    return img

## 2.2) Hyperspectral Dataset Class

In [None]:
#@title Hyperspectral Dataset Class (Extends Sequence)
class HyperspectralDataset(Sequence):
    def __init__(self, data, gt, shuffle=True, equal_class_distribution=False, **params):
        """
        Args:
            data: 3D hyperspectral image
            gt: 2D array of labels
            shuffle: bool, set to True to shuffle data at each epoch
            hyperparams: extra hyperparameters for setting up dataset
        """
        # super(HyperspectralDataset, self).__init__()
        self.data = data
        self.input_channels = params['input_channels']
        self.gt = gt
        self.shuffle = shuffle
        self.equal_class_distribution = equal_class_distribution
        self.batch_size = params['batch_size']
        self.patch_size = params['patch_size']
        self.supervision = params['supervision']
        self.ignored_labels = set(params['ignored_labels'])
        self.num_classes = params['n_classes']
        self.loss = params['loss']
        self.expand_dims = params['expand_dims']
        self.flip_augmentation = params["flip_augmentation"]
        self.radiation_augmentation = params["radiation_augmentation"]
        self.mixture_augmentation = params["mixture_augmentation"]
        self.center_pixel = params["center_pixel"]

        if self.input_channels is not None:
            self.multi_input = True
        else:
            self.multi_input = False

        if self.supervision == "full":
            mask = np.ones_like(gt)
            for label in self.ignored_labels:
                mask[gt == label] = 0
        # Semi-supervised : use all pixels, except padding
        elif self.supervision == "semi":
            mask = np.ones_like(gt)
        x_pos, y_pos = np.nonzero(mask)
        num_neighbors = self.patch_size // 2
        self.indices = np.array(
            [
                (x, y)
                for x, y in zip(x_pos, y_pos)
                if x > num_neighbors
                    and x < data.shape[0] - num_neighbors
                    and y > num_neighbors
                    and y < data.shape[1] - num_neighbors
            ]
        )

        self.labels = np.array([gt[x, y] for x, y in self.indices])

        self.sample_set = {label:[] for label in np.unique(self.gt) if label not in self.ignored_labels}

        for index in self.indices:
            self.sample_set[self.gt[tuple(index)]].append(tuple(index))

        self.num_samples_to_select = None
        for key, class_samples in self.sample_set.items():
            num_samples = len(class_samples)
            if self.num_samples_to_select is None or self.num_samples_to_select > num_samples:
                self.num_samples_to_select = num_samples

        # Run epoch end function to initialize dataset
        self.on_epoch_end()

    def on_epoch_end(self):
        if self.equal_class_distribution:
            # Randomly pick an equal number of indices from each class
            # to be used in the training for this epoch
            self.indices = np.array(
                [index for key in self.sample_set
                    for index in np.random.default_rng().choice(self.sample_set[key],
                                                            self.num_samples_to_select,
                                                            axis=0,
                                                            replace=False)]
            )

            # Make sure to shuffle all of the different class indices
            # around and apply the appropriate labels for those indices
            np.random.shuffle(self.indices)
            self.labels = np.array([self.gt[x, y] for x, y in self.indices])

        elif self.shuffle:
            # Shuffle the indices around and apply the appropriate
            # labels for those indices
            np.random.shuffle(self.indices)
            self.labels = np.array([self.gt[x, y] for x, y in self.indices])


    def __len__(self):
        return math.ceil(len(self.indices) / self.batch_size)

    def __getitem__(self, i):
        if self.multi_input:
            batch_data = [[] for _ in self.input_channels]
        else:
            batch_data = []
        batch_labels = []

        # Get all items in batch
        for item in range(i*self.batch_size,(i+1)*self.batch_size):

            # Make sure not to look for item id greater than number of
            # indices
            if item >= len(self.indices): break

            # Get index tuple from indices
            index = tuple(self.indices[item])

            # Get data patch for the index
            data, label= self.__get_patch(self.data,
                                          self.gt,
                                          index,
                                          self.patch_size)

            if self.flip_augmentation and self.patch_size > 1:
                # Perform data augmentation (only on 2D patches)
                data, label = self.flip(data, label)
            if self.radiation_augmentation and np.random.random() < 0.1:
                data = self.radiation_noise(data)
            if self.mixture_augmentation and np.random.random() < 0.2:
                data = self.mixture_noise(data, label)

            # Extract the center label if needed
            if self.center_pixel and self.patch_size > 1:
                label = label[self.patch_size // 2, self.patch_size // 2]
            # Remove unused dimensions when we work with invidual spectrums
            elif self.patch_size == 1:
                data = data[:, 0, 0]
                label = label[0, 0]

            # Add a fourth dimension for 3D CNN
            if self.expand_dims and self.patch_size > 1:
                # Make 4D data ((Batch x) Planes x Channels x Width x Height)
                # E.g. adding a dimension for 'planes'
                axis = len(data.shape) if K.image_data_format() == 'channels_last' else 0
                data = np.expand_dims(data, axis)
                # patch = tf.expand_dims(patch, 0)

            # Break the data into inputs if needed
            if self.multi_input:
                data = [data.take(channels, axis=data.ndim-1) for channels in self.input_channels]

            # If categorical cross-entropy, make sure labels are one-hot
            # encoded
            if self.loss == 'categorical_crossentropy':
                label = to_categorical(label, num_classes = self.num_classes)

            # Add data to lists
            if self.multi_input:
                for idx in range(len(batch_data)):
                    batch_data[idx].append(tf.convert_to_tensor(data[idx], dtype='float32'))
            else:
                batch_data.append(tf.convert_to_tensor(data, dtype='float32'))
            batch_labels.append(label)

        if self.multi_input:
            for idx in range(len(batch_data)):
                batch_data[idx] = tf.convert_to_tensor(batch_data[idx])
            batch_data = (*batch_data,)
        else:
            batch_data = tf.convert_to_tensor(batch_data)

        batch_labels = tf.convert_to_tensor(batch_labels)

        return batch_data, batch_labels

    @staticmethod
    def __get_patch(data, gt, index, patch_size):
        x, y = index
        x1 = x - patch_size // 2    # Leftmost edge of patch
        y1 = y - patch_size // 2    # Topmost edge of patch
        x2 = x1 + patch_size        # Rightmost edge of patch
        y2 = y1 + patch_size        # Bottommost edge of patch

        patch = data[x1:x2, y1:y2]
        label = gt[x1:x2, y1:y2]

        # Copy the data into numpy arrays
        patch = np.asarray(np.copy(patch), dtype="float32")
        label = np.asarray(np.copy(label), dtype='uint8')

        return patch, label

    @staticmethod
    def flip(*arrays):
        horizontal = np.random.random() > 0.5
        vertical = np.random.random() > 0.5
        if horizontal:
            arrays = [np.fliplr(arr) for arr in arrays]
        if vertical:
            arrays = [np.flipud(arr) for arr in arrays]
        return arrays

    @staticmethod
    def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1 / 25):
        alpha = np.random.uniform(*alpha_range)
        noise = np.random.normal(loc=0.0, scale=1.0, size=data.shape)
        return alpha * data + beta * noise

    def mixture_noise(self, data, label, beta=1 / 25):
        alpha1, alpha2 = np.random.uniform(0.01, 1.0, size=2)
        noise = np.random.normal(loc=0.0, scale=1.0, size=data.shape)
        data2 = np.zeros_like(data)
        for idx, value in np.ndenumerate(label):
            if value not in self.ignored_labels:
                l_indices = np.nonzero(self.labels == value)[0]
                l_indice = np.random.choice(l_indices)
                assert self.labels[l_indice] == value
                x, y = self.indices[l_indice]
                data2[idx] = self.data[x, y]
        return (alpha1 * data + alpha2 * data2) / (alpha1 + alpha2) + beta * noise


## 2.3) Hyperspectral Dataset Getter Function

In [None]:
#@title Get Hyperspectral Dataset
def get_hyperspectral_dataset(data, gt, shuffle=True, **params):
    """
    """

    random_seed = params['random_seed']
    input_channels = params['input_channels']
    batch_size = params['batch_size']
    patch_size = params['patch_size']
    supervision = params['supervision']
    ignored_labels = set(params['ignored_labels'])
    num_classes = params['n_classes']
    loss = params['loss']
    expand_dims = params['expand_dims']

    # Determine if this dataset is feeding a multi-input model
    multi_input = True if input_channels is not None else False

    # Get expected data shapes to make sure data is properly formatted
    # if the Shape needs to be fixed
    if multi_input:
        x_shape = [[None, None, None, None, len(channels)] if expand_dims
                        else [None, None, None, len(channels)]
                        for channels in input_channels]
    else:
        x_shape = [None, None, None, None, data.shape[-1]] if expand_dims \
                    else [None, None, None, data.shape[-1]]

    if loss == 'categorical_crossentropy':
        y_shape = [None, num_classes]
    else:
        y_shape = [None, 1]


    # Fully supervised : use all pixels with label not ignored
    if supervision == "full":
        mask = np.ones_like(gt)
        for label in ignored_labels:
            mask[gt == label] = 0
    # Semi-supervised : use all pixels, except padding
    elif supervision == "semi":
        mask = np.ones_like(gt)
    x_pos, y_pos = np.nonzero(mask)
    num_neighbors = patch_size // 2
    indices = np.array(
        [
            (x, y)
            for x, y in zip(x_pos, y_pos)
            if x > num_neighbors
                and x < data.shape[0] - num_neighbors
                and y > num_neighbors
                and y < data.shape[1] - num_neighbors
        ]
    )

    labels = np.array([gt[x, y] for x, y in indices])

    class HSDataset:
        def __init__(self, data, gt, **params):

            # Save parameters
            self.data = data
            self.gt = gt

            self.input_channels = params['input_channels']
            self.patch_size = params['patch_size']
            self.num_classes = params['n_classes']
            self.loss = params['loss']
            self.expand_dims = params['expand_dims']

            # Determine if this dataset is feeding a multi-input model
            self.multi_input = True if self.input_channels is not None else False

        def __call__(self, i):
            i = tuple(i.numpy())

            x, y = i
            x1 = x - self.patch_size // 2    # Leftmost edge of patch
            y1 = y - self.patch_size // 2    # Topmost edge of patch
            x2 = x1 + self.patch_size        # Rightmost edge of patch
            y2 = y1 + self.patch_size        # Bottommost edge of patch

            patch = self.data[x1:x2, y1:y2]

            # Copy the data into numpy arrays
            patch = np.asarray(np.copy(patch), dtype="float32")
            # patch = tf.convert_to_tensor(patch, dtype="float32")

            if self.patch_size == 1:
                patch = patch[:, 0, 0]

            # Add a fourth dimension for 3D CNN
            if self.expand_dims and self.patch_size > 1:
                # Make 4D data ((Batch x) Planes x Channels x Width x Height)
                # E.g. adding a dimension for 'planes'
                axis = len(patch.shape) if K.image_data_format() == 'channels_last' else 0
                patch = np.expand_dims(patch, axis)

            if self.multi_input:
                # Break the data into inputs
                patch = [patch.take(channels, axis=data.ndim-1) for channels in self.input_channels]

            sample = patch

            # Get label for the patch
            label = self.gt[i]

            if self.loss == 'categorical_crossentropy':
                label = to_categorical(label, num_classes = self.num_classes)


            return sample, label

    class FixShape:
        def __init__(self, x_shape, y_shape, multi_input):
            self.x_shape = x_shape
            self.y_shape = y_shape
            self.multi_input = multi_input

        def __call__(self, x, y):
            if self.multi_input:
                _x = []
                for index, data in enumerate(x):
                    _x.append(data.set_shape(self.x_shape[index]))

                x = tuple(_x)
            else:
                x.set_shape(self.x_shape)

            y.set_shape(self.y_shape)

            return x, y

    hs_dataset = HSDataset(data, gt, **params)

    output_signature = (tf.TensorSpec(shape=(2), dtype=tf.uint32))
    dataset = tf.data.Dataset.from_generator(lambda: indices,
                                            output_signature=output_signature)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(indices), seed=random_seed,
                                reshuffle_each_iteration=True)

    if multi_input:
        Tout = [tf.TensorSpec(shape=(None, None, None, None), dtype=tf.float32),
                tf.TensorSpec(shape=(), dtype=tf.uint8)]
    else:
        Tout = [tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
            tf.TensorSpec(shape=None, dtype=tf.uint8)]
    dataset = dataset.map(lambda i: tf.py_function(func=hs_dataset,
                                                   inp=[i],
                                                   Tout=Tout
                                                   ),
                          num_parallel_calls=tf.data.AUTOTUNE)

    # fixup_shape = FixShape(x_shape, y_shape, multi_input)
    # dataset = dataset.batch(batch_size).map(fixup_shape)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    # dataset.prefetch(tf.data.AUTOTUNE)

    return dataset, labels


## 2.4) Dataset Creation Functions

In [None]:
#@title Create Datasets
def create_datasets(data, train_gt, test_gt, **hyperparams):
    """
    """

    # Get data from hyperparameters
    patch_size = hyperparams['patch_size']  # N in NxN patch per sample
    train_split = hyperparams['train_split']    # training percent in val/train split
    split_mode = hyperparams['split_mode']

    # Set pad length per dimension
    pad = patch_size // 2

    # Pad only first two dimensions
    ignore_dims = [x for x in range(data.ndim) if x >= 2]

    # Pad all images
    data = pad_img(data, pad, ignore_dims=ignore_dims)
    train_gt = pad_img(train_gt, pad)
    test_gt = pad_img(test_gt, pad)

    # Show updated padded dataset shapes
    print(f'padded data shape: {data.shape}')
    print(f'padded train_gt shape: {train_gt.shape}')
    print(f'padded test_gt shape: {test_gt.shape}')

    # Create validation dataset from training set
    if train_split is not None and train_split != 1.0:
        train_gt, val_gt = sample_gt(train_gt, train_split, mode=split_mode)

    dataset_params = (
        'input_channels',
        'batch_size',
        'patch_size',
        'supervision',
        'ignored_labels',
        'n_classes',
        'loss',
        'expand_dims',
        'flip_augmentation',
        'radiation_augmentation',
        'mixture_augmentation',
        'center_pixel',
    )

    # Create dataset parameter subset from hyperparameters
    params = {param: hyperparams[param] for param in dataset_params}

    train_dataset = HyperspectralDataset(data, train_gt, equal_class_distribution=True, **params)

    # Don't use augmentation for validation and test sets
    params['flip_augmentation'] = False
    params['radiation_augmentation'] = False
    params['mixture_augmentation'] = False
    if train_split is not None and train_split != 1.0:
        val_dataset = HyperspectralDataset(data, val_gt, shuffle=False, **params)
    else:
        val_dataset = None

    # If postprocessing is going to occur, change supervision parameter
    # to 'semi' so all pixels are used (so we can predict the full
    # image, the prediction then being used for postprocessing)
    if not hyperparams['skip_data_postprocessing']:
        params['supervision'] = 'semi'

    # Don't use augmentation for test set
    params['flip_augmentation'] = False
    params['radiation_augmentation'] = False
    params['mixture_augmentation'] = False
    test_dataset = HyperspectralDataset(data, test_gt, shuffle=False, **params)

    return train_dataset, val_dataset, test_dataset


In [None]:
#@title Create Datasets (Version 2)
def create_datasets_v2(data, train_gt, test_gt, **hyperparams):
    """
    """

    # Get data from hyperparameters
    patch_size = hyperparams['patch_size']  # N in NxN patch per sample
    train_split = hyperparams['train_split']    # training percent in val/train split
    split_mode = hyperparams['split_mode']
    batch_size = hyperparams['batch_size']

    # Set pad length per dimension
    pad = patch_size // 2

    # Pad only first two dimensions
    ignore_dims = [x for x in range(data.ndim) if x >= 2]

    # Pad all images
    data = pad_img(data, pad, ignore_dims=ignore_dims)
    train_gt = pad_img(train_gt, pad)
    test_gt = pad_img(test_gt, pad)

    # Show updated padded dataset shapes
    print(f'padded data shape: {data.shape}')
    print(f'padded train_gt shape: {train_gt.shape}')
    print(f'padded test_gt shape: {test_gt.shape}')

    # Create validation dataset from training set
    train_gt, val_gt = sample_gt(train_gt, train_split, mode=split_mode)

    dataset_params = (
        'random_seed',
        'input_channels',
        'batch_size',
        'patch_size',
        'supervision',
        'ignored_labels',
        'n_classes',
        'loss',
        'expand_dims',
        'flip_augmentation',
        'radiation_augmentation',
        'mixture_augmentation',
        'center_pixel',
    )

    # Create dataset parameter subset from hyperparameters
    params = {param: hyperparams[param] for param in dataset_params}

    train_dataset, train_labels = get_hyperspectral_dataset(data, train_gt, **params)
    val_dataset, val_labels = get_hyperspectral_dataset(data, val_gt, **params)

    # If postprocessing is going to occur, change supervision parameter
    # to 'semi' so all pixels are used (so we can predict the full
    # image, the prediction then being used for postprocessing)
    if not hyperparams['skip_data_postprocessing']:
        params['supervision'] = 'semi'

    test_dataset, target_test = get_hyperspectral_dataset(data, test_gt, shuffle=False, **params)

    datasets = {
        'train_dataset': train_dataset,
        'train_steps': math.ceil(len(train_labels) / batch_size),
        'val_dataset': val_dataset,
        'val_steps': math.ceil(len(val_labels) / batch_size),
        'test_dataset': test_dataset,
        'test_steps': math.ceil(len(target_test) / batch_size),
        'target_test': target_test,
    }

    return datasets

# 3) Set up GRSS DFC 2018 University of Houston Dataset

## 3.1 GRSS DFC 2018 UH constants

In [None]:
#@title Initialize GRSS DFC 2018 UH constants

# Setup global variables for grss_dfc_2018 datset
# Path to directory containing all GRSS 2018 Data Fusion Contest
# University of Houston image data
UH_2018_DATASET_DIRECTORY_PATH = '../datasets/grss_dfc_2018/'

# Following paths are assumed to be from the root UH 2018 dataset path
UH_2018_TRAINING_GT_IMAGE_PATH = 'TrainingGT/2018_IEEE_GRSS_DFC_GT_TR.tif'
UH_2018_TESTING_GT_IMAGE_PATH = 'TestingGT/Test_Labels.tif'
UH_2018_HS_IMAGE_PATH = 'FullHSIDataset/20170218_UH_CASI_S4_NAD83.pix'
UH_2018_LIDAR_DSM_PATH = 'Lidar GeoTiff Rasters/DSM_C12/UH17c_GEF051.tif'
UH_2018_LIDAR_DEM_3MSR_PATH = 'Lidar GeoTiff Rasters/DEM_C123_3msr/UH17_GEG051.tif'
UH_2018_LIDAR_DEM_TLI_PATH = 'Lidar GeoTiff Rasters/DEM_C123_TLI/UH17_GEG05.tif'
UH_2018_LIDAR_DEM_B_PATH = 'Lidar GeoTiff Rasters/DEM+B_C123/UH17_GEM051.tif'
UH_2018_LIDAR_INTENSITY_1550NM_PATH = 'Lidar GeoTiff Rasters/Intensity_C1/UH17_GI1F051.tif'
UH_2018_LIDAR_INTENSITY_1064NM_PATH = 'Lidar GeoTiff Rasters/Intensity_C2/UH17_GI2F051.tif'
UH_2018_LIDAR_INTENSITY_532NM_PATH = 'Lidar GeoTiff Rasters/Intensity_C3/UH17_GI3F051.tif'

# Paths, in order of tile, for the very high resolution RGB image
UH_2018_VHR_IMAGE_PATHS = [
    [
        'Final RGB HR Imagery/UH_NAD83_271460_3290290.tif',
        'Final RGB HR Imagery/UH_NAD83_272056_3290290.tif',
        'Final RGB HR Imagery/UH_NAD83_272652_3290290.tif',
        'Final RGB HR Imagery/UH_NAD83_273248_3290290.tif',
        'Final RGB HR Imagery/UH_NAD83_273844_3290290.tif',
        'Final RGB HR Imagery/UH_NAD83_274440_3290290.tif',
        'Final RGB HR Imagery/UH_NAD83_275036_3290290.tif',
    ],
    [
        'Final RGB HR Imagery/UH_NAD83_271460_3289689.tif',
        'Final RGB HR Imagery/UH_NAD83_272056_3289689.tif',
        'Final RGB HR Imagery/UH_NAD83_272652_3289689.tif',
        'Final RGB HR Imagery/UH_NAD83_273248_3289689.tif',
        'Final RGB HR Imagery/UH_NAD83_273844_3289689.tif',
        'Final RGB HR Imagery/UH_NAD83_274440_3289689.tif',
        'Final RGB HR Imagery/UH_NAD83_275036_3289689.tif',
    ],
]

# Number of rows and columns in the overall dataset image such that
# parts of the dataset can be matched with the ground truth
UH_2018_NUM_TILE_COLUMNS = 7
UH_2018_NUM_TILE_ROWS = 2

# List of tuples corresponding to the image tiles that match the ground
# truth for the training set
UH_2018_TRAINING_GT_TILES = ((1,1), (1,2), (1, 3), (1, 4))

# List of tuples corresponding to the image tiles that match the ground
# truth for the testing set
UH_2018_TESTING_GT_TILES = ((0,0), (0,1), (0,2), (0,3), (0,4), (0,5), (0,6),
                            (1,0),                             (1,5), (1,6))

# Mapping of overall dataset training ground truth tile indices to the
# tile indices of the actual ground truth image
UH_2018_TRAINING_GT_TILE_OFFSETS = {
    (1,1) : (0,0),
    (1,2) : (0,1),
    (1,3) : (0,2),
    (1,4) : (0,3)
}

# GSD = Ground Sampling Distance
UH_2018_VHR_GSD = 0.05  # Very high resolution image GSD in meters
UH_2018_HS_GSD = 1.0    # Hyperspectral image GSD in meters
UH_2018_LIDAR_GSD = 0.5 # LiDAR raster image GSD in meters
UH_2018_GT_GSD = 0.5    # Ground truth images GSD in meters

# Number of hyperspectral band channels
UH_2018_NUM_HS_BANDS = 48

# Threshold value for hyperspectral band intensities
UH_2018_HS_BAND_THRES = 100_000

# Number of multispectral LiDAR band channels
UH_2018_NUM_LIDAR_BANDS = 3

# Threshold value for LiDAR multispectral intensities
UH_2018_MS_INTENSITY_THRESHOLD = 1e4

# Threshold value for LiDAR DSM
UH_2018_DSM_THRESHOLD = 1e10

# A list of the wavelength values for each of the hyperspectal band
# channels
UH_2018_HS_BAND_WAVELENGTHS = [
    '374.395nm +/- 7.170nm',
    '388.733nm +/- 7.168nm',
    '403.068nm +/- 7.167nm',
    '417.401nm +/- 7.166nm',
    '431.732nm +/- 7.165nm',
    '446.061nm +/- 7.164nm',
    '460.388nm +/- 7.163nm',
    '474.712nm +/- 7.162nm',
    '489.035nm +/- 7.161nm',
    '503.356nm +/- 7.160nm',
    '517.675nm +/- 7.159nm',
    '531.992nm +/- 7.158nm',
    '546.308nm +/- 7.158nm',
    '560.622nm +/- 7.157nm',
    '574.936nm +/- 7.156nm',
    '589.247nm +/- 7.156nm',
    '603.558nm +/- 7.155nm',
    '617.868nm +/- 7.155nm',
    '632.176nm +/- 7.154nm',
    '646.484nm +/- 7.154nm',
    '660.791nm +/- 7.153nm',
    '675.097nm +/- 7.153nm',
    '689.402nm +/- 7.153nm',
    '703.707nm +/- 7.152nm',
    '718.012nm +/- 7.152nm',
    '732.316nm +/- 7.152nm',
    '746.620nm +/- 7.152nm',
    '760.924nm +/- 7.152nm',
    '775.228nm +/- 7.152nm',
    '789.532nm +/- 7.152nm',
    '803.835nm +/- 7.152nm',
    '818.140nm +/- 7.152nm',
    '832.444nm +/- 7.152nm',
    '846.749nm +/- 7.153nm',
    '861.054nm +/- 7.153nm',
    '875.360nm +/- 7.153nm',
    '889.666nm +/- 7.153nm',
    '903.974nm +/- 7.154nm',
    '918.282nm +/- 7.154nm',
    '932.591nm +/- 7.155nm',
    '946.901nm +/- 7.155nm',
    '961.212nm +/- 7.156nm',
    '975.525nm +/- 7.157nm',
    '989.839nm +/- 7.157nm',
    '1004.154nm +/- 7.158nm',
    '1018.471nm +/- 7.159nm',
    '1032.789nm +/- 7.160nm',
    '1047.109nm +/- 7.160nm',
]

# A list of hexidecimal color values corresponding to the wavelength of
# the hyperspectral bands
UH_2018_HS_BAND_RGB = [
    '#610061',  #374nm
    '#780088',  #389nm
    '#8300c0',  #403nm
    '#7100f4',  #417nm
    '#3300ff',  #432nm
    '#002fff',  #446nm
    '#007bff',  #460nm
    '#00c0ff',  #475nm
    '#00fbff',  #489nm
    '#00ff6e',  #503nm
    '#2dff00',  #518nm
    '#65ff00',  #532nm
    '#96ff00',  #546nm
    '#c6ff00',  #561nm
    '#f0ff00',  #575nm
    '#ffe200',  #589nm
    '#ffb000',  #604nm
    '#ff7e00',  #618nm
    '#ff4600',  #632nm
    '#ff0000',  #646nm
    '#fd0000',  #661nm
    '#fb0000',  #675nm
    '#fa0000',  #689nm
    '#f80000',  #704nm
    '#de0000',  #718nm
    '#c40000',  #732nm
    '#a70000',  #747nm
    '#8a0000',  #761nm
    '#6d0000',  #775nm
    '#610000',  #790nm (representation)
    '#5e0000',  #804nm (representation)
    '#5c0000',  #818nm (representation)
    '#590000',  #843nm (representation)
    '#570000',  #847nm (representation)
    '#540000',  #862nm (representation)
    '#510000',  #875nm (representation)
    '#4f0000',  #890nm (representation)
    '#4c0000',  #904nm (representation)
    '#4a0000',  #918nm (representation)
    '#470000',  #933nm (representation)
    '#440000',  #947nm (representation)
    '#420000',  #961nm (representation)
    '#3f0000',  #976nm (representation)
    '#3d0000',  #990nm (representation)
    '#3a0000',  #1004nm (representation)
    '#370000',  #1018nm (representation)
    '#350000',  #1033nm (representation)
    '#320000',  #1047nm (representation)
]

UH_2018_LIDAR_MS_BAND_WAVELENGTHS = [
    '1550nm',
    '1064nm',
    '532nm',
]

UH_2018_LIDAR_MS_BAND_RGB = [
    '#150000',  #1550nm (representation)
    '#2f0000',  #1064nm (representation)
    '#65ff00',  #532nm
]

UH_2018_VHR_CHANNELS = [
    'Red',
    'Green',
    'Blue',
]

UH_2018_VHR_RGB = [
    '#ff0000',  #Red
    '#00ff00',  #Green
    '#0000ff',  #Blue
]

UH_2018_IGNORED_CLASSES = [0]

# List of classes where the index is the value of the pixel in the
# ground truth image
UH_2018_CLASS_LIST = [
    'Undefined',
    'Healthy grass',
    'Stressed grass',
    'Artificial turf',
    'Evergreen trees',
    'Deciduous trees',
    'Bare earth',
    'Water',
    'Residential buildings',
    'Non-residential buildings',
    'Roads',
    'Sidewalks',
    'Crosswalks',
    'Major thoroughfares',
    'Highways',
    'Railways',
    'Paved parking lots',
    'Unpaved parking lots',
    'Cars',
    'Trains',
    'Stadium seats',
]

# Map of classes where the key is the value of the pixel in the
# ground truth image
UH_2018_CLASS_MAP = {index: label for index, label in enumerate(UH_2018_CLASS_LIST)}


# Number of class labels for the University of Houston 2018 dataset
# (one is subtracted to exclude the 'undefined' class)
# UH_2018_NUM_CLASSES = len(UH_2018_CLASS_LIST) - len(UH_2018_IGNORED_CLASSES)
UH_2018_NUM_CLASSES = len(UH_2018_CLASS_LIST)

# The default resampling method to use when no method is indicated or
# the method indicated is unsupported
DEFAULT_RESAMPLING_METHOD = 'nearest'

# Mapping of resampling method strings to their associated value
#   - 'max', 'min', 'med', 'q1', 'q3' are only supported in GDAL >= 2.0.0.
#   - 'nearest', 'bilinear', 'cubic', 'cubic_spline', 'lanczos', 'average',
#      'mode' are always available (GDAL >= 1.10).
#   - 'sum' is only supported in GDAL >= 3.1.
#   - 'rms' is only supported in GDAL >= 3.3.
RESAMPLING_METHODS = {
    'nearest': Resampling.nearest,
    'bilinear': Resampling.bilinear,
    'cubic': Resampling.cubic,
    'cubic_spline': Resampling.cubic_spline,
    'lanczos': Resampling.lanczos,
    'average': Resampling.average,
    'mode': Resampling.mode,
    'gauss': Resampling.gauss,
    'max': Resampling.max,
    'min': Resampling.min,
    'med': Resampling.med,
    'q1': Resampling.q3,
    'q3': Resampling.sum,
    'rms': Resampling.rms
}

## 3.2) Class for loading and manipulating different parts of the GRSS 2018 Data Fusion Contest University of Houston dataset

In [None]:
#@title University of Houston 2018 Dataset Class
class UH_2018_Dataset:
    """
    Class for loading and manipulating different parts of the GRSS 2018
    Data Fusion Contest University of Houston dataset.
    """

    def __init__(self, dataset_path=UH_2018_DATASET_DIRECTORY_PATH):

        # Set dataset attributes
        self.name = 'GRSS_DFC_2018_UH'

        # Set dataset file paths
        self.path_to_dataset_directory = dataset_path
        self.path_to_training_gt_image = UH_2018_TRAINING_GT_IMAGE_PATH
        self.path_to_testing_gt_image = UH_2018_TESTING_GT_IMAGE_PATH
        self.path_to_hs_image = UH_2018_HS_IMAGE_PATH
        self.path_to_lidar_dsm = UH_2018_LIDAR_DSM_PATH
        self.path_to_lidar_dem_3msr = UH_2018_LIDAR_DEM_3MSR_PATH
        self.path_to_lidar_dem_tli = UH_2018_LIDAR_DEM_TLI_PATH
        self.path_to_lidar_dem_b = UH_2018_LIDAR_DEM_B_PATH
        self.path_to_lidar_1550nm_intensity = UH_2018_LIDAR_INTENSITY_1550NM_PATH
        self.path_to_lidar_1064nm_intensity = UH_2018_LIDAR_INTENSITY_1064NM_PATH
        self.path_to_lidar_532nm_intensity = UH_2018_LIDAR_INTENSITY_532NM_PATH
        self.paths_to_vhr_images = UH_2018_VHR_IMAGE_PATHS

        # Set dataset ground truth attributes
        self.gt_class_label_list = UH_2018_CLASS_LIST
        self.gt_class_value_mapping = UH_2018_CLASS_MAP
        self.gt_num_classes = UH_2018_NUM_CLASSES
        self.gt_ignored_labels = UH_2018_IGNORED_CLASSES

        # Set dataset hyperspectral image attributes
        self.hs_num_bands = UH_2018_NUM_HS_BANDS
        self.hs_band_rgb_list = UH_2018_HS_BAND_RGB
        self.hs_band_wavelength_labels = UH_2018_HS_BAND_WAVELENGTHS
        self.hs_band_val_thres = UH_2018_HS_BAND_THRES

        # Set dataset lidar data attributes
        self.lidar_ms_num_bands = UH_2018_NUM_LIDAR_BANDS
        self.lidar_ms_band_wavelength_labels = UH_2018_LIDAR_MS_BAND_WAVELENGTHS
        self.lidar_ms_band_rgb_list = UH_2018_LIDAR_MS_BAND_RGB
        self.lidar_ms_intensity_thres = UH_2018_MS_INTENSITY_THRESHOLD
        self.lidar_dsm_thres = UH_2018_DSM_THRESHOLD

        # Set dataset VHR RGB image attributes
        self.vhr_channel_labels = UH_2018_VHR_CHANNELS
        self.vhr_channel_rgb_list = UH_2018_VHR_RGB

        # Set miscellaneous dataset attributes
        self.gsd_gt = UH_2018_GT_GSD
        self.gsd_hs = UH_2018_HS_GSD
        self.gsd_lidar = UH_2018_LIDAR_GSD
        self.gsd_vhr = UH_2018_VHR_GSD

        self.dataset_tiled_subset_rows = UH_2018_NUM_TILE_ROWS
        self.dataset_tiled_subset_cols = UH_2018_NUM_TILE_COLUMNS
        self.dataset_training_subset = UH_2018_TRAINING_GT_TILES
        self.dataset_training_subset_map = UH_2018_TRAINING_GT_TILE_OFFSETS
        self.dataset_testing_subset = UH_2018_TESTING_GT_TILES

        # Initialize dataset variables
        self.gt_image = None
        self.gt_image_tiles = None
        self.hs_image = None
        self.hs_image_tiles = None
        self.lidar_ms_image = None
        self.lidar_ms_image_tiles = None
        self.lidar_dsm_image = None
        self.lidar_dsm_image_tiles = None
        self.lidar_dem_image = None
        self.lidar_dem_image_tiles = None
        self.lidar_ndsm_image = None
        self.lidar_ndsm_image_tiles = None
        self.vhr_image = None
        self.vhr_image_tiles = None



    def clear_all_images(self):
        """Clears values of all image variables to free memory."""

        # Delete the variables to mark the memory as unused
        del self.gt_image
        del self.gt_image_tiles
        del self.hs_image
        del self.hs_image_tiles
        del self.lidar_ms_image
        del self.lidar_ms_image_tiles
        del self.lidar_dsm_image
        del self.lidar_dsm_image_tiles
        del self.lidar_dem_image
        del self.lidar_dem_image_tiles
        del self.lidar_ndsm_image
        del self.lidar_ndsm_image_tiles
        del self.vhr_image
        del self.vhr_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.gt_image = None
        self.gt_image_tiles = None
        self.hs_image = None
        self.hs_image_tiles = None
        self.lidar_ms_image = None
        self.lidar_ms_image_tiles = None
        self.lidar_dsm_image = None
        self.lidar_dsm_image_tiles = None
        self.lidar_dem_image = None
        self.lidar_dem_image_tiles = None
        self.lidar_ndsm_image = None
        self.lidar_ndsm_image_tiles = None
        self.vhr_image = None
        self.vhr_image_tiles = None



    def clear_gt_images(self):
        """Clears values of ground truth image variables to free memory."""
        # Delete the variables to mark the memory as unused
        del self.gt_image
        del self.gt_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.gt_image = None
        self.gt_image_tiles = None



    def clear_hs_images(self):
        """Clears values of hyperspectral image variables to free memory."""

        # Delete the variables to mark the memory as unused
        del self.hs_image
        del self.hs_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.hs_image = None
        self.hs_image_tiles = None



    def clear_lidar_ms_images(self):
        """
        Clears values of lidar multispectral image variables to free
        memory.
        """

        # Delete the variables to mark the memory as unused
        del self.lidar_ms_image
        del self.lidar_ms_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.lidar_ms_image = None
        self.lidar_ms_image_tiles = None



    def clear_lidar_dsm_images(self):
        """
        Clears values of lidar digital surface model (NDSM) image
        variables to free memory.
        """

        # Delete the variables to mark the memory as unused
        del self.lidar_dsm_image
        del self.lidar_dsm_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.lidar_dsm_image = None
        self.lidar_dsm_image_tiles = None

    def clear_lidar_dem_images(self):
        """
        Clears values of lidar digital surface model (NDSM) image
        variables to free memory.
        """

        # Delete the variables to mark the memory as unused
        del self.lidar_dem_image
        del self.lidar_dem_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.lidar_dem_image = None
        self.lidar_dem_image_tiles = None

    def clear_lidar_ndsm_images(self):
        """
        Clears values of lidar normalized digital surface model (NDSM)
        image variables to free memory.
        """

        # Delete the variables to mark the memory as unused
        del self.lidar_ndsm_image
        del self.lidar_ndsm_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.lidar_ndsm_image = None
        self.lidar_ndsm_image_tiles = None



    def clear_vhr_images(self):
        """
        Clears values of very high resolution image variables to free
        memory.
        """

        # Delete the variables to mark the memory as unused
        del self.vhr_image
        del self.vhr_image_tiles

        # Run garbage collection to release the memory
        gc.collect()

        # Reinitialize the variables
        self.vhr_image = None
        self.vhr_image_tiles = None



    def merge_tiles(self, tiles, num_rows = None, num_cols = None):
        """Merges a set of image tiles into a single image."""

        print('Merging image tiles...')

        # If rows or columns are not specified, use defaults
        if num_rows is None: num_rows = self.dataset_tiled_subset_rows
        if num_cols is None: num_cols = self.dataset_tiled_subset_cols

        # Initialize empty list of image row values
        image_rows = []

        # Loop through each tile and stitch them together into single image
        for row in range(0, num_rows):
            # Get first tile in row
            img_row = np.copy(tiles[row * num_cols])

            # Loop through remaining tiles in current row
            for col in range(1, num_cols):
                # Concatenate each subsequent tile in row to image row array
                img_row = np.concatenate((img_row, tiles[row*num_cols + col]), axis=1)

            # Append image row to list of image rows
            image_rows.append(img_row)

        # Concatenate all image rows together to create single image
        merged_image = np.concatenate(image_rows, axis=0)

        return merged_image



    def load_full_gt_image(self, train_only=False, test_only=False):
        """
        Loads the full-size ground truth image mask for the University
        of Houston 2018 dataset.
        """

        if train_only:
            print('Loading training ground truth image...')
        elif test_only:
            print('Loading test ground truth image...')
        else:
            print('Loading full ground truth image...')

        # Ground truth can only be loaded as tiles since there's two
        # images, so load tiles and then merge them to create full GT
        # image
        self.gt_image = self.merge_tiles(
            self.load_gt_image_tiles(train_only=train_only,
                                     test_only=test_only))

        return self.gt_image



    def load_gt_image_tiles(self, tile_list=None, train_only=False, test_only=False):
        """
        Loads the University of Houston 2018 dataset's ground truth
        images as a set of tiles. If no tile list is given, the whole
        image will be loaded as tiles.
        """

        if train_only:
            print('Loading training ground truth tiles...')
        elif test_only:
            print('Loading test ground truth tiles...')
        else:
            print('Loading training and test ground truth tiles...')

        self.gt_image_tiles = []

        # train_only flag takes priority
        if train_only: test_only = False

        # Get full path to dataset's ground truth images
        train_image_path = os.path.join(self.path_to_dataset_directory,
                                        self.path_to_training_gt_image)
        test_image_path = os.path.join(self.path_to_dataset_directory,
                                       self.path_to_testing_gt_image)

        # Throw error if file path does not exist
        if not os.path.isfile(train_image_path): raise FileNotFoundError(
            f'Path to UH2018 training ground truth image is invalid!'
            f'Path={train_image_path}')

        if not os.path.isfile(test_image_path): raise FileNotFoundError(
            f'Path to UH2018 testing ground truth image is invalid!'
            f'Path={test_image_path}')

        with rasterio.open(train_image_path) as train_src, \
             rasterio.open(test_image_path) as test_src:

            # Get the size of the tile windows (use full size test image)
            tile_width = test_src.width / self.dataset_tiled_subset_cols
            tile_height = test_src.height / self.dataset_tiled_subset_rows

                # Read in the image data for each image tile
            for tile_row in range(0, self.dataset_tiled_subset_rows):
                for tile_column in range(0, self.dataset_tiled_subset_cols):

                    # Check to see if current tile is one of the training
                    # ground truth tiles
                    if not test_only and (tile_row, tile_column) in self.dataset_training_subset:

                        offset_row, offset_column = self.dataset_training_subset_map[
                                                        (tile_row, tile_column)]

                        # Set the tile window to read from the image
                        window = Window(tile_width * offset_column ,
                                        tile_height * offset_row,
                                        tile_width, tile_height)

                        # Read the tile window from the image
                        tile = train_src.read(1, window = window)

                        # Copy the tile to the tiles array
                        self.gt_image_tiles.append(np.copy(tile))
                    elif not train_only and (tile_row, tile_column) in self.dataset_testing_subset:
                        # Set the tile window to read from the image
                        window = Window(tile_width * tile_column,
                                        tile_height * tile_row,
                                        tile_width, tile_height)

                        # Read the tile window from the image
                        tile = test_src.read(1, window = window)

                        # Copy the tile to the tiles array
                        self.gt_image_tiles.append(np.copy(tile))
                    else:
                        self.gt_image_tiles.append(np.zeros(
                            (int(tile_height), int(tile_width)),
                            dtype=np.uint8))

        return self.gt_image_tiles



    def save_full_gt_image_array(self, path, file_name='full_gt_image.npy'):
        """
        Saves the numpy array of the full ground truth image to a file
        for faster loading in the future.
        """

        print(f'Saving full ground truth image numpy array to file ({file_name})...')

        # If the gt image member variable is empty, then load the full
        # ground truth image
        if self.gt_image is None: self.load_full_gt_image()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.gt_image)



    def save_tiled_gt_image_array(self, path, file_name='tiled_gt_image.npy'):
        """
        Saves the numpy array of the tiled ground truth image to a file
        for faster loading in the future.
        """

        print(f'Saving full ground truth image numpy array to file ({file_name})...')

        # If the gt image tile member variable is empty, then load all
        # ground truth image tiles
        if self.gt_image_tiles is None: self.load_gt_image_tiles()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.gt_image_tiles)



    def load_full_gt_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset ground truth image.
        """

        print(f'Loading full ground truth image numpy array from file ({file_path})...')

        with open(file_path, 'rb') as infile:
            self.gt_image = np.load(infile)



    def load_tiled_gt_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset ground truth image tiles.
        """

        print(f'Loading tiled ground truth image numpy array from file ({file_path})...')

        with open(file_path, 'rb') as infile:
            self.gt_image_tiles = np.load(infile)



    def get_gt_class_statistics(self, tiles=None, print_results=False):
        """
        Outputs statistics per each class per ground truth tile (or, if
        no tile or set of tile is specified, the whole ground truth).
        """

        # If the hs image tile member variable is empty, then load all
        # ground truth image tiles
        if self.gt_image_tiles is None: self.load_gt_image_tiles()

        # Initialize statistics dictionary
        statistics = {}

        # If no tiles are specified, use all ground truth tiles
        if tiles is None:
            tiles = []
            for row in range(0, self.dataset_tiled_subset_rows):
                for col in range(0, self.dataset_tiled_subset_cols):
                    tiles.append((row, col))

        # Iterate through the tile list to get statistics for each
        # individual tile
        for tile in tiles:
            row, col = tile
            index = row * self.dataset_tiled_subset_cols + col

            # Create tile statistics mapping
            tile_statistics = {x:0 for x in range(0,len(self.gt_class_label_list))}

            # Count the class of each pixel in the image tile
            for pixel in np.ravel(self.gt_image_tiles[index]):
                tile_statistics[pixel] += 1

            # Create key/value pair for statistics dictionary
            key = f'Tile ({row}, {col})'
            value = [tile_statistics[i] for i in range(1,len(self.gt_class_label_list))]

            # Add key value pair to dictionary
            statistics[key] = value

        # Create Pandas DataFrame from statistics dictionary and set the
        # index to be the class labels
        statistics_df = pd.DataFrame(data=statistics)
        statistics_df.index = self.gt_class_label_list[1:]

        # Print out statistics
        if print_results:
            print(statistics_df)
            print()
            print(statistics_df.T.describe())

        return statistics_df



    def load_full_hs_image(self, gsd=UH_2018_GT_GSD,
                           thres=True, normalize=True,
                           resampling=None):
        """
        Loads the full-size hyperspectral image for the University of
        Houston 2018 dataset sampled at the specified GSD.
        """

        print('Loading full hyperspectral image...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]

        # Set the factor of GSD resampling
        resample_factor = self.gsd_hs / float(gsd)

        # Get full path to dataset's hyperspectral image
        image_path = os.path.join(self.path_to_dataset_directory,
                                  self.path_to_hs_image)

        # Throw error if file path does not exist
        if not os.path.isfile(image_path): raise FileNotFoundError(
            f'Path to UH2018 hyperspectral image is invalid! Path={image_path}')

        # Create image variable
        self.hs_image = None

        # Open the training HSI Envi file as src
        with rasterio.open(image_path, format='ENVI') as src:

            # Read the image, resample it to the appropriate GSD,
            # arrange the numpy array to be (rows, cols, bands),
            # and remove unused bands
            if resample_factor == 1.0:
                self.hs_image = np.moveaxis(src.read(), 0, -1)[:,:,:-2]
            else:
                # Set the shape of the resampled image
                out_shape=(src.count,
                            int(src.height * resample_factor),
                            int(src.width * resample_factor))

                self.hs_image = np.moveaxis(src.read(
                            out_shape=out_shape,
                            resampling=resampling), 0, -1)[:,:,:-2]

            # Cast image array as float type for normalization
            if normalize:
                self.hs_image = self.hs_image.astype(float, copy=False)

            # Threshold the image so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.hs_image[self.hs_image > self.hs_band_val_thres] = self.hs_image.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.hs_image -= self.hs_image.min()
                self.hs_image /= self.hs_image.max()


        return self.hs_image



    def load_hs_image_tiles(self, gsd=UH_2018_GT_GSD, tile_list=None,
                            thres=True, normalize=True, resampling=None):
        """
        Loads the University of Houston 2018 dataset's hyperspectral
        images as a set of tiles sampled at a specified GSD. If no tile
        list is given, the whole image will be loaded as tiles.
        """

        print('Loading hyperspectral image as tiles...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        # Check tile_list parameter value
        if tile_list and not isinstance(tile_list, tuple): raise ValueError(
            "'tile_list' parameter should be a tuple of tuples!")

        # Initialize list for tiles
        self.hs_image_tiles = []

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]

        # Set the factor of GSD resampling
        resample_factor = self.gsd_hs / float(gsd)

        # Get full path to dataset's hyperspectral image
        image_path = os.path.join(self.path_to_dataset_directory,
                                  self.path_to_hs_image)

        # Throw error if file path does not exist
        if not os.path.isfile(image_path): raise FileNotFoundError(
            f'Path to UH2018 hyperspectral image is invalid! Path={image_path}')

        # Open the training HSI Envi file as src
        with rasterio.open(image_path, format='ENVI') as src:
            # Get the size of the tile windows
            tile_width = src.width / self.dataset_tiled_subset_cols
            tile_height = src.height / self.dataset_tiled_subset_rows

            # Set the shape of the resampled tile
            out_shape=(src.count,
                    int(tile_height * resample_factor),
                    int(tile_width * resample_factor))

            # Read in the image data for each image tile
            for tile_row in range(0, self.dataset_tiled_subset_rows):
                for tile_column in range(0, self.dataset_tiled_subset_cols):

                    # If specified tiles are desired, then skip any
                    # tiles that do not match the tile_list parameter
                    if tile_list and (tile_row, tile_column) not in tile_list:
                        continue

                    # Set the tile window to read from the image
                    window = Window(tile_width * tile_column,
                                    tile_height * tile_row,
                                    tile_width, tile_height)

                    # Read the tile window from the image, resample it to
                    # the appropriate GSD, arrange the numpy array to be
                    # (rows, cols, bands), and remove unused bands
                    if resample_factor == 1.0:
                        tile = np.moveaxis(src.read(window = window), 0, -1)[:,:,:-2]
                    else:
                        tile = np.moveaxis(src.read(
                            window = window,
                            out_shape=out_shape,
                            resampling=resampling), 0, -1)[:,:,:-2]

                    # Copy the tile to the tiles array
                    self.hs_image_tiles.append(np.copy(tile))

        # If no tiles were added to the tile list, then set image tiles
        # variable to 'None'
        if len(self.hs_image_tiles) == 0: self.hs_image_tiles = None
        else:
            # Turn list of numpy arrays into single numpy array
            self.hs_image_tiles = np.stack(self.hs_image_tiles)

            # Cast image array as float type for normalization
            if normalize:
                self.hs_image_tiles = self.hs_image_tiles.astype(float, copy=False)

            # Threshold the image tiles so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.hs_image_tiles[self.hs_image_tiles > self.hs_band_val_thres] = self.hs_image_tiles.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.hs_image_tiles -= self.hs_image_tiles.min()
                self.hs_image_tiles /= self.hs_image_tiles.max()

        return self.hs_image_tiles



    def save_full_hs_image_array(self, path, file_name='full_hs_image.npy'):
        """
        Saves the numpy array of the full hyperspectral image to a file
        for faster loading in the future.
        """

        # If the hs image member variable is empty, then load the full
        # hyperspectral image
        if self.hs_image is None: self.load_full_hs_image()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.hs_image)



    def save_tiled_hs_image_array(self, path, file_name='tiled_hs_image.npy'):
        """
        Saves the numpy array of the tiled hyperspectral image to a file
        for faster loading in the future.
        """

        # If the hs image tile member variable is empty, then load all
        # hyperspectral image tiles
        if self.hs_image_tiles is None: self.load_hs_image_tiles()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.hs_image_tiles)




    def load_full_hs_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset hyperspectral image.
        """

        with open(file_path, 'rb') as infile:
            self.hs_image = np.load(infile)



    def load_tiled_hs_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset hyperspectral image tiles.
        """

        with open(file_path, 'rb') as infile:
            self.hs_image_tiles = np.load(infile)



    def show_hs_image(self, rgb_channels=None, size=(15,9),
                      full_gt_overlay=False,
                      train_gt_overlay=False,
                      test_gt_overlay=False):
        """
        Displays the hyperspectral image using the specified band
        channels as the rgb values.
        """

        # If the hs image member variable is empty, then load the full
        # hyperspectral image
        if self.hs_image is None: self.load_full_hs_image()

        image = self.hs_image[:,:,:]

        if full_gt_overlay:
            # If the hs image member variable is empty, then load the full
            # ground truth image
            if self.gt_image is None: self.load_full_gt_image()

            classes = self.gt_image
            title = 'Hyperspectral image w/ ground truth overlay'
        elif test_gt_overlay or train_gt_overlay:
            # If the hs image member variable is empty, then load the
            # ground truth image tiles
            if self.gt_image_tiles is None: self.load_gt_image_tiles()

            # create a copy of ground truth image tiles
            gt_tiles = self.gt_image_tiles.copy()

            # Get tile dimensions
            tile_shape = gt_tiles[0].shape

            # Choose which set of tiles to set to zero values and
            # set proper image title
            if train_gt_overlay:
                tiles_to_remove = self.dataset_testing_subset
                title = 'Hyperspectral image w/ training ground truth overlay'
            else:
                tiles_to_remove = self.dataset_training_subset
                title = 'Hyperspectral image w/ testing ground truth overlay'

            # Zero out tiles not in the desired subset
            for tile in tiles_to_remove:
                row, col = tile
                index = row * self.dataset_tiled_subset_cols + col
                gt_tiles[index] = np.zeros(tile_shape, dtype=np.uint8)

            # Create single ground truth image mask
            classes = self.merge_tiles(gt_tiles)

        else:
            classes = None
            title = 'Hyperspectral image'

        plt.close('all')

        view = spectral.imshow(image,
                               source=image,
                               bands=rgb_channels,
                               classes=classes,
                               figsize=size)
        if (full_gt_overlay
            or test_gt_overlay
            or train_gt_overlay): view.set_display_mode('overlay')

        view.set_title(title)

        plt.show(block=True)



    def visualize_hs_data_cube(self, size=(1200,900)):
        """
        Creates 3-D visualization of the hyperspectral data cube for
        the hyperspectral image.
        """

        import wx

        # If the hs image member variable is empty, then load the full
        # hyperspectral image
        if self.hs_image is None: self.load_full_hs_image()

        # Setup WxApp to display 3D spectral cube
        app = wx.App(False)

        # View 3D hyperspectral cube image
        spectral.view_cube(self.hs_image, size=size)

        # Prevent app from closing immediately
        app.MainLoop()



    def load_full_lidar_ms_image(self, gsd=UH_2018_GT_GSD,
                                 thres=True, normalize=True,
                                 resampling=None):
        """
        Loads the full-size lidar multispectral intensisty image for the
        University of Houston 2018 dataset sampled at the specified GSD.
        """

        print('Loading full LiDAR multispectral intensity image...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]

        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's hyperspectral image
        c1_image_path = os.path.join(self.path_to_dataset_directory,
                                     self.path_to_lidar_1550nm_intensity)
        c2_image_path = os.path.join(self.path_to_dataset_directory,
                                     self.path_to_lidar_1064nm_intensity)
        c3_image_path = os.path.join(self.path_to_dataset_directory,
                                     self.path_to_lidar_532nm_intensity)


        # Throw error if file paths do not exist
        if not os.path.isfile(c1_image_path): raise FileNotFoundError(
            f'Path to UH2018 1550nm LiDAR intensity image is invalid! Path={c1_image_path}')
        if not os.path.isfile(c2_image_path): raise FileNotFoundError(
            f'Path to UH2018 1064nm LiDAR intensity image is invalid! Path={c2_image_path}')
        if not os.path.isfile(c3_image_path): raise FileNotFoundError(
            f'Path to UH2018 532nm LiDAR intensity image is invalid! Path={c3_image_path}')


        # Create image variable
        self.lidar_ms_image = None

        # Open the LiDAR multispectral intensity image files as c1_src
        # (1550nm), c2_src (1064nm), and c3_src (532nm)
        with rasterio.open(c1_image_path) as c1_src, \
             rasterio.open(c2_image_path) as c2_src, \
             rasterio.open(c3_image_path) as c3_src:

            # Read the image, resample it to the appropriate GSD,
            # arrange the numpy array to be (rows, cols, bands)
            if resample_factor == 1.0:
                c1 = np.moveaxis(c1_src.read(), 0, -1)
                c2 = np.moveaxis(c2_src.read(), 0, -1)
                c3 = np.moveaxis(c3_src.read(), 0, -1)
            else:
                # Set the shape of the resampled image
                c1_out_shape=(c1_src.count,
                            int(c1_src.height * resample_factor),
                            int(c1_src.width * resample_factor))
                c2_out_shape=(c2_src.count,
                            int(c2_src.height * resample_factor),
                            int(c2_src.width * resample_factor))
                c3_out_shape=(c3_src.count,
                            int(c3_src.height * resample_factor),
                            int(c3_src.width * resample_factor))

                # Read the image, resample it to the appropriate GSD,
                # arrange the numpy array to be (rows, cols, bands)
                c1 = np.moveaxis(c1_src.read(out_shape=c1_out_shape,
                                resampling=resampling), 0, -1)
                c2 = np.moveaxis(c2_src.read(out_shape=c2_out_shape,
                                resampling=resampling), 0, -1)
                c3 = np.moveaxis(c3_src.read(out_shape=c3_out_shape,
                                resampling=resampling), 0, -1)

            # Stack each intensity band into a single cube
            self.lidar_ms_image = np.dstack((c1, c2, c3))

            # Threshold the image so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.lidar_ms_image[self.lidar_ms_image > self.lidar_ms_intensity_thres] = self.lidar_ms_image.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_ms_image -= self.lidar_ms_image.min()
                self.lidar_ms_image /= self.lidar_ms_image.max()

        return self.lidar_ms_image



    def load_lidar_ms_image_tiles(self, gsd=UH_2018_GT_GSD, tile_list=None,
                                  thres=True, normalize=True,
                                  resampling=None):
        """
        Loads the University of Houston 2018 dataset's lidar
        multispectral intensity image as a set of tiles sampled at a
        specified GSD. If no tile list is given, the whole image will be
        loaded as tiles.
        """

        print('Loading LiDAR multispectral intensity image tiles...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        # Check tile_list parameter value
        if tile_list and not isinstance(tile_list, tuple): raise ValueError(
            "'tile_list' parameter should be a tuple of tuples!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]

        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's hyperspectral image
        c1_image_path = os.path.join(self.path_to_dataset_directory,
                                     self.path_to_lidar_1550nm_intensity)
        c2_image_path = os.path.join(self.path_to_dataset_directory,
                                     self.path_to_lidar_1064nm_intensity)
        c3_image_path = os.path.join(self.path_to_dataset_directory,
                                     self.path_to_lidar_532nm_intensity)


        # Throw error if file paths do not exist
        if not os.path.isfile(c1_image_path): raise FileNotFoundError(
            f'Path to UH2018 1550nm LiDAR intensity image is invalid! Path={c1_image_path}')
        if not os.path.isfile(c2_image_path): raise FileNotFoundError(
            f'Path to UH2018 1064nm LiDAR intensity image is invalid! Path={c2_image_path}')
        if not os.path.isfile(c3_image_path): raise FileNotFoundError(
            f'Path to UH2018 532nm LiDAR intensity image is invalid! Path={c3_image_path}')


        # Create image variable
        self.lidar_ms_image_tiles = []

        # Open the LiDAR multispectral intensity image files as c1_src
        # (1550nm), c2_src (1064nm), and c3_src (532nm)
        with rasterio.open(c1_image_path) as c1_src, \
             rasterio.open(c2_image_path) as c2_src, \
             rasterio.open(c3_image_path) as c3_src:

            # Get the size of the tile windows
            c1_tile_width = c1_src.width / self.dataset_tiled_subset_cols
            c1_tile_height = c1_src.height / self.dataset_tiled_subset_rows
            c2_tile_width = c1_src.width / self.dataset_tiled_subset_cols
            c2_tile_height = c1_src.height / self.dataset_tiled_subset_rows
            c3_tile_width = c1_src.width / self.dataset_tiled_subset_cols
            c3_tile_height = c1_src.height / self.dataset_tiled_subset_rows

            # Set the shape of the resampled tile
            c1_out_shape=(c1_src.count,
                int(c1_src.height * resample_factor),
                int(c1_src.width * resample_factor))
            c2_out_shape=(c2_src.count,
                        int(c2_src.height * resample_factor),
                        int(c2_src.width * resample_factor))
            c3_out_shape=(c3_src.count,
                        int(c3_src.height * resample_factor),
                        int(c3_src.width * resample_factor))

            # Read in the image data for each image tile
            for tile_row in range(0, self.dataset_tiled_subset_rows):
                for tile_column in range(0, self.dataset_tiled_subset_cols):

                    # If specified tiles are desired, then skip any
                    # tiles that do not match the tile_list parameter
                    if tile_list and (tile_row, tile_column) not in tile_list:
                        continue

                    # Set the tile window to read from the image
                    c1_window = Window(c1_tile_width * tile_column,
                                    c1_tile_height * tile_row,
                                    c1_tile_width, c1_tile_height)
                    c2_window = Window(c2_tile_width * tile_column,
                                    c2_tile_height * tile_row,
                                    c2_tile_width, c2_tile_height)
                    c3_window = Window(c3_tile_width * tile_column,
                                    c3_tile_height * tile_row,
                                    c3_tile_width, c3_tile_height)

                    # Read the tile window from the image, resample it to
                    # the appropriate GSD, arrange the numpy array to be
                    # (rows, cols, bands)
                    if resample_factor == 1.0:
                        c1_tile = np.moveaxis(c1_src.read(window = c1_window), 0, -1)
                        c2_tile = np.moveaxis(c2_src.read(window = c2_window), 0, -1)
                        c3_tile = np.moveaxis(c3_src.read(window = c3_window), 0, -1)
                    else:
                        c1_tile = np.moveaxis(c1_src.read(
                                        window = c1_window,
                                        out_shape=c1_out_shape,
                                        resampling=resampling), 0, -1)
                        c2_tile = np.moveaxis(c2_src.read(
                                        window = c2_window,
                                        out_shape=c2_out_shape,
                                        resampling=resampling), 0, -1)
                        c3_tile = np.moveaxis(c3_src.read(
                                        window = c3_window,
                                        out_shape=c3_out_shape,
                                        resampling=resampling), 0, -1)

                    # Copy the tile to the tiles array
                    self.lidar_ms_image_tiles.append(np.dstack((c1_tile,
                                                                c2_tile,
                                                                c3_tile)))

        # If no tiles were added to the tile list, then set image tiles
        # variable to 'None'
        if len(self.lidar_ms_image_tiles) == 0: self.lidar_ms_image_tiles = None
        else:
            # Turn list of numpy arrays into single numpy array
            self.lidar_ms_image_tiles = np.stack(self.lidar_ms_image_tiles)

            # Threshold the image tiles so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.lidar_ms_image_tiles[self.lidar_ms_image_tiles > self.lidar_ms_intensity_thres] = self.lidar_ms_image_tiles.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_ms_image_tiles -= self.lidar_ms_image_tiles.min()
                self.lidar_ms_image_tiles /= self.lidar_ms_image_tiles.max()

        return self.lidar_ms_image_tiles



    def save_full_lidar_ms_image_array(self, path, file_name='full_lidar_multispectral_image.npy'):
        """
        Saves the numpy array of the full lidar multispectral image to a
        file for faster loading in the future.
        """

        # If the lidar multispectral intensity image member variable is
        # empty, then load the full lidar multispectral intensity image
        if self.lidar_ms_image is None: self.load_full_lidar_ms_image()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_ms_image)



    def save_tiled_lidar_ms_image_array(self, path, file_name='tiled_lidar_multispectral_image.npy'):
        """
        Saves the numpy array of the tiled lidar multispectral image to
        a file for faster loading in the future.
        """

        # If the lidar multispectral intensity image tile member
        # variable is empty, then load all lidar multispectral intensity
        # image tiles
        if self.lidar_ms_image_tiles is None: self.load_lidar_ms_image_tiles()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_ms_image_tiles)




    def load_full_lidar_ms_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset lidar multispectral intensity image.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_ms_image = np.load(infile)



    def load_tiled_lidar_ms_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset lidar multispectral intensity image tiles.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_ms_image_tiles = np.load(infile)



    def show_lidar_ms_image(self, size=(15,9),
                            full_gt_overlay=False,
                            train_gt_overlay=False,
                            test_gt_overlay=False):
        """
        Displays the multispectral lidar image in RGB with red=1550nm,
        green=1064nm, and blue=532nm bands, with optional ground truth
        overlay on the image.
        """

        # If the LiDAR multispectral intensity image member variable is
        # empty, then load the full lidar multispectral intensity image
        if self.lidar_ms_image is None: self.load_full_lidar_ms_image()

        image = self.lidar_ms_image[:,:,:]

        if full_gt_overlay:
            # If the ground truth image member variable is empty, then
            # load the full ground truth image
            if self.gt_image is None: self.load_full_gt_image()

            classes = self.gt_image
            title = 'LiDAR multispectral intensity image w/ ground truth overlay'
        elif test_gt_overlay or train_gt_overlay:
            # If the ground truth image tiles member variable is empty,
            # then load the ground truth image tiles
            if self.gt_image_tiles is None: self.load_gt_image_tiles()

            # create a copy of ground truth image tiles
            gt_tiles = self.gt_image_tiles.copy()

            # Get tile dimensions
            tile_shape = gt_tiles[0].shape

            # Choose which set of tiles to set to zero values and
            # set proper image title
            if train_gt_overlay:
                tiles_to_remove = self.dataset_testing_subset
                title = 'LiDAR multispectral intensity image w/ training ground truth overlay'
            else:
                tiles_to_remove = self.dataset_training_subset
                title = 'LiDAR multispectral intensity image w/ testing ground truth overlay'

            # Zero out tiles not in the desired subset
            for tile in tiles_to_remove:
                row, col = tile
                index = row * self.dataset_tiled_subset_cols + col
                gt_tiles[index] = np.zeros(tile_shape, dtype=np.uint8)

            # Create single ground truth image mask
            classes = self.merge_tiles(gt_tiles)

        else:
            classes = None
            title = 'LiDAR multispectral intensity image'

        plt.close('all')

        view = spectral.imshow(image,
                               source=image,
                               classes=classes,
                               figsize=size)
        if (full_gt_overlay
            or test_gt_overlay
            or train_gt_overlay): view.set_display_mode('overlay')

        view.set_title(title)

        plt.show(block=True)



    def load_full_lidar_dsm_image(self, gsd=UH_2018_GT_GSD,
                                  thres=True, normalize=True,
                                  resampling=None):
        """
        Loads the full-size LiDAR digital surface model (DSM) image for
        the University of Houston 2018 dataset sampled at the specified
        GSD.
        """

        print('Loading full LiDAR DSM image...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]

        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's LiDAR DSM image
        image_path = os.path.join(self.path_to_dataset_directory,
                                  self.path_to_lidar_dsm)

        # Throw error if file paths do not exist
        if not os.path.isfile(image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DSM image is invalid! Path={image_path}')

        # Create image variable
        self.lidar_dsm_image = None

        # Open the LiDAR DSM file as src
        with rasterio.open(image_path) as src:

            # Read the image, resample it to the appropriate GSD,
            # arrange the numpy array to be (rows, cols, bands),
            # and remove unused bands
            if resample_factor == 1.0:
                self.lidar_dsm_image = np.moveaxis(src.read(), 0, -1)
            else:
                # Set the shape of the resampled image
                out_shape=(src.count,
                        int(src.height * resample_factor),
                        int(src.width * resample_factor))

                self.lidar_dsm_image = np.moveaxis(src.read(
                                                out_shape=out_shape,
                                                resampling=resampling), 0, -1)

            # Threshold the image so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.lidar_dsm_image[self.lidar_dsm_image > self.lidar_dsm_thres] = self.lidar_dsm_image.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_dsm_image -= self.lidar_dsm_image.min()
                self.lidar_dsm_image /= self.lidar_dsm_image.max()

        return self.lidar_dsm_image



    def load_lidar_dsm_image_tiles(self, gsd=UH_2018_GT_GSD, tile_list=None,
                                   thres=True, normalize=True, resampling=None):
        """
        Loads the University of Houston 2018 dataset's LiDAR digital
        surface model (DSM) image as a set of tiles sampled at a
        specified GSD. If no tile list is given, the whole image will be
        loaded as tiles.
        """

        print('Loading LiDAR DSM image tiles...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        # Check tile_list parameter value
        if tile_list and not isinstance(tile_list, tuple): raise ValueError(
            "'tile_list' parameter should be a tuple of tuples!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]


        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's LiDAR DSM image
        image_path = os.path.join(self.path_to_dataset_directory,
                                  self.path_to_lidar_dsm)

        # Throw error if file paths do not exist
        if not os.path.isfile(image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DSM image is invalid! Path={image_path}')

        # Create image variable
        self.lidar_dsm_image_tiles = []

        # Open the LiDAR DSM file as src
        with rasterio.open(image_path) as src:

            # Set the shape of the resampled image
            out_shape=(src.count,
                       int(src.height * resample_factor),
                       int(src.width * resample_factor))

            # Get the size of the tile windows
            tile_width = src.width / self.dataset_tiled_subset_cols
            tile_height = src.height / self.dataset_tiled_subset_rows

            # Read in the image data for each image tile
            for tile_row in range(0, self.dataset_tiled_subset_rows):
                for tile_column in range(0, self.dataset_tiled_subset_cols):

                    # If specified tiles are desired, then skip any
                    # tiles that do not match the tile_list parameter
                    if tile_list and (tile_row, tile_column) not in tile_list:
                        continue

                    # Set the tile window to read from the image
                    window = Window(tile_width * tile_column,
                                    tile_height * tile_row,
                                    tile_width, tile_height)


                    # Read the tile window from the image, resample it to
                    # the appropriate GSD, arrange the numpy array to be
                    # (rows, cols, bands), and remove unused bands
                    if resample_factor == 1.0:
                        tile = np.moveaxis(src.read(window = window), 0, -1)
                    else:
                        tile = np.moveaxis(src.read(
                                        window = window,
                                        out_shape=out_shape,
                                        resampling=resampling), 0, -1)

                    # Copy the tile to the tiles array
                    self.lidar_dsm_image_tiles.append(tile)

        # If no tiles were added to the tile list, then set image tiles
        # variable to 'None'
        if len(self.lidar_dsm_image_tiles) == 0: self.lidar_dsm_image_tiles = None
        else:
            # Turn list of numpy arrays into single numpy array
            self.lidar_dsm_image_tiles = np.stack(self.lidar_dsm_image_tiles)

            # Threshold the image tiles so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.lidar_dsm_image_tiles[self.lidar_dsm_image_tiles > self.lidar_dsm_thres] = self.lidar_dsm_image_tiles.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_dsm_image_tiles -= self.lidar_dsm_image_tiles.min()
                self.lidar_dsm_image_tiles /= self.lidar_dsm_image_tiles.max()

        return self.lidar_dsm_image_tiles



    def save_full_lidar_dsm_image_array(self, path, file_name='full_lidar_dsm_image.npy'):
        """
        Saves the numpy array of the full LiDAR dsm image to a
        file for faster loading in the future.
        """

        # If the lidar dsm image member variable is empty, then load the full
        # lidar image
        if self.lidar_dsm_image is None: self.load_full_lidar_dsm_image()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_dsm_image)



    def save_tiled_lidar_dsm_image_array(self, path, file_name='tiled_lidar_dsm_image.npy'):
        """
        Saves the numpy array of the tiled LiDAR dsm image to
        a file for faster loading in the future.
        """

        # If the lidar dsm image tile member variable is empty, then load all
        # lidar dsm image tiles
        if self.lidar_dsm_image_tiles is None: self.load_lidar_dsm_image_tiles()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_dsm_image_tiles)




    def load_full_lidar_dsm_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset LiDAR digital surface model image.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_dsm_image = np.load(infile)



    def load_tiled_lidar_dsm_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset LiDAR digital surface model image tiles.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_dsm_image_tiles = np.load(infile)



    def show_lidar_dsm_image(self, size=(15,9),
                            full_gt_overlay=False,
                            train_gt_overlay=False,
                            test_gt_overlay=False):
        """
        Displays the LiDAR dsm image with optional ground truth
        overlay on the image.
        """

        # If the lidar dsm image member variable is empty, then load the
        # full lidar dsm image
        if self.lidar_dsm_image is None: self.load_full_lidar_dsm_image()

        image = self.lidar_dsm_image[:,:,:]

        if full_gt_overlay:
            # If the ground truth image member variable is empty, then
            # load the full ground truth image
            if self.gt_image is None: self.load_full_gt_image()

            classes = self.gt_image
            title = 'LiDAR Digital Surface Model (DSM) image w/ ground truth overlay'
        elif test_gt_overlay or train_gt_overlay:
            # If the ground truth tiles image member variable is empty,
            # then load the ground truth image tiles
            if self.gt_image_tiles is None: self.load_gt_image_tiles()

            # create a copy of ground truth image tiles
            gt_tiles = self.gt_image_tiles.copy()

            # Get tile dimensions
            tile_shape = gt_tiles[0].shape

            # Choose which set of tiles to set to zero values and
            # set proper image title
            if train_gt_overlay:
                tiles_to_remove = self.dataset_testing_subset
                title = 'LiDAR Digital Surface Model (DSM) image w/ training ground truth overlay'
            else:
                tiles_to_remove = self.dataset_training_subset
                title = 'LiDAR Digital Surface Model (DSM) image w/ testing ground truth overlay'

            # Zero out tiles not in the desired subset
            for tile in tiles_to_remove:
                row, col = tile
                index = row * self.dataset_tiled_subset_cols + col
                gt_tiles[index] = np.zeros(tile_shape, dtype=np.uint8)

            # Create single ground truth image mask
            classes = self.merge_tiles(gt_tiles)

        else:
            classes = None
            title = 'LiDAR Digital Surface Model (DSM) image'

        plt.close('all')

        view = spectral.imshow(image,
                               source=image,
                               classes=classes,
                               figsize=size)
        if (full_gt_overlay
            or test_gt_overlay
            or train_gt_overlay): view.set_display_mode('overlay')

        view.set_title(title)

        plt.show(block=True)



    def load_full_lidar_dem_image(self, gsd=UH_2018_GT_GSD,
                                  use_void_filling_model = False,
                                  use_hybrid_model = True,
                                  thres=True, normalize=True,
                                  resampling=None):
        """
        Loads the full-size lidar digital elevation model (DEM) image
        for the University of Houston 2018 dataset sampled at the
        specified GSD.
        """

        if use_hybrid_model:
            dem_path = self.path_to_lidar_dem_b
            print('Loading full LiDAR hybrid DEM image...')
        elif use_void_filling_model:
            dem_path = self.path_to_lidar_dem_tli
            print('Loading full LiDAR bare-earth elevation w/ void filling DEM image...')
        else:
            dem_path = self.path_to_lidar_dem_3msr
            print('Loading full LiDAR bare-earth elevation DEM image...')

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's LiDAR DEM image
        image_path = os.path.join(self.path_to_dataset_directory, dem_path)

        # Throw error if file paths do not exist
        if not os.path.isfile(image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DEM image is invalid! Path={image_path}')

        # Create image variable
        self.lidar_dem_image = None

        # Open the LiDAR DSM file as src
        with rasterio.open(image_path) as src:

            # Read the image, resample it to the appropriate GSD,
            # arrange the numpy array to be (rows, cols, bands),
            # and remove unused bands
            if resample_factor == 1.0:
                self.lidar_dem_image = np.moveaxis(src.read(), 0, -1)
            else:
                # Set the shape of the resampled image
                out_shape=(src.count,
                        int(src.height * resample_factor),
                        int(src.width * resample_factor))

                self.lidar_dem_image = np.moveaxis(src.read(
                                                out_shape=out_shape,
                                                resampling=resampling), 0, -1)

            # Threshold the image so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.lidar_dem_image[self.lidar_dem_image > self.lidar_dsm_thres] = self.lidar_dem_image.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_dem_image -= self.lidar_dem_image.min()
                self.lidar_dem_image /= self.lidar_dem_image.max()

        return self.lidar_dem_image



    def load_lidar_dem_image_tiles(self, gsd=UH_2018_GT_GSD, tile_list=None,
                                   use_void_filling_model = False,
                                   use_hybrid_model = True,
                                   thres=True, normalize=True,
                                   resampling=None):
        """
        Loads the University of Houston 2018 dataset's lidar digital
        elevation model (DEM) image as a set of tiles sampled at a
        specified GSD. If no tile list is given, the whole image will be
        loaded as tiles.
        """

        if use_hybrid_model:
            dem_path = self.path_to_lidar_dem_b
            print('Loading full LiDAR hybrid DEM image...')
        elif use_void_filling_model:
            dem_path = self.path_to_lidar_dem_tli
            print('Loading full LiDAR bare-earth elevation w/ void filling DEM image...')
        else:
            dem_path = self.path_to_lidar_dem_3msr
            print('Loading full LiDAR bare-earth elevation DEM image...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        # Check tile_list parameter value
        if tile_list and not isinstance(tile_list, tuple): raise ValueError(
            "'tile_list' parameter should be a tuple of tuples!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]


        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's LiDAR DSM image
        image_path = os.path.join(self.path_to_dataset_directory, dem_path)

        # Throw error if file paths do not exist
        if not os.path.isfile(image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DEM image is invalid! Path={image_path}')

        # Create image variable
        self.lidar_dem_image_tiles = []

        # Open the LiDAR DSM file as src
        with rasterio.open(image_path) as src:

            # Set the shape of the resampled image
            out_shape=(src.count,
                       int(src.height * resample_factor),
                       int(src.width * resample_factor))

            # Get the size of the tile windows
            tile_width = src.width / self.dataset_tiled_subset_cols
            tile_height = src.height / self.dataset_tiled_subset_rows

            # Read in the image data for each image tile
            for tile_row in range(0, self.dataset_tiled_subset_rows):
                for tile_column in range(0, self.dataset_tiled_subset_cols):

                    # If specified tiles are desired, then skip any
                    # tiles that do not match the tile_list parameter
                    if tile_list and (tile_row, tile_column) not in tile_list:
                        continue

                    # Set the tile window to read from the image
                    window = Window(tile_width * tile_column,
                                    tile_height * tile_row,
                                    tile_width, tile_height)

                    # Read the tile window from the image, resample it to
                    # the appropriate GSD, arrange the numpy array to be
                    # (rows, cols, bands), and remove unused bands
                    if resample_factor == 1.0:
                        tile = np.moveaxis(src.read(window = window), 0, -1)
                    else:
                        tile = np.moveaxis(src.read(
                                        window = window,
                                        out_shape=out_shape,
                                        resampling=resampling), 0, -1)

                    # Copy the tile to the tiles array
                    self.lidar_dem_image_tiles.append(tile)

        # If no tiles were added to the tile list, then set image tiles
        # variable to 'None'
        if len(self.lidar_dem_image_tiles) == 0: self.lidar_dem_image_tiles = None
        else:
            # Turn list of numpy arrays into single numpy array
            self.lidar_dem_image_tiles = np.stack(self.lidar_dem_image_tiles)

            # Threshold the image tiles so that any value over the threshold
            # is set to the image minimum
            if thres:
                self.lidar_dem_image_tiles[self.lidar_dem_image_tiles > self.lidar_dsm_thres] = self.lidar_dem_image_tiles.min()

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_dem_image_tiles -= self.lidar_dem_image_tiles.min()
                self.lidar_dem_image_tiles /= self.lidar_dem_image_tiles.max()

        return self.lidar_dem_image_tiles



    def save_full_lidar_dem_image_array(self, path, file_name='full_lidar_dem_image.npy'):
        """
        Saves the numpy array of the full lidar dem image to a
        file for faster loading in the future.
        """

        # If the lidar dsm image member variable is empty, then load the full
        # lidar image
        if self.lidar_dem_image is None: self.load_full_lidar_dem_image()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_dsm_image)



    def save_tiled_lidar_dem_image_array(self, path, file_name='tiled_lidar_dem_image.npy'):
        """
        Saves the numpy array of the tiled lidar dem image to
        a file for faster loading in the future.
        """

        # If the lidar dsm image tile member variable is empty, then load all
        # lidar dsm image tiles
        if self.lidar_dem_image_tiles is None: self.load_lidar_dem_image_tiles()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_dsm_image_tiles)




    def load_full_lidar_dem_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset lidar digital elevation model image.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_dem_image = np.load(infile)



    def load_tiled_lidar_dem_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset lidar digital elevation model image tiles.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_dsm_image_tiles = np.load(infile)



    def show_lidar_dem_image(self, size=(15,9),
                            full_gt_overlay=False,
                            train_gt_overlay=False,
                            test_gt_overlay=False):
        """
        Displays the lidar dem image with optional ground truth
        overlay on the image.
        """

        # If the lidar dem image member variable is empty, then load the
        # full lidar dem image
        if self.lidar_dem_image is None: self.load_full_lidar_dem_image()

        image = self.lidar_dem_image[:,:,:]

        if full_gt_overlay:
            # If the ground truth image member variable is empty, then
            # load the full ground truth image
            if self.gt_image is None: self.load_full_gt_image()

            classes = self.gt_image
            title = 'LiDAR Digital Elevation Model (DEM) image w/ ground truth overlay'
        elif test_gt_overlay or train_gt_overlay:
            # If the ground truth tiles image member variable is empty,
            # then load the ground truth image tiles
            if self.gt_image_tiles is None: self.load_gt_image_tiles()

            # create a copy of ground truth image tiles
            gt_tiles = self.gt_image_tiles.copy()

            # Get tile dimensions
            tile_shape = gt_tiles[0].shape

            # Choose which set of tiles to set to zero values and
            # set proper image title
            if train_gt_overlay:
                tiles_to_remove = self.dataset_testing_subset
                title = 'LiDAR Digital Elevation Model (DEM) image w/ training ground truth overlay'
            else:
                tiles_to_remove = self.dataset_training_subset
                title = 'LiDAR Digital Elevation Model (DEM) image w/ testing ground truth overlay'

            # Zero out tiles not in the desired subset
            for tile in tiles_to_remove:
                row, col = tile
                index = row * self.dataset_tiled_subset_cols + col
                gt_tiles[index] = np.zeros(tile_shape, dtype=np.uint8)

            # Create single ground truth image mask
            classes = self.merge_tiles(gt_tiles)

        else:
            classes = None
            title = 'LiDAR Digital Elevation Model (DEM) image'

        plt.close('all')

        view = spectral.imshow(image,
                               source=image,
                               classes=classes,
                               figsize=size)
        if (full_gt_overlay
            or test_gt_overlay
            or train_gt_overlay): view.set_display_mode('overlay')

        view.set_title(title)

        plt.show(block=True)



    def load_full_lidar_ndsm_image(self, gsd=UH_2018_GT_GSD,
                                   use_void_filling_model = False,
                                   use_hybrid_model = True,
                                   thres=True, normalize=True,
                                   resampling=None):
        """
        Loads the full-size LiDAR normalized digital surface model
        (NDSM) image for the University of Houston 2018 dataset sampled
        at the specified GSD.
        """

        if use_hybrid_model:
            dem_path = self.path_to_lidar_dem_b
            print('Loading full LiDAR NDSM image using hybrid DEM...')
        elif use_void_filling_model:
            dem_path = self.path_to_lidar_dem_tli
            print('Loading full LiDAR NDSM image using bare-earth elevation w/ void filling DEM...')
        else:
            dem_path = self.path_to_lidar_dem_3msr
            print('Loading full LiDAR NDSM image using bare-earth elevation DEM...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]


        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's LiDAR DSM image
        dsm_image_path = os.path.join(self.path_to_dataset_directory,
                                  self.path_to_lidar_dsm)
        dem_image_path = os.path.join(self.path_to_dataset_directory, dem_path)

        # Throw error if file paths do not exist
        if not os.path.isfile(dsm_image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DSM image is invalid! Path={dsm_image_path}')

        if not os.path.isfile(dem_image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DEM image is invalid! Path={dem_image_path}')

        # Create image variable
        self.lidar_ndsm_image = None

        # Open the LiDAR DSM and DEM files as dsm_src and dem_src
        with rasterio.open(dsm_image_path) as dsm_src, \
             rasterio.open(dem_image_path) as dem_src:

            # Read the images, resample it to the appropriate GSD,
            # arrange the numpy array to be (rows, cols, bands),
            # and remove unused bands
            if resample_factor == 1.0:
                dsm_image = np.moveaxis(dsm_src.read(), 0, -1)
                dem_image = np.moveaxis(dem_src.read(), 0, -1)
            else:
                # Set the shape of the resampled image
                dsm_out_shape=(dsm_src.count,
                            int(dsm_src.height * resample_factor),
                            int(dsm_src.width * resample_factor))
                dem_out_shape=(dem_src.count,
                            int(dem_src.height * resample_factor),
                            int(dem_src.width * resample_factor))

                dsm_image = np.moveaxis(dsm_src.read(
                                        out_shape=dsm_out_shape,
                                        resampling=resampling), 0, -1)
                dem_image = np.moveaxis(dem_src.read(
                                        out_shape=dem_out_shape,
                                        resampling=resampling), 0, -1)

            # Threshold the images so that any value over the threshold
            # is set to the image minimum
            if thres:
                dsm_image[dsm_image > self.lidar_dsm_thres] = dsm_image.min()
                dem_image[dem_image > self.lidar_dsm_thres] = dem_image.min()

            # NDSM is the difference between the DSM and the DEM
            self.lidar_ndsm_image = dsm_image - dem_image

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_ndsm_image -= self.lidar_ndsm_image.min()
                self.lidar_ndsm_image /= self.lidar_ndsm_image.max()

        return self.lidar_ndsm_image



    def load_lidar_ndsm_image_tiles(self, gsd=UH_2018_GT_GSD, tile_list=None,
                                    use_void_filling_model = False,
                                    use_hybrid_model = True,
                                    thres=True, normalize=True,
                                    resampling=None):
        """
        Loads the University of Houston 2018 dataset's LiDAR normalized
        digital surface model (NDSM) image as a set of tiles sampled at
        a specified GSD. If no tile list is given, the whole image will
        be loaded as tiles.
        """

        if use_hybrid_model:
            dem_path = self.path_to_lidar_dem_b
            print('Loading full LiDAR NDSM image using hybrid DEM...')
        elif use_void_filling_model:
            dem_path = self.path_to_lidar_dem_tli
            print('Loading full LiDAR NDSM image using bare-earth elevation w/ void filling DEM...')
        else:
            dem_path = self.path_to_lidar_dem_3msr
            print('Loading full LiDAR NDSM image using bare-earth elevation DEM...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        # Check tile_list parameter value
        if tile_list and not isinstance(tile_list, tuple): raise ValueError(
            "'tile_list' parameter should be a tuple of tuples!")

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]


        # Set the factor of GSD resampling
        resample_factor = self.gsd_lidar / float(gsd)

        # Get full path to dataset's LiDAR DSM image
        dsm_image_path = os.path.join(self.path_to_dataset_directory,
                                  self.path_to_lidar_dsm)
        dem_image_path = os.path.join(self.path_to_dataset_directory, dem_path)

        # Throw error if file paths do not exist
        if not os.path.isfile(dsm_image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DSM image is invalid! Path={dsm_image_path}')

        if not os.path.isfile(dem_image_path): raise FileNotFoundError(
            f'Path to UH2018 LiDAR DEM image is invalid! Path={dem_image_path}')


        # Create image variable
        self.lidar_ndsm_image_tiles = []

        # Open the LiDAR DSM and DEM files as dsm_src and dem_src
        with rasterio.open(dsm_image_path) as dsm_src, \
             rasterio.open(dem_image_path) as dem_src:

            # Set the shape of the resampled image
            dsm_out_shape=(dsm_src.count,
                           int(dsm_src.height * resample_factor),
                           int(dsm_src.width * resample_factor))
            dem_out_shape=(dem_src.count,
                           int(dem_src.height * resample_factor),
                           int(dem_src.width * resample_factor))

            # Get the size of the tile windows
            dsm_tile_width = dsm_src.width / self.dataset_tiled_subset_cols
            dsm_tile_height = dsm_src.height / self.dataset_tiled_subset_rows
            dem_tile_width = dem_src.width / self.dataset_tiled_subset_cols
            dem_tile_height = dem_src.height / self.dataset_tiled_subset_rows

            # Read in the image data for each image tile
            for tile_row in range(0, self.dataset_tiled_subset_rows):
                for tile_column in range(0, self.dataset_tiled_subset_cols):

                    # If specified tiles are desired, then skip any
                    # tiles that do not match the tile_list parameter
                    if tile_list and (tile_row, tile_column) not in tile_list:
                        continue

                    # Set the tile windows to read from the image
                    dsm_window = Window(dsm_tile_width * tile_column,
                                       dsm_tile_height * tile_row,
                                       dsm_tile_width, dsm_tile_height)
                    dem_window = Window(dem_tile_width * tile_column,
                                       dem_tile_height * tile_row,
                                       dem_tile_width, dem_tile_height)

                    # Read the tile window from the image, resample it to
                    # the appropriate GSD, arrange the numpy array to be
                    # (rows, cols, bands), and remove unused bands
                    if resample_factor == 1.0:
                        dsm_tile = np.moveaxis(dsm_src.read(
                                            window = dsm_window), 0, -1)

                        dem_tile = np.moveaxis(dem_src.read(
                                            window = dem_window), 0, -1)
                    else:
                        dsm_tile = np.moveaxis(dsm_src.read(
                                            window = dsm_window,
                                            out_shape=dsm_out_shape,
                                            resampling=resampling), 0, -1)

                        dem_tile = np.moveaxis(dem_src.read(
                                            window = dem_window,
                                            out_shape=dem_out_shape,
                                            resampling=resampling), 0, -1)


                    # Threshold the image tiles so that any value over
                    # the threshold is set to the image minimum
                    if thres:
                        dsm_tile[dsm_tile > self.lidar_dsm_thres] = dsm_tile.min()
                        dem_tile[dem_tile > self.lidar_dsm_thres] = dem_tile.min()

                    # NDSM is the difference between the DSM and the DEM
                    ndsm_tile = dsm_tile - dem_tile

                    # Copy the tile to the tiles array
                    self.lidar_ndsm_image_tiles.append(ndsm_tile)

        # If no tiles were added to the tile list, then set image tiles
        # variable to 'None'
        if len(self.lidar_ndsm_image_tiles) == 0: self.lidar_ndsm_image_tiles = None
        else:
            # Turn list of numpy arrays into single numpy array
            self.lidar_ndsm_image_tiles = np.stack(self.lidar_ndsm_image_tiles)

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.lidar_ndsm_image_tiles -= self.lidar_ndsm_image_tiles.min()
                self.lidar_ndsm_image_tiles /= self.lidar_ndsm_image_tiles.max()

        return self.lidar_ndsm_image_tiles



    def save_full_lidar_ndsm_image_array(self, path, file_name='full_lidar_ndsm_image.npy'):
        """
        Saves the numpy array of the full LiDAR ndsm image to a
        file for faster loading in the future.
        """

        # If the lidar ndsm image member variable is empty, then load the full
        # lidar image
        if self.lidar_ndsm_image is None: self.load_full_lidar_ndsm_image()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_ndsm_image)



    def save_tiled_lidar_ndsm_image_array(self, path, file_name='tiled_lidar_ndsm_image.npy'):
        """
        Saves the numpy array of the tiled LiDAR ndsm image to
        a file for faster loading in the future.
        """

        # If the lidar dsm image tile member variable is empty, then load all
        # lidar ndsm image tiles
        if self.lidar_ndsm_image_tiles is None: self.load_lidar_ndsm_image_tiles()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.lidar_ndsm_image_tiles)




    def load_full_lidar_ndsm_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset LiDAR normalized digital surface model image.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_ndsm_image = np.load(infile)



    def load_tiled_lidar_ndsm_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset LiDAR normalized digital surface model image tiles.
        """

        with open(file_path, 'rb') as infile:
            self.lidar_ndsm_image_tiles = np.load(infile)



    def show_lidar_ndsm_image(self, size=(15,9),
                            full_gt_overlay=False,
                            train_gt_overlay=False,
                            test_gt_overlay=False):
        """
        Displays the LiDAR ndsm image with optional ground truth
        overlay on the image.
        """

        # If the lidar ndsm image member variable is empty, then load the
        # full lidar ndsm image
        if self.lidar_ndsm_image is None: self.load_full_lidar_ndsm_image()

        image = self.lidar_ndsm_image[:,:,:]

        if full_gt_overlay:
            # If the ground truth image member variable is empty, then
            # load the full ground truth image
            if self.gt_image is None: self.load_full_gt_image()

            classes = self.gt_image
            title = 'LiDAR Normalized Digital Surface Model (NDSM) image w/ ground truth overlay'
        elif test_gt_overlay or train_gt_overlay:
            # If the ground truth tiles image member variable is empty,
            # then load the ground truth image tiles
            if self.gt_image_tiles is None: self.load_gt_image_tiles()

            # create a copy of ground truth image tiles
            gt_tiles = self.gt_image_tiles.copy()

            # Get tile dimensions
            tile_shape = gt_tiles[0].shape

            # Choose which set of tiles to set to zero values and
            # set proper image title
            if train_gt_overlay:
                tiles_to_remove = self.dataset_testing_subset
                title = 'LiDAR Normalized Digital Surface Model (NDSM) image w/ training ground truth overlay'
            else:
                tiles_to_remove = self.dataset_training_subset
                title = 'LiDAR Normalized Digital Surface Model (NDSM) image w/ testing ground truth overlay'

            # Zero out tiles not in the desired subset
            for tile in tiles_to_remove:
                row, col = tile
                index = row * self.dataset_tiled_subset_cols + col
                gt_tiles[index] = np.zeros(tile_shape, dtype=np.uint8)

            # Create single ground truth image mask
            classes = self.merge_tiles(gt_tiles)

        else:
            classes = None
            title = 'LiDAR Normalized Digital Surface Model (NDSM) image'

        plt.close('all')

        view = spectral.imshow(image,
                               source=image,
                               classes=classes,
                               figsize=size)
        if (full_gt_overlay
            or test_gt_overlay
            or train_gt_overlay): view.set_display_mode('overlay')

        view.set_title(title)

        plt.show(block=True)

    def load_full_vhr_image(self, gsd=UH_2018_GT_GSD,
                           thres=True, normalize=True, resampling=None):
        """
        Loads the full-size VHR RGB image for the University of
        Houston 2018 dataset sampled at the specified GSD.
        """
        print('Loading full VHR RGB image...')

        # VHR image can only be loaded as tiles since there's 14
        # images, so load tiles and then merge them to create full VHR
        # image
        self.vhr_image = self.merge_tiles(
            self.load_vhr_image_tiles(gsd=gsd, thres=thres,
                                      normalize=normalize,
                                      resampling=resampling))

        return self.vhr_image

    def load_vhr_image_tiles(self, gsd=UH_2018_GT_GSD, tile_list=None,
                             thres=True, normalize=True, resampling=None):
        """
        Loads the University of Houston 2018 dataset's hyperspectral
        images as a set of tiles sampled at a specified GSD. If no tile
        list is given, the whole image will be loaded as tiles.
        """

        print('Loading VHR RGB dataset tile images...')

        # Check GSD parameter value
        if gsd <= 0: raise ValueError("'gsd' parameter must be greater than 0!")

        # Initialize list for tiles
        self.vhr_image_tiles = []

        if resampling is None:
            print('No resampling method chosen, defaulting to '
                    f'{DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD
        elif resampling not in RESAMPLING_METHODS:
            print(f'Incompatible resampling method {resampling}, '
                    f'defaulting to {DEFAULT_RESAMPLING_METHOD}')
            resampling = DEFAULT_RESAMPLING_METHOD

        resampling = RESAMPLING_METHODS[resampling]


        # Set the factor of GSD resampling 1/factor for 1m to 0.5m
        resample_factor = self.gsd_vhr / float(gsd)

        for tile_row, tile_paths in enumerate(self.paths_to_vhr_images):
            for tile_column, tile_path in enumerate(tile_paths):

                # If specified tiles are desired, then skip any
                # tiles that do not match the tile_list parameter
                if tile_list and (tile_row, tile_column) not in tile_list:
                    continue

                image_path = os.path.join(self.path_to_dataset_directory, tile_path)

                # Throw error if file path does not exist
                if not os.path.isfile(image_path): raise FileNotFoundError(
                    f'Path to UH2018 VHR RGB tile ({tile_row}, {tile_column})'
                    f' image is invalid! Path={image_path}')


                with rasterio.open(image_path) as src:

                    # Read the tile image and resample it to
                    # the appropriate GSD, arrange the numpy array to be
                    # (rows, cols, bands)
                    if resample_factor == 1.0:
                            tile = np.moveaxis(src.read(), 0, -1)
                    else:
                        # Set the shape of the resampled tile
                        out_shape=(src.count,
                                int(src.height * resample_factor),
                                int(src.width * resample_factor))

                        tile = np.moveaxis(src.read(
                            out_shape=out_shape,
                            resampling=resampling), 0, -1)

                    # Copy the tile to the tiles array
                    self.vhr_image_tiles.append(np.copy(tile))

        # If no tiles were added to the tile list, then set image tiles
        # variable to 'None'
        if len(self.vhr_image_tiles) == 0: self.vhr_image_tiles = None
        else:
            # Turn list of numpy arrays into single numpy array
            self.vhr_image_tiles = np.stack(self.vhr_image_tiles)

            # Normalize each intensity band between 0.0 and 1.0
            if normalize:
                self.vhr_image_tiles = self.vhr_image_tiles.astype(float, copy=False)
                self.vhr_image_tiles -= self.vhr_image_tiles.min()
                self.vhr_image_tiles /= self.vhr_image_tiles.max()

        return self.vhr_image_tiles



    def save_full_vhr_image_array(self, path, file_name='full_vhr_image.npy'):
        """
        Saves the numpy array of the full vhr image to a file
        for faster loading in the future.
        """

        # If the hs image member variable is empty, then load the full
        # vhr image
        if self.vhr_image is None: self.load_full_vhr_image()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.hs_image)



    def save_tiled_vhr_image_array(self, path, file_name='tiled_vhr_image.npy'):
        """
        Saves the numpy array of the tiled vhr image to a file
        for faster loading in the future.
        """

        # If the hs image tile member variable is empty, then load all
        # vhr image tiles
        if self.vhr_image_tiles is None: self.load_vhr_image_tiles()

        with open(os.path.join(path, file_name), 'wb') as outfile:
            np.save(outfile, self.hs_image_tiles)




    def load_full_vhr_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset VHR image.
        """

        with open(file_path, 'rb') as infile:
            self.vhr_image = np.load(infile)



    def load_tiled_vhr_image_array(self, file_path):
        """
        Loads a saved numpy array for the University of Houston 2018
        dataset VHR image tiles.
        """

        with open(file_path, 'rb') as infile:
            self.vhr_image_tiles = np.load(infile)



    def show_vhr_image(self, size=(15,9),
                      full_gt_overlay=False,
                      train_gt_overlay=False,
                      test_gt_overlay=False):
        """
        Displays the vhr rgb image.
        """

        # If the hs image member variable is empty, then load the full
        # vhr rgb image
        if self.vhr_image is None: self.load_full_vhr_image()

        image = self.vhr_image[:,:,:]

        if full_gt_overlay:
            # If the gt image member variable is empty, then load the full
            # ground truth image
            if self.gt_image is None: self.load_full_gt_image()

            classes = self.gt_image
            title = 'VHR RGB image w/ ground truth overlay'
        elif test_gt_overlay or train_gt_overlay:
            # If the hs image member variable is empty, then load the
            # ground truth image tiles
            if self.gt_image_tiles is None: self.load_gt_image_tiles()

            # create a copy of ground truth image tiles
            gt_tiles = self.gt_image_tiles.copy()

            # Get tile dimensions
            tile_shape = gt_tiles[0].shape

            # Choose which set of tiles to set to zero values and
            # set proper image title
            if train_gt_overlay:
                tiles_to_remove = self.dataset_testing_subset
                title = 'VHR RGB image w/ training ground truth overlay'
            else:
                tiles_to_remove = self.dataset_training_subset
                title = 'VHR RGB image w/ testing ground truth overlay'

            # Zero out tiles not in the desired subset
            for tile in tiles_to_remove:
                row, col = tile
                index = row * self.dataset_tiled_subset_cols + col
                gt_tiles[index] = np.zeros(tile_shape, dtype=np.uint8)

            # Create single ground truth image mask
            classes = self.merge_tiles(gt_tiles)

        else:
            classes = None
            title = 'VHR RGB image'

        plt.close('all')

        view = spectral.imshow(image,
                               source=image,
                               classes=classes,
                               figsize=size)
        if (full_gt_overlay
            or test_gt_overlay
            or train_gt_overlay):
            view.set_display_mode('overlay')
            view.class_alpha = 0.5

        view.set_title(title)

        plt.show(block=True)

    def get_tile_indices(self, tile, row_offset=0, col_offset=0):
        """
        Returns the indices where there is a ground truth defined for a
        specific tile. Tile row and column offsets for the x and y
        values can also be defined.
        """

        # If the hs image member variable is empty, then load the
        # ground truth image tiles
        if self.gt_image_tiles is None: self.load_gt_image_tiles()

        tile_height, tile_width = self.gt_image_tiles[0].shape

        col_offset = col_offset * tile_width
        row_offset = row_offset * tile_height

        indices = []

        row, col = tile
        index = row * self.dataset_tiled_subset_cols + col
        img = self.gt_image_tiles[index]

        for r in range(tile_height):
            for c in range(tile_width):
                if img[r][c] > 0:
                    indices.append((r+row_offset,c+col_offset))

        return indices

    def get_train_test_split(self, flatten=False):
        """
        Returns the training and testing indicies of the dataset.
        """

        # If the hs image member variable is empty, then load the
        # ground truth image tiles
        if self.gt_image_tiles is None: self.load_gt_image_tiles()

        tile_height, tile_width = self.gt_image_tiles[0].shape

        image_width = tile_width * self.dataset_tiled_subset_cols

        train_indices = []
        test_indices = []

        for row in range(0, self.dataset_tiled_subset_rows):
            for col in range(0, self.dataset_tiled_subset_cols):
                tile = self.gt_image_tiles[row * self.dataset_tiled_subset_cols + col]
                for tr in range(tile_height):
                    for tc in range(tile_width):
                        r = tr + row*tile_height
                        c = tc + col*tile_width
                        if tile[tr][tc] > 0:
                            if flatten:
                                index = r*image_width + c
                            else:
                                index = (r,c)
                            if (row,col) in self.dataset_training_subset:
                                train_indices.append(index)
                            else:
                                test_indices.append(index)

        return train_indices, test_indices



## 3.3) Dataset loading function

In [None]:
#@title Set up GRSS DFC 2018 UH loading function

def load_grss_dfc_2018_uh_dataset(**hyperparams):
    """
    """

    path_to_dataset = hyperparams['path_to_dataset']

    skip_data_preprocessing = hyperparams['skip_data_preprocessing']

    hs_resampling = hyperparams['hs_resampling']
    lidar_ms_resampling = hyperparams['lidar_ms_resampling']
    lidar_ndsm_resampling = hyperparams['lidar_ndsm_resampling']
    vhr_resampling = hyperparams['vhr_resampling']

    normalize_hs_data = hyperparams['normalize_hs_data']
    normalize_lidar_ms_data = hyperparams['normalize_lidar_ms_data']
    normalize_lidar_ndsm_data = hyperparams['normalize_lidar_ndsm_data']
    normalize_vhr_data = hyperparams['normalize_vhr_data']

    hs_histogram_equalization = hyperparams['hs_histogram_equalization']
    lidar_ms_histogram_equalization = hyperparams['lidar_ms_histogram_equalization']
    lidar_dsm_histogram_equalization = hyperparams['lidar_dsm_histogram_equalization']
    lidar_dem_histogram_equalization = hyperparams['lidar_dem_histogram_equalization']
    lidar_ndsm_histogram_equalization = hyperparams['lidar_ndsm_histogram_equalization']
    vhr_histogram_equalization = hyperparams['vhr_histogram_equalization']

    hs_data_filter = hyperparams['hs_data_filter']
    lidar_ms_data_filter = hyperparams['lidar_ms_data_filter']
    lidar_dsm_data_filter = hyperparams['lidar_dsm_data_filter']
    lidar_dem_data_filter = hyperparams['lidar_dem_data_filter']
    vhr_data_filter = hyperparams['vhr_data_filter']


    use_all_data = hyperparams['use_all_data']
    if use_all_data:
        use_hs_data = True
        use_lidar_ms_data = True
        use_lidar_ndsm_data = True
        use_vhr_data = True
    else:
        use_hs_data = hyperparams['use_hs_data']
        use_lidar_ms_data = hyperparams['use_lidar_ms_data']
        use_lidar_ndsm_data = hyperparams['use_lidar_ndsm_data']
        use_vhr_data = hyperparams['use_vhr_data']

    hs_channels = []
    lidar_ms_channels = []
    lidar_ndsm_channels = []
    vhr_rgb_channels = []


    if path_to_dataset is not None:
        dataset = UH_2018_Dataset(dataset_path=path_to_dataset)
    else:
        dataset = UH_2018_Dataset()
    train_gt = dataset.load_full_gt_image(train_only=True)
    test_gt = dataset.load_full_gt_image(test_only=True)

    data = None
    channel_labels = None

    # Check to see if hyperspectral data is being used
    if use_hs_data:
        # Load hyperspectral data
        if dataset.hs_image is None:
            hs_data = dataset.load_full_hs_image(resampling=hs_resampling)
        else:
            hs_data = dataset.hs_image
        print(f'{dataset.name} hs_data shape: {hs_data.shape}')

        # Check for data equalization, filtering and normalization
        if hs_histogram_equalization and not skip_data_preprocessing:
            hs_data = histogram_equalization(hs_data)
        if hs_data_filter is not None and not skip_data_preprocessing:
            print(f'Filtering hyperspectral data with {hs_data_filter} filter...')
            hs_data = filter_image(hs_data, hs_data_filter)
        if normalize_hs_data:
            print('Normalizing hyperspectral data...')
            hs_data = normalize_image(hs_data)

        # Add hyperspectral data to data cube and save channel indices
        # for hyperspectral data
        if data is None:
            hs_channels = range(hs_data.shape[-1])
            data = np.copy(hs_data)
            channel_labels = dataset.hs_band_wavelength_labels
        else:
            hs_channels = [x + data.shape[-1] for x in range(hs_data.shape[-1])]
            data = np.dstack((data, hs_data))
            channel_labels = channel_labels + dataset.hs_band_wavelength_labels

    # Check to see if lidar multispectral intensity data is being used
    if use_lidar_ms_data:
        # Load LiDAR multispectral data
        if dataset.lidar_ms_image is None:
            lidar_ms_data = dataset.load_full_lidar_ms_image(normalize=normalize_lidar_ms_data,
                                                             resampling=lidar_ms_resampling)
        else:
            lidar_ms_data = dataset.lidar_ms_image
        print(f'{dataset.name} lidar_ms_data shape: {lidar_ms_data.shape}')

        # Check for data equalization, filtering and normalization
        if lidar_ms_histogram_equalization and not skip_data_preprocessing:
            lidar_ms_data = histogram_equalization(lidar_ms_data)
        if lidar_ms_data_filter is not None and not skip_data_preprocessing:
            print(f'Filtering LiDAR multispectral data with {lidar_ms_data_filter} filter...')
            lidar_ms_data = filter_image(lidar_ms_data, lidar_ms_data_filter)
        if normalize_lidar_ms_data:
            print('Normalizing LiDAR multispectral data...')
            lidar_ms_data = normalize_image(lidar_ms_data)

        # Add lidar multispectral data to data cube and save channel
        # indices for lidar multispectral data
        if data is None:
            lidar_ms_channels = range(lidar_ms_data.shape[-1])
            data = np.copy(lidar_ms_data)
            channel_labels = dataset.lidar_ms_band_wavelength_labels
        else:
            lidar_ms_channels = [x + data.shape[-1] for x in range(lidar_ms_data.shape[-1])]
            data = np.dstack((data, lidar_ms_data))
            channel_labels = channel_labels + dataset.lidar_ms_band_wavelength_labels

    # Check to see if lidar normalized digital surface model data is
    # being used
    if use_lidar_ndsm_data:
        if dataset.lidar_dsm_image is None or dataset.lidar_dem_image is None:
            lidar_dsm_data = dataset.load_full_lidar_dsm_image(resampling=lidar_ndsm_resampling)
            lidar_dem_data = dataset.load_full_lidar_dem_image(resampling=lidar_ndsm_resampling)
        else:
            lidar_dsm_data = dataset.lidar_dsm_image
            lidar_dem_data = dataset.lidar_dem_image
        print(f'{dataset.name} lidar_dsm_data shape: {lidar_dsm_data.shape}')
        print(f'{dataset.name} lidar_dem_data shape: {lidar_dem_data.shape}')

        # Check for data equalization, filtering and normalization
        if lidar_dsm_histogram_equalization and not skip_data_preprocessing:
            lidar_dem_data = histogram_equalization(lidar_dsm_data)
        # Check for data equalization, filtering and normalization
        if lidar_dem_histogram_equalization and not skip_data_preprocessing:
            lidar_dem_data = histogram_equalization(lidar_dem_data)

        # Check for data filtering
        if lidar_dsm_data_filter is not None and not skip_data_preprocessing:
            print(f'Filtering LiDAR DSM data with {lidar_dsm_data_filter} filter...')
            lidar_dsm_data = filter_image(lidar_dsm_data, lidar_dsm_data_filter)
        if lidar_dem_data_filter is not None:
            print(f'Filtering LiDAR DEM data with {lidar_dem_data_filter} filter...')
            lidar_dem_data = filter_image(lidar_dem_data, lidar_dem_data_filter)

        # Create NDSM image
        print('Creating NDSM image from DSM and DEM (NDSM = DSM - DEM)...')
        lidar_ndsm_data = lidar_dsm_data - lidar_dem_data

        # Check for data equalization, filtering and normalization
        if lidar_ndsm_histogram_equalization and not skip_data_preprocessing:
            lidar_ndsm_data = histogram_equalization(lidar_ndsm_data)

        # Check for data normalization
        if normalize_lidar_ndsm_data:
            print('Normalizing LiDAR NDSM data...')
            lidar_ndsm_data = normalize_image(lidar_ndsm_data)

        # Add lidar NDSM data to data cube and save channel
        # index for lidar NDSM data
        if data is None:
            lidar_ndsm_channels = [0]
            data = np.copy(lidar_ndsm_data)
            channel_labels = ['NDSM']
        else:
            lidar_ndsm_channels = [data.shape[-1]]
            data = np.dstack((data, lidar_ndsm_data))
            channel_labels = channel_labels + ['NDSM']

    # Check to see if very high resolution RGB image data is being used
    if use_vhr_data:
        # Load Very High Resolution RGB image
        if dataset.vhr_image is None:
            vhr_data = dataset.load_full_vhr_image(normalize=normalize_vhr_data,
                                                   resampling=vhr_resampling)
        else:
            vhr_data = dataset.vhr_image
        print(f'{dataset.name} vhr_data shape: {vhr_data.shape}')

        # Check for data equalization, filtering and normalization
        if vhr_histogram_equalization and not skip_data_preprocessing:
            vhr_data = histogram_equalization(vhr_data)
        if vhr_data_filter is not None and not skip_data_preprocessing:
            print(f'Filtering VHR RGB data with {vhr_data_filter} filter...')
            vhr_data = filter_image(vhr_data, vhr_data_filter)
        if normalize_vhr_data:
            print('Normalizing VHR RGB data...')
            vhr_data = normalize_image(vhr_data)

        # Add VHR data to data cube and save channel indices for VHR
        # RGB data
        if data is None:
            vhr_rgb_channels = range(vhr_data.shape[-1])
            data = np.copy(vhr_data)
            channel_labels = ['vhr_red', 'vhr_green', 'vhr_blue']
        else:
            vhr_rgb_channels = [x + data.shape[-1] for x in range(vhr_data.shape[-1])]
            data = np.dstack((data, vhr_data))
            channel_labels = channel_labels + ['vhr_red', 'vhr_green', 'vhr_blue']

    # Verify that some data was loaded
    if data is not None:
        print(f'{dataset.name} full dataset shape: {data.shape}')
    else:
        print('No data was loaded! Training cancelled...')
        return


    print(f'{dataset.name} train_gt shape: {train_gt.shape}')
    print(f'{dataset.name} test_gt shape: {test_gt.shape}')

    dataset_info = {
        'name': dataset.name,
        'num_classes': dataset.gt_num_classes,
        'ignored_labels': dataset.gt_ignored_labels,
        'class_labels': dataset.gt_class_label_list,
        'label_mapping': dataset.gt_class_value_mapping,
        'hs_channels': list(hs_channels),
        'lidar_ms_channels': list(lidar_ms_channels),
        'lidar_ndsm_channels': list(lidar_ndsm_channels),
        'vhr_rgb_channels': list(vhr_rgb_channels),
        'channel_labels': channel_labels,
    }

    return data, train_gt, test_gt, dataset_info


# 4) Data Preprocessing Functions

In [None]:
#@title Preprocess Data
def preprocess_data(data, **hyperparams):
    """
    """
    #TODO
    return data

# 5) Data Postprocessing Functions

In [None]:
#@title Postprocess Data
def postprocess_data(pred_test, **hyperparams):
    """
    """
    #TODO
    return pred_test

# 6) Utility functions and classes

In [None]:
#@title Setup 'Get GPU Device' function
def get_device(ordinal):
    """
    Takes a GPU device identifier and, if available, returns the device,
    and if not returns the CPU device.

    Parameters
    ----------
    ordinal : int
        The Tensorflow device ordinal ID

    Returns
    -------
    device
        A context manager for the specified device to use for newly created ops
    """
    if ordinal < 0:
        print("Computation on CPU")
        device = '/CPU:0'
    elif len(tf.config.list_physical_devices('GPU')) > 0:
        print(f'Computation on CUDA GPU device {ordinal}')
        device = f'/GPU:{ordinal}'
    else:
        print("<!> CUDA was requested but is not available! Computation will go on CPU. <!>")
        device = '/CPU:0'
    return device

In [None]:
#@title Setup 'Prime Generator' function
def prime_generator():
    """
    Generate an infinite sequence of prime numbers.

    Sieve of Eratosthenes
    Code by David Eppstein, UC Irvine, 28 Feb 2002
    http://code.activestate.com/recipes/117119/
    """
    # Maps composites to primes witnessing their compositeness.
    # This is memory efficient, as the sieve is not "run forward"
    # indefinitely, but only as long as required by the current
    # number being tested.
    #
    D = {}

    # The running integer that's checked for primeness
    q = 2

    while True:
        if q not in D:
            # q is a new prime.
            # Yield it and mark its first multiple that isn't
            # already marked in previous iterations
            #
            yield q
            D[q * q] = [q]
        else:
            # q is composite. D[q] is the list of primes that
            # divide it. Since we've reached q, we no longer
            # need it in the map, but we'll mark the next
            # multiples of its witnesses to prepare for larger
            # numbers
            #
            for p in D[q]:
                D.setdefault(p + q, []).append(p)
            del D[q]

        q += 1

In [None]:
#@title Setup 'Filter Prediciton Results' function
def filter_pred_results(test_gt, pred_test, ignored_labels):
    """
    """
    # Reshape pred_test to be the same size as train_gt
    pred_test = np.reshape(pred_test, test_gt.shape)
    indices = get_valid_gt_indices(test_gt, ignored_labels=ignored_labels)
    target_test = np.array([test_gt[x, y] for x, y in indices])
    pred_test = np.array([pred_test[x, y] for x, y in indices])

    return target_test, pred_test

In [None]:
#@title Setup 'Average Accuracy' and 'Each Class Accuracy' function

def AA_andEachClassAccuracy(confusion_matrix):
    counter = confusion_matrix.shape[0]
    list_diag = np.diag(confusion_matrix)
    list_raw_sum = np.sum(confusion_matrix, axis=1)
    each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
    average_acc = np.mean(each_acc)
    return each_acc, average_acc

In [None]:
#@title Setup zero padding functions

def zeroPadding_1D(old_matrix, pad_length, pad_depth = 0):
    new_matrix = np.lib.pad(old_matrix, ((0, pad_length)), 'constant', constant_values=0)
    return new_matrix

def zeroPadding_2D(old_matrix, pad_length):
    new_matrix = np.lib.pad(old_matrix, ((pad_length, pad_length),(pad_length, pad_length)), 'constant', constant_values=0)
    return new_matrix

def zeroPadding_3D(old_matrix, pad_length, pad_depth = 0):
    new_matrix = np.lib.pad(old_matrix, ((pad_length, pad_length), (pad_length, pad_length), (pad_depth, pad_depth)), 'constant', constant_values=0)
    return new_matrix



In [None]:
#@title Setup model statistics recording functions

# KAPPA_3D_DenseNet, OA_3D_DenseNet, AA_3D_DenseNet, ELEMENT_ACC_3D_DenseNet,TRAINING_TIME_3D_DenseNet,
# TESTING_TIME_3D_DenseNet, history_3d_densenet, loss_and_metrics, CATEGORY,
def outputStats(KAPPA_AE, OA_AE, AA_AE, ELEMENT_ACC_AE, TRAINING_TIME_AE, TESTING_TIME_AE, history, loss_and_metrics, CATEGORY, path1, path2):


    f = open(path1, 'a')

    sentence0 = 'KAPPAs, mean_KAPPA ± std_KAPPA for each iteration are:' + str(KAPPA_AE) + str(np.mean(KAPPA_AE)) + ' ± ' + str(np.std(KAPPA_AE)) + '\n'
    f.write(sentence0)
    sentence1 = 'OAs, mean_OA ± std_OA for each iteration are:' + str(OA_AE) + str(np.mean(OA_AE)) + ' ± ' + str(np.std(OA_AE)) + '\n'
    f.write(sentence1)
    sentence2 = 'AAs, mean_AA ± std_AA for each iteration are:' + str(AA_AE) + str(np.mean(AA_AE)) + ' ± ' + str(np.std(AA_AE)) + '\n'
    f.write(sentence2)
    sentence3 = 'Total average Training time is :' + str(np.sum(TRAINING_TIME_AE)) + '\n'
    f.write(sentence3)
    sentence4 = 'Total average Testing time is:' + str(np.sum(TESTING_TIME_AE)) + '\n'
    f.write(sentence4)

    element_mean = np.mean(ELEMENT_ACC_AE, axis=0)
    element_std = np.std(ELEMENT_ACC_AE, axis=0)
    sentence5 = "Mean of all elements in confusion matrix:" + str(np.mean(ELEMENT_ACC_AE, axis=0)) + '\n'
    f.write(sentence5)
    sentence6 = "Standard deviation of all elements in confusion matrix" + str(np.std(ELEMENT_ACC_AE, axis=0)) + '\n'
    f.write(sentence6)

    f.close()

    print_matrix = np.zeros((CATEGORY), dtype=object)
    for i in range(CATEGORY):
        print_matrix[i] = str(element_mean[i]) + " ± " + str(element_std[i])

    np.savetxt(path2, print_matrix.astype(str), fmt='%s', delimiter="\t",
               newline='\n')

    print('Test score:', loss_and_metrics[0])
    print('Test accuracy:', loss_and_metrics[1])
    print(history.history.keys())


def outputStats_assess(KAPPA_AE, OA_AE, AA_AE, ELEMENT_ACC_AE, CATEGORY, path1, path2):


    f = open(path1, 'a')

    sentence0 = 'KAPPAs, mean_KAPPA ± std_KAPPA for each iteration are:' + str(KAPPA_AE) + str(np.mean(KAPPA_AE)) + ' ± ' + str(np.std(KAPPA_AE)) + '\n'
    f.write(sentence0)
    sentence1 = 'OAs, mean_OA ± std_OA for each iteration are:' + str(OA_AE) + str(np.mean(OA_AE)) + ' ± ' + str(np.std(OA_AE)) + '\n'
    f.write(sentence1)
    sentence2 = 'AAs, mean_AA ± std_AA for each iteration are:' + str(AA_AE) + str(np.mean(AA_AE)) + ' ± ' + str(np.std(AA_AE)) + '\n'
    f.write(sentence2)

    element_mean = np.mean(ELEMENT_ACC_AE, axis=0)
    element_std = np.std(ELEMENT_ACC_AE, axis=0)
    sentence5 = "Mean of all elements in confusion matrix:" + str(np.mean(ELEMENT_ACC_AE, axis=0)) + '\n'
    f.write(sentence5)
    sentence6 = "Standard deviation of all elements in confusion matrix" + str(np.std(ELEMENT_ACC_AE, axis=0)) + '\n'
    f.write(sentence6)

    f.close()

    print_matrix = np.zeros((CATEGORY), dtype=object)
    for i in range(CATEGORY):
        print_matrix[i] = str(element_mean[i]) + " ± " + str(element_std[i])

    np.savetxt(path2, print_matrix.astype(str), fmt='%s', delimiter="\t",
               newline='\n')


def outputStats_SVM(KAPPA_AE, OA_AE, AA_AE, ELEMENT_ACC_AE, TRAINING_TIME_AE, TESTING_TIME_AE, CATEGORY, path1, path2):


    f = open(path1, 'a')

    sentence0 = 'KAPPAs, mean_KAPPA ± std_KAPPA for each iteration are:' + str(KAPPA_AE) + str(np.mean(KAPPA_AE)) + ' ± ' + str(np.std(KAPPA_AE)) + '\n'
    f.write(sentence0)
    sentence1 = 'OAs, mean_OA ± std_OA for each iteration are:' + str(OA_AE) + str(np.mean(OA_AE)) + ' ± ' + str(np.std(OA_AE)) + '\n'
    f.write(sentence1)
    sentence2 = 'AAs, mean_AA ± std_AA for each iteration are:' + str(AA_AE) + str(np.mean(AA_AE)) + ' ± ' + str(np.std(AA_AE)) + '\n'
    f.write(sentence2)
    sentence3 = 'Total average Training time is :' + str(np.sum(TRAINING_TIME_AE)) + '\n'
    f.write(sentence3)
    sentence4 = 'Total average Testing time is:' + str(np.sum(TESTING_TIME_AE)) + '\n'
    f.write(sentence4)

    element_mean = np.mean(ELEMENT_ACC_AE, axis=0)
    element_std = np.std(ELEMENT_ACC_AE, axis=0)
    sentence5 = "Mean of all elements in confusion matrix:" + str(np.mean(ELEMENT_ACC_AE, axis=0)) + '\n'
    f.write(sentence5)
    sentence6 = "Standard deviation of all elements in confusion matrix" + str(np.std(ELEMENT_ACC_AE, axis=0)) + '\n'
    f.write(sentence6)

    f.close()

    print_matrix = np.zeros((CATEGORY), dtype=object)
    for i in range(CATEGORY):
        print_matrix[i] = str(element_mean[i]) + " ± " + str(element_std[i])

    np.savetxt(path2, print_matrix.astype(str), fmt='%s', delimiter="\t",
               newline='\n')

In [None]:
#@title Setup 1-D index to row/column assignment function

def indexToAssignment(indices, Row, Col, pad_length):
    """
    Takes a list of indices to samples in the dataset and creates a new
    list of row-column index pairs.

    Parameters
    ----------
    indices : list of int
        A list of indices to the sample points on the dataset.
    Row : int
        The number of rows in the dataset.
    Col : int
        The number of columns in the dataset.
    pad_length : int
        The number of neighbors of the sample in each spatial direction.

    Returns
    -------
    new_assign : dictionary of lists of int
        A new list of row-column sample indicies.
    """

    # Initialize assignment dictionary
    new_assign = {}

    # Loop through the enumeration of the indices
    for counter, value in enumerate(indices):
        assign_0 = value // Col + pad_length    # Row assignment
        assign_1 = value % Col + pad_length     # Column assignment
        new_assign[counter] = [assign_0, assign_1] # Assign row-col pair

    return new_assign

In [None]:
#@title Setup neighboring pixel patch selection function

def selectNeighboringPatch(matrix, pos_row, pos_col, ex_len):
    """
    Selects the patch of neighbors for a particular sample point.

    Parameters
    ----------
    matrix : zero padded nparray
        The dataset from which to select the neighborhood patch.
    pos_row : int
        Row index of sample to find neighborhood of.
    pos_col : int
        Column index of sample to find neighborhood of.
    ex_len : int
        The number of neighbors in each spatial direction.

    Returns
    -------
    selected_patch : nparray
        The (ex_len*2+1) by (ex_len*2+1) matrix of samples in the
        (pos_row, pos_col) sample neighborhood.
    """
    # Narrow down the data matrix to the rows that are in the sample's
    # neighborhood
    selected_rows = matrix[range(pos_row - ex_len, pos_row + ex_len + 1), :]

    # Of the set of rows that are in the neighborhood, select the set
    # of columns in the neighborhood
    selected_patch = selected_rows[:, range(pos_col - ex_len, pos_col + ex_len + 1)]

    return selected_patch

In [None]:
#@title Setup test/train/validation sampling function

def sampling(proportionVal, groundTruth):
    """
    Divides the dataset into training and testing datasets by randomly
    sampling each class and separating the samples by validation split.

    Parameters
    ----------
    proportionVal : float
        The 0.0 < 'proportionVal' < 1.0 proportion of the entire dataset
        that will be used for validation/test set.
    groundTruth : nparray of int
        The dataset of ground truth classes.

    Returns
    -------
    train_indices : list of int
        A list of whole dataset indices that will be used for the
        training dataset.
    test_indices : list of int
        A list of whole dataset indices that will be used for the
        testing/validation dataset.
    """

    # Initialize label - sample dictionaries
    labels_loc = {}
    train = {}
    test = {}

    # Get the number of classes in the ground truth
    m = max(groundTruth)
    print(m)

    # Get a random sampling of each class for the training and testing
    # sets
    for i in range(m):
        # Get indicies of samples that belong to class i
        indices = [j for j, x in enumerate(groundTruth.ravel().tolist()) if x == i + 1]

        # Shuffle the indicies 'randomly' (repeatable due to random seed)
        np.random.shuffle(indices)

        # Save the locations of all the matching samples for current
        # label
        labels_loc[i] = indices

        # Get the number of samples dedicated to the training set vs.
        # the testing set
        nb_val = int(proportionVal * len(indices))

        # Set (1-proportionVal) fraction of samples for this label to
        # the training set
        train[i] = indices[:-nb_val]

        # Set proportionVal fraction of samples for this label to
        # the testing/validation set
        test[i] = indices[-nb_val:]

    # Initialize lists for training and testing point indicies
    train_indices = []
    test_indices = []

    # Copy training and testing sample indicies to their respective list
    for i in range(m):
        train_indices += train[i]
        test_indices += test[i]

    # Shuffle the order of the sample indicies in the indices lists
    np.random.shuffle(train_indices)
    np.random.shuffle(test_indices)

    # Print number of testing and training samples
    print(len(test_indices))
    print(len(train_indices))

    return train_indices, test_indices

# 7) Band Selection Classes & Functions

## 7.1) Preprocessing

In [None]:
#@title Rolling Window Function
def rolling_window(array, window=(0,), asteps=None, wsteps=None, axes=None, toend=True):
    """Create a view of `array` which for every point gives the n-dimensional
    neighbourhood of size window. New dimensions are added at the end of
    `array` or after the corresponding original dimension.

    Parameters
    ----------
    array : array_like
        Array to which the rolling window is applied.
    window : int or tuple
        Either a single integer to create a window of only the last axis or a
        tuple to create it for the last len(window) axes. 0 can be used as a
        to ignore a dimension in the window.
    asteps : tuple
        Aligned at the last axis, new steps for the original array, ie. for
        creation of non-overlapping windows. (Equivalent to slicing result)
    wsteps : int or tuple (same size as window)
        steps for the added window dimensions. These can be 0 to repeat values
        along the axis.
    axes: int or tuple
        If given, must have the same size as window. In this case window is
        interpreted as the size in the dimension given by axes. IE. a window
        of (2, 1) is equivalent to window=2 and axis=-2.
    toend : bool
        If False, the new dimensions are right after the corresponding original
        dimension, instead of at the end of the array. Adding the new axes at the
        end makes it easier to get the neighborhood, however toend=False will give
        a more intuitive result if you view the whole array.

    Returns
    -------
    A view on `array` which is smaller to fit the windows and has windows added
    dimensions (0s not counting), ie. every point of `array` is an array of size
    window.

    Examples
    --------
    >>> a = np.arange(9).reshape(3,3)
    >>> rolling_window(a, (2,2))
    array([[[[0, 1],
             [3, 4]],

            [[1, 2],
             [4, 5]]],


           [[[3, 4],
             [6, 7]],

            [[4, 5],
             [7, 8]]]])

    Or to create non-overlapping windows, but only along the first dimension:
    >>> rolling_window(a, (2,0), asteps=(2,1))
    array([[[0, 3],
            [1, 4],
            [2, 5]]])

    Note that the 0 is discared, so that the output dimension is 3:
    >>> rolling_window(a, (2,0), asteps=(2,1)).shape
    (1, 3, 2)

    This is useful for example to calculate the maximum in all (overlapping)
    2x2 submatrixes:
    >>> rolling_window(a, (2,2)).max((2,3))
    array([[4, 5],
           [7, 8]])

    Or delay embedding (3D embedding with delay 2):
    >>> x = np.arange(10)
    >>> rolling_window(x, 3, wsteps=2)
    array([[0, 2, 4],
           [1, 3, 5],
           [2, 4, 6],
           [3, 5, 7],
           [4, 6, 8],
           [5, 7, 9]])
    """
    array = np.asarray(array)
    orig_shape = np.asarray(array.shape)
    window = np.atleast_1d(window).astype(int)  # maybe crude to cast to int...

    if axes is not None:
        axes = np.atleast_1d(axes)
        w = np.zeros(array.ndim, dtype=int)
        for axis, size in zip(axes, window):
            w[axis] = size
        window = w

    # Check if window is legal:
    if window.ndim > 1:
        raise ValueError("`window` must be one-dimensional.")
    if np.any(window < 0):
        raise ValueError("All elements of `window` must be larger then 1.")
    if len(array.shape) < len(window):
        raise ValueError("`window` length must be less or equal `array` dimension.")

    _asteps = np.ones_like(orig_shape)
    if asteps is not None:
        asteps = np.atleast_1d(asteps)
        if asteps.ndim != 1:
            raise ValueError("`asteps` must be either a scalar or one dimensional.")
        if len(asteps) > array.ndim:
            raise ValueError("`asteps` cannot be longer then the `array` dimension.")
        # does not enforce alignment, so that steps can be same as window too.
        _asteps[-len(asteps):] = asteps

        if np.any(asteps < 1):
            raise ValueError("All elements of `asteps` must be larger then 1.")
    asteps = _asteps

    _wsteps = np.ones_like(window)
    if wsteps is not None:
        wsteps = np.atleast_1d(wsteps)
        if wsteps.shape != window.shape:
            raise ValueError("`wsteps` must have the same shape as `window`.")
        if np.any(wsteps < 0):
            raise ValueError("All elements of `wsteps` must be larger then 0.")

        _wsteps[:] = wsteps
        _wsteps[window == 0] = 1  # make sure that steps are 1 for non-existing dims.
    wsteps = _wsteps

    # Check that the window would not be larger then the original:
    if np.any(orig_shape[-len(window):] < window * wsteps):
        raise ValueError("`window` * `wsteps` larger then `array` in at least one dimension.")

    new_shape = orig_shape  # just renaming...

    # For calculating the new shape 0s must act like 1s:
    _window = window.copy()
    _window[_window == 0] = 1

    new_shape[-len(window):] += wsteps - _window * wsteps
    new_shape = (new_shape + asteps - 1) // asteps
    # make sure the new_shape is at least 1 in any "old" dimension (ie. steps
    # is (too) large, but we do not care.
    new_shape[new_shape < 1] = 1
    shape = new_shape

    strides = np.asarray(array.strides)
    strides *= asteps
    new_strides = array.strides[-len(window):] * wsteps

    # The full new shape and strides:
    if toend:
        new_shape = np.concatenate((shape, window))
        new_strides = np.concatenate((strides, new_strides))
    else:
        _ = np.zeros_like(shape)
        _[-len(window):] = window
        _window = _.copy()
        _[-len(window):] = new_strides
        _new_strides = _

        new_shape = np.zeros(len(shape) * 2, dtype=int)
        new_strides = np.zeros(len(shape) * 2, dtype=int)

        new_shape[::2] = shape
        new_strides[::2] = strides
        new_shape[1::2] = _window
        new_strides[1::2] = _new_strides

    new_strides = new_strides[new_shape != 0]
    new_shape = new_shape[new_shape != 0]

    return np.lib.stride_tricks.as_strided(array, shape=new_shape, strides=new_strides)

In [None]:
#@title Processor Class
import spectral as spy

class Processor:

    def __init__(self):
        pass

    def prepare_data(self, img_path, gt_path):
        if img_path[-3:] == 'mat':
            import scipy.io as sio
            img_mat = sio.loadmat(img_path)
            gt_mat = sio.loadmat(gt_path)
            img_keys = img_mat.keys()
            gt_keys = gt_mat.keys()
            img_key = [k for k in img_keys if k != '__version__' and k != '__header__' and k != '__globals__']
            gt_key = [k for k in gt_keys if k != '__version__' and k != '__header__' and k != '__globals__']
            return img_mat.get(img_key[0]).astype('float64'), gt_mat.get(gt_key[0]).astype('int8')
        else:
            import spectral as spy
            img = spy.open_image(img_path).load()
            gt = spy.open_image(gt_path)
            a = spy.principal_components()
            a.transform()
            return img, gt.read_band(0)

    def get_correct(self, img, gt):
        """
        :param img: 3D arr
        :param gt: 2D arr
        :return: covert arr  [n_samples,n_bands]
        """
        gt_1D = gt.reshape(-1)
        index = gt_1D.nonzero()
        gt_correct = gt_1D[index]
        img_2D = img.reshape(img.shape[0] * img.shape[1], img.shape[2])
        img_correct = img_2D[index]
        return img_correct, gt_correct

    def get_tr_tx_index(self, y, test_size=0.9):
        from sklearn.model_selection import train_test_split
        X_train_index, X_test_index, y_train_, y_test_ = \
            train_test_split(np.arange(0, y.shape[0]), y, test_size=test_size)
        return X_train_index, X_test_index

    def divide_img_blocks(self, img, gt, block_size=(5, 5)):
        """
        split image into a*b blocks, the edge filled with its mirror
        :param img:
        :param gt:
        :param block_size; tuple of size, it must be odd and >=3
        :return: correct image blocks
        """
        # TODO: padding edge with mirror
        w_1, w_2 = int((block_size[0] - 1) / 2), int((block_size[1] - 1) / 2)
        img_padding = np.pad(img, ((w_1, w_2),
                                   (w_1, w_2), (0, 0)), 'symmetric')
        gt_padding = np.pad(gt, ((w_1, w_2),
                                 (w_1, w_2)), 'symmetric')
        img_blocks = rolling_window(img_padding, block_size, axes=(1, 0))  # divide data into 5x5 blocks
        gt_blocks = rolling_window(gt_padding, block_size, axes=(1, 0))
        i_1, i_2 = int((block_size[0] - 1) / 2), int((block_size[0] - 1) / 2)
        nonzero_index = gt_blocks[:, :, i_1, i_2].nonzero()
        img_blocks_nonzero = img_blocks[nonzero_index]
        gt_blocks_nonzero = (gt_blocks[:, :, i_1, i_2])[nonzero_index]
        return img_blocks_nonzero, gt_blocks_nonzero

    def split_tr_tx(self, X, y, test_size=0.4):
        """
        X_train, X_test, y_train, y_test
        :param X:
        :param y:
        :param test_size:
        :return:
        """
        from sklearn.model_selection import train_test_split
        return train_test_split(X, y, test_size=test_size)

    def split_each_class(self, X, y, each_train_size=10):
        X_tr, y_tr, X_ts, y_ts = [], [], [], []
        for c in np.unique(y):
            y_index = np.nonzero(y == c)[0]
            np.random.shuffle(y_index)
            cho, non_cho = np.split(y_index, [each_train_size, ])
            X_tr.append(X[cho])
            y_tr.append(y[cho])
            X_ts.append(X[non_cho])
            y_ts.append(y[non_cho])
        X_tr, X_ts, y_tr, y_ts = np.asarray(X_tr), np.asarray(X_ts), np.asarray(y_tr), np.asarray(y_ts)
        return X_tr.reshape(X_tr.shape[0] * X_tr.shape[1], X.shape[1]),\
               X_ts.reshape(X_ts.shape[0] * X_ts.shape[1], X.shape[1]), \
               y_tr.flatten(), y_ts.flatten()

    def save_experiment(self, y_pre, y_test, file_neme=None, parameters=None):
        """
        save classification results and experiment parameters into files for k-folds cross validation.
        :param y_pre:
        :param y_test:
        :param parameters:
        :return:
        """
        import os
        home = os.getcwd() + '/experiments'
        if not os.path.exists(home):
            os.makedirs(home)
        if parameters == None:
            parameters = [None]
        if file_neme == None:
            file_neme = home + '/scores.npz'
        else:
            file_neme = home + '/' + file_neme + '.npz'

        '''save results and scores into a numpy file'''
        ca, oa, aa, kappa = [], [], [], []
        if np.array(y_pre).shape.__len__() > 1:  # that means test data tested k times
            for y in y_pre:
                ca_, oa_, aa_, kappa_ = self.score(y_test, y)
                ca.append(ca_), oa.append(oa_), aa.append(aa_), kappa.append(kappa_)
        else:
            ca, oa, aa, kappa = self.score(y_test, y_pre)
        np.savez(file_neme, y_test=y_test, y_pre=y_pre, CA=np.array(ca), OA=np.array(oa), AA=aa, Kappa=kappa,
                 param=parameters)
        print('the experiments have been saved in experiments/scores.npz')

    # def get_train_test_indexes(self, train_size, gt):
    #     """
    #
    #     :param train_size:
    #     :param gt:
    #     :return:
    #     """
    #     gt_1D = gt.reshape(-1)
    #     samples_correct = gt_1D[gt_1D.nonzero()]
    #     n_samples = samples_correct.shape[0]  # the num of available samples
    #     classes = {}
    #     for i in np.unique(samples_correct):
    #         classes[i] = len(np.nonzero(samples_correct == i)[0])
    #     if train_size >= min(classes.values()):
    #             train_size = min(classes.values())
    #     train_indexes = np.empty((0))
    #     test_indexes = np.empty((0))
    #     for key in classes:
    #         size_ci = classes[key]
    #         index_ci = np.nonzero(gt_1D == key)[0]  # 1 dim: (row,col=None)
    #         index_train__ = np.empty(0)
    #         if train_size > 0 and train_size < 1.:
    #             # slip data as percentage for each class
    #             index_train__ = np.random.choice(index_ci, int(size_ci * train_size), replace=False)
    #         else:
    #             # slip data as form of fixed numbers
    #             index_train__ = np.random.choice(index_ci, int(train_size), replace=False)
    #         index_test__ = np.setdiff1d(index_ci,index_train__)
    #         train_indexes = np.append(train_indexes,index_train__)
    #         test_indexes = np.append(test_indexes,index_test__)
    #     return train_indexes.astype(np.int64),test_indexes.astype(np.int64)

    def majority_filter(self, classes_map, selems):
        """
        :param classes_map: 2 dim image
        :param selems: elements: [disk(1),square(2)...]
        :return:
        """
        from skimage.filters.rank import modal
        # from skimage.morphology import disk,square
        classes_map__ = classes_map.astype(np.uint16)  # convert dtype to uint16
        out = classes_map__
        for selem in selems:
            out = modal(classes_map__, selem)
            classes_map__ = out
        return out.astype(np.int8)

    def score(self, y_test, y_predicted):
        """
        calculate the accuracy and other criterion according to predicted results
        :param y_test:
        :param y_predicted:
        :return: ca, oa, aa, kappa
        """
        from sklearn.metrics import accuracy_score
        '''overall accuracy'''
        oa = accuracy_score(y_test, y_predicted)
        '''average accuracy for each class'''
        n_classes = max([np.unique(y_test).__len__(), np.unique(y_predicted).__len__()])
        ca = []
        for c in np.unique(y_test):
            y_c = y_test[np.nonzero(y_test == c)]  # find indices of each class
            y_c_p = y_predicted[np.nonzero(y_test == c)]
            acurracy = accuracy_score(y_c, y_c_p)
            ca.append(acurracy)
        aa = (np.array(ca)).mean()

        '''kappa'''
        kappa = self.kappa(y_test, y_predicted)
        return ca, oa, aa, kappa

    def result2gt(self, y_predicted, test_indexes, gt):
        """

        :param y_predicted:
        :param test_indexes: indexes got from ground truth
        :param gt: 2-dim img
        :return:
        """
        n_row, n_col = gt.shape
        gt_1D = gt.reshape((n_row * n_col))
        gt_1D[test_indexes] = y_predicted
        return gt_1D.reshape(n_row, n_col)

    def extended_morphological_profile(self, components, disk_radius):
        """

        :param components:
        :param disk_radius:
        :return:2-dim emp
        """
        rows, cols, bands = components.shape
        n = disk_radius.__len__()
        import numpy as np
        emp = np.zeros((rows * cols, bands * (2 * n + 1)))
        from skimage.morphology import opening, closing, disk
        for band in range(bands):
            position = band * (n * 2 + 1) + n
            emp_ = np.zeros((rows, cols, 2 * n + 1))
            emp_[:, :, n] = components[:, :, band]
            i = 1
            for r in disk_radius:
                closed = closing(components[:, :, band], selem=disk(r))
                opened = opening(components[:, :, band], selem=disk(r))
                emp_[:, :, n - i] = closed
                emp_[:, :, n + i] = opened
                i += 1
            emp[:, position - n:position + n + 1] = emp_.reshape((rows * cols, 2 * n + 1))
        return emp.reshape(rows, cols, bands * (2 * n + 1))

    def texture_feature(self, components, theta_arr=None, frequency_arr=None):
        """
        extract the texture features
        :param components:
        :param theta_arr:
        :param frequency_arr:
        :return:
        """
        if theta_arr == None:
            theta_arr = np.arange(0, 8) * np.pi / 4  # 8 orientations
        if frequency_arr == None:
            frequency_arr = np.pi / (2 ** np.arange(1, 5))  # 4 frequency

        from skimage.filters import gabor
        results = []
        for img in components.transpose():
            for theta in theta_arr:
                for fre in frequency_arr:
                    filt_real, filt_imag = gabor(img, frequency=fre, theta=theta)
                    results.append(filt_real)
        return np.array(results).transpose()

    def pca_transform(self, n_components, samples):
        """

        :param n_components:
        :param samples: [nb_samples, bands]/or [n_row, n_column, n_bands]
        :return:
        """
        HSI_or_not = samples.shape.__len__() == 3  # denotes HSI data
        n_row, n_column, n_bands = 0, 0, 0
        if HSI_or_not:
            n_row, n_column, n_bands = samples.shape
            samples = samples.reshape((n_row * n_column, n_bands))
        from sklearn.decomposition import PCA
        pca = PCA(n_components=n_components)
        trans_samples = pca.fit_transform(samples)
        if HSI_or_not:
            return trans_samples.reshape((n_row, n_column, n_components))
        return trans_samples

    def normlize_HSI(self, img):
        from sklearn.preprocessing import normalize
        n_row, n_column, n_bands = img.shape
        norm_img = normalize(img.reshape(n_row * n_column, n_bands))
        return norm_img.reshape(n_row, n_column, n_bands)

    def each_class_OA(self, y_test, y_predicted):
        """
        get each OA for all class respectively
        :param y_test:
        :param y_predicted:
        :return:{}
        """
        classes = np.unique(y_test)
        results = []
        for c in classes:
            y_c = y_test[np.nonzero(y_test == c)]  # find indices of each class
            y_c_p = y_predicted[np.nonzero(y_test == c)]
            acurracy = self.score(y_c, y_c_p)
            results.append(acurracy)
        return np.array(results)

    def kappa(self, y_test, y_predicted):
        from sklearn.metrics import cohen_kappa_score
        return round(cohen_kappa_score(y_test, y_predicted), 3)

    def color_legend(self, color_map, label):
        """

        :param color_map: 1-n color map in range 0-255
        :param label: label list
        :return:
        """
        import matplotlib.patches as mpatches
        import matplotlib.pyplot as plt
        size = len(label)
        patchs = []
        m = 255.  # float(color_map.max())
        color_map_ = (color_map / m)[1:]
        for i in range(0, size):
            patchs.append(mpatches.Patch(color=color_map_[i], label=label[i]))
        # plt.legend(handles=patchs)
        return patchs

    def get_tr_ts_index_num(self, y, n_labeled=10):
        import random
        classes = np.unique(y)
        X_train_index, X_test_index = np.empty(0, dtype='int8'), np.empty(0, dtype='int8')
        for c in classes:
            index_c = np.nonzero(y == c)[0]
            random.shuffle(index_c)
            X_train_index = np.append(X_train_index, index_c[:n_labeled])
            X_test_index = np.append(X_test_index, index_c[n_labeled:])
        return X_train_index, X_test_index

    def save_res_4kfolds_cv(self, y_pres, y_tests, file_name=None, verbose=False):
        """
        save experiment results for k-folds cross validation
        :param y_pres: predicted labels, k*Ntest
        :param y_tests: true labels, k*Ntest
        :param file_name:
        :return:
        """
        ca, oa, aa, kappa = [], [], [], []
        for y_p, y_t in zip(y_pres, y_tests):
            ca_, oa_, aa_, kappa_ = self.score(y_t, y_p)
            ca.append(np.asarray(ca_)), oa.append(np.asarray(oa_)), aa.append(np.asarray(aa_)),
            kappa.append(np.asarray(kappa_))
        ca = np.asarray(ca) * 100
        oa = np.asarray(oa) * 100
        aa = np.asarray(aa) * 100
        kappa = np.asarray(kappa)
        ca_mean, ca_std = np.round(ca.mean(axis=0), 2), np.round(ca.std(axis=0), 2)
        oa_mean, oa_std = np.round(oa.mean(), 2), np.round(oa.std(), 2)
        aa_mean, aa_std = np.round(aa.mean(), 2), np.round(aa.std(), 2)
        kappa_mean, kappa_std = np.round(kappa.mean(), 3), np.round(kappa.std(), 3)
        if file_name is not None:
            file_name = 'scores.npz'
            np.savez(file_name, y_test=y_tests, y_pre=y_pres,
                     ca_mean=ca_mean, ca_std=ca_std,
                     oa_mean=oa_mean, oa_std=oa_std,
                     aa_mean=aa_mean, aa_std=aa_std,
                     kappa_mean=kappa_mean, kappa_std=kappa_std)
            print('the experiments have been saved in ', file_name)

        if verbose is True:
            print('---------------------------------------------')
            print('ca\t\t', '\taa\t\t', '\toa\t\t', '\tkappa\t\t')
            print(ca_mean, '+-', ca_std)
            print(aa_mean, '+-', aa_std)
            print(oa_mean, '+-', oa_std)
            print(kappa_mean, '+-', kappa_std)
        return ca, oa, aa, kappa

    # def view_clz_map(self, gt, y_index, y_predicted, save_path=None, show_error=False):
    #     """
    #     view HSI classification results
    #     :param gt:
    #     :param y_index: index of excluding 0th class
    #     :param y_predicted:
    #     :param show_error:
    #     :return:
    #     """
    #     n_row, n_column = gt.shape
    #     gt_1d = gt.reshape(-1).copy()
    #     nonzero_index = gt_1d.nonzero()
    #     gt_corrected = gt_1d[nonzero_index]
    #     if show_error:
    #         t = y_predicted.copy()
    #         correct_index = np.nonzero(y_predicted == gt_corrected[y_index])
    #         t[correct_index] = 0  # leave error
    #         gt_corrected[:] = 0
    #         gt_corrected[y_index] = t
    #         gt_1d[nonzero_index] = t
    #     else:
    #         gt_corrected[y_index] = y_predicted
    #         gt_1d[nonzero_index] = gt_corrected
    #     gt_map = gt_1d.reshape((n_row, n_column)).astype('uint8')
    #     spy.imshow(classes=gt_map)
    #     if save_path != None:
    #         spy.save_rgb(save_path, gt_map, colors=spy.spy_colors)
    #         print('the figure is saved in ', save_path)

    def split_source_target(self, X, y, split_attribute_index, split_threshold, save_name=None):
        """
        split source/target domain data for transfer learning according to attribute
        :param X:
        :param y:
        :param split_attribute_index:
        :param split_threshold: split condition. e.g if 1.2 those x[:,index] >= 1.2 are split into source
        :param save_name:
        :return:
        """
        source_index = np.nonzero(X[:, split_attribute_index] >= split_threshold)
        target_index = np.nonzero(X[:, split_attribute_index] < split_threshold)
        X_source = X[source_index]
        X_target = X[target_index]
        y_source = y[source_index].astype('int')
        y_target = y[target_index].astype('int')
        if save_name is not None:
            np.savez(save_name, X_source=X_source, X_target=X_target, y_source=y_source, y_target=y_target)
        return X_source, X_target, y_source, y_target

    def results_to_cvs(self, res_file_name, save_name):
        import csv
        dt = np.load(res_file_name)
        ca_mean = np.round(dt['CA'].mean(axis=0) * 100, 2)
        ca_std = np.round(dt['CA'].std(axis=0), 2)
        oa_mean = np.round(dt['OA'].mean() * 100, 2)
        oa_std = np.round(dt['OA'].std(axis=0), 2)
        aa_mean = np.round(dt['AA'].mean() * 100, 2)
        aa_std = np.round(dt['AA'].std(axis=0), 2)
        kappa_mean = np.round(dt['Kappa'].mean(), 3)
        kappa_std = np.round(dt['Kappa'].std(axis=0), 2)
        with open(save_name, 'wb') as f:
            writer = csv.writer(f)
            for i in zip(ca_mean, ca_std):
                writer.writerow(i)
            writer.writerow([oa_mean, oa_std])
            writer.writerow([aa_mean, aa_std])
            writer.writerow([kappa_mean, kappa_std])

    def view_clz_map_spyversion4single_img(self, gt, y_index, y_predicted, save_path=None, show_error=False,
                                           show_axis=False):
        """
        view HSI classification results
        :param gt:
        :param y_index: test index of excluding 0th class
        :param y_predicted:
        :param show_error:
        :return:
        """
        n_row, n_column = gt.shape
        gt_1d = gt.reshape(-1).copy()
        nonzero_index = gt_1d.nonzero()
        gt_corrected = gt_1d[nonzero_index]
        if show_error:
            t = y_predicted.copy()
            correct_index = np.nonzero(y_predicted == gt_corrected[y_index])
            t[correct_index] = 0  # leave error
            gt_corrected[:] = 0
            gt_corrected[y_index] = t
            gt_1d[nonzero_index] = t
        else:
            gt_corrected[y_index] = y_predicted
            gt_1d[nonzero_index] = gt_corrected
        gt_map = gt_1d.reshape((n_row, n_column)).astype('uint8')
        spy.imshow(classes=gt_map)
        if save_path != None:
            import matplotlib.pyplot as plt
            spy.save_rgb('temp.png', gt_map, colors=spy.spy_colors)
            if show_axis:
                plt.savefig(save_path, format='eps')
            else:
                plt.axis('off')
                plt.savefig(save_path, format='eps')
            print('the figure is saved in ', save_path)

    def view_clz_map_mlpversion(self, test_index, results, sub_indexes, labels, save_name=None):
        """ visualize image with 2 rows and 3 columns with the color legend for knn classification
            --------
            Usage:
                res = [gt, y_pre_spectral, y_pre_shape, y_pre_texture, y_pre_stack, y_pre_kernel]
                sub_index = [331, 332, 333, 334, 335, 336, 313]
                labels = ['(a) groundtruth', r'(b) $kNN_{spectral}$', r'(c) $kNN_{shape}$', r'(d) $kNN_{texture}$',
                r'(e) $kNN_{stack}$', r'(f) $kNN_{multi}$']
                view_clz_map_mlpversion(tx_index, res, sub_index, labels, save_name='./experiments/paviaU_class_map.eps')
        """
        import matplotlib.patches as mpatches
        import matplotlib.pyplot as plt
        import copy
        n_res = results.__len__()
        gt = copy.deepcopy(results[0])
        n_row, n_column = gt.shape
        gt_1d = gt.reshape(-1).copy()
        nonzero_index = gt_1d.nonzero()
        for i in range(n_res):
            if i == 0:
                gt_map = gt
            else:
                gt_corrected = copy.deepcopy(gt_1d[nonzero_index])
                gt_corrected[test_index] = results[i]
                gt_1d_temp = copy.deepcopy(gt.reshape(-1))
                gt_1d_temp[nonzero_index] = gt_corrected
                gt_map = gt_1d_temp.reshape((n_row, n_column)).astype('uint8')
            axe = plt.subplot(sub_indexes[i])
            im = axe.imshow(gt_map, cmap='jet')
            axe.set_title(labels[i], fontdict={'fontsize': 10})
            axe.axis('off')
        values = np.unique(gt.ravel())
        # get the colors of the values, according to the
        # colormap used by imshow
        colors = [im.cmap(im.norm(value)) for value in values]
        # create a patch (proxy artist) for every color
        patches = [mpatches.Patch(color=colors[i], label="{l}".format(l=values[i])) for i in range(len(values))]
        # put those patched as legend-handles into the legend
        axe_legend = plt.subplot(sub_indexes[-1])
        axe_legend.legend(handles=patches, loc=10, ncol=6)
        axe_legend.axis('off')

        # save image
        plt.savefig(save_name, format='eps', dpi=1000)
        print('the figure is saved in ', save_name)

    def standardize_label(self, y):
        """
        standardize the class label into 0-k
        :param y:
        :return:
        """
        import copy
        classes = np.unique(y)
        standardize_y = copy.deepcopy(y)
        for i in range(classes.shape[0]):
            standardize_y[np.nonzero(y == classes[i])] = i
        return standardize_y

    def one2array(self, y):
        n_classes = np.unique(y).__len__()
        y_expected = np.zeros((y.shape[0], n_classes))
        for i in range(y.shape[0]):
            y_expected[i][y[i]] = 1
        return y_expected

## 7.2) Band Selection Utility Functions & Classes

In [None]:
#@title Evaluate Band
def eval_band(new_img, gt, train_inx, test_idx):
    """

    :param new_img:
    :param gt:
    :param train_inx:
    :param test_idx:
    :return:
    """
    p = Processor()
    # img_, gt_ = p.get_correct(new_img, gt)
    gt_ = gt
    img_ = maxabs_scale(new_img)
    # X_train, X_test, y_train, y_test = train_test_split(img_, gt_, test_size=0.4, random_state=42)
    X_train, X_test, y_train, y_test = img_[train_inx], img_[test_idx], gt_[train_inx], gt_[test_idx]
    knn_classifier = KNN(n_neighbors=5)
    knn_classifier.fit(X_train, y_train)
    # score = cross_val_score(knn_classifier, img_, y=gt_, cv=3)
    y_pre = knn_classifier.predict(X_test)
    score = accuracy_score(y_test, y_pre)
    # score = np.mean(score)
    return score

In [None]:
#@title Evaluate Band with Cross Validation
def eval_band_cv(X, y, times=10):
    """
    :param X:
    :param y:
    :param times: n times k-fold cv
    :return:  knn/svm/elm=>(OA+std, Kappa+std)
    """
    p = Processor()
    img_ = maxabs_scale(X)
    estimator = [KNN(n_neighbors=5), SVC(C=1e4, kernel='rbf', gamma=1.), ELM_Classifier(200)]
    estimator_pre, y_test_all = [[], [], []], []
    for i in range(times):  # repeat N times K-fold CV
        skf = StratifiedKFold(n_splits=3, shuffle=True)
        for train_index, test_index in skf.split(img_, y):
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]
            y_test_all.append(y_test)
            for c in range(3):
                estimator[c].fit(X_train, y_train)
                estimator_pre[c].append(estimator[c].predict(X_test))
    clf = ['knn', 'svm', 'elm']
    score = []
    for z in range(3):
        ca, oa, aa, kappa = p.save_res_4kfolds_cv(estimator_pre[z], y_test_all, file_name=clf[z] + 'score.npz', verbose=True)
        score.append([oa, kappa])
    return score

In [None]:
#@title Extreme Learning Classifier Class
class ELM_Classifier(BaseEstimator, ClassifierMixin):
    upper_bound = .5
    lower_bound = -.5

    def __init__(self, n_hidden, dropout_prob=None):
        self.n_hidden = n_hidden
        self.dropout_prob = dropout_prob

    def fit(self, X, y, sample_weight=None):
        # check label has form of 2-dim array
        X, y, = copy.deepcopy(X), copy.deepcopy(y)
        self.sample_weight = None
        if y.shape.__len__() != 2:
            self.classes_ = np.unique(y)
            self.n_classes_ = self.classes_.__len__()
            y = self.one2array(y, self.n_classes_)
        else:
            self.classes_ = np.arange(y.shape[1])
            self.n_classes_ = self.classes_.__len__()
        self.W = np.random.uniform(self.lower_bound, self.upper_bound, size=(X.shape[1], self.n_hidden))
        if self.dropout_prob is not None:
            self.W = self.dropout(self.W, prob=self.dropout_prob)
            # X = self.dropout(X, prob=self.dropout_prob)
        self.b = np.random.uniform(self.lower_bound, self.upper_bound, size=self.n_hidden)
        H = expit(np.dot(X, self.W) + self.b)
        # H = self.dropout(H, prob=0.1)
        if sample_weight is not None:
            self.sample_weight = sample_weight / sample_weight.sum()
            extend_sample_weight = np.diag(self.sample_weight)
            inv_ = linalg.pinv(np.dot(
                np.dot(H.transpose(), extend_sample_weight), H))
            self.B = np.dot(np.dot(np.dot(inv_, H.transpose()), extend_sample_weight), y)
        else:
            self.B = np.dot(linalg.pinv(H), y)
        return self

    def one2array(self, y, n_dim):
        y_expected = np.zeros((y.shape[0], n_dim))
        for i in range(y.shape[0]):
            y_expected[i][y[i]] = 1
        return y_expected

    def predict(self, X, prob=False):
        X = copy.deepcopy(X)
        H = expit(np.dot(X, self.W) + self.b)
        output = np.dot(H, self.B)
        if prob:
            return output
        return output.argmax(axis=1)

    def get_params(self, deep=True):
        params = {'n_hidden': self.n_hidden, 'dropout_prob': self.dropout_prob}
        return params

    def set_params(self, **parameters):
        return self

    def dropout(self, x, prob=0.2):
        if prob < 0. or prob >= 1:
            raise Exception('Dropout level must be in interval [0, 1]')
        retain_prob = 1. - prob
        sample = np.random.binomial(n=1, p=retain_prob, size=x.shape)
        x *= sample
        # x /= retain_prob
        return x


## 7.3) CAE SSC Band Selection

In [None]:
#@title Setup 'CAE SSC Band Selection' Class
class CAE_BS(object):
    """
    :argument:
        Implementation of L2 norm based sparse self-expressive clustering model
        with affinity measurement basing on angular similarity
    """
    def __init__(self, n_band=10, coef_=1):
        self.n_band = n_band
        self.coef_ = coef_

    def fit(self, X):
        self.X = X
        return self

    def predict(self, X_cae_fea, X_origin):
        """
        :param X_cae_fea: shape [n_CAE_fea, n_band]
        :param X_origin: original HSI data with a 2-D shape of (n_row*n_clm, n_band)
        :return: selected band subset
        """
        cluster_res = self.__get_cluster_close(X_cae_fea)
        selected_band = self.__get_band(cluster_res, X_origin)
        return selected_band

    def __get_band(self, cluster_result, X):
        """
        select band according to the center of each cluster
        :param cluster_result:
        :param X:
        :return:
        """
        selected_band = []
        n_cluster = np.unique(cluster_result).__len__()
        # img_ = X.reshape((n_row * n_column, -1))  # n_sample * n_band
        for c in np.unique(cluster_result):
            idx = np.nonzero(cluster_result == c)
            center = np.mean(X[:, idx[0]], axis=1).reshape((-1, 1))
            distance = np.linalg.norm(X[:, idx[0]] - center, axis=0)
            band_ = X[:, idx[0]][:, distance.argmin()]
            print(f'idx[0] for c={c} :  {idx[0]}')
            print(f'distance.argmin() for c={c} :  {distance.argmin()}')
            print(f'band_ : {band_}')
            print()
            selected_band.append(band_)
        bands = np.asarray(selected_band).transpose()
        # bands = bands.reshape(n_cluster, n_row, n_column)
        # bands = np.transpose(bands, axes=(1, 2, 0))
        return bands

    def __get_cluster_close(self, X):
        """
        using close-form solution
        :param X:
        :return:
        """
        n_sample = X.transpose().shape[0]
        H = X.transpose()    # NRP_ELM(self.n_hidden, sparse=False).fit(X).predict(X)
        C = np.zeros((n_sample, n_sample))
        for i in range(n_sample):
            y_i = H[i]
            H_i = np.delete(H, i, axis=0).transpose()
            term_1 = np.linalg.inv(np.dot(H_i.transpose(), H_i) + self.coef_ * np.eye(n_sample - 1))
            w = np.dot(np.dot(term_1, H_i.transpose()), y_i.reshape((y_i.shape[0], 1)))
            w = w.flatten()
            #  Normalize the columns of C: ci = ci / ||ci||_ss.
            coef = w / np.max(np.abs(w))
            C[:i, i] = coef[:i]
            C[i + 1:, i] = coef[i:]
        # compute affinity matrix
        L = 0.5 * (np.abs(C) + np.abs(C.T))  # affinity graph
        self.affinity_matrix = L
        # spectral clustering
        sc = SpectralClustering(n_clusters=self.n_band, affinity='precomputed')
        sc.fit(self.affinity_matrix)
        return sc.labels_

## 7.4) DSC NET Band Selection

Code Authors: Pan Ji,     University of Adelaide,         pan.ji@adelaide.edu.au
Tong Zhang, Australian National University, tong.zhang@anu.edu.au
Copyright Reserved!

In [None]:
#@title DSC Net Class
class DSC_NET(object):
    def __init__(self, n_input, kernel_size, n_hidden, reg_const1=1.0, reg_const2=1.0, reg=None, batch_size=256,
                 max_iter=10, denoise=False, model_path=None, logs_path='./logs'):
        # n_hidden is a arrary contains the number of neurals on every layer
        tf.reset_default_graph()
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.reg = reg
        self.model_path = model_path
        self.kernel_size = kernel_size
        self.iter = 0
        self.batch_size = batch_size
        # weights = self._initialize_weights()
        # # Variable initialization
        weights = dict()
        with tf.variable_scope('weight', reuse=tf.AUTO_REUSE):
            weights['enc_w0'] = tf.get_variable("enc_w0",
                                                    shape=[self.kernel_size[0], self.kernel_size[0], 1,
                                                           self.n_hidden[0]],
                                                    initializer=layers.xavier_initializer_conv2d(),
                                                    regularizer=self.reg)
            weights['enc_b0'] = tf.Variable(tf.zeros([self.n_hidden[0]], dtype=tf.float32))

            weights['dec_w0'] = tf.get_variable("dec_w0",
                                                    shape=[self.kernel_size[0], self.kernel_size[0], 1,
                                                           self.n_hidden[0]],
                                                    initializer=layers.xavier_initializer_conv2d(),
                                                    regularizer=self.reg)
            weights['dec_b0'] = tf.Variable(tf.zeros([1], dtype=tf.float32))

        self.max_iter = max_iter
        # model
        self.x = tf.placeholder(tf.float32, [None, self.n_input[0], self.n_input[1], 1])
        self.learning_rate = tf.placeholder(tf.float32, [])

        if denoise == False:
            x_input = self.x
            latent, shape = self.encoder(x_input, weights)

        else:
            x_input = tf.add(self.x, tf.random_normal(shape=tf.shape(self.x),
                                                      mean=0,
                                                      stddev=0.2,
                                                      dtype=tf.float32))

            latent, shape = self.encoder(x_input, weights)
        self.z_conv = tf.reshape(latent, [batch_size, -1])
        self.z_ssc, Coef = self.selfexpressive_moduel(batch_size)
        self.Coef = Coef
        latent_de_ft = tf.reshape(self.z_ssc, tf.shape(latent))
        self.x_r_ft = self.decoder(latent_de_ft, weights, shape)

        self.saver = tf.train.Saver([v for v in tf.trainable_variables() if not (v.name.startswith("Coef"))])

        self.cost_ssc = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.z_conv, self.z_ssc), 2))
        self.recon_ssc = tf.reduce_sum(tf.pow(tf.subtract(self.x_r_ft, self.x), 2.0))
        self.reg_ssc = tf.reduce_sum(tf.pow(self.Coef, 2))
        tf.summary.scalar("ssc_loss", self.cost_ssc)
        tf.summary.scalar("reg_lose", self.reg_ssc)

        self.loss_ssc = self.cost_ssc * reg_const2 + reg_const1 * self.reg_ssc + self.recon_ssc

        self.merged_summary_op = tf.summary.merge_all()
        self.optimizer_ssc = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss_ssc)
        self.init = tf.global_variables_initializer()
        self.sess = tf.InteractiveSession()
        self.sess.run(self.init)
        self.summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())

    # # this function will raise an exception in higher version Tensorflow
    # def _initialize_weights(self):
    #     all_weights = dict()
    #     with tf.variable_scope('weight', reuse=tf.AUTO_REUSE):
    #         all_weights['enc_w0'] = tf.get_variable("enc_w0",
    #         shape=[self.kernel_size[0], self.kernel_size[0], 1, self.n_hidden[0]],
    #         initializer=layers.xavier_initializer_conv2d(), regularizer=self.reg)
    #         all_weights['enc_b0'] = tf.Variable(tf.zeros([self.n_hidden[0]], dtype=tf.float32))
    #
    #         all_weights['dec_w0'] = tf.get_variable("dec_w0",
    #         shape=[self.kernel_size[0], self.kernel_size[0], 1, self.n_hidden[0]],
    #         initializer=layers.xavier_initializer_conv2d(), regularizer=self.reg)
    #         all_weights['dec_b0'] = tf.Variable(tf.zeros([1], dtype=tf.float32))
    #     return all_weights

    # Building the encoder
    def encoder(self, x, weights):
        shapes = []
        # Encoder Hidden layer with relu activation #1
        shapes.append(x.get_shape().as_list())
        layer1 = tf.nn.bias_add(tf.nn.conv2d(x, weights['enc_w0'], strides=[1, 2, 2, 1], padding='SAME'),
                                weights['enc_b0'])
        layer1 = tf.nn.relu(layer1)
        return layer1, shapes

    # Building the decoder
    def decoder(self, z, weights, shapes):
        # Encoder Hidden layer with relu activation #1
        shape_de1 = shapes[0]
        layer1 = tf.add(tf.nn.conv2d_transpose(z, weights['dec_w0'], tf.stack(
            [tf.shape(self.x)[0], shape_de1[1], shape_de1[2], shape_de1[3]]),
                                               strides=[1, 2, 2, 1], padding='SAME'), weights['dec_b0'])
        layer1 = tf.nn.relu(layer1)

        return layer1

    def selfexpressive_moduel(self, batch_size):

        Coef = tf.Variable(1.0e-8 * tf.ones([self.batch_size, self.batch_size], tf.float32), name='Coef')
        z_ssc = tf.matmul(Coef, self.z_conv)
        return z_ssc, Coef

    def finetune_fit(self, X, lr):
        C, l1_cost, l2_cost, total_loss, summary, _ = self.sess.run(
            (self.Coef, self.reg_ssc, self.cost_ssc, self.loss_ssc, self.merged_summary_op, self.optimizer_ssc), \
            feed_dict={self.x: X, self.learning_rate: lr})
        self.summary_writer.add_summary(summary, self.iter)
        self.iter = self.iter + 1
        return C, l1_cost, l2_cost, total_loss

    def initlization(self):
        tf.reset_default_graph()
        self.sess.run(self.init)

    def transform(self, X):
        return self.sess.run(self.z_conv, feed_dict={self.x: X})

    def save_model(self):
        save_path = self.saver.save(self.sess, self.model_path)
        print("model saved in file: %s" % save_path)

    def restore(self):
        self.saver.restore(self.sess, self.model_path)
        print("model restored")

    def best_map(self, L1, L2):
        # L1 should be the labels and L2 should be the clustering number we got
        Label1 = np.unique(L1)
        nClass1 = len(Label1)
        Label2 = np.unique(L2)
        nClass2 = len(Label2)
        nClass = np.maximum(nClass1, nClass2)
        G = np.zeros((nClass, nClass))
        for i in range(nClass1):
            ind_cla1 = L1 == Label1[i]
            ind_cla1 = ind_cla1.astype(float)
            for j in range(nClass2):
                ind_cla2 = L2 == Label2[j]
                ind_cla2 = ind_cla2.astype(float)
                G[i, j] = np.sum(ind_cla2 * ind_cla1)
        m = Munkres()
        index = m.compute(-G.T)
        index = np.array(index)
        c = index[:, 1]
        newL2 = np.zeros(L2.shape)
        for i in range(nClass2):
            newL2[L2 == Label2[i]] = Label1[c[i]]
        return newL2

    def thrC(self, C, ro):
        if ro < 1:
            N = C.shape[1]
            Cp = np.zeros((N, N))
            S = np.abs(np.sort(-np.abs(C), axis=0))
            Ind = np.argsort(-np.abs(C), axis=0)
            for i in range(N):
                cL1 = np.sum(S[:, i]).astype(float)
                stop = False
                csum = 0
                t = 0
                while (stop == False):
                    csum = csum + S[t, i]
                    if csum > ro * cL1:
                        stop = True
                        Cp[Ind[0:t + 1, i], i] = C[Ind[0:t + 1, i], i]
                    t = t + 1
        else:
            Cp = C
        return Cp

    def post_proC(self, C, K, d, alpha):
        # C: coefficient matrix, K: number of clusters, d: dimension of each subspace
        n = C.shape[0]
        C = 0.5 * (C + C.T)
        C = C - np.diag(np.diag(C)) + np.eye(n, n)  # for sparse C, this step will make the algorithm more numerically stable
        r = d * K + 1
        print('r = %s, C:%s' % (r, np.unique(C).shape))
        U, S, _ = svds(C, r)
        U = U[:, ::-1]
        S = np.sqrt(S[::-1])
        S = np.diag(S)
        U = U.dot(S)
        U = normalize(U, norm='l2', axis=1)
        Z = U.dot(U.T)
        Z = Z * (Z > 0)
        L = np.abs(Z ** alpha)
        L = L / L.max()
        L = 0.5 * (L + L.T)
        spectral = cluster.SpectralClustering(n_clusters=K, eigen_solver='arpack', affinity='precomputed',
                                              assign_labels='discretize')
        spectral.fit(L)
        grp = spectral.fit_predict(L)
        return grp, L

    def cluster(self, X, n_cluster):
        n_row, n_column, n_band = X.shape
        img_transposed = np.transpose(X, axes=(2, 0, 1))  # Img.transpose()
        img_transposed = np.reshape(img_transposed, (n_band, n_row, n_column, 1))
        # ft_times = 30
        alpha = 0.04
        learning_rate = 1e-3
        all_loss = []
        for i in range(0, 1):
            self.initlization()
            for iter_ft in range(self.max_iter):
                iter_ft = iter_ft + 1
                C, l1_cost, l2_cost, total_loss = self.finetune_fit(img_transposed, learning_rate)
                print('# epoch %s' % (iter_ft))
                all_loss.append(total_loss)
            C = self.thrC(C, alpha)
            y_x, CKSym_x = self.post_proC(C, n_cluster, 1, 4)
            print(all_loss)
            return y_x
                # all_loss.append(total_loss)
                # if iter_ft % display_step == 0:
                #     print("epoch: %.1d" % iter_ft,
                #           "L1 cost: %.8f, L2 cost: %.8f, total cost: %.8f" % (l1_cost, l2_cost, total_loss))
                #     C = self.thrC(C, alpha)
                #     y_x, CKSym_x = self.post_proC(C, n_cluster, 1, 4)
                    # bands = self.select_band(y_x, img)  # n_row * n_clm * n_class
                    # score = dsc_bs.eval_band(bands, gt, train_inx, test_idx)
                    # all_acc.append(score)
                    # print('eval score:', score)
            # print(all_loss)
            # print(all_acc)

In [None]:
#@title DSC Band Selection Class
class DSCBS(object):
    """
    Select band subset using DSC algorithm
    """
    def __init__(self, n_band, **kwargs_DSC):
        self.n_band = n_band
        self.dsc = DSC_NET(**kwargs_DSC)

    def fit(self, X):
        """
        :param X: Array-like with size (n_row, n_column, n_band)
        :return:
        """
        self.X = X
        return self

    def predict(self, X):
        cluster_result = self.dsc.cluster(X, self.n_band)
        selected_band = []
        n_row, n_column, n_band = X.shape
        n_cluster = np.unique(cluster_result).__len__()
        img_ = X.reshape((n_row * n_column, -1))  # n_sample * n_band
        for c in np.unique(cluster_result):
            idx = np.nonzero(cluster_result == c)
            center = np.mean(img_[:, idx[0]], axis=1).reshape((-1, 1))
            distance = np.linalg.norm(img_[:, idx[0]] - center, axis=0)
            band_ = img_[:, idx[0]][:, distance.argmin()]
            selected_band.append(band_)
        bands = np.asarray(selected_band)
        bands = bands.reshape(n_cluster, n_row, n_column)
        bands = np.transpose(bands, axes=(1, 2, 0))
        self.bands = bands
        return self.bands


# if __name__ == '__main__':
#     from Toolbox.Preprocessing import Processor
#     from sklearn.preprocessing import minmax_scale
#
#     root = 'F:\\Python\\HSI_Files\\'
#     # im_, gt_ = 'SalinasA_corrected', 'SalinasA_gt'
#     im_, gt_ = 'Indian_pines_corrected', 'Indian_pines_gt'
#     # im_, gt_ = 'Pavia', 'Pavia_gt'
#     # im_, gt_ = 'Botswana', 'Botswana_gt'
#     # im_, gt_ = 'KSC', 'KSC_gt'
#
#     img_path = root + im_ + '.mat'
#     gt_path = root + gt_ + '.mat'
#     print(img_path)
#
#     p = Processor()
#     img, gt = p.prepare_data(img_path, gt_path)
#     # Img, Label = Img[:256, :, :], Label[:256, :]
#     n_row, n_column, n_band = img.shape
#     train_inx, test_idx = p.get_tr_tx_index(p.get_correct(img, gt)[1], test_size=0.9)
#
#     img_train = minmax_scale(img.reshape(n_row * n_column, n_band)).reshape((n_row, n_column, n_band))
#     # img_train = np.transpose(img_train, axes=(2, 0, 1))  # Img.transpose()
#     # img_train = np.reshape(img_train, (n_band, n_row, n_column, 1))
#
#     n_input = [n_row, n_column]
#     kernel_size = [11]
#     n_hidden = [16]
#     batch_size = n_band
#     model_path = './pretrain-model-COIL20/model.ckpt'
#     ft_path = './pretrain-model-COIL20/model.ckpt'
#     logs_path = './pretrain-model-COIL20/logs'
#
#     num_class = 5  # how many class we sample
#     batch_size_test = n_band
#
#     iter_ft = 0
#     ft_times = 50
#     display_step = 1
#     alpha = 0.04
#     learning_rate = 1e-3
#
#     reg1 = 1e-4
#     reg2 = 150.0
#     kwargs = {'n_input': n_input, 'n_hidden': n_hidden, 'reg_const1': reg1, 'reg_const2': reg2,
#               'kernel_size': kernel_size,'batch_size': batch_size_test, 'model_path': model_path, 'logs_path': logs_path}
#     dscbs = DSCBS(10, **kwargs)
#     dscbs.fit(img_train)
#     bands = dscbs.predict(img_train)
#     print(bands.shape)
#
#
#
#     # CAE = ConvAE(n_input=n_input, n_hidden=n_hidden, reg_const1=reg1, reg_const2=reg2, kernel_size=kernel_size,
#     #              batch_size=batch_size_test, model_path=model_path, logs_path=logs_path)
#
#     # acc_ = []
#     # all_loss = []
#     # all_acc = []
#     # for i in range(0, 1):
#     #     # coil20_all_subjs = copy.deepcopy(Img)
#     #     # coil20_all_subjs = coil20_all_subjs.astype(float)
#     #     # label_all_subjs = copy.deepcopy(Label)
#     #     # label_all_subjs = label_all_subjs - label_all_subjs.min() + 1
#     #     # label_all_subjs = np.squeeze(label_all_subjs)
#     #     CAE.initlization()
#     #     # CAE.restore()
#     #     for iter_ft in range(ft_times):
#     #         iter_ft = iter_ft + 1
#     #         C, l1_cost, l2_cost, total_loss = CAE.finetune_fit(img_train, learning_rate)
#     #         all_loss.append(total_loss)
#     #         if iter_ft % display_step == 0:
#     #             print("epoch: %.1d" % iter_ft,
#     #                   "L1 cost: %.8f, L2 cost: %.8f, total cost: %.8f" % (l1_cost, l2_cost, total_loss))
#     #             C = thrC(C, alpha)
#     #             y_x, CKSym_x = post_proC(C, num_class, 1, 4)
#     #             bands = band_selection(y_x, img)  # n_row * n_clm * n_class
#     #             score = eval_band(bands, gt, train_inx, test_idx)
#     #             all_acc.append(score)
#     #             print('eval score:', score)
#     #     print(all_loss)
#     #     print(all_acc)

## 7.5) ISSC Band Selection
Ref:
    [1]	W. Sun, L. Zhang, B. Du, W. Li, and Y. Mark Lai, "Band Selection Using Improved Sparse Subspace Clustering
    for Hyperspectral Imagery Classification," IEEE Journal of Selected Topics in Applied Earth Observations and
    Remote Sensing, vol. 8, pp. 2784-2797, 2015.

Formula:
    arg min ||X - XW||_F + lambda||W||_F subject to diag(Z) = 0
Solution:
    Wˆ = −(X^T X + lambda*I)^−1 (diag((X^T X + lambda*I)−1))^−1

In [None]:
#@title ISSC Hyperspectral Image Class
class ISSC_HSI(object):
    """
    :argument:
        Implementation of L2 norm based sparse self-expressive clustering model
        with affinity measurement basing on angular similarity
    """
    def __init__(self, n_band=10, coef_=1):
        self.n_band = n_band
        self.coef_ = coef_

    def fit(self, X):
        self.X = X
        return self

    def predict(self, X):
        """
        :param X: shape [n_row*n_clm, n_band]
        :return: selected band subset
        """
        I = np.eye(X.shape[1])
        coefficient_mat = -1 * np.dot(np.linalg.inv(np.dot(X.transpose(), X) + self.coef_ * I),
                                      np.linalg.inv(np.diag(np.diag(np.dot(X.transpose(), X) + self.coef_ * I))))
        temp = np.linalg.norm(coefficient_mat, axis=0).reshape(1, -1)
        affinity = (np.dot(coefficient_mat.transpose(), coefficient_mat) /
                    np.dot(temp.transpose(), temp))**2

        sc = SpectralClustering(n_clusters=self.n_band, affinity='precomputed')
        sc.fit(affinity)
        selected_band, band_list = self.__get_band(sc.labels_, X)
        return selected_band, band_list

    def __get_band(self, cluster_result, X):
        """
        select band according to the center of each cluster
        :param cluster_result:
        :param X:
        :return:
        """
        selected_band = []
        band_list = []
        n_cluster = np.unique(cluster_result).__len__()
        # img_ = X.reshape((n_row * n_column, -1))  # n_sample * n_band
        for c in np.unique(cluster_result):
            idx = np.nonzero(cluster_result == c)
            center = np.mean(X[:, idx[0]], axis=1).reshape((-1, 1))
            distance = np.linalg.norm(X[:, idx[0]] - center, axis=0)
            band_list.append(idx[0][distance.argmin()])
            band_ = X[:, idx[0]][:, distance.argmin()]
            selected_band.append(band_)
        bands = np.asarray(selected_band).transpose()
        # bands = bands.reshape(n_cluster, n_row, n_column)
        # bands = np.transpose(bands, axes=(1, 2, 0))
        return bands, band_list

## 7.6) Lap Score Band Selection

In [None]:
#@title Lap Score Hyperspectral Image Class
class Lap_score_HSI(object):

    def __init__(self, n_band=10):
        self.n_band = n_band

    def fit(self, X):
        self.X = X
        return self

    def predict(self, X):
        """
        :param X: shape [n_row*n_clm, n_band]
        :return:
        """
        # n_row, n_column, __n_band = X.shape
        # XX = X.reshape((n_row * n_column, -1))  # n_sample * n_band
        XX = X

        kwargs_W = {"metric": "euclidean", "neighbor_mode": "knn", "weight_mode": "heat_kernel", "k": 5, 't': 1}
        W = construct_W.construct_W(XX, **kwargs_W)

        # obtain the scores of features
        score = lap_score.lap_score(X, W=W)

        # sort the feature scores in an ascending order according to the feature scores
        idx = lap_score.feature_ranking(score)

        # obtain the dataset on the selected features
        selected_features = X[:, idx[0:self.n_band]]

        # selected_features.reshape((self.n_band, n_row, n_column))
        # selected_features = np.transpose(selected_features, axes=(1, 2, 0))
        return selected_features

## 7.7) NDFS Band Selection
Nonnegative Discriminative Feature Selection (NDFS)
Reference:
    Li, Zechao, et al. "Unsupervised Feature Selection Using Nonnegative Spectral Analysis." AAAI. 2012.

In [None]:
#@title NDFS Hyperspectral Image Class
class NDFS_HSI(object):

    def __init__(self, n_cluster, n_band=10):
        self.n_band = n_band
        self.n_cluster = n_cluster

    def fit(self, X):
        self.X = X
        return self

    def predict(self, X):
        """

        :param X: shape [n_row*n_clm, n_band]
        :return:
        """
        # construct affinity matrix
        kwargs = {"metric": "euclidean", "neighborMode": "knn", "weightMode": "heatKernel", "k": 5, 't': 1}
        W = construct_W.construct_W(X, **kwargs)

        # obtain the feature weight matrix
        Weight = NDFS.ndfs(X, W=W, n_clusters=self.n_cluster)

        # sort the feature scores in an ascending order according to the feature scores
        idx = feature_ranking(Weight)

        # obtain the dataset on the selected features
        selected_features = X[:, idx[0:self.n_band]]
        return selected_features

## 7.8) SNMF Band Selection

In [None]:
#@title SNMF Band Selection Class
class BandSelection_SNMF(object):
    def __init__(self, n_band):
        self.n_band = n_band

    def predict(self, X):
        """
        :param X: with shape (n_pixel, n_band)
        :return:
        """
        # # Note that X has to reshape to (n_fea., n_sample)
        # XX = X.transpose()  # (n_band, n_pixel)
        # snmf = nimfa.Snmf(X, seed="random_c", rank=self.n_band)  # remain para. default
        snmf = nimfa.Snmf(X, rank=self.n_band, max_iter=20, version='r', eta=1.,
                          beta=1e-4, i_conv=10, w_min_change=0)
        snmf_fit = snmf()
        W = snmf.basis()  # shape: n_band * k
        H = snmf.coef()  # shape: k * n_pixel

        #  get clustering res.
        H = np.asarray(H)
        indx_sort = np.argsort(H, axis=0)  # ascend order
        cluster_res = indx_sort[-1].reshape(-1)

        #  select band
        selected_band = []
        for c in np.unique(cluster_res):
            idx = np.nonzero(cluster_res == c)
            center = np.mean(X[:, idx[0]], axis=1).reshape((-1, 1))
            distance = np.linalg.norm(X[:, idx[0]] - center, axis=0)
            band_ = X[:, idx[0]][:, distance.argmin()]
            selected_band.append(band_)
        while selected_band.__len__() < self.n_band:
            selected_band.append(np.zeros(X.shape[0]))
        bands = np.asarray(selected_band).transpose()
        return bands

    # # 得到W和H
    # def getWH(self, x_input, rank=10):
    #     snmf = nimfa.Snmf(x_input, seed="random_c", rank=rank, max_iter=12, version='r', eta=1.,
    #                         beta=1e-4, i_conv=10, w_min_change=0)
    #     snmf_fit = snmf()
    #     W = snmf.basis()
    #     H = snmf.coef()
    #     return W, H
    #
    # # 从H中选择每列最大的值
    # def maxh_selection(self, H):
    #     selection_h = []
    #     n_row, n_column = H.shape
    #     for i in range(n_column):
    #         max = H[0, i]
    #         for j in range(n_row-1):
    #             if H[j+1,i]>H[j,i]:
    #                 max = H[j+1, i]
    #         selection_h.append(max)
    #     return selection_h
    #
    # def fit(self, X):
    #     self.X = X
    #     return self
    #
    # # 波段选择，根据聚类选择中心波段
    # def predict(self, X):
    #     """
    #     Select band according to clustering center
    #     :param X: array like: shape (n_row, n_column, n_band)
    #     :return:
    #     """
    #     n_row, n_column, n_band = X.shape
    #     XX = X.reshape((n_row * n_column, -1))  # n_sample * n_band
    #     self.W, self.H = self.getWH(XX, rank=self.n_band)
    #     cluster_result = self.maxh_selection(self.H)
    #     selected_band = []
    #     n_cluster = np.unique(cluster_result).__len__()
    #     for c in np.unique(cluster_result):
    #         idx = np.nonzero(cluster_result == c)
    #         center = np.mean(XX[:, idx[0]], axis=1).reshape((-1, 1))
    #         distance = np.linalg.norm(XX[:, idx[0]] - center, axis=0)
    #         band_ = XX[:, idx[0]][:, distance.argmin()]
    #         selected_band.append(band_)
    #     bands = np.asarray(selected_band)
    #     bands = bands.reshape(n_cluster, n_row, n_column)
    #     bands = np.transpose(bands, axes=(1, 2, 0))
    #     return bands
    #

## 7.9) SpaBS Band Selection
@ Author by Zeng Meng

In [None]:
#@title Approximate KSVD Class
class ApproximateKSVD(object):
    def __init__(self, n_components, max_iter=10, tol=1e-6,
                 transform_n_nonzero_coefs=None):
        """
        Parameters
        ----------
        n_components:
            Number of dictionary elements
        max_iter:
            Maximum number of iterations
        tol:
            tolerance for error
        transform_n_nonzero_coefs:
            Number of nonzero coefficients to target
        """
        self.components_ = None
        self.max_iter = max_iter
        self.tol = tol
        self.n_components = n_components
        self.transform_n_nonzero_coefs = transform_n_nonzero_coefs

    def _update_dict(self, X, D, gamma):
        for j in range(self.n_components):
            I = gamma[:, j] > 0
            if np.sum(I) == 0:
                continue

            D[j, :] = 0
            g = gamma[I, j].T
            r = X[I, :] - gamma[I, :].dot(D)
            d = r.T.dot(g)
            d /= np.linalg.norm(d)
            g = r.dot(d)
            D[j, :] = d
            gamma[I, j] = g.T
        return D, gamma

    def _initialize(self, X):
        if min(X.shape) <= self.n_components:
            D = np.random.randn(self.n_components, X.shape[1])
        else:
            u, s, vt = sp.sparse.linalg.svds(X, k=self.n_components)
            D = np.dot(np.diag(s), vt)
        D /= np.linalg.norm(D, axis=1)[:, np.newaxis]
        return D

    def _transform(self, D, X):
        gram = D.dot(D.T)
        Xy = D.dot(X.T)

        n_nonzero_coefs = self.transform_n_nonzero_coefs
        if n_nonzero_coefs is None:
            n_nonzero_coefs = int(0.1 * X.shape[1])

        return orthogonal_mp_gram(
            gram, Xy, n_nonzero_coefs=n_nonzero_coefs).T

    def fit(self, X):
        """
        Parameters
        ----------
        X: shape = [n_samples, n_features]
        """
        D = self._initialize(X)
        for i in range(self.max_iter):
            gamma = self._transform(D, X)
            e = np.linalg.norm(X - gamma.dot(D))
            if e < self.tol:
                break
            D, gamma = self._update_dict(X, D, gamma)

        self.components_ = D
        return self

    def transform(self, X):
        return self._transform(self.components_, X)

In [None]:
#@title SpaBS Class
class SpaBS(object):

    def __init__(self, n_band, sparsity_level=0.5):
        self.n_band = n_band
        self.sparsity_level = sparsity_level

    def fit(self, X):
        self.X = X
        return self

    def predict(self, X):
        """
        Select band according to sparse representation
        :param X: array like: shape (n_row*n_column, n_band)
        :return:
        """
        # n_row, n_column, n_band = X.shape
        # XX = X.reshape((n_row * n_column, -1))  # n_sample * n_band
        # 使用SpaBS算法
        # 调用ksvd
        # TODO: according to ref., X has to be with shape (n_band, n_sample)
        # X = X.transpose()
        dico = ApproximateKSVD(n_components=X.shape[1])
        dico.fit(X)
        gamma_ = dico.transform(X)  # gamma为系数矩阵, shape(n_sample, n_atom)
        gamma = gamma_.transpose()
        sorted_inx = np.argsort(gamma, axis=0)  # ascending order for each column
        K = X.shape[0] * self.sparsity_level
        largest_k = sorted_inx[-self.n_band:, :]

        # # statistic
        element, freq = np.unique(largest_k, return_counts=True)
        selected_inx = element[np.argsort(freq)][-self.n_band:]
        selected_band = X[:, selected_inx]
        return selected_band


'''
---------------------------
        Test
'''

# X ~ gamma.dot(dictionary)
# X = np.random.randn(1000, 20)
# aksvd = ApproximateKSVD(n_components=20)
# dictionary = aksvd.fit(X).components_
# gamma = aksvd.transform(X)
# print(gamma.shape)

## 7.10) SPEC Band Selection
Ref:
    Zheng Zhao and Huan Liu. 2007. Spectral feature selection for supervised and unsupervised learning. In Proceedings
    of the 24th international conference on Machine learning (ICML '07), Zoubin Ghahramani (Ed.). ACM, New York, NY,
    USA, 1151-1157.

In [None]:
#@title SPEC Hyperspectral Image Class
class SPEC_HSI(object):

    def __init__(self, n_band=10):
        self.n_band = n_band

    def fit(self, X):
        self.X = X
        return self

    def predict(self, X):
        """

        :param X: shape [n_row*n_clm, n_band]
        :return:
        """
        # specify the second ranking function which uses all except the 1st eigenvalue
        kwargs = {'style': 0}
        # n_row, n_column, __n_band = X.shape
        # XX = X.reshape((n_row * n_column, -1))  # n_sample * n_band
        XX = X

        # obtain the scores of features
        score = SPEC.spec(XX, **kwargs)

        # sort the feature scores in an descending order according to the feature scores
        idx = SPEC.feature_ranking(score, **kwargs)

        # obtain the dataset on the selected features
        selected_features = XX[:, idx[0:self.n_band]]
        # selected_features.reshape((self.n_band, n_row, n_column))
        # selected_features = np.transpose(selected_features, axes=(1, 2, 0))
        return selected_features

## 7.11) SSR Band Selection
Ref:
    孙伟伟,蒋曼,李巍岳.利用稀疏自表达实现高光谱影像波段选择[J]. 武汉大学学报·信息科学版, 2017, 42(4): 441-448.
    SUN Weiwei,JIANG Man,LI Weiyue.Band Selection Using Sparse Self-representation for Hyperspectral Imagery[J].
    GEOMATICS AND INFORMATION SCIENCE OF WUHAN UNIVERS, 2017, 42(4): 441-448.

In [None]:
#@title SSC Band Selection Class
class SSC_BS(BaseEstimator, ClassifierMixin):
    """
    """
    def __init__(self, n_hidden, n_clusters, lambda_coef=1):
        self.n_hidden = n_hidden
        self.n_clusters = n_clusters
        self.lambda_coef = lambda_coef

    def fit_predict_omp(self, X, y=None):
        n_sample = X.transpose().shape[0]
        H = X.transpose()      #NRP_ELM(self.n_hidden, sparse=False).fit(X).predict(X)
        C = np.zeros((n_sample, n_sample))
        # solve sparse self-expressive representation
        for i in range(n_sample):
            y_i = H[i]
            H_i = np.delete(H, i, axis=0)
            # H_T = H_i.transpose()  # M x (N-1)
            omp = OrthogonalMatchingPursuit(n_nonzero_coefs=int(n_sample * 0.5), tol=1e20)
            omp.fit(H_i.transpose(), y_i)
            #  Normalize the columns of C: ci = ci / ||ci||_ss.
            coef = omp.coef_ / np.max(np.abs(omp.coef_))
            C[:i, i] = coef[:i]
            C[i+1:, i] = coef[i:]
        # # compute affinity matrix
        # L = 0.5 * (np.abs(C) + np.abs(C.T))  # affinity graph
        # # L = 0.5 * (C + C.T)
        # self.affinity_matrix = L
        # # spectral clustering
        # sc = SpectralClustering(n_clusters=self.n_clusters, affinity='precomputed')
        # sc.fit(self.affinity_matrix)
        # K-means clustering
        kmeans = KMeans(n_clusters=self.n_clusters, max_iter=500).fit(C)
        label = kmeans.labels_
        C_ = C
        band_index = []
        for i in np.unique(label):
            index__ = np.nonzero(label == i)
            centroids_ = C_[index__]
            centroids = centroids_.mean(axis=0)
            dis = pairwise_distances(centroids_, centroids.reshape((1, centroids_.shape[1]))).flatten()
            index_min = np.argmin(dis)
            C_bestrow = centroids_[index_min, :]
            index = np.nonzero(np.all(C_ == C_bestrow, axis=1))
            band_index.append(index[0][0])
        BandData = X[:, band_index]  # BandData = self.X[:, band_index]
        print('selected band:', band_index)
        return BandData  #sc.labels_

    def fit_predict_close(self, X, raw_input_=False):
        """
        using close-form solution
        :param X:
        :param raw_input_:
        :return:
        """
        n_sample = X.transpose().shape[0]
        if raw_input_ is True:
            H = X.transpose()
        else:
            H = X.transpose()    #NRP_ELM(self.n_hidden, sparse=False).fit(X).predict(X)
        C = np.zeros((n_sample, n_sample))
        for i in range(n_sample):
            y_i = H[i]
            H_i = np.delete(H, i, axis=0).transpose()
            term_1 = np.linalg.inv(np.dot(H_i.transpose(), H_i) + self.lambda_coef * np.eye(n_sample - 1))
            w = np.dot(np.dot(term_1, H_i.transpose()), y_i.reshape((y_i.shape[0], 1)))
            w = w.flatten()
            #  Normalize the columns of C: ci = ci / ||ci||_ss.
            coef = w / np.max(np.abs(w))
            C[:i, i] = coef[:i]
            C[i + 1:, i] = coef[i:]
        # # compute affinity matrix
        # L = 0.5 * (np.abs(C) + np.abs(C.T))  # affinity graph
        # self.affinity_matrix = L
        # # spectral clustering
        # sc = SpectralClustering(n_clusters=self.n_clusters, affinity='precomputed')
        # sc.fit(self.affinity_matrix)
        kmeans = KMeans(n_clusters=self.n_clusters, max_iter=500).fit(C)
        label = kmeans.labels_
        C_ = C
        band_index = []
        for i in np.unique(label):
            index__ = np.nonzero(label == i)
            centroids_ = C_[index__]
            centroids = centroids_.mean(axis=0)
            dis = pairwise_distances(centroids_, centroids.reshape((1, centroids_.shape[1]))).flatten()
            index_min = np.argmin(dis)
            C_bestrow = centroids_[index_min, :]
            index = np.nonzero(np.all(C_ == C_bestrow, axis=1))
            band_index.append(index[0][0])
        BandData = X[:, band_index]  # BandData = self.X[:, band_index]
        print('selected band:', band_index)
        return BandData  # sc.labels_

    def fit_predict_cvx(self, X):
        n_sample = X.transpose().shape[0]
        H = X.transpose()  #NRP_ELM(self.n_hidden, sparse=False).fit(X).predict(X)
        C = np.zeros((n_sample, n_sample))
        # solve sparse self-expressive representation
        for i in range(n_sample):
            y_i = H[i]
            H_i = np.delete(H, i, axis=0)
            # H_T = H_i.transpose()  # M x (N-1)
            # omp = OrthogonalMatchingPursuit(n_nonzero_coefs=500)
            # omp.fit(H_i.transpose(), y_i)
            w = cvx.Variable(n_sample-1)
            objective = cvx.Minimize(0.5 * cvx.sum_squares(H_i.transpose() * w - y_i) + 0.5 * self.lambda_coef * cvx.norm(w, 1))
            prob = cvx.Problem(objective)
            result = prob.solve()
            #  Normalize the columns of C: ci = ci / ||ci||_ss.
            ww = np.asarray(w.value).flatten()
            coef = ww / np.max(np.abs(ww))
            C[:i, i] = coef[:i]
            C[i + 1:, i] = coef[i:]
        # compute affinity matrix
        # L = 0.5 * (np.abs(C) + np.abs(C.T))  # affinity graph
        # # L = 0.5 * (C + C.T)
        # self.affinity_matrix = L
        # # spectral clustering
        # sc = SpectralClustering(n_clusters=self.n_clusters, affinity='precomputed')
        # sc.fit(self.affinity_matrix)
        # k-means clustering
        kmeans = KMeans(n_clusters=self.n_clusters, max_iter=500).fit(C)
        label = kmeans.labels_
        C_ = C
        band_index = []
        for i in np.unique(label):
            index__ = np.nonzero(label == i)
            centroids_ = C_[index__]
            centroids = centroids_.mean(axis=0)
            dis = pairwise_distances(centroids_, centroids.reshape((1, centroids_.shape[1]))).flatten()
            index_min = np.argmin(dis)
            C_bestrow = centroids_[index_min, :]
            index = np.nonzero(np.all(C_ == C_bestrow, axis=1))
            band_index.append(index[0][0])
        BandData = X[:, band_index]   #BandData = self.X[:, band_index]
        print ('selected band:', band_index)
        return BandData  #sc.labels_


## 7.12 Band Selection Function

In [None]:
#@title Setup 'Band Selection' function
def band_selection(data, gt, **hyperparams):
    """
    """
    band_reduction_method = hyperparams['band_reduction_method']
    n_components = hyperparams['n_components']

    orig_rows, orig_cols, orig_channels = data.shape
    bands_selected = None

    band_selection_start = time.time()

    if band_reduction_method == 'pca':
        print('Using PCA dimensionality reduction on data...')

        # https://towardsdatascience.com/pca-on-hyperspectral-data-99c9c5178385
        print('Reshaping the data into two dimensions...')
        data = data.reshape(data.shape[0]*data.shape[1], -1)
        print(f'Reshaped data shape: {data.shape}')

        if n_components is None: n_components = 'mle'
        print('Fitting PCA to data...')
        pca = PCA(n_components=n_components,
                  svd_solver='auto',
                  tol=0.0,
                  iterated_power='auto',
                  random_state=hyperparams['random_seed'])
        fit_start = time.time()
        pca.fit(data)
        fit_end = time.time()
        fit_runtime = datetime.timedelta(seconds=(fit_end - fit_start))
        print(f'PCA fitting completed! Fit runtime: {fit_runtime}')

        print(f'PCA fit data to {pca.n_components_} components!')
        data = pca.transform(data)
        print(f'New data shape: {data.shape}')
        data = np.reshape(data, (orig_rows, orig_cols, data.shape[-1]))
        print(f'Reshaped new data shape: {data.shape}')
    elif band_reduction_method == 'ica':
        print('Using ICA dimensionality reduction on data...')

        print('Reshaping the data into two dimensions...')
        data = data.reshape(data.shape[0]*data.shape[1], -1)
        print(f'Reshaped data shape: {data.shape}')

        print('Fitting ICA to data...')
        ica = FastICA(n_components=n_components,
                  algorithm='parallel',
                  fun='logcosh',
                  fun_args=None,
                  max_iter=200,
                  tol=1e-4,
                  w_init=None,
                  random_state=hyperparams['random_seed'])
        fit_start = time.time()
        ica.fit(data.astype(np.float32))
        fit_end = time.time()
        fit_runtime = datetime.timedelta(seconds=(fit_end - fit_start))
        print(f'ICA fitting completed! Fit runtime: {fit_runtime}')
        print(f'ICA found {ica.n_features_in_} features while fitting!')
        data = ica.transform(data)
        print(f'New data shape: {data.shape}')
        data = np.reshape(data, (orig_rows, orig_cols, data.shape[-1]))
        print(f'Reshaped new data shape: {data.shape}')

    elif band_reduction_method == 'cae-ssc':
        print('Using CAE SSC dimensionality reduction on data...')
        cae_ssc = CAE_BS(n_band=n_components)
        # cae_ssc = CAE_BS(n_band=data.shape[-1])
        predict_start = time.time()
        # bands_selected = cae_ssc.predict(np.array([gt.flatten(), data.shape[-1]]),
        #                                  data.reshape(data.shape[0]*data.shape[1], -1))
        # bands_selected = cae_ssc.predict(data.reshape(data.shape[0]*data.shape[1], -1),
        #                                  data.reshape(data.shape[0]*data.shape[1], -1))
        reduced_data = cae_ssc.predict(data.reshape(data.shape[0]*data.shape[1], -1),
                                         data.reshape(data.shape[0]*data.shape[1], -1))
        predict_end = time.time()
        predict_runtime = datetime.timedelta(seconds=(predict_end - predict_start))
        print(f'CAE SSC prediction completed! Prediction runtime: {predict_runtime}')
        print(f'Bands selected by CAE SSC: {bands_selected}')
        # data = data[...,bands_selected]
        data = np.reshape(reduced_data, (orig_rows, orig_cols, reduced_data.shape[-1]))
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'dsc-net':
        print('Using DSC-NET dimensionality reduction on data...')
        dscnet = DSCBS(n_band=n_components,
                       n_input=(data.shape[0]*data.shape[1], data.shape[2]),
                       kernel_size=(3,),
                       n_hidden=2)
        # dscnet = DSCBS(n_band=data.shape[-1])
        predict_start = time.time()
        # bands_selected = dscnet.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        reduced_data = dscnet.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        predict_end = time.time()
        predict_runtime = datetime.timedelta(seconds=(predict_end - predict_start))
        print(f'DSC-NET prediction completed! Prediction runtime: {predict_runtime}')
        print(f'Bands selected by DSC-NET: {bands_selected}')
        # data = data[...,bands_selected]
        data = np.reshape(reduced_data, (orig_rows, orig_cols, reduced_data.shape[-1]))
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'issc':
        print('Using ISSC dimensionality reduction on data...')
        issc = ISSC_HSI(n_band=n_components)
        # issc = ISSC_HSI(n_band=data.shape[-1])
        predict_start = time.time()
        # bands_selected = issc.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        reduced_data, bands_selected = issc.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        predict_end = time.time()
        predict_runtime = datetime.timedelta(seconds=(predict_end - predict_start))

        # Sort bands
        bands_selected = sorted(bands_selected)

        print(f'ISSC prediction completed! Prediction runtime: {predict_runtime}')
        print(f'Bands selected by ISSC: {bands_selected}')
        # data = data[...,bands_selected]
        data = np.reshape(reduced_data, (orig_rows, orig_cols, reduced_data.shape[-1]))
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'ndfs':
        print('Using NDFS dimensionality reduction on data...')
        ndfs = NDFS_HSI(n_band=data.shape[-1], n_cluster=n_components)
        # ndfs = NDFS_HSI(n_band=data.shape[-1])
        predict_start = time.time()
        # bands_selected = ndfs.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        reduced_data = ndfs.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        predict_end = time.time()
        predict_runtime = datetime.timedelta(seconds=(predict_end - predict_start))
        print(f'NDFS prediction completed! Prediction runtime: {predict_runtime}')
        print(f'Bands selected by NDFS: {bands_selected}')
        # data = data[...,bands_selected]
        data = np.reshape(reduced_data, (orig_rows, orig_cols, reduced_data.shape[-1]))
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'snmf':
        print('Using SNMF dimensionality reduction on data...')
        snmf = BandSelection_SNMF(n_band=n_components)
        # snmf = BandSelection_SNMF(n_band=data.shape[-1])
        predict_start = time.time()
        # bands_selected = snmf.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        reduced_data = snmf.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        predict_end = time.time()
        predict_runtime = datetime.timedelta(seconds=(predict_end - predict_start))
        print(f'SNMF prediction completed! Prediction runtime: {predict_runtime}')
        print(f'Bands selected by SNMF: {bands_selected}')
        # data = data[...,bands_selected]
        data = np.reshape(reduced_data, (orig_rows, orig_cols, reduced_data.shape[-1]))
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'spabs':
        print('Using SpaBS dimensionality reduction on data...')
        spabs = SpaBS(n_band=n_components, sparsity_level=0.5)
        # spabs = SpaBS(n_band=data.shape[-1], sparsity_level=0.5)
        predict_start = time.time()
        # bands_selected = spabs.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        reduced_data = spabs.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        predict_end = time.time()
        predict_runtime = datetime.timedelta(seconds=(predict_end - predict_start))
        print(f'SpaBS prediction completed! Prediction runtime: {predict_runtime}')
        print(f'Bands selected by SpaBS: {bands_selected}')
        # data = data[...,bands_selected]
        data = np.reshape(reduced_data, (orig_rows, orig_cols, reduced_data.shape[-1]))
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'spec':
        print('Using SPEC dimensionality reduction on data...')
        spec = SPEC_HSI(n_band=n_components)
        # spec = SPEC_HSI(n_band=data.shape[-1])
        predict_start = time.time()
        # bands_selected = spec.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        reduced_data = spec.predict(data.reshape(data.shape[0]*data.shape[1], -1))
        predict_end = time.time()
        predict_runtime = datetime.timedelta(seconds=(predict_end - predict_start))
        print(f'SPEC prediction completed! Prediction runtime: {predict_runtime}')
        print(f'Bands selected by SPEC: {bands_selected}')
        # data = data[...,bands_selected]
        data = np.reshape(reduced_data, (orig_rows, orig_cols, reduced_data.shape[-1]))
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'ssr' or band_reduction_method == 'ssr-close':
        print('Using SSR (closed form solution) dimensionality reduction on data...')
        ssc = SSC_BS(n_hidden=2, n_clusters=n_components, lambda_coef=1)
        fit_start = time.time()
        data = ssc.fit_predict_close(data.reshape(data.shape[0]*data.shape[1], -1))
        fit_end = time.time()
        fit_runtime = datetime.timedelta(seconds=(fit_end - fit_start))
        print(f'SSR fitting completed! Fit runtime: {fit_runtime}')
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'ssr-cvx':
        print('Using SSR (self-expressive representation) dimensionality reduction on data...')
        ssc = SSC_BS(n_hidden=2, n_clusters=n_components, lambda_coef=1)
        fit_start = time.time()
        data = ssc.fit_predict_cvx(data.reshape(data.shape[0]*data.shape[1], -1))
        fit_end = time.time()
        fit_runtime = datetime.timedelta(seconds=(fit_end - fit_start))
        print(f'SSR fitting completed! Fit runtime: {fit_runtime}')
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'ssr-omp':
        print('Using SSR (orthogonal matching pursuit) dimensionality reduction on data...')
        ssc = SSC_BS(n_hidden=2, n_clusters=n_components, lambda_coef=1)
        fit_start = time.time()
        data = ssc.fit_predict_omp(data.reshape(data.shape[0]*data.shape[1], -1))
        fit_end = time.time()
        fit_runtime = datetime.timedelta(seconds=(fit_end - fit_start))
        print(f'SSR fitting completed! Fit runtime: {fit_runtime}')
        print(f'New data shape: {data.shape}')

    elif band_reduction_method == 'manual':
        print('Using manually selected bands...')
        if hyperparams['selected_bands'] is not None and type (hyperparams['selected_bands']) is list:
            # Sort bands
            bands_selected = sorted(hyperparams['selected_bands'])
            fit_start = time.time()
            data = data[...,np.array(bands_selected, dtype=int)]
            fit_end = time.time()
            fit_runtime = datetime.timedelta(seconds=(fit_end - fit_start))
            print(f'Manual fitting completed! Fit runtime: {fit_runtime}')
            print(f'Bands selected by manual selection: {bands_selected}')
            print(f'New data shape: {data.shape}')
        else:
            print('Selected bands list for manual selection is invalid! Dimensionality will be unaltered...')
    else:
        print('No valid band selection method chosen! Dimensionality will be unaltered...')

    band_selection_end = time.time()
    band_selection_time = datetime.timedelta(seconds=(band_selection_end - band_selection_start))

    return data, band_selection_time, bands_selected

# 8) Models

## 8.1) Common Model Functions & Classes

In [None]:
#@title Get Optimizer
def get_optimizer(**hyperparams):
    """
    Returns appropriately constructed optimizer from hyperparameter
    inputs.

    Parameters
    ----------
    **hyperparams : dict
        dictionary of hyperparameter values to use to construct the
        optimizer

    Returns
    -------
    optimizer : tensorflow.keras.optimizer.Optimizer
        A keras Optimizer object constructed to hyperparam specification
    """

    # Get requisite hyperparameter values
    optimizer_name = hyperparams['optimizer']
    learning_rate = hyperparams['lr']
    momentum = hyperparams['momentum']
    epsilon = hyperparams['epsilon']
    initial_accumulator_value = hyperparams['initial_accumulator_value']
    beta_1 = hyperparams['beta_1']
    beta_2 = hyperparams['beta_2']
    amsgrad = hyperparams['amsgrad']
    rho = hyperparams['rho']
    centered = hyperparams['centered']
    nesterov = hyperparams['nesterov']
    learning_rate_power = hyperparams['learning_rate_power']
    l1_regularization_strength = hyperparams['l1_regularization_strength']
    l2_regularization_strength = hyperparams['l2_regularization_strength']
    l2_shrinkage_regularization_strength = hyperparams['l2_shrinkage_regularization_strength']
    beta = hyperparams['beta']

    # Set up the optimizers according to the input hyperparameters
    if optimizer_name == 'adadelta':
        if learning_rate is None: learning_rate = 0.001
        if rho is None: rho = 0.95
        if epsilon is None: epsilon = 1e-7
        optimizer = Adadelta(learning_rate=learning_rate,
                             rho=rho,
                             epsilon=epsilon)
    elif optimizer_name == 'adagrad':
        if learning_rate is None: learning_rate = 0.001
        if initial_accumulator_value is None: initial_accumulator_value = 0.1
        if epsilon is None: epsilon = 1e-7
        optimizer = Adagrad(learning_rate=learning_rate,
                            initial_accumulator_value=initial_accumulator_value,
                            epsilon=epsilon)
    elif optimizer_name == 'adam':
        if learning_rate is not None: learning_rate = 0.001
        if beta_1 is None: beta_1 = 0.9
        if beta_2 is None: beta_2 = 0.999
        if epsilon is None: epsilon = 1e-7
        if amsgrad is None: amsgrad = False
        optimizer = Adam(learning_rate=learning_rate,
                         beta_1=beta_1,
                         beta_2=beta_2,
                         epsilon=epsilon,
                         amsgrad=amsgrad)
    elif optimizer_name == 'adamax':
        if learning_rate is not None: learning_rate = 0.001
        if beta_1 is None: beta_1 = 0.9
        if beta_2 is None: beta_2 = 0.999
        if epsilon is None: epsilon = 1e-7
        optimizer = Adamax(learning_rate=learning_rate,
                           beta_1=beta_1,
                           beta_2=beta_2,
                           epsilon=epsilon)
    elif optimizer_name == 'ftrl':
        if learning_rate is not None: learning_rate = 0.001
        if learning_rate_power is None: learning_rate_power = -0.5
        if initial_accumulator_value is None: initial_accumulator_value = 0.1
        if l1_regularization_strength is None: l1_regularization_strength = 0.0
        if l2_regularization_strength is None: l2_regularization_strength = 0.0
        if l2_shrinkage_regularization_strength is None: l2_shrinkage_regularization_strength = 0.0
        if beta is None: beta = 0.0
        optimizer = Ftrl(learning_rate=learning_rate,
                         learning_rate_power=learning_rate_power,
                         initial_accumulator_value=initial_accumulator_value,
                         l1_regularization_strength=l1_regularization_strength,
                         l2_regularization_strength=l2_regularization_strength,
                         l2_shrinkage_regularization_strength=l2_shrinkage_regularization_strength,
                         beta=beta)

    elif optimizer_name == 'nadam':
        if learning_rate is not None: learning_rate = 0.001
        if beta_1 is None: beta_1 = 0.9
        if beta_2 is None: beta_2 = 0.999
        if epsilon is None: epsilon = 1e-7
        optimizer = Nadam(learning_rate=learning_rate,
                          beta_1=beta_1,
                          beta_2=beta_2,
                          epsilon=epsilon)
    elif optimizer_name == 'rmsprop':
        if learning_rate is not None: learning_rate = 0.001
        if rho is None: rho = 0.9
        if momentum is None: momentum = 0.0
        if epsilon is None: epsilon = 1e-7
        if centered is None: centered = False
        optimizer = RMSprop(learning_rate=learning_rate,
                            rho=rho,
                            momentum=momentum,
                            epsilon=epsilon,
                            centered=centered)
    elif optimizer_name == 'sgd':
        if learning_rate is not None: learning_rate = 0.001
        if momentum is None: momentum = 0.0
        if nesterov is None: nesterov = False
        optimizer = SGD(learning_rate=learning_rate,
                        momentum=momentum,
                        nesterov=nesterov)
    else:
        # This is the default value for the Tensorflow keras compile
        # function optimizer argument
        optimizer = 'rmsprop'

    return optimizer

In [None]:
#@title Handle Dimension Ordering
def _handle_dim_ordering():
    global CONV_DIM1
    global CONV_DIM2
    global CONV_DIM3
    global CHANNEL_AXIS
    if K.image_data_format() == 'channels_last':
        CONV_DIM1 = 1
        CONV_DIM2 = 2
        CONV_DIM3 = 3
        CHANNEL_AXIS = 4
    else:
        CHANNEL_AXIS = 1
        CONV_DIM1 = 2
        CONV_DIM2 = 3
        CONV_DIM3 = 4

## 8.2) Model Building Blocks

In [None]:
#@title 2D Convolutional Block
def conv_block_2d(x, growth_rate, name, activation='relu'):
    """A building block for a dense block.

    # Arguments
        x: input tensor.
        growth_rate: float, growth rate at dense layers.
        name: string, block label.

    # Returns
        output tensor for the block.
    """
    bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
    x1 = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                            name=name + '_0_bn')(x)
    x1 = Activation(activation, name=name + f'_0_{activation}')(x1)
    x1 = Conv2D(4 * growth_rate, 1, use_bias=False,
                name=name + '_1_conv', padding='same')(x1)
    x1 = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                            name=name + '_1_bn')(x1)
    x1 = Activation(activation, name=name + f'_1_{activation}')(x1)
    x1 = Conv2D(growth_rate, 3, padding='same', use_bias=False,
                name=name + '_2_conv')(x1)
    x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
    return x

In [None]:
#@title 3D Convolutional Block
def conv_block_3d(x, growth_rate, name, activation='relu'):
    """A building block for a dense block.

    # Arguments
        x: input tensor.
        growth_rate: float, growth rate at dense layers.
        name: string, block label.

    # Returns
        output tensor for the block.
    """
    bn_axis = 4 if K.image_data_format() == 'channels_last' else 1
    x1 = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                            name=name + '_0_bn')(x)
    x1 = Activation(activation, name=name + f'_0_{activation}')(x1)
    x1 = Conv3D(4 * growth_rate, 1, use_bias=False,
                name=name + '_1_conv', padding='same')(x1)
    x1 = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                            name=name + '_1_bn')(x1)
    x1 = Activation(activation, name=name + f'_1_{activation}')(x1)
    x1 = Conv3D(growth_rate, 3, padding='same', use_bias=False,
                name=name + '_2_conv')(x1)
    x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
    return x

In [None]:
#@title 2D Transition Block
def transition_block_2d(x, reduction, name, activation='relu'):
    """A transition block.

    # Arguments
        x: input tensor.
        reduction: float, compression rate at transition layers.
        name: string, block label.

    # Returns
        output tensor for the block.
    """
    bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
    x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                           name=name + '_bn')(x)
    x = Activation(activation, name=name + f'_{activation}')(x)
    x = Conv2D(int(K.int_shape(x)[bn_axis] * reduction), 1, use_bias=False,
               name=name + '_conv', padding='same')(x)
    x = AveragePooling2D(1, strides=(2, 2), name=name + '_pool', padding='same')(x)
    return x


In [None]:
#@title 3D Transition Block
def transition_block_3d(x, reduction, name, activation='relu'):
    """A transition block.

    # Arguments
        x: input tensor.
        reduction: float, compression rate at transition layers.
        name: string, block label.

    # Returns
        output tensor for the block.
    """
    bn_axis = 4 if K.image_data_format() == 'channels_last' else 1
    x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                           name=name + '_bn')(x)
    x = Activation(activation, name=name + f'_{activation}')(x)
    x = Conv3D(int(K.int_shape(x)[bn_axis] * reduction), 1, use_bias=False,
               name=name + '_conv', padding='same')(x)
    x = AveragePooling3D(1, strides=(2, 2, 2), name=name + '_pool', padding='same')(x)
    return x

In [None]:
#@title 2D Dense Block
def dense_block_2d(x, blocks, name, growth_rate=32, activation='relu'):
    """A dense block.

    # Arguments
        x: input tensor.
        blocks: integer, the number of building blocks.
        name: string, block label.

    # Returns
        output tensor for the block.
    """
    for i in range(blocks):
        x = conv_block_2d(x, growth_rate, activation=activation, name=name + '_block' + str(i + 1))
    return x

In [None]:
#@title 3D Dense Block
def dense_block_3d(x, blocks, name, growth_rate=32, activation='relu'):
    """A dense block.

    # Arguments
        x: input tensor.
        blocks: integer, the number of building blocks.
        name: string, block label.

    # Returns
        output tensor for the block.
    """
    for i in range(blocks):
        x = conv_block_3d(x, growth_rate, activation=activation, name=name + '_block' + str(i + 1))
    return x

In [None]:
#@title Fusion Fully Convolutional Network (FCN) Convolution Block
def fusion_fcn_conv_block(x, branch_num, block_num, growth_rate=64, activation='relu'):
    """
    """

    if x.shape[0] == 3:
        height = x.shape[0]
        width = x.shape[1]
    else:
        height = x.shape[1]
        width = x.shape[2]

    x = Conv2D(growth_rate, kernel_size=(3,3), strides=(1,1), padding='same',
                            name=f'Branch_{branch_num}_Conv2D_{block_num}')(x)
    x = Activation(activation, name=f'Branch_{branch_num}_{activation}_{block_num}')(x)
    x = AveragePooling2D(pool_size=(2,2), padding='same',
                            name=f'Branch_{branch_num}_AveragePool2D_{block_num}')(x)
    x = Resizing(height, width, interpolation='nearest',
                            name=f'Branch_{branch_num}_Resizing_{block_num}')(x)

    return x

In [None]:
#@title Network-in-Network (NiN) Block
def nin_block(x, filters, kernel_size, block_num, strides=(1,1), num_mlp_layers=2, activation='relu'):
    """
    """

    for layer in range(num_mlp_layers):
        x = Conv2D(x.shape[-1], kernel_size=(1,1), strides=(1,1),
                   name=f'MLPConv_{block_num}_layer_{layer}',
                   activation=activation, padding='valid',
                   kernel_regularizer=regularizers.l2(0.01))(x)

    x = Conv2D(filters, kernel_size=kernel_size, strides=strides,
               name=f'Conv_{block_num}', activation=activation, padding='same',
               kernel_regularizer=regularizers.l2(0.01))(x)

    return x

## 8.3) Densenet Model Builder

In [None]:
#@title Densenet 3D Builder Class
class Densenet3DBuilder(object):
    @staticmethod
    def build(input_shape, num_outputs):
        print('original input shape:', input_shape)
        _handle_dim_ordering()
        if len(input_shape) != 4:
            raise Exception("Input shape should be a tuple (nb_channels, kernel_dim1, kernel_dim2, kernel_dim3)")

        print('original input shape:', input_shape)
        # orignal input shape: 1,7,7,200

        print(f'Image data format: {K.image_data_format()}')
        if K.image_data_format() == 'channels_last':
            input_shape = (input_shape[1], input_shape[2], input_shape[3], input_shape[0])
        print('change input shape:', input_shape)

        # 张量流输入
        input = Input(shape=input_shape)

        # 3D Convolution and pooling
        conv1 = Conv3D(64, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='SAME', kernel_initializer='he_normal')(
            input)
        pool1 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding='same')(conv1)

        # Dense Block1
        # x = dense_block_3d(pool1, 6, name='conv1')
        # x = transition_block_3d(x, 0.5, name='pool1')
        # x = dense_block_3d(x, 6, name='conv2')
        # x = transition_block_3d(x, 0.5, name='pool2')
        # x = dense_block_3d(x, 6, name='conv3')
        x = dense_block_3d(pool1, 3, name='conv1')
        x = transition_block_3d(x, 0.5, name='pool1')
        x = dense_block_3d(x, 3, name='conv2')
        x = transition_block_3d(x, 0.5, name='pool2')
        x = dense_block_3d(x, 3, name='conv3')
        print(x.shape)
        x = GlobalAveragePooling3D(name='avg_pool')(x)
        print(x.shape)
        # x = Dense(16, activation='softmax')(x)

        # 输入分类器
        # Classifier block
        dense = Dense(units=num_outputs, activation="softmax", kernel_initializer="he_normal")(x)

        model = Model(inputs=input, outputs=dense, name='3D-DenseNet')
        return model

    @staticmethod
    def build_resnet_8(input_shape, num_outputs):
        # (1,7,7,200),16
        return Densenet3DBuilder.build(input_shape, num_outputs)


In [None]:
#@title Densenet 3D Model Assignment
def densenet_3d_model(img_rows, img_cols, img_channels, nb_classes):

    model = Densenet3DBuilder.build_resnet_8(
        (1, img_rows, img_cols, img_channels), nb_classes)

    return model

## 8.4) 3D CNN Model Builder

In [None]:
#@title CNN 3D Builder Class
class CNN3DBuilder(object):
    @staticmethod
    def build(input_shape, num_outputs):
        print('original input shape:', input_shape)
        _handle_dim_ordering()
        if len(input_shape) != 4:
            raise Exception("Input shape should be a tuple (nb_channels, kernel_dim1, kernel_dim2, kernel_dim3)")

        print('original input shape:', input_shape)
        # orignal input shape: 1,7,7,200

        print(f'Image data format: {K.image_data_format()}')
        if K.image_data_format() == 'channels_last':
            input_shape = (input_shape[1], input_shape[2], input_shape[3], input_shape[0])
        print('change input shape:', input_shape)

        input = Input(shape=input_shape)

        conv1 = Conv3D(filters=128, kernel_size=(3, 3, 20), strides=(1, 1, 5), padding='same',
                       kernel_regularizer=regularizers.l2(0.01))(input)
        act1 = Activation('relu')(conv1)
        pool1 = MaxPooling3D(pool_size=(2, 2, 2), strides=(1, 1, 1), padding='same')(act1)

        conv2 = Conv3D(filters=192, kernel_size=(2, 2, 3), strides=(1, 1, 2), padding='same',
                       kernel_regularizer=regularizers.l2(0.01))(pool1)
        act2 = Activation('relu')(conv2)
        drop1 = Dropout(0.5)(act2)
        pool2 = MaxPooling3D(pool_size=(2, 2, 2), strides=(1, 1, 1), padding='same')(drop1)

        conv3 = Conv3D(filters=256, kernel_size=(3, 3, 3), strides=(1, 1, 2), padding='same',
                       kernel_regularizer=regularizers.l2(0.01))(pool2)
        act3 = Activation('relu')(conv3)
        drop2 = Dropout(0.5)(act3)

        flatten1 = Flatten()(drop2)
        fc1 = Dense(200, kernel_regularizer=regularizers.l2(0.01))(flatten1)
        act3 = Activation('relu')(fc1)

        # conv1 = Conv3D(filters=32, kernel_size=(3, 3, 20), strides=(1, 1, 5), padding='same',
        #                kernel_regularizer=regularizers.l2(0.01))(input)
        # act1 = Activation('relu')(conv1)
        # pool1 = MaxPooling3D(pool_size=(2, 2, 2), strides=(1, 1, 1), padding='same')(act1)

        # conv2 = Conv3D(filters=64, kernel_size=(2, 2, 3), strides=(1, 1, 2), padding='same',
        #                kernel_regularizer=regularizers.l2(0.01))(pool1)
        # act2 = Activation('relu')(conv2)
        # drop1 = Dropout(0.5)(act2)
        # pool2 = MaxPooling3D(pool_size=(2, 2, 2), strides=(1, 1, 1), padding='same')(drop1)

        # conv3 = Conv3D(filters=128, kernel_size=(3, 3, 3), strides=(1, 1, 2), padding='same',
        #                kernel_regularizer=regularizers.l2(0.01))(pool2)
        # act3 = Activation('relu')(conv3)
        # drop2 = Dropout(0.5)(act3)

        # flatten1 = Flatten()(drop2)
        # fc1 = Dense(num_outputs*2, kernel_regularizer=regularizers.l2(0.01))(flatten1)
        # act3 = Activation('relu')(fc1)


        # Classifier block
        dense = Dense(units=num_outputs, activation="softmax", kernel_initializer="he_normal")(act3)

        model = Model(inputs=input, outputs=dense, name='3D-CNN')
        return model

    @staticmethod
    def build_resnet_8(input_shape, num_outputs):
        # (1,7,7,200),16
        return CNN3DBuilder.build(input_shape, num_outputs)

In [None]:
#@title CNN 3D Model Assignment
def cnn_3d_model(img_rows, img_cols, img_channels, nb_classes):

    model = CNN3DBuilder.build_resnet_8(
        (1, img_rows, img_cols, img_channels), nb_classes)

    return model

## 8.5) 2D CNN Model Builder

In [None]:
#@title CNN 2D Builder Class
class CNN2DBuilder(object):
    @staticmethod
    def build(input_shape, num_outputs):
        print('original input shape:', input_shape)
        _handle_dim_ordering()
        if len(input_shape) != 3:
            raise Exception("Input shape should be a tuple (nb_channels, kernel_dim1, kernel_dim2)")

        print('original input shape:', input_shape)

        print(f'Image data format: {K.image_data_format()}')
        if K.image_data_format() == 'channels_last':
            input_shape = (input_shape[1], input_shape[2], input_shape[0])
        print('change input shape:', input_shape)

        input = Input(shape=input_shape)

        conv1 = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding='same',
                       kernel_regularizer=regularizers.l2(0.01))(input)
        act1 = Activation('relu')(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same')(act1)

        conv2 = Conv2D(filters=192, kernel_size=(2, 2), strides=(1, 1), padding='same',
                       kernel_regularizer=regularizers.l2(0.01))(pool1)
        act2 = Activation('relu')(conv2)
        drop1 = Dropout(0.5)(act2)
        pool2 = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same')(drop1)

        conv3 = Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding='same',
                       kernel_regularizer=regularizers.l2(0.01))(pool2)
        act3 = Activation('relu')(conv3)
        drop2 = Dropout(0.5)(act3)

        flatten1 = Flatten()(drop2)
        fc1 = Dense(200, kernel_regularizer=regularizers.l2(0.01))(flatten1)
        act3 = Activation('relu')(fc1)

        # Classifier block
        dense = Dense(units=num_outputs, activation="softmax", kernel_initializer="he_normal")(act3)

        model = Model(inputs=input, outputs=dense, name='2D-CNN')
        return model

    @staticmethod
    def build_resnet_8(input_shape, num_outputs):
        return CNN2DBuilder.build(input_shape, num_outputs)

In [None]:
#@title CNN 2D Model Assignment
def cnn_2d_model(img_rows, img_cols, img_channels, nb_classes):

    model = CNN2DBuilder.build_resnet_8(
        (img_channels, img_rows, img_cols), nb_classes)

    return model

## 8.6) Baseline CNN Model Builder

In [None]:
#@title Baseline CNN Model Builder and Assignment
def baseline_cnn_model(img_rows, img_cols, img_channels,
                       patch_size, nb_filters, nb_classes):
    """
    Generates baseline CNN model for classifying HSI dataset.

    Parameters
    ----------
    img_rows : int
        Number of rows in neighborhood patch.
    img_cols : int
        Number of columns in neighborhood patch.
    img_channels : int
        Number of spectral bands.
    nb_classes : int
        Number of label categories.
    lr : float
        Learning rate for the model
    momentum : float
        Momentum value for optimizer

    Returns
    -------
    model : Model
        A keras API model of the constructed ML network.
    """

    model_input = Input(shape=(img_rows, img_cols, img_channels))
    conv_layer = Conv2D(nb_filters, (patch_size, patch_size),
                        strides=(1, 1),name='2d_convolution_layer', padding='same',
                        kernel_regularizer=regularizers.l2(0.01))(model_input)
    activation_layer = Activation('relu', name='activation_layer')(conv_layer)
    max_pool_layer = MaxPooling2D(pool_size=(2, 2), name='2d_max_pooling_layer', padding='same')(activation_layer)
    flatten_layer = Flatten(name='flatten_layer')(max_pool_layer)
    dense_layer = Dense(units=nb_classes, name='dense_layer')(flatten_layer)
    classifier_layer = Activation('softmax', name='classifier_layer')(dense_layer)

    model = Model(model_input, classifier_layer, name='baseline_cnn_model')

    return model

## 8.7) Fusion Fully Connected Network (FCN) Models

In [None]:
#@title Fusion FCN Model Builder and Assignment
def fusion_fcn_model(branch_1_shape, branch_2_shape, branch_3_shape, nb_classes):

    # Initialize Inputs
    branch_1_input = Input(shape=branch_1_shape)
    branch_2_input = Input(shape=branch_2_shape)
    branch_3_input = Input(shape=branch_3_shape)

    print(f'Branch 1 shape: {branch_1_input.shape}')
    print(f'Branch 2 shape: {branch_2_input.shape}')
    print(f'Branch 3 shape: {branch_3_input.shape}')

    # Set channel axis
    channel_axis = len(branch_1_input.shape) - 1 if K.image_data_format() == 'channels_last' else 1

    # First branch
    branch_1_a = fusion_fcn_conv_block(branch_1_input, 1, 1)
    branch_1_b = fusion_fcn_conv_block(branch_1_a, 1, 2)
    branch_1_c = fusion_fcn_conv_block(branch_1_b, 1, 3)

    branch_1 = Add(name='Branch_1_Add')([branch_1_a, branch_1_b, branch_1_c])

    # Second branch
    branch_2_a = fusion_fcn_conv_block(branch_2_input, 2, 1)
    branch_2_b = fusion_fcn_conv_block(branch_2_a, 2, 2)
    branch_2_c = fusion_fcn_conv_block(branch_2_b, 2, 3)

    branch_2 = Add(name='Branch_2_Add')([branch_2_a, branch_2_b, branch_2_c])

    # Third branch
    branch_3 = branch_3_input

    # Branch fusion
    fusion = Concatenate(axis=channel_axis, name='Fusion_Concatenate')([branch_1, branch_2, branch_3])
    fusion = Conv2D(nb_classes, (1,1), strides=(1,1), padding='same',
                        name='Fusion_Conv2D')(fusion)
    fusion = Activation('relu', name='Fusion_ReLU')(fusion)
    out = Activation('softmax', name='Fusion_Softmax')(fusion)

    model = Model(inputs=[branch_1_input, branch_2_input, branch_3_input],
                  outputs=out,
                  name='Fusion-FCN')

    return model

In [None]:
#@title Fusion FCN Model (Version 2) Model Builder and Assignment
def fusion_fcn_v2_model(branch_1_shape, branch_2_shape, branch_3_shape, nb_classes):

    # Initialize Inputs
    branch_1_input = Input(shape=branch_1_shape)
    branch_2_input = Input(shape=branch_2_shape)
    branch_3_input = Input(shape=branch_3_shape)

    print(f'Branch 1 shape: {branch_1_input.shape}')
    print(f'Branch 2 shape: {branch_2_input.shape}')
    print(f'Branch 3 shape: {branch_3_input.shape}')

    # Set channel axis
    channel_axis = len(branch_1_input.shape) - 1 if K.image_data_format() == 'channels_last' else 1

    # First branch
    branch_1_a = fusion_fcn_conv_block(branch_1_input, 1, 1)
    branch_1_b = fusion_fcn_conv_block(branch_1_a, 1, 2)
    branch_1_c = fusion_fcn_conv_block(branch_1_b, 1, 3)

    branch_1 = Add(name='Branch_1_Add')([branch_1_a, branch_1_b, branch_1_c])

    # Second branch
    branch_2_a = fusion_fcn_conv_block(branch_2_input, 2, 1)
    branch_2_b = fusion_fcn_conv_block(branch_2_a, 2, 2)
    branch_2_c = fusion_fcn_conv_block(branch_2_b, 2, 3)

    branch_2 = Add(name='Branch_2_Add')([branch_2_a, branch_2_b, branch_2_c])

    # Third branch
    branch_3_a = fusion_fcn_conv_block(branch_3_input, 3, 1)
    branch_3_b = fusion_fcn_conv_block(branch_3_a, 3, 2)
    branch_3_c = fusion_fcn_conv_block(branch_3_b, 3, 3)

    branch_3 = Add(name='Branch_3_Add')([branch_3_a, branch_3_b, branch_3_c])

    # Branch fusion
    fusion = Concatenate(axis=channel_axis, name='Fusion_Concatenate')([branch_1, branch_2, branch_3])
    fusion = Conv2D(fusion.shape[channel_axis], (1,1), strides=(1,1), padding='same',
                        name='Fusion_Conv2D')(fusion)
    fusion = Activation('relu', name='Fusion_ReLU')(fusion)
    fusion = Flatten(name='Fusion_Flatten')(fusion)
    fusion = Dense(units=nb_classes, name='Fusion_Dense')(fusion)
    out = Activation('softmax', name='Fusion_Softmax')(fusion)

    model = Model(inputs=[branch_1_input, branch_2_input, branch_3_input],
                  outputs=out,
                  name='Fusion-FCN-V2')

    return model

## 8.8) 2D Densenet Model

In [None]:
#@title Densenet 2D Builder Class
class Densenet2DBuilder(object):
    @staticmethod
    def build(input_shape, num_outputs):
        print('original input shape:', input_shape)
        _handle_dim_ordering()
        if len(input_shape) != 3:
            raise Exception("Input shape should be a tuple (nb_channels, kernel_dim1, kernel_dim2)")

        print('original input shape:', input_shape)

        print(f'Image data format: {K.image_data_format()}')
        if K.image_data_format() == 'channels_last':
            input_shape = (input_shape[1], input_shape[2], input_shape[0])
        print('change input shape:', input_shape)

        input = Input(shape=input_shape)

        # 2D Convolution and pooling
        conv1 = Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding='SAME', kernel_initializer='he_normal')(
            input)
        pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(conv1)

        # Dense Block1
        x = dense_block_2d(pool1, 6, name='conv1')
        x = transition_block_2d(x, 0.5, name='pool1')
        x = dense_block_2d(x, 6, name='conv2')
        x = transition_block_2d(x, 0.5, name='pool2')
        x = dense_block_2d(x, 6, name='conv3')
        x = GlobalAveragePooling2D(name='avg_pool')(x)

        # Classifier block
        dense = Dense(units=num_outputs, activation="softmax", kernel_initializer="he_normal")(x)

        model = Model(inputs=input, outputs=dense, name='2D-DenseNet')
        return model

    @staticmethod
    def build_resnet_8(input_shape, num_outputs):
        return Densenet2DBuilder.build(input_shape, num_outputs)


In [None]:
#@title Densenet 2D Model Assignment
def densenet_2d_model(img_rows, img_cols, img_channels, nb_classes):

    model = Densenet2DBuilder.build_resnet_8(
        (img_channels, img_rows, img_cols), nb_classes)

    return model

## 8.9) Network-in-Network (NiN) Models

In [None]:
#@title Network-in-Network Model Builder and Assignment
def nin_model(img_rows, img_cols, img_channels, num_classes, num_mlp_layers=2):
    """
    """

    model_input = Input(shape=(img_rows, img_cols, img_channels))

    # Convolution block 1
    x = nin_block(model_input, img_channels, (5,5), 1, num_mlp_layers=num_mlp_layers)
    x = Dropout(0.5)(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same',
                     name='Spatial_Pooling_1')(x)

    # Convolution block 2
    x = nin_block(x, img_channels, (3,3), 2, num_mlp_layers=num_mlp_layers)
    x = Dropout(0.5)(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same',
                     name='Spatial_Pooling_1')(x)

    # Convolution block 3
    x = nin_block(x, num_classes, (3,3), 3, num_mlp_layers=num_mlp_layers)
    x = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same',
                     name='Spatial_Pooling_1')(x)

    # Global Average Pooling
    x = GlobalAveragePooling2D(name='Global_Average_Pooling')(x)
    x = Activation('softmax', name='Softmax_Classification')(x)

    # Model creation
    model = Model(model_input, x, name='nin_model')

    return model

In [None]:
#@title NiN Band Selection Model Builder and Assignment
def nin_band_selection_model(nb_channels, nb_classes, nb_layers=2):
    """
    """
    model_input = Input(shape=(1, 1, nb_channels))
    x = model_input

    for layer in range(nb_layers):
        x = Conv2D(nb_channels, kernel_size=(1,1), strides=(1,1), name=f'mlp_conv_{layer}',
               padding='same', kernel_regularizer=regularizers.l2(0.01))(x)

    x = GlobalAveragePooling2D(name='global_average_pooling')(x)
    x = Activation('softmax', name='softmax_classification')(x)


    model = Model(model_input, x, name='nin_band_selection_model')

    return model

## 8.10) Band Selection 3D Densenet Model

In [None]:
#@title Band Selection Densenet 3D Builder Class
class BSDensenet3DBuilder(object):
    @staticmethod
    def build(input_shape, num_outputs):
        print('original input shape:', input_shape)
        if len(input_shape) != 4:
            raise Exception("Input shape should be a tuple (nb_channels, kernel_dim1, kernel_dim2, kernel_dim3)")

        print('original input shape:', input_shape)
        # orignal input shape: 1,7,7,200

        print(f'Image data format: {K.image_data_format()}')
        channels = input_shape[3]
        if K.image_data_format() == 'channels_last':
            input_shape = (input_shape[1], input_shape[2], input_shape[3], input_shape[0])
        print('change input shape:', input_shape)

        # Set input
        input = Input(shape=input_shape)

        # 3D Convolution and pooling
        x = Conv2D(channels, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_initialize='he_normal')(input)
        x = Conv3D(64, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='SAME', kernel_initializer='he_normal')(x)
        x = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding='same')(x)

        # Dense Block1
        x = dense_block_3d(x, 6, name='conv1')
        x = transition_block_3d(x, 0.5, name='pool1')
        x = dense_block_3d(x, 6, name='conv2')
        x = transition_block_3d(x, 0.5, name='pool2')
        x = dense_block_3d(x, 6, name='conv3')
        print(x.shape)
        x = GlobalAveragePooling3D(name='avg_pool')(x)
        print(x.shape)
        # x = Dense(16, activation='softmax')(x)

        # 输入分类器
        # Classifier block
        output = Dense(units=num_outputs, activation="softmax", kernel_initializer="he_normal")(x)

        model = Model(inputs=input, outputs=output, name='Band Section 3D-DenseNet')
        return model

    @staticmethod
    def build(input_shape, num_outputs):
        return BSDensenet3DBuilder.build(input_shape, num_outputs)

In [None]:
#@title Band Selection 3D Densenet Model Assignment
def bs_3d_densenet_model(img_rows, img_cols, img_channels, nb_classes):

    model = BSDensenet3DBuilder.build(
        (1, img_rows, img_cols, img_channels), nb_classes)

## 8.11) Densenet 3D Fusion Models

In [None]:
#@title Densenet 3D Fusion Model Builder and Assignment
def densenet_3d_fusion_model(img_rows, img_cols, img_channels_list,
                             nb_classes, num_dense_blocks=3):

    # Note - normal num_dense_blocks is 6

    branch_shapes = []

    # Initialize shapes
    for img_channels in img_channels_list:
        if K.image_data_format() == 'channels_last':
            branch_shapes.append((img_rows, img_cols, img_channels))
        else:
            branch_shapes.append((img_channels, img_rows, img_cols))

    # Initialize inputs
    branch_inputs = [Input(shape=shape) for shape in branch_shapes]

    # Print input shapes
    for index, branch_input in enumerate(branch_inputs):
        print(f'Branch {index+1} input shape: {branch_input.shape}')


    # Set up branches
    branches = []

    for index, branch_input in enumerate(branch_inputs):
        num_channels = img_channels_list[index]
        branch_num = index + 1
        x = branch_input

        # x = Conv2D(num_channels, kernel_size=(1, 1), strides=(1, 1), padding='valid',
        #             kernel_initializer='he_normal', name=f'Branch_{branch_num}_Conv1x1_1')(x)
        # x = Conv2D(num_channels, kernel_size=(1, 1), strides=(1, 1), padding='valid',
        #             kernel_initializer='he_normal', name=f'Branch_{branch_num}_Conv1x1_2')(x)
        if K.image_data_format() == 'channels_last':
            x = Reshape((*branch_input.shape[1:], 1), name=f'Branch_{branch_num}_Reshape')(x)
        else:
            x = Reshape((1, *branch_input.shape[1:]), name=f'Branch_{branch_num}_Reshape')(x)
        x = Conv3D(64, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='same',
                    kernel_initializer='he_normal', name=f'Branch_{branch_num}_3DConv')(x)
        x = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding='same',
                    name=f'Branch_{branch_num}_MaxPooling3D')(x)

        # Dense Blocks
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv1')
        x = transition_block_3d(x, 0.5, name=f'Branch_{branch_num}__pool1')
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv2')
        x = transition_block_3d(x, 0.5, name=f'Branch_{branch_num}__pool2')
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv3')

        x = GlobalAveragePooling3D(name=f'Branch_{branch_num}__avg_pool')(x)

        branches.append(x)

    # Print the output shape of the branches
    for index, branch in enumerate(branches):
        print(f'Branch {index+1} output shape: {branch.shape}')


    # Branch fusion
    fusion = Concatenate(axis=1, name='Fusion_Concatenate')(branches)
    print(f'Shape after concatenation: {fusion.shape}')
    fusion = Dense(units=fusion.shape[-1], kernel_initializer='he_normal', name='Fusion_Dense_1')(fusion)
    fusion = Activation('relu', name='Fusion_ReLU')(fusion)
    fusion = Dense(units=nb_classes, kernel_initializer='he_normal', name='Fusion_Dense_2')(fusion)
    out = Activation('softmax', name='Fusion_Softmax')(fusion)

    print(f'Fusion output shape: {out.shape}')

    model = Model(inputs=branch_inputs,
                  outputs=out,
                  name='3D-Densenet-Fusion')

    return model

In [None]:
#@title Densenet 3D Fusion Model (Version 2) Builder and Assignment
def densenet_3d_fusion_model2(img_rows, img_cols, img_channels_list,
                             nb_classes, num_dense_blocks=3):

    # Note - normal num_dense_blocks is 6

    branch_shapes = []

    # Initialize shapes
    for img_channels in img_channels_list:
        if K.image_data_format() == 'channels_last':
            branch_shapes.append((img_rows, img_cols, img_channels))
        else:
            branch_shapes.append((img_channels, img_rows, img_cols))

    # Initialize inputs
    branch_inputs = [Input(shape=shape) for shape in branch_shapes]

    # Print input shapes
    for index, branch_input in enumerate(branch_inputs):
        print(f'Branch {index+1} input shape: {branch_input.shape}')


    # Set up branches
    branches = []

    for index, branch_input in enumerate(branch_inputs):
        num_channels = img_channels_list[index]
        branch_num = index + 1
        x = branch_input

        x = Conv2D(num_channels, kernel_size=(1, 1), strides=(1, 1), padding='valid',
                    kernel_initializer='he_normal', name=f'Branch_{branch_num}_Conv1x1_1')(x)
        if K.image_data_format() == 'channels_last':
            x = Reshape((*branch_input.shape[1:], 1), name=f'Branch_{branch_num}_Reshape')(x)
        else:
            x = Reshape((1, *branch_input.shape[1:]), name=f'Branch_{branch_num}_Reshape')(x)
        x = Conv3D(64, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='same',
                    kernel_initializer='he_normal', name=f'Branch_{branch_num}_3DConv')(x)
        x = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding='same',
                    name=f'Branch_{branch_num}_MaxPooling3D')(x)

        # Dense Blocks
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv1')
        x = transition_block_3d(x, 0.5, name=f'Branch_{branch_num}__pool1')
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv2')
        x = transition_block_3d(x, 0.5, name=f'Branch_{branch_num}__pool2')
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv3')

        x = GlobalAveragePooling3D(name=f'Branch_{branch_num}__avg_pool')(x)

        branches.append(x)

    # Print the output shape of the branches
    for index, branch in enumerate(branches):
        print(f'Branch {index+1} output shape: {branch.shape}')


    # Branch fusion
    fusion = Concatenate(axis=1, name='Fusion_Concatenate')(branches)
    print(f'Shape after concatenation: {fusion.shape}')
    fusion = Dense(units=fusion.shape[-1], kernel_initializer='he_normal', name='Fusion_Dense_1')(fusion)
    fusion = Activation('relu', name='Fusion_ReLU')(fusion)
    fusion = Dense(units=nb_classes, kernel_initializer='he_normal', name='Fusion_Dense_2')(fusion)
    out = Activation('softmax', name='Fusion_Softmax')(fusion)

    print(f'Fusion output shape: {out.shape}')

    model = Model(inputs=branch_inputs,
                  outputs=out,
                  name='3D-Densenet-Fusion')

    return model

In [None]:
#@title Densenet 3D Fusion Model (Version 3) Builder and Assignment
def densenet_3d_fusion_model3(img_rows, img_cols, img_channels_list,
                             nb_classes, num_dense_blocks=3):

    # Note - normal num_dense_blocks is 6

    branch_shapes = []

    # Initialize shapes
    for img_channels in img_channels_list:
        if K.image_data_format() == 'channels_last':
            branch_shapes.append((img_rows, img_cols, img_channels))
        else:
            branch_shapes.append((img_channels, img_rows, img_cols))

    # Initialize inputs
    branch_inputs = [Input(shape=shape) for shape in branch_shapes]

    # Print input shapes
    for index, branch_input in enumerate(branch_inputs):
        print(f'Branch {index+1} input shape: {branch_input.shape}')


    # Set up branches
    branches = []

    for index, branch_input in enumerate(branch_inputs):
        num_channels = img_channels_list[index]
        branch_num = index + 1
        x = branch_input

        x = Conv2D(num_channels, kernel_size=(1, 1), strides=(1, 1), padding='valid',
                    kernel_initializer='he_normal', name=f'Branch_{branch_num}_Conv1x1_1')(x)
        x = Conv2D(num_channels, kernel_size=(1, 1), strides=(1, 1), padding='valid',
                    kernel_initializer='he_normal', name=f'Branch_{branch_num}_Conv1x1_2')(x)
        if K.image_data_format() == 'channels_last':
            x = Reshape((*branch_input.shape[1:], 1), name=f'Branch_{branch_num}_Reshape')(x)
        else:
            x = Reshape((1, *branch_input.shape[1:]), name=f'Branch_{branch_num}_Reshape')(x)
        x = Conv3D(64, kernel_size=(3, 3, 3), strides=(1, 1, 1), padding='same',
                    kernel_initializer='he_normal', name=f'Branch_{branch_num}_3DConv')(x)
        x = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding='same',
                    name=f'Branch_{branch_num}_MaxPooling3D')(x)

        # Dense Blocks
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv1')
        x = transition_block_3d(x, 0.5, name=f'Branch_{branch_num}__pool1')
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv2')
        x = transition_block_3d(x, 0.5, name=f'Branch_{branch_num}__pool2')
        x = dense_block_3d(x, num_dense_blocks, name=f'Branch_{branch_num}__conv3')

        x = GlobalAveragePooling3D(name=f'Branch_{branch_num}__avg_pool')(x)

        branches.append(x)

    # Print the output shape of the branches
    for index, branch in enumerate(branches):
        print(f'Branch {index+1} output shape: {branch.shape}')


    # Branch fusion
    fusion = Concatenate(axis=1, name='Fusion_Concatenate')(branches)
    print(f'Shape after concatenation: {fusion.shape}')
    fusion = Dense(units=fusion.shape[-1], kernel_initializer='he_normal', name='Fusion_Dense_1')(fusion)
    fusion = Activation('relu', name='Fusion_ReLU')(fusion)
    fusion = Dense(units=nb_classes, kernel_initializer='he_normal', name='Fusion_Dense_2')(fusion)
    out = Activation('softmax', name='Fusion_Softmax')(fusion)

    print(f'Fusion output shape: {out.shape}')

    model = Model(inputs=branch_inputs,
                  outputs=out,
                  name='3D-Densenet-Fusion')

    return model

## 8.12) Densenet 3D Modified Model

In [None]:
#@title Densenet 3D Modified Model Builder And Assignment
def densenet_3d_modified_model(img_rows, img_cols, img_channels_list, nb_classes,
                               num_dense_blocks=3,
                               growth_rate=32,
                               num_1x1_convs=0,
                               first_conv_filters=64,
                               first_conv_kernel=(3,3,3),
                               dropout_1=0.5,
                               dropout_2=0.5,
                               activation='relu'):

    branch_shapes = []

    # Initialize shapes
    for img_channels in img_channels_list:
        if K.image_data_format() == 'channels_last':
            branch_shapes.append((img_rows, img_cols, img_channels))
        else:
            branch_shapes.append((img_channels, img_rows, img_cols))

    # Initialize inputs
    branch_inputs = [Input(shape=shape) for shape in branch_shapes]

    # Print input shapes
    for index, branch_input in enumerate(branch_inputs):
        print(f'Branch {index+1} input shape: {branch_input.shape}')


    # Set up branches
    branches = []

    for index, branch_input in enumerate(branch_inputs):
        num_channels = img_channels_list[index]
        branch_num = index + 1
        x = branch_input

        for conv_1x1_num in range(num_1x1_convs):
            x = Conv2D(num_channels, kernel_size=(1, 1), strides=(1, 1), padding='valid',
                        kernel_initializer='he_normal', name=f'Branch_{branch_num}_Conv1x1_{conv_1x1_num}')(x)
            x = Activation(activation, name=f'Branch_{branch_num}_Conv1x1_{activation}_{conv_1x1_num}')(x)

        if K.image_data_format() == 'channels_last':
            x = Reshape((*branch_input.shape[1:], 1), name=f'Branch_{branch_num}_Reshape')(x)
        else:
            x = Reshape((1, *branch_input.shape[1:]), name=f'Branch_{branch_num}_Reshape')(x)


        x = Conv3D(first_conv_filters, kernel_size=first_conv_kernel, strides=(1, 1, 1), padding='same',
                    kernel_initializer='he_normal', name=f'Branch_{branch_num}_3DConv')(x)
        x = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding='same',
                    name=f'Branch_{branch_num}_MaxPooling3D')(x)

        # Dense Blocks
        x = dense_block_3d(x, num_dense_blocks, growth_rate=growth_rate, activation=activation, name=f'Branch_{branch_num}__conv1')
        x = transition_block_3d(x, dropout_1, activation=activation, name=f'Branch_{branch_num}__pool1')
        x = dense_block_3d(x, num_dense_blocks, growth_rate=growth_rate, activation=activation, name=f'Branch_{branch_num}__conv2')
        x = transition_block_3d(x, dropout_2, activation=activation, name=f'Branch_{branch_num}__pool2')
        x = dense_block_3d(x, num_dense_blocks, growth_rate=growth_rate, activation=activation, name=f'Branch_{branch_num}__conv3')

        x = GlobalAveragePooling3D(name=f'Branch_{branch_num}__avg_pool')(x)

        branches.append(x)

    # Print the output shape of the branches
    for index, branch in enumerate(branches):
        print(f'Branch {index+1} output shape: {branch.shape}')


    if len(img_channels_list) > 1:
    # Branch fusion
        fusion = Concatenate(axis=1, name='Fusion_Concatenate')(branches)
        print(f'Shape after concatenation: {fusion.shape}')
        fusion = Dense(units=fusion.shape[-1], kernel_initializer='he_normal', name='Fusion_Dense_1')(fusion)
        fusion = Activation(activation, name=f'Fusion_{activation}')(fusion)
        fusion = Dense(units=nb_classes, kernel_initializer='he_normal', name='Fusion_Dense_2')(fusion)
        out = Activation('softmax', name='Fusion_Softmax')(fusion)

        print(f'Fusion output shape: {out.shape}')

        model_name = '3D-Densenet-Fusion'
    else:
        # x = Dense(units=x.shape[-1], kernel_initializer='he_normal', name='Dense_1')(x)
        # x = Activation(activation, name=f'Dense_1_{activation}')(x)
        x = Dense(units=nb_classes, kernel_initializer='he_normal', name='Dense_2')(x)
        out = Activation('softmax', name='Softmax')(x)

        print(f'Output shape: {out.shape}')

        model_name = '3D-Densenet'

    model = Model(inputs=branch_inputs,
                  outputs=out,
                  name=model_name)

    return model

# 9) Training Functions

In [None]:
#@title Training Summary Plot Creation Function
def create_training_summary_plot(history, experiment_name, output_path):
    """
    """
    print('Creating training summary plots...')

    # Plot loss
    plt.subplot(211)
    plt.title('Cross Entropy Loss')
    plt.plot(history.history['loss'], color='blue', label='train')
    if 'val_loss' in history.history:
        plt.plot(history.history['val_loss'], color='orange', label='test')

    # Plot accuracy
    plt.subplot(212)
    plt.title('Sparse Categorical Accuracy')
    plt.plot(history.history['sparse_categorical_accuracy'], color='blue', label='train')
    if 'val_sparse_categorical_accuracy' in history.history:
        plt.plot(history.history['val_sparse_categorical_accuracy'], color='orange', label='test')

    # save plot to file
    print('Saving training summary plot to file...')
    filename = os.path.join(output_path, f'{experiment_name}_training_summary_plot.png')
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()
    plt.clf()

In [None]:
#@title Model Training Function
def train_model(model, train_dataset, val_dataset,
                iteration = None, **hyperparams):
    """
    """
    # Initialize variables from the hyperparameters
    experiment_name = hyperparams['experiment_name']
    epochs = hyperparams['epochs']
    epochs_before_decay = hyperparams['epochs_before_decay']
    lr_decay_rate = hyperparams['lr_decay_rate']
    patience = hyperparams['patience']
    loss = hyperparams['loss']
    model_metrics = hyperparams['metrics']
    output_path = hyperparams['output_path']
    model_save_period = hyperparams['model_save_period']
    optimizer = get_optimizer(**hyperparams)
    callbacks = []

    # Determine ID string for experiment
    if experiment_name is not None:
        experiment_id = experiment_name
    elif iteration is not None:
        experiment_id = f'experiment_{iteration+1}'
    else:
        experiment_id = 'experiment'

    # Create best weights path filename
    best_weights_path = os.path.join(output_path,
        f'{model.name}_best_weights_{experiment_id}.hdf5')

    checkpoint_path_prefix = os.path.join(output_path, f'{experiment_id}_checkpoint_')

    if model_save_period is not None:
        # Create callback to save model weights every 'period' number
        # of epochs
        cb_periodic_model_checkpoint = ModelCheckpoint(checkpoint_path_prefix + '{epoch:08d}.hdf5', period=model_save_period)
        callbacks.append(cb_periodic_model_checkpoint)

    if patience is not None:
        # Create callback to stop training early if metrics don't improve
        cb_early_stopping = EarlyStopping(monitor='val_loss',
                                        patience=patience,
                                        verbose=1,
                                        mode='auto',
                                        restore_best_weights=True)
        callbacks.append(cb_early_stopping)

    if val_dataset is not None:
        # Create callback to save model weights if the model performs
        # better than the previously trained models
        cb_save_best_model = ModelCheckpoint(best_weights_path,
                                            monitor='val_loss',
                                            verbose=1,
                                            save_best_only=True,
                                            mode='auto')
        callbacks.append(cb_save_best_model)

    if lr_decay_rate is not None and epochs_before_decay is not None:
        # This function keeps the initial learning rate for a set number of epochs
        # and reduces it at decay rate after that
        def scheduler(epoch, lr):
            if epoch < epochs_before_decay:
                return lr
            else:
                print(f'Learning rate reduced from {lr} to {lr*lr_decay_rate}...')
                return lr * lr_decay_rate
                # return lr * tf.math.exp(-0.1)

        # Create learning rate scheduler callback for learning rate decay
        cb_lr_decay = LearningRateScheduler(scheduler)

        callbacks.append(cb_lr_decay)

    # Compile the model with the appropriate loss function, optimizer,
    # and metrics
    model.compile(loss=loss,
                  optimizer=optimizer,
                  metrics=model_metrics,
                  loss_weights=None,
                  weighted_metrics=None,
                  run_eagerly=None,
                  )

    # Display a summary of the model being trained
    model.summary()

    # Record start time for model training
    model_train_start = time.process_time()

    # Train the model
    model_history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=epochs,
            verbose=1,
            shuffle=True,
            callbacks=callbacks
        )

    # Record end time for model training
    model_train_end = time.process_time()

    # Calculate training and testing times
    model_train_time = datetime.timedelta(seconds=(model_train_end - model_train_start))

    # Write model history to file
    with open(os.path.join(output_path,
         f'{experiment_id}_training_history.txt'), 'w') as hf:

        hf.write(f'EXPERIMENT #{iteration+1} MODEL HISTORY:\n')
        hf.write('-----------------------------------------------\n')
        hf.write(f'MODEL: {model.name}\n')
        hf.write('-----------------------------------------------\n')

        # Save model summary to file as well
        model.summary(print_fn=lambda x: hf.write(x + '\n'))

        # Show epoch with best validation value
        if patience is not None:
            hf.write(f'Best Epoch: {cb_early_stopping.best_epoch}\n')

        # Get number of epochs model actually ran for
        ran_epochs = model_history.params['epochs']
        if patience is not None:
            if cb_early_stopping.stopped_epoch > 0:
                ran_epochs = cb_early_stopping.stopped_epoch

        # Save info from each epoch to file
        for epoch in range(ran_epochs):
            hf.write(f'EPOCH: {epoch+1}\n')
            for key in model_history.history.keys():
                hf.write(f'  {key}: {model_history.history[key][epoch]}\n')

    create_training_summary_plot(model_history, experiment_id, output_path)

    # Save final weights
    # final_weights_path = os.path.join(output_path,
    #     f'{model.name}_final_weights_{experiment_id}.hdf5')
    # model.save_weights(final_weights_path)

    return model, model_train_time

# 10) Evaluation Functions

In [None]:
#@title Model Testing Function
def test_model(model, test_dataset, **hyperparams):
    """
    """

    print(f'Testing {model.name} with test dataset...')

    # Record start time for model evaluation
    model_test_start = time.process_time()

    # Get prediction values for test dataset
    pred_test = model.predict(test_dataset,
                              verbose=1,
                              ).argmax(axis=1)

    # Record end time for model evaluation
    model_test_end = time.process_time()

    # Get time elapsed for testing model
    model_test_time = datetime.timedelta(seconds=(model_test_end - model_test_start))

    print('Testing completed!')


    return pred_test, model_test_time

In [None]:
#@title Model Statistics Calculation Function
def calculate_model_statistics(pred_test, target_test, labels,
                               **hyperparams):
    """
    """
    labels = hyperparams['all_class_labels']

    overall_acc = metrics.accuracy_score(target_test, pred_test)
    precision = metrics.precision_score(target_test, pred_test, average='micro')
    recall = metrics.recall_score(target_test, pred_test, average='micro')
    kappa = metrics.cohen_kappa_score(target_test, pred_test)
    confusion_matrix = metrics.confusion_matrix(target_test, pred_test, labels=range(len(labels)))

    # Supress/hide invalid value warning
    # np.seterr(invalid='ignore')

    # Calculate average accuracy and per-class accuracies
    list_diag = np.diag(confusion_matrix)
    list_raw_sum = np.sum(confusion_matrix, axis=1)
    each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
    average_acc = np.mean(each_acc)

    # Get classification report
    classification_report = metrics.classification_report(target_test,
                                                          pred_test,
                                                          labels=range(len(labels)),
                                                          target_names=labels,
                                                          digits=3)

    results = {
        'overall_accuracy': overall_acc,
        'average_accuracy': average_acc,
        'precision_score': precision,
        'recall_score': recall,
        'cohen_kappa_score': kappa,
        'confusion_matrix': confusion_matrix,
        'per_class_accuracies': each_acc,
        'labels': labels,
        'classification_report': classification_report,
    }

    return results

In [None]:
#@title Confusion Matrix Plot Creation Function
def create_confusion_matrix_plot(confusion_matrix, labels, model_name,
                                 output_path='./', iteration=None):
    """
    """

    # Create filename for confusion matrix image file
    if iteration is not None:
        filename = f'experiment_{iteration+1}_{model_name}_confusion_matrix.png'
    else:
        filename = f'experiment_{model_name}_confusion_matrix.png'

    # Create full file name for confusion matrix image file
    cm_filename = os.path.join(output_path, filename)

    # Create annotations for confusion matrix
    print('Creating confusion matrix annotations...')
    cm_sum = np.sum(confusion_matrix, axis=1, keepdims=True)
    cm_perc = confusion_matrix / cm_sum.astype(float) * 100
    annot = np.empty_like(confusion_matrix).astype(str)
    nrows, ncols = confusion_matrix.shape
    for i in range(nrows):
        for j in range(ncols):
            c = confusion_matrix[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)


    # Create confusion matrix dataframe
    print('Creating confusion matrix plot...')
    cm = pd.DataFrame(confusion_matrix, index=labels, columns=labels)
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=(30,30))
    sns.heatmap(cm, annot=annot, fmt='', ax=ax)

    if iteration is not None:
        plt.title(f'Experiment #{iteration+1} {model_name} Confusion Matrix')
    else:
        plt.title(f'Experiment w/ {model_name} Confusion Matrix')

    print('Saving confusion matrix plot...')
    plt.savefig(cm_filename)

    # Clear plot data for next plot
    plt.clf()

In [None]:
#@title Experiment Results Output Function
def output_experiment_results(experiment_info):
    """
    """

    # Get variables from dictionary
    experiment_name = experiment_info['experiment_name']
    model_name = experiment_info['model_name']
    model_train_time = experiment_info['model_train_time']
    model_test_time = experiment_info['model_test_time']
    overall_acc = experiment_info['overall_accuracy']
    average_acc = experiment_info['average_accuracy']
    per_class_accuracies = experiment_info['per_class_accuracies']
    precision = experiment_info['precision_score']
    recall = experiment_info['recall_score']
    kappa = experiment_info['cohen_kappa_score']
    labels = experiment_info['labels']
    classification_report = experiment_info['classification_report']

    # Print results
    print('---------------------------------------------------')
    if experiment_name is None:
        print('             MODEL EXPERIMENT RESULTS              ')
    else:
        print(f'          "{experiment_name}" RESULTS              ')
    print('---------------------------------------------------')
    print(f' MODEL NAME: {model_name}')
    print('---------------------------------------------------')
    print(f'{model_name} train time: {model_train_time}')
    print(f'{model_name} test time:  {model_test_time}')
    print('...................................................')
    print(f'{model_name} overall accuracy:  {overall_acc}')
    print(f'{model_name} average accuracy:  {average_acc}')
    print(f'{model_name} precision score:   {precision}')
    print(f'{model_name} recall score:      {recall}')
    print(f'{model_name} cohen kappa score: {kappa}')
    print('...................................................')
    print(f'{model_name} Per-class accuracies:')
    for i, label in enumerate(labels):
        print(f'{label}: {per_class_accuracies[i]}')
    print('---------------------------------------------------')
    print('              CLASSIFICATION REPORT                ')
    print('...................................................')
    print(classification_report)
    print('---------------------------------------------------')
    print()

# 11) Test Harness Function

In [None]:
#@title Experiment Parameter List
PARAMETER_LIST = (
    "experiment_name",
    "experiment_number",
    "cuda",
    "restore",
    "output_path",
    "dataset",
    "path_to_dataset",
    "reuse_last_dataset",
    "predict_only",
    "skip_data_preprocessing",
    "skip_band_selection",
    "skip_data_postprocessing",
    "model_id",
    "add_branch",
    "random_seed",
    "epochs",
    "epochs_before_decay",
    "batch_size",
    "patch_size",
    "center_pixel",
    "train_split",
    "split_mode",
    "class_balancing",
    "iterations",
    "patience",
    "model_save_period",
    "optimizer",
    "lr",
    "lr_decay_rate",
    "momentum",
    "epsilon",
    "initial_accumulator_value",
    "beta",
    "beta_1",
    "beta_2",
    "amsgrad",
    "rho",
    "centered",
    "nesterov",
    "learning_rate_power",
    "l1_regularization_strength",
    "l2_regularization_strength",
    "l2_shrinkage_regularization_strength",
    "flip_augmentation",
    "radiation_augmentation",
    "mixture_augmentation",
    "use_hs_data",
    "use_lidar_ms_data",
    "use_lidar_ndsm_data",
    "use_vhr_data",
    "use_all_data",
    "normalize_hs_data",
    "normalize_lidar_ms_data",
    "normalize_lidar_ndsm_data",
    "normalize_vhr_data",
    "hs_resampling",
    "lidar_ms_resampling",
    "lidar_ndsm_resampling",
    "vhr_resampling",
    "hs_histogram_equalization",
    "lidar_ms_histogram_equalization",
    "lidar_dsm_histogram_equalization",
    "lidar_dem_histogram_equalization",
    "lidar_ndsm_histogram_equalization",
    "vhr_histogram_equalization",
    "hs_data_filter",
    "lidar_ms_data_filter",
    "lidar_dsm_data_filter",
    "lidar_dem_data_filter",
    "vhr_data_filter",
    "band_reduction_method",
    "n_components",
    "selected_bands",
    "select_only_hs_bands",
)

In [None]:
#@title Test Harness Runner Function
def run_test_harness(**hyperparams):
    """
    """

    experiment_name = hyperparams['experiment_name']

    # Get output path
    if hyperparams['experiment_number'] < 1:
        experiment_number = 1
    else:
        experiment_number = hyperparams['experiment_number']

    # Get output path
    if hyperparams['output_path'] is not None:
        output_path = hyperparams['output_path']
    else:
        output_path = './'

    # Get hyperparam derived variable values
    if 'experiments_json' in hyperparams and hyperparams['experiments_json'] is not None:
        # Transpose the json dataframe, since the experiments are read
        # in as columns instead of rows
        experiments = pd.read_json(hyperparams['experiments_json']).T
        iterations = experiments.shape[0]
        outfile_prefix = Path(hyperparams['experiments_json']).stem
    elif 'experiments_csv' in hyperparams and hyperparams['experiments_csv'] is not None:
        experiments = pd.read_csv(hyperparams['experiments_csv'])
        iterations = experiments.shape[0]
        outfile_prefix = Path(hyperparams['experiments_csv']).stem
    else:
        experiments = None
        iterations = hyperparams['iterations']
        if iterations is None:
            iterations = 1
        if hyperparams['experiment_name'] is None:
            outfile_prefix = 'experiment'
        else:
            outfile_prefix = hyperparams['experiment_name']

    # Get model name of CPU
    cpu_name = cpuinfo.get_cpu_info()['brand_raw']

    # Get model names of all GPUs on system
    gpu_names = []
    for gpu in tf.config.list_physical_devices(device_type = 'GPU'):
        gpu_names.append(tf.config.experimental.get_device_details(gpu)['device_name'])

    # Initialize data list variables for CSV output at end of program
    experiment_data_list = []
    per_class_data_lists = {}
    per_class_selected_band_lists = {}

    print('-------------------------------------------------------------------')
    print('-------------------------------------------------------------------')
    print('-------------------------------------------------------------------')
    print('BEGINNING EXPERIMENTS...')
    print('vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv')
    print()

    # Make a generater to create prime number random seeds
    primes = prime_generator()

    # Set variables that carry state over experiments to None
    dataset_choice = None
    data = None
    train_gt = None
    test_gt = None
    dataset_info = {
        'name': None,
        'num_classes': None,
        'ignored_labels': None,
        'class_labels': None,
        'label_mapping': None,
    }
    train_dataset = None
    val_dataset = None
    test_dataset = None
    target_test = None
    band_selection_time = None
    bands_selected = None

    # Go through experiment iterations
    for iteration in range(experiment_number - 1, experiment_number - 1 + iterations):

        # Clean memory in each iteration (otherwise the machine may
        # randomly run out of memory if it is being pushed to its
        # limit)
        gc.collect()

        print('*******************************************************')
        print(f'<<< EXPERIMENT #{iteration+1}  STARTING >>>')
        print('*******************************************************')
        print()

        experiment_data = {
            'experiment_number': iteration + 1,
            'experiment_name': experiment_name,
            'success': False,
            'random_seed': None,
            'dataset': None,
            'band_reduction_method': None,
            'band_selection_time': None,
            'bands_selected': None,
            'channels': None,
            'model': None,
            'device': None,
            'epochs': None,
            'batch_size': None,
            'patch_size': None,
            'train_split': None,
            'optimizer': None,
            'learning_rate': None,
            'loss': None,
            'train_time': 0.0,
            'test_time': 0.0,
            'overall_accuracy': 0.0,
            'average_accuracy': 0.0,
            'precision_score': 0.0,
            'recall_score': 0.0,
            'cohen_kappa_score': 0.0,
        }

        per_class_data = {
            'experiment_number': iteration + 1,
            'experiment_name': experiment_name,
            'random_seed': None,
            'band_reduction_method': None,
            'band_selection_time': None,
            'bands_selected': None,
            'model': None,
            'train_time': 0.0,
            'test_time': 0.0,
            'overall_accuracy': 0.0,
            'average_accuracy': 0.0,
            'precision_score': 0.0,
            'recall_score': 0.0,
            'cohen_kappa_score': 0.0,
        }

        per_class_selected_bands = None

        experiments_results_file = f'{outfile_prefix}_results.csv'
        class_results_file = f'{outfile_prefix}__{dataset_choice}__class_results.csv'
        selected_bands_file = f'{outfile_prefix}__{dataset_choice}__selected_bands.csv'

        # Experiment has begun, so make sure to catch any failures that
        # may occur
        try:

            # If loading experiments from a file, get new set of hyperparams
            if experiments is not None:
                # Get hyperparameters from dictionary
                hyperparams = experiments.iloc[iteration].to_dict()

                experiment_name = experiments.index[iteration]

                hyperparams['experiment_name'] = experiment_name
                experiment_data['experiment_name'] = experiment_name
                per_class_data['experiment_name'] = experiment_name

                # Fill in any missing parameters with None
                for param in PARAMETER_LIST:
                    if param not in hyperparams:
                        hyperparams[param] = None

                if hyperparams['experiment_number'] > 0:
                    iteration = hyperparams['experiment_number'] - 1
                    experiment_data['experiment_number'] = hyperparams['experiment_number']
                    per_class_data['experiment_number'] = hyperparams['experiment_number']

                # Ignore the output path in the experiments, use the path
                # from command line arguments
                hyperparams['output_path'] = output_path


                print('<~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~>')
                print(f'EXPERIMENT NAME: {experiments.index[iteration]}')
                print('<~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~>')
                print()
            else:
                # Fill in any missing parameters with None
                for param in PARAMETER_LIST:
                    if param not in hyperparams:
                        hyperparams[param] = None

                experiment_name = hyperparams['experiment_name']

            # Save hyperparameters to experiment file if argument
            # has been given
            if hyperparams['save_experiment_path'] is not None:
                extension = Path(hyperparams['save_experiment_path']).suffix
                if extension == '.json':
                    experiments_params_df = pd.DataFrame.from_dict({experiment_name:hyperparams,}, orient='index')
                    experiments_params_df.to_json(hyperparams['save_experiment_path'], orient='index', indent=4)
                elif extension == '.csv':
                    # experiments_params_df = pd.DataFrame.from_dict(hyperparams)
                    experiments_params_df = pd.DataFrame.from_dict({experiment_name:hyperparams,}, orient='index')
                    # experiments_params_df.to_csv(hyperparams['save_experiment_path'])
                    experiments_params_df.to_csv(hyperparams['save_experiment_path'])
                else:
                    #TODO
                    pass


            # Print out parameters for experiment
            print('.......................................................')
            print('EXPERIMENT PARAMETERS')
            print('.......................................................')
            header = '{:<40} | {:<40}'.format('PARAMETER', 'VALUE')
            print(header)
            print('=' * len(header))
            for key in hyperparams:
                print('{:<40} | {:<40}'.format(key, str(hyperparams[key])))
                print('-' * len(header))
            print('-' * len(header))
            print('.......................................................')


            # Model checks
            if (hyperparams['model_id'] == 'fusion-fcn'
                and hyperparams['dataset'] != 'grss_dfc_2018'):
                print('Cannot use fusion-fcn model without the grss_dfc_2018 dataset!')
                exit(1)
            elif (hyperparams['model_id'] == 'fusion-fcn-v2'
                and hyperparams['dataset'] != 'grss_dfc_2018'):
                print('Cannot use fusion-fcn-v2 model without the grss_dfc_2018 dataset!')
                exit(1)
            elif (hyperparams['model_id'] == '3d-densenet-fusion'
                and hyperparams['dataset'] != 'grss_dfc_2018'):
                print('Cannot use 3d-densenet-fusion model without the grss_dfc_2018 dataset!')
                exit(1)

            # Initialize random seed for sampling function
            # Each random seed is a prime number, in order
            if hyperparams['random_seed'] is not None:
                seed = hyperparams['random_seed']
            else:
                seed = next(primes)
            print(f'< Iteration #{iteration} random seed: {seed} >')
            print()
            np.random.seed(seed)

            # Choose the appropriate device from the hyperparameters
            device = get_device(hyperparams['cuda'])

            if 'CPU' in device:
                device_name = cpu_name
            else:
                gpu_num = int(device.split(':')[-1])
                device_name = gpu_names[gpu_num]
                gpu = tf.config.list_physical_devices('GPU')[gpu_num]
                tf.config.experimental.set_memory_growth(gpu, True)
                # tf.config.experimental.set_virtual_device_configuration(
                #     gpu, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)]
                # )

            # Device has been selected, so do all possible computation with device
            with tf.device(device):
                reuse_last_dataset = hyperparams['reuse_last_dataset']
                if reuse_last_dataset and dataset_choice is not None:
                    print()
                    print(f'< Reusing last dataset: {dataset_choice} >')
                else:

                    reuse_last_dataset = False

                    print()
                    print('-------------------------------------------------------------------')
                    print('LOADING DATASET...')
                    print('vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv')

                    # Get dataset choice parameter
                    dataset_choice = hyperparams['dataset']
                    print()
                    print(f' < Dataset Chosen: {dataset_choice} >')
                    print()

                    # Make sure dataset is in per-class data list dictionary
                    if dataset_choice not in per_class_data_lists:
                        per_class_data_lists[dataset_choice] = []

                    # Make sure dataset is in per-class data list dictionary
                    if dataset_choice not in per_class_selected_band_lists:
                        per_class_selected_band_lists[dataset_choice] = []

                    # Get selected dataset
                    if not reuse_last_dataset:
                        if dataset_choice == 'grss_dfc_2018':
                            # Determine what parts of dataset to use
                            if (not hyperparams['use_hs_data']
                                and not hyperparams['use_lidar_ms_data']
                                and not hyperparams['use_lidar_ndsm_data']
                                and not hyperparams['use_vhr_data']
                                and not hyperparams['use_all_data']):

                                print('<!> No specific data selected, defaulting to using only hyperspectral data... <!>')
                                hyperparams['use_hs_data'] = True

                            data, train_gt, test_gt, dataset_info = load_grss_dfc_2018_uh_dataset(**hyperparams)
                        elif dataset_choice == 'indian_pines':
                            # data, train_gt, test_gt, dataset_info = load_indian_pines_dataset(**hyperparams)
                            print("Indian Pines dataset currently unimplemented")
                            exit(1)
                        elif dataset_choice == 'pavia_center':
                            # data, train_gt, test_gt, dataset_info = load_pavia_center_dataset(**hyperparams)
                            print("Pavia Center dataset currently unimplemented")
                            exit(1)
                        elif dataset_choice == 'university_of_pavia':
                            # data, train_gt, test_gt, dataset_info = load_university_of_pavia_dataset(**hyperparams)
                            print("University of Pavia dataset currently unimplemented")
                            exit(1)
                        else:
                            print('No dataset chosen! Defaulting to only hyperspectral bands of grss_dfc_2018...')
                            dataset_choice = 'grss_dfc_2018'
                            hyperparams['use_hs_data'] = True
                            data, train_gt, test_gt, dataset_info = load_grss_dfc_2018_uh_dataset(**hyperparams)
                    print(f"Dataset Info: {dataset_info}")

                    print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
                    print('DATASET LOADED!')
                    print('-------------------------------------------------------------------')
                    print()

                    if not hyperparams['skip_data_preprocessing']:
                        print('-------------------------------------------------------------------')
                        print('PREPROCESS THE DATA')
                        print('-------------------------------------------------------------------')
                        data = preprocess_data(data, **hyperparams)
                        print('-------------------------------------------------------------------')
                        print()

                    if not hyperparams['skip_band_selection']:
                        print('-------------------------------------------------------------------')
                        print('RUN BAND SELECTION ALGORITHM')
                        print('-------------------------------------------------------------------')
                        #TODO - allow per-modality band selection

                        if hyperparams['select_only_hs_bands']:
                            hs_channels = dataset_info['hs_channels']
                            if hs_channels is not None:
                                # Get the non-hyperspectral data so it
                                # can be appended to the reduced
                                # hyperspectral data later
                                non_hs_channels = [channel for channel in range(data.shape[-1]) if channel not in hs_channels]
                                non_hs_data = data[..., non_hs_channels]
                                data = data[..., hs_channels]

                                hs_channel_labels = [label for channel, label in enumerate(dataset_info['channel_labels']) if channel in hs_channels]
                                non_hs_channel_labels = [label for channel, label in enumerate(dataset_info['channel_labels']) if channel not in hs_channels]

                                data, band_selection_time, bands_selected = band_selection(data, train_gt, **hyperparams)

                                # Update the channel indices to reflect
                                # reduced HS data
                                num_hs_channels = data.shape[-1]
                                dataset_info['hs_channels'] = range(num_hs_channels)
                                dataset_info['lidar_ms_channels'] = [new_channel + num_hs_channels for new_channel, channel in enumerate(non_hs_channels) if channel in dataset_info['lidar_ms_channels']]
                                dataset_info['lidar_ndsm_channels'] = [new_channel + num_hs_channels for new_channel, channel in enumerate(non_hs_channels) if channel in dataset_info['lidar_ndsm_channels']]
                                dataset_info['vhr_rgb_channels'] = [new_channel + num_hs_channels for new_channel, channel in enumerate(non_hs_channels) if channel in dataset_info['vhr_rgb_channels']]

                                # Stack the non hyperspectral data onto
                                # the data array
                                data = np.dstack((data, non_hs_data))

                                if bands_selected is not None:
                                    # Initialize dictionary of band selection
                                    # data to output to file
                                    per_class_selected_bands = {
                                        'experiment_number': iteration + 1,
                                        'experiment_name': experiment_name,
                                        'random_seed': seed,
                                        'band_reduction_method': hyperparams['band_reduction_method'],
                                        'band_selection_time': band_selection_time,
                                        'model': None,
                                        'overall_accuracy': 0.0,
                                        'average_accuracy': 0.0,
                                        'precision_score': 0.0,
                                        'recall_score': 0.0,
                                        'cohen_kappa_score': 0.0,
                                    }


                                    # Write selected bands to band selection
                                    # dictionary
                                    for channel, label in enumerate(hs_channel_labels):
                                        if bands_selected is not None and channel in bands_selected:
                                            per_class_selected_bands[f'{label} (channel {channel})'] = True
                                        else:
                                            per_class_selected_bands[f'{label} (channel {channel})'] = False

                                    for channel, label in enumerate(non_hs_channel_labels):
                                        per_class_selected_bands[f'{label} (channel {channel + len(hs_channels)})'] = True

                            else:
                                print('There are no hyperspectral channels in this experiment! Skipping band selection...')

                        else:
                            data, band_selection_time, bands_selected = band_selection(data, train_gt, **hyperparams)

                            if bands_selected is not None:
                                per_class_selected_bands = {
                                    'experiment_number': iteration + 1,
                                    'experiment_name': experiment_name,
                                    'random_seed': seed,
                                    'band_reduction_method': hyperparams['band_reduction_method'],
                                    'band_selection_time': band_selection_time,
                                    'model': None,
                                    'overall_accuracy': 0.0,
                                    'average_accuracy': 0.0,
                                    'precision_score': 0.0,
                                    'recall_score': 0.0,
                                    'cohen_kappa_score': 0.0,
                                }

                                if dataset_info['channel_labels'] is not None:
                                    for channel, label in enumerate(dataset_info['channel_labels']):
                                        if bands_selected is not None and channel in bands_selected:
                                            per_class_selected_bands[f'{label} (channel {channel})'] = True
                                        else:
                                            per_class_selected_bands[f'{label} (channel {channel})'] = False

                                            # Remove any channels from
                                            # modality channel lists if they
                                            # are not selected
                                            if channel in dataset_info['hs_channels']:
                                                dataset_info['hs_channels'].remove(channel)
                                            if channel in dataset_info['lidar_ms_channels']:
                                                dataset_info['lidar_ms_channels'].remove(channel)
                                            if channel in dataset_info['lidar_ndsm_channels']:
                                                dataset_info['lidar_ndsm_channels'].remove(channel)
                                            if channel in dataset_info['vhr_rgb_channels']:
                                                dataset_info['vhr_rgb_channels'].remove(channel)

                        print('-------------------------------------------------------------------')
                        print()

                # Set dataset variables
                dataset_name = dataset_info['name']
                num_classes = dataset_info['num_classes']
                hs_channels = dataset_info['hs_channels']
                lidar_ms_channels = dataset_info['lidar_ms_channels']
                lidar_ndsm_channels = dataset_info['lidar_ndsm_channels']
                vhr_rgb_channels = dataset_info['vhr_rgb_channels']
                ignored_labels = dataset_info['ignored_labels']
                all_class_labels = dataset_info['class_labels']
                valid_class_labels = [label for index, label in enumerate(all_class_labels)
                                        if index not in ignored_labels]


                epochs = hyperparams['epochs']
                supervision = 'full'
                batch_size = hyperparams['batch_size']
                patch_size = hyperparams['patch_size']
                train_split = hyperparams['train_split']
                optimizer = hyperparams['optimizer']
                learning_rate = hyperparams['lr']
                loss = 'sparse_categorical_crossentropy'
                img_channels = data.shape[-1]
                img_rows = patch_size
                img_cols = patch_size

                if (hyperparams['model_id'] == 'fusion-fcn'
                    or hyperparams['model_id'] == 'fusion-fcn-v2'):
                    branch_1_channels = (*vhr_rgb_channels, *lidar_ms_channels)
                    branch_2_channels = (*lidar_ndsm_channels,)
                    branch_3_channels = (*hs_channels,)
                    input_channels = (branch_1_channels, branch_2_channels, branch_3_channels)
                    input_sizes = [len(input_channel) for input_channel in input_channels]
                elif (hyperparams['model_id'] == '3d-densenet-fusion'
                      or hyperparams['model_id'] == '3d-densenet-fusion2'
                      or hyperparams['model_id'] == '3d-densenet-fusion3'
                      or hyperparams['model_id'] == '3d-densenet-fusion4'):
                    # branch_1_channels = (*vhr_rgb_channels, )
                    # branch_2_channels = (*lidar_ms_channels, *lidar_ndsm_channels, )
                    # branch_3_channels = (*hs_channels, )
                    # input_channels = (branch_1_channels, branch_2_channels, branch_3_channels)
                    branch_1_channels = (*hs_channels, )
                    branch_2_channels = (*lidar_ms_channels, )
                    branch_3_channels = (*lidar_ndsm_channels, )
                    branch_4_channels = (*vhr_rgb_channels, )
                    input_channels = (branch_1_channels, branch_2_channels, branch_3_channels, branch_4_channels)
                    input_sizes = [len(input_channel) for input_channel in input_channels]
                elif hyperparams['model_id'] == '3d-densenet-modified':
                    if hyperparams['add_branch'] is not None:
                        input_channels = []
                        branch_list = []
                        for branch in hyperparams['add_branch']:
                            branch_channels = []
                            branch_modalities = []
                            for modality in str(branch).split(','):
                                if modality == 'hs' and (hyperparams['use_hs_data'] or hyperparams['use_all_data']):
                                    branch_channels += [*hs_channels, ]
                                elif modality == 'lidar_ms' and (hyperparams['use_lidar_ms_data'] or hyperparams['use_all_data']):
                                    branch_channels += [*lidar_ms_channels, ]
                                elif modality == 'lidar_ndsm' and (hyperparams['use_lidar_ndsm_data'] or hyperparams['use_all_data']):
                                    branch_channels += [*lidar_ndsm_channels, ]
                                elif modality == 'vhr_rgb' and (hyperparams['use_vhr_data'] or hyperparams['use_all_data']):
                                    branch_channels += [*vhr_rgb_channels, ]

                                branch_modalities.append(modality)
                            branch_list.append(branch_modalities)
                            if len(branch_channels) > 0:
                                input_channels.append(branch_channels)

                        for index, modalities in enumerate(branch_list):
                            print(f'Branch {index} modalities: {modalities}')

                        if len(input_channels) > 0:
                            input_sizes = [len(input_channel) for input_channel in input_channels]
                        else:
                            input_channels = None
                            input_sizes = [img_channels]

                    else:
                        input_channels = None
                        input_sizes = [img_channels]

                else:
                    input_channels = None
                    input_sizes = None

                # Check to see if model uses 3d convolutions - if so
                # then the input dimensions will need to be expanded
                # to include the 'planes' dimension
                if (hyperparams['model_id'] == '3d-densenet'
                    # or hyperparams['model_id'] == '3d-densenet-fusion'
                    or hyperparams['model_id'] == '3d-cnn'):
                    expand_dims = True
                else:
                    expand_dims = False

                # Add and update hyperparameters for model training
                hyperparams.update(
                    {
                        'random_seed': seed,
                        'n_classes': num_classes,
                        'n_bands': img_channels,
                        'all_class_labels': all_class_labels,
                        'ignored_labels': ignored_labels,
                        'device': device,
                        'supervision': supervision,
                        'center_pixel': True,
                        'one_hot_encoding': True,
                        'metrics': ['sparse_categorical_accuracy'],
                        'loss': loss,
                        'input_channels': input_channels,
                        'expand_dims': expand_dims,
                    }
                )

                # Update experiment data
                experiment_data.update({
                    'random_seed': seed,
                    'dataset': dataset_name,
                    'band_reduction_method': hyperparams['band_reduction_method'],
                    'band_selection_time': band_selection_time,
                    'bands_selected': bands_selected,
                    'channels': img_channels,
                    'device': device_name,
                    'epochs': epochs,
                    'batch_size': batch_size,
                    'patch_size': patch_size,
                    'train_split': train_split,
                    'optimizer': optimizer,
                    'learning_rate': learning_rate,
                    'loss': loss,
                })

                # Update per-class data for experiment
                per_class_data.update({
                    'random_seed': seed,
                    'band_reduction_method': hyperparams['band_reduction_method'],
                    'band_selection_time': band_selection_time,
                    'bands_selected': bands_selected,
                })
                for label in all_class_labels:
                    per_class_data[label] = 0.0

                if not reuse_last_dataset:
                    print('-------------------------------------------------------------------')
                    print('SPLIT DATA FOR TRAINING, VALIDATION, AND TESTING')
                    print('-------------------------------------------------------------------')

                    print('Breaking down image into data patches and splitting data into train, validation, and test sets...')
                    train_dataset, val_dataset, test_dataset = create_datasets(data, train_gt, test_gt, **hyperparams)
                    target_test = np.array(test_dataset.labels)
                    # datasets = create_datasets_v2(data, train_gt, test_gt, **hyperparams)

                    # train_dataset = (datasets['train_dataset'], datasets['train_steps'])
                    # val_dataset = (datasets['val_dataset'], datasets['val_steps'])
                    # test_dataset = (datasets['test_dataset'], datasets['test_steps'])
                    # target_test = datasets['target_test']

                    print('-------------------------------------------------------------------')
                    print()


                print('-------------------------------------------------------------------')
                print('CREATE MODEL')
                print('-------------------------------------------------------------------')

                # Create specified model
                if hyperparams['model_id'] == '2d-densenet':
                    model = densenet_2d_model(img_rows=img_rows,
                                    img_cols=img_cols,
                                    img_channels=img_channels,
                                    nb_classes=num_classes)
                elif hyperparams['model_id'] == '2d-densenet-multi':
                    pass    #TODO
                elif hyperparams['model_id'] == '3d-densenet':
                    model = densenet_3d_model(img_rows=img_rows,
                                    img_cols=img_cols,
                                    img_channels=img_channels,
                                    nb_classes=num_classes)
                elif hyperparams['model_id'] == '3d-densenet-modified':
                    model = densenet_3d_modified_model(img_rows=img_rows,
                                                       img_cols=img_cols,
                                                       img_channels_list=input_sizes,
                                                       nb_classes=num_classes,
                                                       num_dense_blocks=3,
                                                       growth_rate=32,
                                                       num_1x1_convs=0,
                                                       first_conv_filters=64,
                                                       first_conv_kernel=(5,5,5),
                                                       dropout_1=0.5,
                                                       dropout_2=0.5,
                                                       activation='leaky_relu')
                elif hyperparams['model_id'] == '3d-densenet-fusion':
                    model = densenet_3d_modified_model(img_rows=img_rows,
                                                       img_cols=img_cols,
                                                       img_channels_list=input_sizes,
                                                       nb_classes=num_classes,
                                                       num_dense_blocks=3,
                                                       growth_rate=32,
                                                       num_1x1_convs=0,
                                                       first_conv_filters=64,
                                                       first_conv_kernel=(5,5,5),
                                                       dropout_1=0.5,
                                                       dropout_2=0.5,
                                                       activation='leaky_relu')
                elif hyperparams['model_id'] == '3d-densenet-fusion2':
                    model = densenet_3d_fusion_model2(img_rows=img_rows,
                                                     img_cols=img_cols,
                                                     img_channels_list=[
                                                            len(hs_channels),
                                                            len(lidar_ms_channels),
                                                            len(lidar_ndsm_channels),
                                                            len(vhr_rgb_channels),
                                                     ],
                                                     nb_classes=num_classes,
                                                     num_dense_blocks=3)
                elif hyperparams['model_id'] == '3d-densenet-fusion3':
                    model = densenet_3d_fusion_model3(img_rows=img_rows,
                                                     img_cols=img_cols,
                                                     img_channels_list=[
                                                            len(hs_channels),
                                                            len(lidar_ms_channels),
                                                            len(lidar_ndsm_channels),
                                                            len(vhr_rgb_channels),
                                                     ],
                                                     nb_classes=num_classes,
                                                     num_dense_blocks=3)
                # elif hyperparams['model_id'] == '3d-densenet-fusion':
                #     model = densenet_3d_fusion_model(img_rows=img_rows,
                #                                      img_cols=img_cols,
                #                                      img_channels_1=len(vhr_rgb_channels),
                #                                      img_channels_2=len(lidar_ms_channels) + len(lidar_ndsm_channels),
                #                                      img_channels_3=len(hs_channels),
                #                                      nb_classes=num_classes,
                #                                      num_dense_blocks=3)
                # elif hyperparams['model_id'] == '3d-densenet-fusion2':
                #     model = densenet_3d_fusion_model2(img_rows=img_rows,
                #                                      img_cols=img_cols,
                #                                      img_channels_1=len(vhr_rgb_channels),
                #                                      img_channels_2=len(lidar_ms_channels) + len(lidar_ndsm_channels),
                #                                      img_channels_3=len(hs_channels),
                #                                      nb_classes=num_classes,
                #                                      num_dense_blocks=3)
                # elif hyperparams['model_id'] == '3d-densenet-fusion3':
                #     model = densenet_3d_fusion_model3(img_rows=img_rows,
                #                                      img_cols=img_cols,
                #                                      img_channels_1=len(vhr_rgb_channels),
                #                                      img_channels_2=len(lidar_ms_channels) + len(lidar_ndsm_channels),
                #                                      img_channels_3=len(hs_channels),
                #                                      nb_classes=num_classes,
                #                                      num_dense_blocks=3)
                # elif hyperparams['model_id'] == '3d-densenet-fusion4':
                #     model = densenet_3d_fusion_model4(img_rows=img_rows,
                #                                      img_cols=img_cols,
                #                                      img_channels_1=len(vhr_rgb_channels),
                #                                      img_channels_2=len(lidar_ms_channels) + len(lidar_ndsm_channels),
                #                                      img_channels_3=len(hs_channels),
                #                                      nb_classes=num_classes,
                #                                      num_dense_blocks=3)
                elif hyperparams['model_id'] == '2d-cnn':
                    model = cnn_2d_model(img_rows=img_rows,
                                    img_cols=img_cols,
                                    img_channels=img_channels,
                                    nb_classes=num_classes)
                elif hyperparams['model_id'] == '3d-cnn':
                    model = cnn_3d_model(img_rows=img_rows,
                                    img_cols=img_cols,
                                    img_channels=img_channels,
                                    nb_classes=num_classes)
                elif hyperparams['model_id'] == 'cnn-baseline':
                    filter_size = patch_size // 2 + 1
                    num_filters = img_channels * 2
                    model = baseline_cnn_model(img_rows=img_rows,
                                            img_cols=img_cols,
                                            img_channels=img_channels,
                                            patch_size=filter_size,
                                            nb_filters=num_filters,
                                            nb_classes=num_classes)
                elif hyperparams['model_id'] == 'nin':
                    model = nin_model(img_rows=img_rows,
                                    img_cols=img_cols,
                                    img_channels=img_channels,
                                    num_classes=num_classes)
                elif hyperparams['model_id'] == 'fusion-fcn':
                    branch_1_shape = (img_rows, img_cols,
                                len(lidar_ms_channels) + len(vhr_rgb_channels))
                    branch_2_shape = (img_rows, img_cols,
                                len(lidar_ndsm_channels))
                    branch_3_shape = (img_rows, img_cols, len(hs_channels))
                    model = fusion_fcn_model(
                                    branch_1_shape=branch_1_shape,
                                    branch_2_shape=branch_2_shape,
                                    branch_3_shape=branch_3_shape,
                                    nb_classes=num_classes)
                elif hyperparams['model_id'] == 'fusion-fcn-v2':
                    branch_1_shape = (img_rows, img_cols,
                                len(lidar_ms_channels) + len(vhr_rgb_channels))
                    branch_2_shape = (img_rows, img_cols,
                                len(lidar_ndsm_channels))
                    branch_3_shape = (img_rows, img_cols, len(hs_channels))
                    model = fusion_fcn_v2_model(
                                    branch_1_shape=branch_1_shape,
                                    branch_2_shape=branch_2_shape,
                                    branch_3_shape=branch_3_shape,
                                    nb_classes=num_classes)
                else:
                    print('<!> No model specified, defaulting to 3d-densenet <!>')
                    model = densenet_3d_model(img_rows=img_rows,
                                    img_cols=img_cols,
                                    img_channels=img_channels,
                                    nb_classes=num_classes)

                # Record model name for output
                experiment_data['model'] = model.name
                per_class_data['model'] = model.name
                if per_class_selected_bands is not None:
                    per_class_selected_bands['model'] = model.name

                if hyperparams['restore'] is not None:
                    print(f'Restoring {model.name} weights from {hyperparams["restore"]}')
                    model.load_weights(hyperparams['restore'])


                print('-------------------------------------------------------------------')
                print()

                if not hyperparams['predict_only'] or hyperparams['predict_only'] is None:
                    print('-------------------------------------------------------------------')
                    print('TRAIN MODEL')
                    print('-------------------------------------------------------------------')

                    # Run experiment on model
                    model, model_train_time = train_model(model=model,
                                                        train_dataset=train_dataset,
                                                        val_dataset=val_dataset,
                                                        iteration=iteration,
                                                        **hyperparams)
                else:
                    model_train_time = None

                print('-------------------------------------------------------------------')
                print('TEST MODEL')
                print('-------------------------------------------------------------------')

                pred_test, model_test_time = test_model(model=model,
                                                        test_dataset=test_dataset,
                                                        **hyperparams)

                if not hyperparams['skip_data_postprocessing']:
                    print('-------------------------------------------------------------------')
                    print('POSTPROCESS THE TEST RESULTS')
                    print('-------------------------------------------------------------------')

                    # Check whether pred_test is the right size
                    print(f'pred_test shape: {pred_test.shape}')
                    print(f'pred_test size:  {pred_test.size}')
                    print(f'test_gt size:    {test_gt.size}')
                    if pred_test.size != test_gt.size:
                        print('Error! pred_test and test_gt do not have same number of elements!')
                        print(f'       pred_test delta: {pred_test.size - test_gt.size} more elements')

                    # Reshape pred_test to original gt image size so that
                    # postprocessing can occur
                    pred_test = np.reshape(pred_test, test_gt.shape)
                    print(f'reshaped pred_test shape: {pred_test.shape}')
                    print(f'test_gt shape:            {test_gt.shape}')

                    pred_test = postprocess_data(pred_test, **hyperparams)
                    print('-------------------------------------------------------------------')
                    print()

                    # Remove ignored labels from target and predicted data
                    target_test, pred_test = filter_pred_results(test_gt, pred_test, ignored_labels)



                # Calculate the model performance statistics
                experiment_results = calculate_model_statistics(pred_test, target_test, all_class_labels, **hyperparams)
                experiment_results.update({
                    'experiment_name': experiment_name,
                    'model_name': model.name,
                    'model_train_time': model_train_time,
                    'model_test_time': model_test_time,
                })

                # Copy results to output data
                experiment_data['train_time'] = experiment_results['model_train_time']
                experiment_data['test_time'] = experiment_results['model_test_time']
                experiment_data['overall_accuracy'] = experiment_results['overall_accuracy']
                experiment_data['average_accuracy'] = experiment_results['average_accuracy']
                experiment_data['precision_score'] = experiment_results['precision_score']
                experiment_data['recall_score'] = experiment_results['recall_score']
                experiment_data['cohen_kappa_score'] = experiment_results['cohen_kappa_score']

                per_class_data['train_time'] = model_train_time
                per_class_data['test_time'] = model_test_time
                per_class_data['overall_accuracy'] = experiment_results['overall_accuracy']
                per_class_data['average_accuracy'] = experiment_results['average_accuracy']
                per_class_data['precision_score'] = experiment_results['precision_score']
                per_class_data['recall_score'] = experiment_results['recall_score']
                per_class_data['cohen_kappa_score'] = experiment_results['cohen_kappa_score']

                for index, acc in enumerate(experiment_results['per_class_accuracies']):
                    per_class_data[experiment_results['labels'][index]] = acc

                if per_class_selected_bands is not None:
                    per_class_selected_bands['overall_accuracy'] = experiment_results['overall_accuracy']
                    per_class_selected_bands['average_accuracy'] = experiment_results['average_accuracy']
                    per_class_selected_bands['precision_score'] = experiment_results['precision_score']
                    per_class_selected_bands['recall_score'] = experiment_results['recall_score']
                    per_class_selected_bands['cohen_kappa_score'] = experiment_results['cohen_kappa_score']

                # Output experimental results
                output_experiment_results(experiment_results)

                # Save image of confusion matrix
                create_confusion_matrix_plot(experiment_results['confusion_matrix'],
                                             all_class_labels,
                                             model.name,
                                             output_path = output_path,
                                             iteration=iteration)

                print('-------------------------------------------------------------------')
                print()

                experiment_data['success'] = True

        except Exception as e:
            print()
            print('###################################################')
            print('!!! EXCEPTION OCCURRED !!!')
            print('###################################################')
            print(f'Exception Type: {type(e)}')
            print(f'Exception Line: {e.__traceback__.tb_lineno}')
            print(f'Exception Desc: {e}')
            print()
            print('---------------------------------------------------')
            print('** Full Traceback **')
            print()
            # Print full exception
            traceback.print_exc()
            print('###################################################')
            print()

            # Write exception to file
            with open(os.path.join(output_path, f'experiment_{iteration+1}_exception.log'),'w') as ef:
                ef.write('\n')
                ef.write('###################################################\n')
                ef.write('!!! EXCEPTION OCCURRED !!!\n')
                ef.write('###################################################\n')
                ef.write(f'Exception Type: {type(e)}\n')
                ef.write(f'Exception Line: {e.__traceback__.tb_lineno}\n')
                ef.write(f'Exception Desc: {e}\n')
                ef.write('\n')
                ef.write('---------------------------------------------------\n')
                ef.write('** Full Traceback **\n')
                ef.write('\n')
                # Print full exception
                ef.write(f'{traceback.format_exc()}\n')
                ef.write('###################################################\n')
                ef.write('\n')

            print(f'Experiment #{iteration+1} crashed and thus failed!')

        experiment_data_list.append(experiment_data)
        per_class_data_lists[dataset_choice].append(per_class_data)
        if per_class_selected_bands is not None:
            per_class_selected_band_lists[dataset_choice].append(per_class_selected_bands)

        print()
        print('-------------------------------------------------------------------')
        print('SAVING RESULTS...')

        experiment_results = pd.DataFrame(experiment_data_list)
        experiment_results.set_index('experiment_number', inplace=True)
        experiment_results.to_csv(os.path.join(output_path, experiments_results_file))

        print('  >>> Experiment results saved!')

        for dataset_choice in per_class_data_lists:
            if len(per_class_data_lists[dataset_choice]) > 0:
                file_name = f'{outfile_prefix}__{dataset_choice}__class_results.csv'
                per_class_data_results = pd.DataFrame(per_class_data_lists[dataset_choice])
                per_class_data_results.set_index('experiment_number', inplace=True)
                per_class_data_results.to_csv(os.path.join(output_path, file_name))
                print(f'  >>> {dataset_choice} per-class results saved!')

        for dataset_choice in per_class_selected_band_lists:
            if len(per_class_selected_band_lists[dataset_choice]) > 0:
                file_name = f'{outfile_prefix}__{dataset_choice}__selected_band_results.csv'
                per_class_selected_band_results = pd.DataFrame(per_class_selected_band_lists[dataset_choice])
                per_class_selected_band_results.set_index('experiment_number', inplace=True)
                per_class_selected_band_results.to_csv(os.path.join(output_path, file_name))
                print(f'  >>> {dataset_choice} per-class selected band results saved!')

        print('RESULTS SAVED!')
        print('-------------------------------------------------------------------')

        print()
        print('*******************************************************')
        print(f'<<< EXPERIMENT #{iteration+1}  COMPLETE! >>>')
        print('*******************************************************')
        print()

    print()
    print()
    print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
    print('EXPERIMENTS COMPLETE!')
    print('-------------------------------------------------------------------')
    print('-------------------------------------------------------------------')
    print('-------------------------------------------------------------------')
    print()

# 12) Experiments

## 12.1) Set Parameters For Experiment

In [None]:
################################################################
#@title EXPERIMENT HYPERPARAMETERS
################################################################
hyperparameters = {}

# Name to use for the experiment and output files
hyperparameters["experiment_name"] = None

# The numerical identifier of this experiment (i.e. the sequence number of this experiment)
hyperparameters["experiment_number"] = 1

# File path to save the experimental parameters to
hyperparameters["save_experiment_path"] = None

# Specify CUDA device (-1 learns on CPU)
hyperparameters["cuda"] = 0

# Number of runs of experiment
hyperparameters["runs"] = 1

# Path to file containing weights to use for initialization, e.g. a checkpoint
hyperparameters["restore"] = None

# Path to where output files should be created
hyperparameters["output_path"] = "./"

# The hyperspectral or data fusion dataset to use for experiments
hyperparameters["dataset"] = 'grss_dfc_2018'

# The path to the dataset directory
hyperparameters["path_to_dataset"] = "../datasets/grss_dfc_2018"

# Reuse the last dataset generator
hyperparameters["reuse_last_dataset"] = False

# Skip training and only do prediction on the model
hyperparameters["predict_only"] = False

# Skip the data preprocessing step
hyperparameters["skip_data_preprocessing"] = False

# Skip the band selection step
hyperparameters["skip_band_selection"] = True

# Skip the data postprocessing step
hyperparameters["skip_data_postprocessing"] = True

# The identifier for the machine learning model to use on the dataset
# (2d-cnn, 3d-cnn, cnn-baseline,
#  2d-densenet, 2d-densenet-multi, 3d-densenet, 3d-densenet-modified,
#  3d-densenet-fusion, 3d-densenet-fusion2, 3d-densenet-fusion3, 3d-densenet-fusion4,
#  nin, fusion-fcn, fusion-fcn-v2, )
hyperparameters["model_id"] = "3d-densenet"

# Add a branch to the machine learning model, with the branch modalities
# as a comma-separated string after the argument (ex. hs,vhr_rgb)
# [modalities: hs, lidar_ms, lidar_ndsm, vhr_rgb]
hyperparameters['add_branch'] = ["hs"]

In [None]:
################################################################
#@title TRAINING OPTION PARAMETERS
################################################################

# Random number generator seed.
hyperparameters["random_seed"] = 123

# Training epochs
hyperparameters["epochs"] = 10

# Number of training epochs to pass before learning rate decay
hyperparameters["epochs_before_decay"] = 10

# Batch size
hyperparameters["batch_size"] = 32

# Size of the spatial neighborhood [e.g. patch_size X patch_size square]
hyperparameters["patch_size"] = 15

# Uses the label of the center pixel when training
hyperparameters["center_pixel"] = True

# The amount of samples set aside for training during validation split
hyperparameters["train_split"] = 0.60

# The mode by which to split datasets (random, fixed, or disjoint)
hyperparameters["split_mode"] = "fixed"

# Inverse median frequency class balancing
hyperparameters["class_balancing"] = False

# Number of iterations to run the model for
hyperparameters["iterations"] = None

# Number of epochs without improvement before stopping training
hyperparameters["patience"] = 3

# The number of epochs to pass before saving model again
hyperparameters["model_save_period"] = None



In [None]:
################################################################
#@title MODEL OPTIMIZER PARAMETERS
################################################################

# The optimizer used by the machine learning model
hyperparameters["optimizer"] = "nadam"

# The model's learning rate
hyperparameters["lr"] = 0.00005

# The percentage rate at which the model's learning rate decays
hyperparameters["lr_decay_rate"] = 0.95

# The optimizer's momentum, if applicable
hyperparameters["momentum"] = None

# The optimizer's epsilon value, if applicable
hyperparameters["epsilon"] = None

# The optimizer's initial_accumulator_value value, if applicable
hyperparameters["initial_accumulator_value"] = None

# The optimizer's beta value, if applicable (Ftrl only)
hyperparameters["beta"] = None

# The optimizer's beta_1 value, if applicable
hyperparameters["beta_1"] = None

# The optimizer's beta_2 value, if applicable
hyperparameters["beta_2"] = None

# The optimizer's amsgrad value, if applicable
hyperparameters["amsgrad"] = None

# The optimizer's rho value, if applicable
hyperparameters["rho"] = None

# The optimizer's centered value, if applicable
hyperparameters["centered"] = None

# The optimizer's nesterov value, if applicable
hyperparameters["nesterov"] = None

# The optimizer's learning_rate_power value, if applicable
hyperparameters["learning_rate_power"] = None

# The optimizer's l1_regularization_strength value, if applicable
hyperparameters["l1_regularization_strength"] = None

# The optimizer's l2_regularization_strength value, if applicable
hyperparameters["l2_regularization_strength"] = None

# The optimizer's l2_shrinkage_regularization_strength value, if applicable
hyperparameters["l2_shrinkage_regularization_strength"] = None

In [None]:
################################################################
#@title DATA AUGMENTATION PARAMETERS
################################################################

# Random flips (if patch_size > 1)
hyperparameters["flip_augmentation"] = False

# Random radiation noise (illumination)
hyperparameters["radiation_augmentation"] = False

# Random mixes between spectra
hyperparameters["mixture_augmentation"] = False

In [None]:
################################################################
#@title GRSS DATA FUSION CONTEST 2018 DATASET PARAMETERS
################################################################

# Load Hyperspectral data for this experiment
hyperparameters["use_hs_data"] = True

# Load lidar multispectral intensity data for this experiment
hyperparameters["use_lidar_ms_data"] = True

# Load lidar NDSM data for this experiment
hyperparameters["use_lidar_ndsm_data"] = True

# Load very high resolution RGB data for this experiment
hyperparameters["use_vhr_data"] = True

# Load all data sources for this experiment
hyperparameters["use_all_data"] = False

# Normalize hyperspectral data
hyperparameters["normalize_hs_data"] = False

# Normalize LiDAR multispectral data
hyperparameters["normalize_lidar_ms_data"] = False

# Normalize LiDAR NDSM data
hyperparameters["normalize_lidar_ndsm_data"] = False

# Normalize VHR RGB data
hyperparameters["normalize_vhr_data"] = False

# Resampling method to use on the grss_dfc_2018 hyperspectral image
# (nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
#  gauss, max, min, med, q1, q3, rms)
hyperparameters["hs_resampling"] = "average"

# Resampling method to use on the grss_dfc_2018 LiDAR multispectral image
# (nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
#  gauss, max, min, med, q1, q3, rms)
hyperparameters["lidar_ms_resampling"] = "average"

# Resampling method to use on the grss_dfc_2018 LiDAR NDSM image
# (nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
#  gauss, max, min, med, q1, q3, rms)
hyperparameters["lidar_ndsm_resampling"] = "average"

# Resampling method to use on the grss_dfc_2018 VHR RGB image
# (nearest, bilinear, cubic, cubic_spline, lanczos, average, mode,
#  gauss, max, min, med, q1, q3, rms)
hyperparameters["vhr_resampling"] = "cubic_spline"

# Perform data equalization based on layer histogram for grss_dfc_2018 hyperspectral image
hyperparameters["hs_histogram_equalization"] = False

# Perform data equalization based on layer histogram for grss_dfc_2018 LiDAR multispectral image
hyperparameters["lidar_ms_histogram_equalization"] = False

# Perform data equalization based on layer histogram for grss_dfc_2018 LiDAR DSM image
hyperparameters["lidar_dsm_histogram_equalization"] = False

# Perform data equalization based on layer histogram for grss_dfc_2018 LiDAR DEM image
hyperparameters["lidar_dem_histogram_equalization"] = False

# Perform data equalization based on layer histogram for grss_dfc_2018 LiDAR NDSM image
hyperparameters["lidar_ndsm_histogram_equalization"] = False

# Perform data equalization based on layer histogram for grss_dfc_2018 VHR RGB image
hyperparameters["vhr_histogram_equalization"] = False

# Filtering method to use on the grss_dfc_2018 hyperspectral image
# (median, gaussian)
hyperparameters["hs_data_filter"] = None

# Filtering method to use on the grss_dfc_2018 LiDAR multispectral image
# (median, gaussian)
hyperparameters["lidar_ms_data_filter"] = "median"

# Filtering method to use on the grss_dfc_2018 LiDAR DSM image
# (median, gaussian)
hyperparameters["lidar_dsm_data_filter"] = "gaussian"

# Filtering method to use on the grss_dfc_2018 LiDAR DEM image
# (median, gaussian)
hyperparameters["lidar_dem_data_filter"] = "gaussian"

# Filtering method to use on the grss_dfc_2018 VHR RGB image
# (median, gaussian)
hyperparameters["vhr_data_filter"] = None

In [None]:
################################################################
#@title BAND SELECTION PARAMETERS
################################################################

# The band dimensionality reduction method to be used
hyperparameters["band_reduction_method"] = None

# The number of components to be used with the band reduction method
hyperparameters["n_components"] = None

# A list of channels indices for manual band selection (each channel separated by spaces)
hyperparameters["selected_bands"] = None

# Only perform bands selection on the hyperspectral data
hyperparameters["select_only_hs_bands"] = False

## 12.2) Run Experiment

In [None]:
# Start timing experiments
test_harness_start = time.time()

# Run the test harness
run_test_harness(**hyperparameters)

test_harness_end = time.time()
test_harness_runtime = datetime.timedelta(seconds=(test_harness_end - test_harness_start))

print(f' < Total Test Harness Runtime: {test_harness_runtime} >')



# References and Citations
1. Fernandez-Diaz, Juan Carlos, and Ramesh L. Shrestha. “Data Collection &amp; Processing Report.” 2018 IEEE GRSS Data Fusion Challenge – Fusion of Multispectral LiDAR and Hyperspectral Data, University of Houston, 2017, https://hyperspectral.ee.uh.edu/2018IEEEDocs/DataReport.pdf.

2. Saurabh Prasad, Bertrand Le Saux, Naoto Yokoya, Ronny Hansch, December 18, 2020, "2018 IEEE GRSS Data Fusion Challenge – Fusion of Multispectral LiDAR and Hyperspectral Data", IEEE Dataport, doi: https://dx.doi.org/10.21227/jnh9-nz89.

3. Zhang C, Li G, Du S, et al. Three-dimensional densely connected convolutional network for hyperspectral remote sensing image classification[J]. Journal of Applied Remote Sensing, 2019, 13(1): 016519.

4. Leeguandong. “Leeguandong/3D-Densenet-for-HSI: Paper：Three-Dimensional Densely Connected Convolutional Network for Hyperspectral Remote Sensing Image Classification.” GitHub, Journal of Applied Remote Sensing, 13(1), 3 Feb. 2019, https://github.com/leeguandong/3D-DenseNet-for-HSI.