<a href="https://colab.research.google.com/github/fahim1703061/Hyperspectral-Image-Processing/blob/main/Research_data_preprocess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/ringochuchudull/CNN-Based-Hyperspectral-Image-Classification

In [None]:
def getdtype(t):
    import numpy as np
    if t == 'float64':
        return np.float64
    elif t == 'float32':
        return np.float32
    elif t == 'float16':
        return np.float16
    elif t == 'int64':
        return np.int64
    elif t == 'int32':
        return np.int32
    elif t == 'int16':
        return np.int16
    elif t == 'int8':
        return np.int8
    else:
        # Default value
        return np.float64

#Get Dataset
def maybeExtract(data, patch_size):
    import scipy.io
    try:
        TRAIN = scipy.io.loadmat("./data/" + data + "_Train_patch_" + str(patch_size) + ".mat")
        VALIDATION = scipy.io.loadmat("./data/" + data + "_Val_patch_" + str(patch_size) + ".mat")
        TEST = scipy.io.loadmat("./data/" + data + "_Test_patch_" + str(patch_size) + ".mat")

    except:
        raise Exception('--data options are: Indian_pines, Salinas, KSC, Botswana OR data files not existed')

    return TRAIN, VALIDATION, TEST


def maybeDownloadOrExtract(data):
    import scipy.io as io
    import os
    # Somehow this is necessary, even I cannot tell why -_-
    if data in ('KSC', 'Botswana'):
        filename = data
    else:
        filename = data.lower()

    print("Dataset: " + filename)

    try:
        print("Try using images from Data folder...")
        input_mat = io.loadmat('./data/' + data + '.mat')[filename]
        target_mat = io.loadmat('./data/' + data + '_gt.mat')[filename + '_gt']

    except:
        print("Data not found, downloading input images and labelled images!\n\n")
        if data == "Indian_pines":
            url1 = "http://www.ehu.eus/ccwintco/uploads/2/22/Indian_pines.mat"
            url2 = "http://www.ehu.eus/ccwintco/uploads/c/c4/Indian_pines_gt.mat"

        elif data == "Salinas":
            url1 = "http://www.ehu.eus/ccwintco/uploads/f/f1/Salinas.mat"
            url2 = "http://www.ehu.eus/ccwintco/uploads/f/fa/Salinas_gt.mat"

        elif data == "KSC":
            url1 = "http://www.ehu.eus/ccwintco/uploads/2/26/KSC.mat"
            url2 = "http://www.ehu.eus/ccwintco/uploads/a/a6/KSC_gt.mat"

        elif data == "Botswana":
            url1 = "http://www.ehu.eus/ccwintco/uploads/7/72/Botswana.mat"
            url2 = "http://www.ehu.eus/ccwintco/uploads/5/58/Botswana_gt.mat"

        else:
            raise Exception("Available datasets are:: Indian_pines, Salinas, KSC, Botswana")

        os.system('wget -P' + ' ' + './data/' + ' ' + url1)
        os.system('wget -P' + ' ' + './data/' + ' ' + url2)

        input_mat = io.loadmat('./data/' + data + '.mat')[filename]
        target_mat = io.loadmat('./data/' + data + '_gt.mat')[filename + '_gt']

    return input_mat, target_mat


