In [None]:
'''
DATA PATHS
'''
TOP_DIR = '/tf/Notebooks/Iwashita'

IR_PATH = TOP_DIR + '/Data/IR/'
RGB_PATH = TOP_DIR + '/Data/RGB/'
MASKS_PATH = TOP_DIR + '/Data/Masks/'
ANNOTATIONS_PATH = TOP_DIR + '/Data/Annotations/'

'''
OUTPUTS PATH
'''
WEIGHTS_PATH = TOP_DIR + '/output/Weights/'
METRICS_PATH = TOP_DIR + '/output/Metrics/'

!cd $TOP_DIR && ls

In [None]:
from enum import Enum
import numpy as np
from numpy import asarray, save
import os
from PIL import Image
import re
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt 
from matplotlib.pyplot import axis, figure, imshow, show, subplot

In [None]:
'''
SUPPORTING FUNCTIONS
'''

'''
CLASSES
'''
classes = Enum('Classes', [
    'UNLABELED',
    'SAND',
    'SOIL',
    'BALLAST',
    'ROCK',
    'BEDROCK',
    'ROCKY_TERRAIN'
    ], start=0)

num_classes = max(classes, key=lambda x: x.value).value + 1

'''
POPULATE DATA SETS
'''
def populate_data_sets(X_list, y_list, rgb_dict, ir_dict, annotations_dict):
    X_rgb = []
    X_ir = []
    y_data =[]

    for i, fn in enumerate(X_list, start=0):        
        X_rgb.append(rgb_dict[fn])
        X_ir.append(ir_dict[fn])
        y_data.append(annotations_dict[fn])

    return np.array(X_rgb), np.array(X_ir), np.array(y_data)

'''
CALCULATE CLASS FREQUENCY
'''
def calculate_class_frequency_single(annotation_array):
    num_classes = annotation_array.shape[2]
    
    pixel_count = np.zeros(num_classes)

    # Iterate over each class and count pixels
    for i in range(num_classes):
        pixel_count[i] = np.sum(annotation_array[:, :, i] == 1)

    # Compute class frequencies
    class_frequency = pixel_count / np.sum(pixel_count)

    # Print out the frequency for each class
    for cls in classes:
        print(f"{cls.name}: {class_frequency[cls.value]*100:.4f}")
        
    return class_frequency

def calculate_class_frequency_set(annotation_array):
    num_classes = annotation_array.shape[3]
    
    pixel_count = np.zeros(num_classes)
    
    # Iterate over each class and count pixels
    for i in range(num_classes):
        pixel_count[i] = np.sum(annotation_array[:, :, :, i] == 1)
        
    # Compute class frequencies
    class_frequency = pixel_count / np.sum(pixel_count)
    
    # Print out the frequency for each class
    for cls in classes:
        print(f"{cls.name}: {class_frequency[cls.value]*100:.4f}")
        
    return class_frequency

'''
FIND MAX CLASS
'''
def find_max_class(annotation_array):
    num_classes = annotation_array.shape[2]
    
    pixel_count = np.zeros(num_classes)

    # Iterate over each class and count pixels
    for i in range(num_classes):
        pixel_count[i] = np.sum(annotation_array[:, :, i] == 1)

    # Compute class frequencies
    class_frequency = pixel_count / np.sum(pixel_count)
    
    # Find max class
    max_freq_index = np.argmax(class_frequency)
        
    return classes(max_freq_index), class_frequency[classes(max_freq_index).value]

'''
ONE-HOT ANNOTATION CHECK
'''
def display_onehot_annotation(annotations_onehot):
    label = np.argmax(annotations_onehot, axis=-1)
    cmap = plt.get_cmap('tab10', 7)

    plt.imshow(label, cmap=cmap)
    plt.colorbar(ticks=range(num_classes), format=plt.FuncFormatter(lambda val, loc: {
        0: "unlabeled",
        1: "sand",
        2: "soil",
        3: "ballast",
        4: "rock",
        5: "bedrock",
        6: "rocky terrain"
    }[val]))
    plt.show()

