Author: Daniel Lusk, University of Potsdam

Inspired by: Ankit Kariryaa ([github repo](https://github.com/ankitkariryaa/An-unexpectedly-large-count-of-trees-in-the-western-Sahara-and-Sahel))

In [167]:
import glob
import os
import warnings
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
from config import UNetTraining
from core.dataset_generator import DataGenerator
from core.frame_info import FrameInfo
from core.split_frames import split_dataset
from tqdm import tqdm_notebook as tqdm

warnings.filterwarnings("ignore")  # ignore annoying warnings

# Magic commands
%matplotlib inline
%reload_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

Load the configuration and get the image directories.

In [169]:
config = UNetTraining.Config()
im_dirs = glob.glob(os.path.join(config.image_dir, "*"))

Read all images (aka frames) into memory

In [132]:
frames = []

for d in tqdm(im_dirs):
    rgbi_im = rio.open(glob.glob(os.path.join(d, config.rgbi_dn, "*.tif"))[0])
    ndvi_im = rio.open(glob.glob(os.path.join(d, config.ndvi_dn, "*.tif"))[0])
    label_im = rio.open(glob.glob(os.path.join(d, config.label_dn, "*.tif"))[0])
    weights_im = rio.open(glob.glob(os.path.join(d, config.weights_dn, "*.tif"))[0])
    
    read_rgbi_im = (np.moveaxis(rgbi_im.read(), 0, -1)) / 255  # Scale to 0-1
    read_ndvi_im = (np.moveaxis(ndvi_im.read(), 0, -1) + 1) / 2  # Scale to 0-1
    read_label_im = np.moveaxis(label_im.read(), 0, -1)
    read_weights_im = np.moveaxis(weights_im.read(), 0, -1)
    
    if config.use_binary_labels:
        read_label_im[read_label_im > 0] = 1 # Binarize labels
    
    comb_im = np.dstack((read_rgbi_im, read_ndvi_im))
    f = FrameInfo(comb_im, read_label_im, read_weights_im, d)
    frames.append(f)

  0%|          | 0/360 [00:00<?, ?it/s]

Split into train, validation, and test sets, and init generators.

In [170]:
train_frame_idx, val_frame_idx, test_frame_idx = split_dataset(
    frames, config.frames_json, config.patch_dir
)

# Sanity check for set sizes
print("Training set size:", len(train_frame_idx))
print("Validation set size:", len(val_frame_idx))
print("Testing set size:", len(test_frame_idx))

annotation_channels = config.input_label_channel + config.input_weight_channel

# Training data generator
# Don't apply augmentation for now until the weighting scheme overwriting is figured out.
train_generator = DataGenerator(
    config.input_image_channels,
    config.patch_size,
    train_frame_idx,
    frames,
    annotation_channels,
).random_generator(config.BATCH_SIZE)

# Validation data generator
val_generator = DataGenerator(
    config.input_image_channels,
    config.patch_size,
    val_frame_idx,
    frames,
    annotation_channels
).random_generator(config.BATCH_SIZE)

# Testing data generator
test_generator = DataGenerator(
    config.input_image_channels,
    config.patch_size,
    test_frame_idx,
    frames,
    annotation_channels
).random_generator(config.BATCH_SIZE)

Reading train-test split from file...
Training set size: 230
Validation set size: 58
Testing set size: 72


Inspect the images to ensure their labels and weights correspond accurately

In [None]:
for _ in range(1):
    train_images, real_label = next(train_generator)
    ann = real_label[:,:,:,0]
    wei = real_label[:,:,:,1]
    #overlay of annotation with boundary to check the accuracy
    #5 images in each row are: pan, ndvi, annotation, weight(boundary), overlay of annotation with weight
    overlay = ann + wei
    overlay = overlay[:,:,:,np.newaxis]
    display_images(np.concatenate((train_images,real_label, overlay), axis = -1))