def getListLabel(data):
    if data == 'Indian_pines':
        return [2, 3, 4, 5, 6, 8, 10, 11, 12, 14, 15]

    elif data == 'Salinas':
        return list(range(1,16+1))

    elif data == 'Botswana':
        return [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,  13, 14]

    elif data == 'KSC':
        return [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

    else:
        raise Exception("Type error")



def OnehotTransform(labels):
    import numpy as np
    from sklearn.preprocessing import OneHotEncoder
    onehot_encoder = OneHotEncoder(sparse=False)

    labels = np.reshape(labels, (len(labels), 1))
    labels = onehot_encoder.fit_transform(labels).astype(np.uint8)

    return labels



def getTestDataset(test, test_label, size=250):
    '''
    Arguments: whole test data, test label,
    return randomized test data, test label of 'size'
    '''
    from numpy import array
    from random import shuffle

    assert test.shape[0] == test_label.shape[0]

    idx = list(range(test.shape[0]))
    shuffle(idx)
    idx = idx[:size]
    accuracy_x, accuracy_y = [], []
    for i in idx:
        accuracy_x.append(test[i])
        accuracy_y.append(test_label[i])

    return array(accuracy_x), array(accuracy_y)


def plot_random_spec_img(pic, true_label):
    '''
    Take first hyperspectral image from dataset and plot spectral data distribution
    Arguements pic = list of images in size (?, height, width, bands), where ? represents any number > 0
                true_labels = lists of ground truth corrospond to pic
    '''
    pic = pic[0]  #Take first data only
    from matplotlib import pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from numpy import mean, argmax

    print("Image Shape: " + str(pic.shape) )
    print("Label of this image is -> " + str(true_label[0] ) )

    title = argmax(true_label[0], axis=0)
    # Calculate mean of all elements in the 3d element
    mean_value = mean(pic)
    # Replace element with less than mean by zero
    pic[pic < mean_value] = 0
    
    x = []
    y = []
    z = []
    # Coordinate position extractions
    for z1 in range(pic.shape[0]): 
        for x1 in range(pic.shape[1]):
            for y1 in range(pic.shape[2]):
                if pic[z1,x1,y1] != 0:
                    z.append(z1)
                    x.append(x1)
                    y.append(y1)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_title('True class = '+ str(title))
    ax.scatter(x, y, z, color='#0606aa', marker='o', s=0.5)
    ax.set_xlabel('X Label')
    ax.set_ylabel('Spectral Label')
    ax.set_zlabel('Y Label')
    plt.show()


def GroundTruthVisualise(data):
    from matplotlib.pyplot import imshow, show, colorbar, set_cmap
    imshow(data)
    set_cmap('tab20b')
    colorbar()
    show()


# Arguement: data = 3D image in size (h,w,bands)
def plotStatlieImage(data, bird=False):
    from matplotlib.pyplot import imshow, show, subplots, axis, figure
    print('\nPlotting a band image')
    fig, ax = subplots(nrows=3, ncols=3)
    i = 1
    for row in ax:
        for col in row:
            i += 11
            if bird:
                col.imshow(data[i,:,:])
            else:
                col.imshow(data[:,:,i])
            axis('off')
    show()


def showClassTable(number_of_list, title='Number of samples'):
    import pandas as pd 
    print("\n+------------Show Table---------------+")
    lenth = len(number_of_list)
    column1 = range(1, lenth+1)
    table = {'Class#': column1, title: number_of_list}
    table_df = pd.DataFrame(table).to_string(index=False)
    print(table_df)   
    print("+-----------Close Table---------------+")



# This section here is for debugs only
if __name__ == '__main__':
    pass

In [None]:
import numpy as np
from random import shuffle
import scipy.io as io
import argparse
from helper import *
import threading
import time
import itertools
import sys

parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument('--data', type=str, default='Indian_pines', help='Default: Indian_pines, options: Salinas, KSC, Botswana')
parser.add_argument('--train_ratio', type=float, default=0.2)
parser.add_argument('--validation_ratio', type=float, default=0.05)
parser.add_argument('--channel_first', type=bool, default=False, help='Image channel located on the last dimension')
parser.add_argument('--dtype', type=str, default='float32', help='Data type (Eg float64, float32, float16, int64...')
parser.add_argument('--plot', type=bool, default=False, help='TRUE to plot satellite images and ground truth at the end')
opt = parser.parse_args()

# Try loading data from the folder... Otherwise download from online
input_mat, target_mat = maybeDownloadOrExtract(opt.data)

# Output data type
datatype = getdtype(opt.dtype)
HEIGHT = input_mat.shape[0]
WIDTH = input_mat.shape[1]
BAND = input_mat.shape[2]
OUTPUT_CLASSES = np.max(target_mat)
PATCH_SIZE = 5

CHANNEL_FIRST = opt.channel_first

# Normalize image data and select datatype
input_mat = input_mat.astype(datatype)
input_mat = input_mat - np.min(input_mat)
input_mat = input_mat / np.max(input_mat)

# Extract a list that contains the class number with sufficient training samples
list_labels = getListLabel(opt.data)

# For showing a animation only
end_loading = False
def animate():
    global end_loading
    for c in itertools.cycle(['|', '/', '-', '\\']):
        if end_loading:
            break
        sys.stdout.write('\rExtracting '+ opt.data + ' dataset features...' + c)
        sys.stdout.flush()
        time.sleep(0.1)
        sys.stdout.write('\rFinished!\t')

print("+-------------------------------------+")
print('Input_mat shape: ' + str(input_mat.shape))

MEAN_ARRAY = np.ndarray(shape=(BAND, 1))
new_input_mat = []

input_mat = np.transpose(input_mat, (2, 0, 1))

calib_val_pad = int((PATCH_SIZE - 1)/2)
for i in range(BAND):
    MEAN_ARRAY[i] = np.mean(input_mat[i, :, :])
    new_input_mat.append(np.pad(input_mat[i, :, :], calib_val_pad, 'constant', constant_values=0))

input_mat = np.array(new_input_mat)

def Patch(height_index, width_index):

    # Input:
    # Given the index position (x,y) of spatio dimension of the hyperspectral image,

    # Output:
    # a data cube with patch size S (24 neighbours), with label based on central pixel

    height_slice = slice(height_index, height_index+PATCH_SIZE)
    width_slice = slice(width_index, width_index+PATCH_SIZE)

    patch = input_mat[:, height_slice, width_slice]
    mean_normalized_patch = []
    for i in range(patch.shape[0]):
        mean_normalized_patch.append(patch[i] - MEAN_ARRAY[i])

    return np.array(mean_normalized_patch).astype(datatype)


# Assign empty array to store patched images
CLASSES = []
for i in range(OUTPUT_CLASSES):
    CLASSES.append([])

# Assign empty array to count samples in each class
class_label_counter = [0] * OUTPUT_CLASSES

# Start timing for loading
t = threading.Thread(target=animate).start()
start = time.time()

count = 0
for i in range(HEIGHT-1):
    for j in range(WIDTH-1):
        curr_inp = Patch(i, j)
        curr_tar = target_mat[i, j]

        if curr_tar:
            CLASSES[curr_tar-1].append(curr_inp)
            class_label_counter[curr_tar-1] += 1
            count += 1

end_loading = True
end = time.time()
print("Total excution time..." + str(end-start)+'seconds')
print('Total number of samples: ' + str(count))
showClassTable(class_label_counter)

TRAIN_PATCH, TRAIN_LABELS = [], []
TEST_PATCH, TEST_LABELS =[], []
VAL_PATCH, VAL_LABELS = [], []

train_ratio = opt.train_ratio
val_ratio = opt.validation_ratio
# test_ratio = reminder of data

counter = 0  # Represent train_index position
for i, data in enumerate(CLASSES):
    datasize = []
    if i + 1 in list_labels:

        shuffle(data)
        print('Class ' + str(i + 1) + ' is accepted')

        size = round(class_label_counter[i]*train_ratio)

        TRAIN_PATCH += data[:size]
        TRAIN_LABELS += [counter] * size
        datasize.append(size)

        size1 = round(class_label_counter[i]*val_ratio)
        VAL_PATCH += data[size:size+size1]
        VAL_LABELS += [counter] * (size1)
        datasize.append(size1)

        TEST_PATCH += data[size+size1:]
        TEST_LABELS += [counter] * len(data[size+size1:])
        datasize.append(len(TEST_PATCH))

        counter += 1

    else:
        print('-Class ' + str(i + 1) + ' is rejected due to insufficient samples')

TRAIN_LABELS = np.array(TRAIN_LABELS)
TRAIN_PATCH = np.array(TRAIN_PATCH)
TEST_PATCH = np.array(TEST_PATCH)
TEST_LABELS = np.array(TEST_LABELS)
VAL_PATCH = np.array(VAL_PATCH)
VAL_LABELS = np.array(VAL_LABELS)

print("+-------------------------------------+")
print("Size of Training data: " + str(len(TRAIN_PATCH)) )
print("Size of Validation data: " + str(len(VAL_PATCH))  )
print("Size of Testing data: " + str(len(TEST_PATCH)) )
print("+-------------------------------------+")


train_idx = list(range(len(TRAIN_PATCH)))
shuffle(train_idx)
TRAIN_PATCH = TRAIN_PATCH[train_idx]
if not CHANNEL_FIRST:
    TRAIN_PATCH = np.transpose(TRAIN_PATCH, (0, 2, 3, 1))
TRAIN_LABELS = OnehotTransform(TRAIN_LABELS[train_idx])
train = {}
train["train_patch"] = TRAIN_PATCH
train["train_labels"] = TRAIN_LABELS
io.savemat("./data/" + opt.data + "_Train_patch_" + str(PATCH_SIZE) + ".mat", train)


test_idx = list(range(len(TEST_PATCH)))
shuffle(test_idx)
TEST_PATCH = TEST_PATCH[test_idx]
if not CHANNEL_FIRST:
    TEST_PATCH = np.transpose(TEST_PATCH, (0, 2, 3, 1))
TEST_LABELS = OnehotTransform(TEST_LABELS[test_idx])
test = {}
test["test_patch"] = TEST_PATCH
test["test_labels"] = TEST_LABELS
io.savemat("./data/" + opt.data + "_Test_patch_" + str(PATCH_SIZE) + ".mat", test)


val_idx = list(range(len(VAL_PATCH)))
shuffle(val_idx)
VAL_PATCH = VAL_PATCH[val_idx]
if not CHANNEL_FIRST:
    VAL_PATCH = np.transpose(VAL_PATCH, (0, 2, 3, 1))
    print(VAL_PATCH.shape)
VAL_LABELS = OnehotTransform(VAL_LABELS[val_idx])
val = {}
val["val_patch"] = VAL_PATCH
val["val_labels"] = VAL_LABELS
io.savemat("./data/" + opt.data + "_Val_patch_" + str(PATCH_SIZE) + ".mat", val)

print("+-------------------------------------+")
print("Summary")
print('Train_patch.shape: '+ str(TRAIN_PATCH.shape) )
print('Train_label.shape: '+ str(TRAIN_LABELS.shape) )
print('Test_patch.shape: ' + str(TEST_PATCH.shape))
print('Test_label.shape: ' + str(TEST_LABELS.shape))
print("Validation batch Shape: " + str(VAL_PATCH.shape) )
print("Validation label Shape: " + str(VAL_LABELS.shape) )
print("+-------------------------------------+")
print("\nFinished processing.......")


if opt.plot:
    print('\n Looking at some sample images')
    plot_random_spec_img(TRAIN_PATCH, TRAIN_LABELS)
    plot_random_spec_img(TEST_PATCH, TEST_LABELS)
    plot_random_spec_img(VAL_PATCH, VAL_LABELS)

    GroundTruthVisualise(target_mat)

Dataset: indian_pines
Try using images from Data folder...
Data not found, downloading input images and labelled images!


+-------------------------------------+
Input_mat shape: (145, 145, 220)
Extracting Indian_pines dataset features.../Total excution time...13.840296745300293seconds
Total number of samples: 10249
Finished!	
+------------Show Table---------------+
 Class#  Number of samples
      1                 46
      2               1428
      3                830
      4                237
      5                483
      6                730
      7                 28
      8                478
      9                 20
     10                972
     11               2455
     12                593
     13                205
     14               1265
     15                386
     16                 93
+-----------Close Table---------------+
-Class 1 is rejected due to insufficient samples
Class 2 is accepted
Class 3 is accepted
Class 4 is accepted
Class 5 is accepted
Cl

In [None]:
import numpy as np
from random import shuffle
import scipy.io as io
import argparse
from helper import *
import threading
import time
import itertools
import sys

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='Indian_pines', help='Default: Indian_pines, options: Salinas, KSC, Botswana')
parser.add_argument('--train_ratio', type=float, default=0.2)
parser.add_argument('--validation_ratio', type=float, default=0.05)
parser.add_argument('--channel_first', type=bool, default=False, help='Image channel located on the last dimension')
parser.add_argument('--dtype', type=str, default='float32', help='Data type (Eg float64, float32, float16, int64...')
parser.add_argument('--plot', type=bool, default=False, help='TRUE to plot satellite images and ground truth at the end')
opt = parser.parse_args()