'''
AUGMENTED FILENAME
'''
def generate_augmented_filename(filename, tag):
    fn, ext = os.path.splitext(filename)
    
    return fn + tag + ext

'''
AUGMENT IMAGES
'''
def augment_image(filename, img):
    img_dict = {}
    img_shape = img.shape

    if (len(img_shape) == 3):
        img_array = np.zeros((5, img_shape[0], img_shape[1], img_shape[2]), dtype=np.uint8)
    else:
        img_array = np.zeros((5, img_shape[0], img_shape[1]), dtype=np.uint8)
    
    # Rotate 90 degrees
    fn = generate_augmented_filename(filename, "_rot90")
    aug = np.rot90(img)
    img_dict[fn] = aug
    
    # Rotate 180 degrees
    fn = generate_augmented_filename(filename, "_rot180")
    aug = np.rot90(img, k=2)
    img_dict[fn] = aug
    
    # Rotate 270 degrees
    fn = generate_augmented_filename(filename, "_rot270")
    aug = np.rot90(img, k=3)
    img_dict[fn] = aug
    
    # Flip horizontally
    fn = generate_augmented_filename(filename, "_flipH")
    aug = np.fliplr(img)
    img_dict[fn] = aug
    
    # Flip vertically
    fn = generate_augmented_filename(filename, "_flipV")
    aug = np.flipud(img)
    img_dict[fn] = aug
    
    return img_dict

def print_augmented_images(original_img, augmented_imgs):
    imgs = list(augmented_imgs.values())
    shape = imgs[0].shape
    
    fig, axes = plt.subplots(2, 3, figsize=(10, 6))
    
    # IR Images
    if (len(shape) == 2):
        axes[0,0].imshow(original_img, cmap='gray')
        axes[0,1].imshow(imgs[0], cmap='gray')
        axes[0,2].imshow(imgs[1], cmap='gray')
        axes[1,0].imshow(imgs[2], cmap='gray')
        axes[1,1].imshow(imgs[3], cmap='gray')
        axes[1,2].imshow(imgs[4], cmap='gray')
    
    # RBG Images
    elif(len(shape) == 3 and shape[2] == 3):
        axes[0,0].imshow(original_img)
        axes[0,1].imshow(imgs[0])
        axes[0,2].imshow(imgs[1])
        axes[1,0].imshow(imgs[2])
        axes[1,1].imshow(imgs[3])
        axes[1,2].imshow(imgs[4])
        
    # Annotation Images
    elif(len(shape) == 3 and shape[2] == 7):
        
        cmap = plt.get_cmap('tab10', 7)
    
        axes[0,0].imshow(np.argmax(original_img, axis=-1), cmap=cmap)
        
        # Showing colorbar for original image only
        fig.colorbar(axes[0,0].imshow(np.argmax(original_img, axis=-1), cmap=cmap), ax=axes[0,0], ticks=range(7), format=plt.FuncFormatter(lambda val, loc: {
            0: "unlabeled",
            1: "sand",
            2: "soil",
            3: "ballast",
            4: "rock",
            5: "bedrock",
            6: "rocky terrain"
        }[val]))
        
        axes[0,1].imshow(np.argmax(imgs[0], axis=-1), cmap=cmap)
        axes[0,2].imshow(np.argmax(imgs[1], axis=-1), cmap=cmap)
        axes[1,0].imshow(np.argmax(imgs[2], axis=-1), cmap=cmap)
        axes[1,1].imshow(np.argmax(imgs[3], axis=-1), cmap=cmap)
        axes[1,2].imshow(np.argmax(imgs[4], axis=-1), cmap=cmap)
    
    else:
        print("Unexpected image dimensions: " + repr(len(shape)))
        
    axes[0,0].set_title("Original")
    axes[0,1].set_title("Rotated 90")
    axes[0,2].set_title("Rotated 180")
    axes[1,0].set_title("Rotated 270")
    axes[1,1].set_title("Flipped Horizontal")
    axes[1,2].set_title("Flipped Vertical")
    
    for ax in axes.flatten():
        ax.axis("off")
        
    plt.tight_layout()
    plt.show()