# Try loading data from the folder... Otherwise download from online
input_mat, target_mat = maybeDownloadOrExtract(opt.data)

# Output data type
datatype = getdtype(opt.dtype)
HEIGHT = input_mat.shape[0]
WIDTH = input_mat.shape[1]
BAND = input_mat.shape[2]
OUTPUT_CLASSES = np.max(target_mat)
PATCH_SIZE = 5

CHANNEL_FIRST = opt.channel_first

# Normalize image data and select datatype
input_mat = input_mat.astype(datatype)
input_mat = input_mat - np.min(input_mat)
input_mat = input_mat / np.max(input_mat)

# Extract a list that contains the class number with sufficient training samples
list_labels = getListLabel(opt.data)

# For showing a animation only
end_loading = False
def animate():
    global end_loading
    for c in itertools.cycle(['|', '/', '-', '\\']):
        if end_loading:
            break
        sys.stdout.write('\rExtracting '+ opt.data + ' dataset features...' + c)
        sys.stdout.flush()
        time.sleep(0.1)
        sys.stdout.write('\rFinished!\t')

print("+-------------------------------------+")
print('Input_mat shape: ' + str(input_mat.shape))

MEAN_ARRAY = np.ndarray(shape=(BAND, 1))
new_input_mat = []

input_mat = np.transpose(input_mat, (2, 0, 1))

calib_val_pad = int((PATCH_SIZE - 1)/2)
for i in range(BAND):
    MEAN_ARRAY[i] = np.mean(input_mat[i, :, :])
    new_input_mat.append(np.pad(input_mat[i, :, :], calib_val_pad, 'constant', constant_values=0))

input_mat = np.array(new_input_mat)

def Patch(height_index, width_index):

    # Input:
    # Given the index position (x,y) of spatio dimension of the hyperspectral image,

    # Output:
    # a data cube with patch size S (24 neighbours), with label based on central pixel

    height_slice = slice(height_index, height_index+PATCH_SIZE)
    width_slice = slice(width_index, width_index+PATCH_SIZE)

    patch = input_mat[:, height_slice, width_slice]
    mean_normalized_patch = []
    for i in range(patch.shape[0]):
        mean_normalized_patch.append(patch[i] - MEAN_ARRAY[i])

    return np.array(mean_normalized_patch).astype(datatype)


# Assign empty array to store patched images
CLASSES = []
for i in range(OUTPUT_CLASSES):
    CLASSES.append([])

# Assign empty array to count samples in each class
class_label_counter = [0] * OUTPUT_CLASSES

# Start timing for loading
t = threading.Thread(target=animate).start()
start = time.time()

count = 0
for i in range(HEIGHT-1):
    for j in range(WIDTH-1):
        curr_inp = Patch(i, j)
        curr_tar = target_mat[i, j]

        if curr_tar:
            CLASSES[curr_tar-1].append(curr_inp)
            class_label_counter[curr_tar-1] += 1
            count += 1