In [None]:
'''
IMAGE PROPERTIES - Original image dimensions are 800x600
'''
IMG_HEIGHT = 572
IMG_WIDTH = 572

RGB_CHANNELS = 3
IR_CHANNELS = 1

'''
IMAGE LISTS
'''
img_list = [file for file in os.listdir(RGB_PATH) if file.lower().endswith('0000.png')]

rgb_imgs = {}
ir_imgs = {}

'''
LOAD MASKS
'''
rgb_mask = np.array(Image.open(os.path.join(MASKS_PATH, 'rgb_mask.ppm'))) / 255.0
ir_mask = np.array(Image.open(os.path.join(MASKS_PATH, 'ir_mask.png')))[:,:,0] / 255.0

'''
LOAD AND NORMALIZE
'''
print("Processing RGB images...")

for n, filename in tqdm(enumerate(img_list, start=0), total=len(img_list)):
    
    # Open image from file
    rgb_img = Image.open(os.path.join(RGB_PATH, filename))
    
    # Normalize RGB image in an 600x800x3 numpy array
    rgb_array = np.array(rgb_img, dtype=np.float32) / 255.0
    
    # Apply mask
    rgb_array = rgb_array * rgb_mask
    
    # Resize
    rgb_array = resize(rgb_array, (IMG_WIDTH, IMG_HEIGHT), mode='reflect', anti_aliasing=True)
    
    # Save to dictionary
    rgb_imgs[filename] = rgb_array
    
print("Processing IR images...")

for n, filename in tqdm(enumerate(img_list, start=0), total=len(img_list)):
    
    # Open image from file
    ir_img = Image.open(os.path.join(IR_PATH, filename))
    
    # Normalize IR image in an 600x800 numpy array
    ir_array = np.array(ir_img, dtype=np.float32) / 255.0
    
    # Apply mask
    ir_array = ir_array * ir_mask
    
    # Resize
    ir_array = resize(ir_array, (IMG_WIDTH, IMG_HEIGHT), mode='reflect', anti_aliasing=True)
    
    # Save to dictionary
    ir_imgs[filename] = ir_array

In [None]:
img = rgb_imgs[filename]
plt.imshow(img)

print(np.min(img))
print(np.max(img))
print(np.mean(img))
print(np.std(img))

In [None]:
'''
LOAD/DECODE ANNOTATION FILES
'''
annotations = {}

print("Loading annotation files...")

for n, filename in tqdm(enumerate(img_list, start=0), total=len(img_list)):

    img = Image.open(os.path.join(ANNOTATIONS_PATH, filename)).resize((IMG_WIDTH, IMG_HEIGHT))

    encoded = np.array(img)

    label = np.bitwise_or(np.bitwise_or(
        encoded[:, :, 0].astype(np.uint32),
        encoded[:, :, 1].astype(np.uint32) << 8),
        encoded[:, :, 2].astype(np.uint32) << 16)

    annotations[filename] = label

'''
ONE-HOT ENCODE
'''
annotations_onehot = {}
class_freq = {i: 0 for i in range(num_classes)}

print("One-hot encoding annotation files...")

for n, filename in tqdm(enumerate(img_list, start=0), total=len(img_list)):

    onehot_annotation = np.zeros((IMG_HEIGHT, IMG_WIDTH, num_classes), dtype=np.uint8)
    
    for c in range(num_classes):
        mask = (annotations[filename] == c)
        onehot_annotation[..., c] = mask
        class_freq[c] += np.sum(mask)
    
    annotations_onehot[filename] = onehot_annotation

In [None]:
'''
CLASS FREQUENCY PRE-AUGMENTATION
'''
annotation_imgs = np.array(list(annotations_onehot.values()))
pre_aug_freq = calculate_class_frequency_set(annotation_imgs)

In [None]:
'''
FIND CANDIDATE IMAGES TO AUGMENT
'''
imgs_to_augment = []