end_loading = True
end = time.time()
print("Total excution time..." + str(end-start)+'seconds')
print('Total number of samples: ' + str(count))
showClassTable(class_label_counter)

TRAIN_PATCH, TRAIN_LABELS = [], []
TEST_PATCH, TEST_LABELS =[], []
VAL_PATCH, VAL_LABELS = [], []

train_ratio = opt.train_ratio
val_ratio = opt.validation_ratio
# test_ratio = reminder of data

counter = 0  # Represent train_index position
for i, data in enumerate(CLASSES):
    datasize = []
    if i + 1 in list_labels:

        shuffle(data)
        print('Class ' + str(i + 1) + ' is accepted')

        size = round(class_label_counter[i]*train_ratio)

        TRAIN_PATCH += data[:size]
        TRAIN_LABELS += [counter] * size
        datasize.append(size)

        size1 = round(class_label_counter[i]*val_ratio)
        VAL_PATCH += data[size:size+size1]
        VAL_LABELS += [counter] * (size1)
        datasize.append(size1)

        TEST_PATCH += data[size+size1:]
        TEST_LABELS += [counter] * len(data[size+size1:])
        datasize.append(len(TEST_PATCH))

        counter += 1

    else:
        print('-Class ' + str(i + 1) + ' is rejected due to insufficient samples')

TRAIN_LABELS = np.array(TRAIN_LABELS)
TRAIN_PATCH = np.array(TRAIN_PATCH)
TEST_PATCH = np.array(TEST_PATCH)
TEST_LABELS = np.array(TEST_LABELS)
VAL_PATCH = np.array(VAL_PATCH)
VAL_LABELS = np.array(VAL_LABELS)

print("+-------------------------------------+")
print("Size of Training data: " + str(len(TRAIN_PATCH)) )
print("Size of Validation data: " + str(len(VAL_PATCH))  )
print("Size of Testing data: " + str(len(TEST_PATCH)) )
print("+-------------------------------------+")


train_idx = list(range(len(TRAIN_PATCH)))
shuffle(train_idx)
TRAIN_PATCH = TRAIN_PATCH[train_idx]
if not CHANNEL_FIRST:
    TRAIN_PATCH = np.transpose(TRAIN_PATCH, (0, 2, 3, 1))
TRAIN_LABELS = OnehotTransform(TRAIN_LABELS[train_idx])
train = {}
train["train_patch"] = TRAIN_PATCH
train["train_labels"] = TRAIN_LABELS
io.savemat("./data/" + opt.data + "_Train_patch_" + str(PATCH_SIZE) + ".mat", train)


test_idx = list(range(len(TEST_PATCH)))
shuffle(test_idx)
TEST_PATCH = TEST_PATCH[test_idx]
if not CHANNEL_FIRST:
    TEST_PATCH = np.transpose(TEST_PATCH, (0, 2, 3, 1))
TEST_LABELS = OnehotTransform(TEST_LABELS[test_idx])
test = {}
test["test_patch"] = TEST_PATCH
test["test_labels"] = TEST_LABELS
io.savemat("./data/" + opt.data + "_Test_patch_" + str(PATCH_SIZE) + ".mat", test)


val_idx = list(range(len(VAL_PATCH)))
shuffle(val_idx)
VAL_PATCH = VAL_PATCH[val_idx]
if not CHANNEL_FIRST:
    VAL_PATCH = np.transpose(VAL_PATCH, (0, 2, 3, 1))
    print(VAL_PATCH.shape)
VAL_LABELS = OnehotTransform(VAL_LABELS[val_idx])
val = {}
val["val_patch"] = VAL_PATCH
val["val_labels"] = VAL_LABELS
io.savemat("./data/" + opt.data + "_Val_patch_" + str(PATCH_SIZE) + ".mat", val)

print("+-------------------------------------+")
print("Summary")
print('Train_patch.shape: '+ str(TRAIN_PATCH.shape) )
print('Train_label.shape: '+ str(TRAIN_LABELS.shape) )
print('Test_patch.shape: ' + str(TEST_PATCH.shape))
print('Test_label.shape: ' + str(TEST_LABELS.shape))
print("Validation batch Shape: " + str(VAL_PATCH.shape) )
print("Validation label Shape: " + str(VAL_LABELS.shape) )
print("+-------------------------------------+")
print("\nFinished processing.......")


if opt.plot:
    print('\n Looking at some sample images')
    plot_random_spec_img(TRAIN_PATCH, TRAIN_LABELS)
    plot_random_spec_img(TEST_PATCH, TEST_LABELS)
    plot_random_spec_img(VAL_PATCH, VAL_LABELS)

    GroundTruthVisualise(target_mat)

usage: ipykernel_launcher.py [-h] [--data DATA] [--train_ratio TRAIN_RATIO]
                             [--validation_ratio VALIDATION_RATIO]
                             [--channel_first CHANNEL_FIRST] [--dtype DTYPE]
                             [--plot PLOT]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-574540a8-69a1-4d7b-baa4-d88daa92bfc0.json


SystemExit: ignored