for key in annotations_onehot:
    tmp, freq = find_max_class(annotations_onehot[key])
    
    # SOIL is dominant, so do not add to list
    if tmp == classes.SAND:
        imgs_to_augment.append(key)
    elif tmp == classes.BALLAST:
        imgs_to_augment.append(key)
    elif tmp == classes.ROCK:
        imgs_to_augment.append(key)
    elif tmp == classes.BEDROCK:
        imgs_to_augment.append(key)
    elif tmp == classes.ROCKY_TERRAIN:
        imgs_to_augment.append(key)

# There should be 103 candidate images
if(len(imgs_to_augment) == 103):
    print("OK")
else:
    print("Something went wrong")

In [None]:
'''
GENERATE AUGMENTED IMAGES
'''
for img in imgs_to_augment:
    
    # Create augmented RGB images
    rgb_tmp = augment_image(img, rgb_imgs[img])
    rgb_imgs.update(rgb_tmp)
    
    # Create augmented IR images
    ir_tmp = augment_image(img, ir_imgs[img])
    ir_imgs.update(ir_tmp)
    
    # Create augmented annotation images
    ann_tmp = augment_image(img, annotations_onehot[img])
    annotations_onehot.update(ann_tmp)
    
if(len(rgb_imgs) == 989 and len(ir_imgs) == 989 and len(annotations_onehot) == 989):
    print("OK")
else:
    print("Something went wrong")

In [None]:
img = rgb_imgs[img_list[-1]]
plt.imshow(img)

In [None]:
'''
CLASS FREQUENCY POST-AUGMENTATION
'''
annotation_imgs = np.array(list(annotations_onehot.values()))
post_aug_freq = calculate_class_frequency_set(annotation_imgs)

In [None]:
'''
GENERATE NEW IMAGE LIST
'''
aug_img_list = list(annotations_onehot.keys())

In [None]:
'''
FILTER EXPERIMENT 1 DATA
'''
exp1_pattern = r'^\d{2}__2017-11-17-16(4[0-9]|[4-5]\d)[0-9]{2}-0000(?:_(?:rot90|rot180|rot270|flipH|flipV))?.png$'
exp1_img_list = [file for file in aug_img_list if re.match(exp1_pattern, file)]

if(len(exp1_img_list) == 112):
    print("OK")
else:
    print("Something went wrong")

In [None]:
'''
SPLIT DATA
'''
X_train, X_temp, y_train, y_temp = train_test_split(
    exp1_img_list, 
    exp1_img_list, 
    test_size=0.50, 
    train_size=0.50, 
    random_state=42, 
    shuffle=True)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, 
    y_temp, 
    test_size=0.50, 
    train_size=0.50, 
    random_state=42, 
    shuffle=False)

print("Populating experiment 1 training data sets...")
exp1_rgb_X_train, exp1_ir_X_train, exp1_y_train = populate_data_sets(
    X_train, y_train, rgb_imgs, ir_imgs, annotations_onehot)

print("Populating experiment 1 validation data sets...")  
exp1_rgb_X_val, exp1_ir_X_val, exp1_y_val = populate_data_sets(
    X_val, y_val, rgb_imgs, ir_imgs, annotations_onehot)

print("Populating experiment 1 test data sets...")  
exp1_rgb_X_test, exp1_ir_X_test, exp1_y_test = populate_data_sets(
    X_test, y_test, rgb_imgs, ir_imgs, annotations_onehot)

In [None]:
cf_train = calculate_class_frequency_set(exp1_y_train)
print("")

cf_test = calculate_class_frequency_set(exp1_y_test)
print("")

cf_val = calculate_class_frequency_set(exp1_y_val)
print("")

In [None]:
img = exp1_rgb_X_train[0]
plt.imshow(img)
print(img.shape)

In [None]:
EXP1_DIR = '/tf/Notebooks/Iwashita/Data/Preprocessed_wAugmentation/Experiment1'

'''
SAVE TRAINING DATA
'''
save(EXP1_DIR + '/Train/exp1_rgb_X_train.npy', exp1_rgb_X_train)
save(EXP1_DIR + '/Train/exp1_ir_X_train.npy', exp1_ir_X_train)
save(EXP1_DIR + '/Train/exp1_y_train.npy', exp1_y_train)

'''
SAVE VALIDATION DATA
'''
save(EXP1_DIR + '/Validate/exp1_rgb_X_val.npy', exp1_rgb_X_val)
save(EXP1_DIR + '/Validate/exp1_ir_X_val.npy', exp1_ir_X_val)
save(EXP1_DIR + '/Validate/exp1_y_val.npy', exp1_y_val)

'''
SAVE TEST DATA
'''
save(EXP1_DIR + '/Test/exp1_rgb_X_test.npy', exp1_rgb_X_test)
save(EXP1_DIR + '/Test/exp1_ir_X_test.npy', exp1_ir_X_test)
save(EXP1_DIR + '/Test/exp1_y_test.npy', exp1_y_test)

print("Done")

In [None]:
'''
FILTER EXPERIMENT 2 DATA
'''
exp2_img_list = aug_img_list

if(len(exp2_img_list) == 989):
    print("OK")
else:
    print("Something went wrong")

In [None]:
'''
SPLIT DATA
'''
X_train, X_temp, y_train, y_temp = train_test_split(
    exp2_img_list, 
    exp2_img_list, 
    test_size=0.50, 
    train_size=0.50, 
    random_state=42, 
    shuffle=True)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, 
    y_temp, 
    test_size=0.50, 
    train_size=0.50, 
    random_state=42, 
    shuffle=False)

print("Populating experiment 2 training data sets...")
exp2_rgb_X_train, exp2_ir_X_train, exp2_y_train = populate_data_sets(
    X_train, y_train, rgb_imgs, ir_imgs, annotations_onehot)

print("Populating experiment 2 validation data sets...")  
exp2_rgb_X_val, exp2_ir_X_val, exp2_y_val = populate_data_sets(
    X_val, y_val, rgb_imgs, ir_imgs, annotations_onehot)

print("Populating experiment 2 test data sets...")  
exp2_rgb_X_test, exp2_ir_X_test, exp2_y_test = populate_data_sets(
    X_test, y_test, rgb_imgs, ir_imgs, annotations_onehot)

In [None]:
cf_train = calculate_class_frequency_set(exp2_y_train)
print("")

cf_test = calculate_class_frequency_set(exp2_y_test)
print("")

cf_val = calculate_class_frequency_set(exp2_y_val)
print("")

In [None]:
EXP2_DIR = '/tf/Notebooks/Iwashita/Data/Preprocessed_wAugmentation/Experiment2'

'''
SAVE TRAINING DATA
'''
save(EXP2_DIR + '/Train/exp2_rgb_X_train.npy', exp2_rgb_X_train)
save(EXP2_DIR + '/Train/exp2_ir_X_train.npy', exp2_ir_X_train)
save(EXP2_DIR + '/Train/exp2_y_train.npy', exp2_y_train)

'''
SAVE VALIDATION DATA
'''
save(EXP2_DIR + '/Validate/exp2_rgb_X_val.npy', exp2_rgb_X_val)
save(EXP2_DIR + '/Validate/exp2_ir_X_val.npy', exp2_ir_X_val)
save(EXP2_DIR + '/Validate/exp2_y_val.npy', exp2_y_val)

'''
SAVE TEST DATA
'''
save(EXP2_DIR + '/Test/exp2_rgb_X_test.npy', exp2_rgb_X_test)
save(EXP2_DIR + '/Test/exp2_ir_X_test.npy', exp2_ir_X_test)
save(EXP2_DIR + '/Test/exp2_y_test.npy', exp2_y_test)

print("Done")

In [None]:
'''
FILTER EXPERIMENT 3 DATA
'''
exp3_pattern = r'^\d{2}__2017-11-17-(?:14(?:1[0-9]|[2-9][0-9])|15\d{2}|16(?:[0-4][0-9]|5[0-9]))[0-5][0-9]-0000(?:_(?:rot90|rot180|rot270|flipH|flipV))?.png$' 
exp3_img_list = [file for file in aug_img_list if re.match(exp3_pattern, file)]