In [None]:
import numpy as np
from random import shuffle
import scipy.io as io
import argparse
from helper import *
import threading
import time
import itertools
import sys

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='Indian_pines', help='Default: Indian_pines, options: Salinas, KSC, Botswana')
parser.add_argument('--train_ratio', type=float, default=0.2)
parser.add_argument('--validation_ratio', type=float, default=0.05)
parser.add_argument('--channel_first', type=bool, default=False, help='Image channel located on the last dimension')
parser.add_argument('--dtype', type=str, default='float32', help='Data type (Eg float64, float32, float16, int64...')
parser.add_argument('--plot', type=bool, default=False, help='TRUE to plot satellite images and ground truth at the end')
opt = parser.parse_args()

# Try loading data from the folder... Otherwise download from online
input_mat, target_mat = maybeDownloadOrExtract(opt.data)

# Output data type
datatype = getdtype(opt.dtype)
HEIGHT = input_mat.shape[0]
WIDTH = input_mat.shape[1]
BAND = input_mat.shape[2]
OUTPUT_CLASSES = np.max(target_mat)
PATCH_SIZE = 5

CHANNEL_FIRST = opt.channel_first

# Normalize image data and select datatype
input_mat = input_mat.astype(datatype)
input_mat = input_mat - np.min(input_mat)
input_mat = input_mat / np.max(input_mat)

# Extract a list that contains the class number with sufficient training samples
list_labels = getListLabel(opt.data)

# For showing a animation only
end_loading = False
def animate():
    global end_loading
    for c in itertools.cycle(['|', '/', '-', '\\']):
        if end_loading:
            break
        sys.stdout.write('\rExtracting '+ opt.data + ' dataset features...' + c)
        sys.stdout.flush()
        time.sleep(0.1)
        sys.stdout.write('\rFinished!\t')

print("+-------------------------------------+")
print('Input_mat shape: ' + str(input_mat.shape))

MEAN_ARRAY = np.ndarray(shape=(BAND, 1))
new_input_mat = []

input_mat = np.transpose(input_mat, (2, 0, 1))

calib_val_pad = int((PATCH_SIZE - 1)/2)
for i in range(BAND):
    MEAN_ARRAY[i] = np.mean(input_mat[i, :, :])
    new_input_mat.append(np.pad(input_mat[i, :, :], calib_val_pad, 'constant', constant_values=0))

input_mat = np.array(new_input_mat)

def Patch(height_index, width_index):

    # Input:
    # Given the index position (x,y) of spatio dimension of the hyperspectral image,

    # Output:
    # a data cube with patch size S (24 neighbours), with label based on central pixel

    height_slice = slice(height_index, height_index+PATCH_SIZE)
    width_slice = slice(width_index, width_index+PATCH_SIZE)

    patch = input_mat[:, height_slice, width_slice]
    mean_normalized_patch = []
    for i in range(patch.shape[0]):
        mean_normalized_patch.append(patch[i] - MEAN_ARRAY[i])

    return np.array(mean_normalized_patch).astype(datatype)


# Assign empty array to store patched images
CLASSES = []
for i in range(OUTPUT_CLASSES):
    CLASSES.append([])

# Assign empty array to count samples in each class
class_label_counter = [0] * OUTPUT_CLASSES

# Start timing for loading
t = threading.Thread(target=animate).start()
start = time.time()

count = 0
for i in range(HEIGHT-1):
    for j in range(WIDTH-1):
        curr_inp = Patch(i, j)
        curr_tar = target_mat[i, j]

        if curr_tar:
            CLASSES[curr_tar-1].append(curr_inp)
            class_label_counter[curr_tar-1] += 1
            count += 1

end_loading = True
end = time.time()
print("Total excution time..." + str(end-start)+'seconds')
print('Total number of samples: ' + str(count))
showClassTable(class_label_counter)

TRAIN_PATCH, TRAIN_LABELS = [], []
TEST_PATCH, TEST_LABELS =[], []
VAL_PATCH, VAL_LABELS = [], []

train_ratio = opt.train_ratio
val_ratio = opt.validation_ratio
# test_ratio = reminder of data