exp3_test_a_pattern = r'^\d{2}__2017-11-17-(?:1(?:[0]\d{2}|1(?:[0-9][0-9]|2(?:[0-9][0-9]|3[0-5][0-9]))))[0-5][0-9]-0000(?:_(?:rot90|rot180|rot270|flipH|flipV))?.png$' 
exp3_test_a_list = [file for file in aug_img_list if re.match(exp3_test_a_pattern, file)]

exp3_test_b_pattern = r'^\d{2}__2017-11-17-(?:1(?:[4-6]\d{2}|7(?:[0-2][0-9]|3[0-9])))[0-5][0-9]-0000(?:_(?:rot90|rot180|rot270|flipH|flipV))?.png$' 
exp3_test_b_list = [file for file in aug_img_list if re.match(exp3_test_b_pattern, file)]

if(len(exp3_img_list) == 431 and len(exp3_test_a_list) == 338 and len(exp3_test_b_list) == 431):
    print("OK")
else:
    print("Something went wrong")

In [None]:
'''
SPLIT DATA
'''
X_train, X_val, y_train, y_val = train_test_split(
    exp3_img_list, 
    exp3_img_list, 
    test_size=0.70, 
    train_size=0.30, 
    random_state=42, 
    shuffle=True)

print("Populating experiment 3 training data sets...")
exp3_rgb_X_train, exp3_ir_X_train, exp3_y_train = populate_data_sets(X_train, y_train, rgb_imgs, ir_imgs, annotations_onehot)

print("Populating experiment 3 validation data sets...")  
exp3_rgb_X_val, exp3_ir_X_val, exp3_y_val = populate_data_sets(X_val, y_val, rgb_imgs, ir_imgs, annotations_onehot)

print("Populating experiment 3 test data sets...")  
exp3_rgb_X_test_a, exp3_ir_X_test_a, exp3_y_test_a = populate_data_sets(exp3_test_a_list, exp3_test_a_list, rgb_imgs, ir_imgs, annotations_onehot)
exp3_rgb_X_test_b, exp3_ir_X_test_b, exp3_y_test_b = populate_data_sets(exp3_test_b_list, exp3_test_b_list, rgb_imgs, ir_imgs, annotations_onehot)


In [None]:
cf_train = calculate_class_frequency_set(exp3_y_train)
print("")

cf_test = calculate_class_frequency_set(exp3_y_test_a)
print("")

cf_test = calculate_class_frequency_set(exp3_y_test_b)
print("")

cf_val = calculate_class_frequency_set(exp3_y_val)
print("")

In [None]:
EXP3_DIR = '/tf/Notebooks/Iwashita/Data/Preprocessed_wAugmentation/Experiment3'

'''
SAVE TRAINING DATA
'''
save(EXP3_DIR + '/Train/exp3_rgb_X_train.npy', exp3_rgb_X_train)
save(EXP3_DIR + '/Train/exp3_ir_X_train.npy', exp3_ir_X_train)
save(EXP3_DIR + '/Train/exp3_y_train.npy', exp3_y_train)

'''
SAVE VALIDATION DATA
'''
save(EXP3_DIR + '/Validate/exp3_rgb_X_val.npy', exp3_rgb_X_val)
save(EXP3_DIR + '/Validate/exp3_ir_X_val.npy', exp3_ir_X_val)
save(EXP3_DIR + '/Validate/exp3_y_val.npy', exp3_y_val)

'''
SAVE TEST DATA
'''
save(EXP3_DIR + '/Test/exp3_rgb_X_test_a.npy', exp3_rgb_X_test_a)
save(EXP3_DIR + '/Test/exp3_ir_X_test_a.npy', exp3_ir_X_test_a)
save(EXP3_DIR + '/Test/exp3_y_test_a.npy', exp3_y_test_a)

save(EXP3_DIR + '/Test/exp3_rgb_X_test_b.npy', exp3_rgb_X_test_b)
save(EXP3_DIR + '/Test/exp3_ir_X_test_b.npy', exp3_ir_X_test_b)
save(EXP3_DIR + '/Test/exp3_y_test_b.npy', exp3_y_test_b)

print("Done")