counter = 0  # Represent train_index position
for i, data in enumerate(CLASSES):
    datasize = []
    if i + 1 in list_labels:

        shuffle(data)
        print('Class ' + str(i + 1) + ' is accepted')

        size = round(class_label_counter[i]*train_ratio)

        TRAIN_PATCH += data[:size]
        TRAIN_LABELS += [counter] * size
        datasize.append(size)

        size1 = round(class_label_counter[i]*val_ratio)
        VAL_PATCH += data[size:size+size1]
        VAL_LABELS += [counter] * (size1)
        datasize.append(size1)

        TEST_PATCH += data[size+size1:]
        TEST_LABELS += [counter] * len(data[size+size1:])
        datasize.append(len(TEST_PATCH))

        counter += 1

    else:
        print('-Class ' + str(i + 1) + ' is rejected due to insufficient samples')

TRAIN_LABELS = np.array(TRAIN_LABELS)
TRAIN_PATCH = np.array(TRAIN_PATCH)
TEST_PATCH = np.array(TEST_PATCH)
TEST_LABELS = np.array(TEST_LABELS)
VAL_PATCH = np.array(VAL_PATCH)
VAL_LABELS = np.array(VAL_LABELS)

print("+-------------------------------------+")
print("Size of Training data: " + str(len(TRAIN_PATCH)) )
print("Size of Validation data: " + str(len(VAL_PATCH))  )
print("Size of Testing data: " + str(len(TEST_PATCH)) )
print("+-------------------------------------+")


train_idx = list(range(len(TRAIN_PATCH)))
shuffle(train_idx)
TRAIN_PATCH = TRAIN_PATCH[train_idx]
if not CHANNEL_FIRST:
    TRAIN_PATCH = np.transpose(TRAIN_PATCH, (0, 2, 3, 1))
TRAIN_LABELS = OnehotTransform(TRAIN_LABELS[train_idx])
train = {}
train["train_patch"] = TRAIN_PATCH
train["train_labels"] = TRAIN_LABELS
io.savemat("./data/" + opt.data + "_Train_patch_" + str(PATCH_SIZE) + ".mat", train)


test_idx = list(range(len(TEST_PATCH)))
shuffle(test_idx)
TEST_PATCH = TEST_PATCH[test_idx]
if not CHANNEL_FIRST:
    TEST_PATCH = np.transpose(TEST_PATCH, (0, 2, 3, 1))
TEST_LABELS = OnehotTransform(TEST_LABELS[test_idx])
test = {}
test["test_patch"] = TEST_PATCH
test["test_labels"] = TEST_LABELS
io.savemat("./data/" + opt.data + "_Test_patch_" + str(PATCH_SIZE) + ".mat", test)


val_idx = list(range(len(VAL_PATCH)))
shuffle(val_idx)
VAL_PATCH = VAL_PATCH[val_idx]
if not CHANNEL_FIRST:
    VAL_PATCH = np.transpose(VAL_PATCH, (0, 2, 3, 1))
    print(VAL_PATCH.shape)
VAL_LABELS = OnehotTransform(VAL_LABELS[val_idx])
val = {}
val["val_patch"] = VAL_PATCH
val["val_labels"] = VAL_LABELS
io.savemat("./data/" + opt.data + "_Val_patch_" + str(PATCH_SIZE) + ".mat", val)

print("+-------------------------------------+")
print("Summary")
print('Train_patch.shape: '+ str(TRAIN_PATCH.shape) )
print('Train_label.shape: '+ str(TRAIN_LABELS.shape) )
print('Test_patch.shape: ' + str(TEST_PATCH.shape))
print('Test_label.shape: ' + str(TEST_LABELS.shape))
print("Validation batch Shape: " + str(VAL_PATCH.shape) )
print("Validation label Shape: " + str(VAL_LABELS.shape) )
print("+-------------------------------------+")
print("\nFinished processing.......")


if opt.plot:
    print('\n Looking at some sample images')
    plot_random_spec_img(TRAIN_PATCH, TRAIN_LABELS)
    plot_random_spec_img(TEST_PATCH, TEST_LABELS)
    plot_random_spec_img(VAL_PATCH, VAL_LABELS)

    GroundTruthVisualise(target_mat)

usage: ipykernel_launcher.py [-h] [--data DATA] [--train_ratio TRAIN_RATIO]
                             [--validation_ratio VALIDATION_RATIO]
                             [--channel_first CHANNEL_FIRST] [--dtype DTYPE]
                             [--plot PLOT]
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-574540a8-69a1-4d7b-baa4-d88daa92bfc0.json


SystemExit: ignored