In [None]:
import numpy as np
import pathlib
import rasterio
import rasterio.plot
import matplotlib.pyplot as plt
import pathlib
import random
#import PIL

PATH = 'data/toulon-laspezia-tiles/WV02'

In [None]:
def list_imgs(path):
    p = pathlib.Path(PATH).glob('**/*.tif')
    files = [x for x in p if x.is_file()]
    ms_l, pan_l = [], []
    for file in files:
        if file.parent.stem == 'ms':
            ms_l.append(file)
        else:
            pan_l.append(file)
    return ms_l, pan_l

In [None]:
ms_l, pan_l = list_imgs(PATH)
print(len(ms_l), len(pan_l))

In [None]:
def decode_geotiff(image_path):
    if isinstance(image_path, str):
        image_path = pathlib.Path(image_path)
    with rasterio.open(image_path) as src:
        img = src.read()
    img = rasterio.plot.reshape_as_image(img) # from channels first to channels last
    return img

def stretch(image, individual_bands = True):
    image_out = np.empty(image.shape)
    if individual_bands:
        for i in range(image.shape[2]):
            image_out[:,:,i] = (image[:,:,i] - np.min(image[:,:,i])) / (np.max(image[:,:,i]) - np.min(image[:,:,i]))
    else:
        image_out = (image - np.min(image)) / (np.max(image) - np.min(image))
    return image_out

def ms_to_rgb(ms, sensor = 'WV02'):
    if sensor == 'WV02':
        rgb = [np.expand_dims(ms[:,:,4], -1), 
               np.expand_dims(ms[:,:,2], -1), 
               np.expand_dims(ms[:,:,1], -1)]
    elif sensor == 'GE01':
        rgb = [np.expand_dims(ms[:,:,2], -1), 
               np.expand_dims(ms[:,:,1], -1), 
               np.expand_dims(ms[:,:,0], -1)]
    elif sensor == 'WV03_VNIR':
        rgb = [np.expand_dims(ms[:,:,1], -1), 
               np.expand_dims(ms[:,:,2], -1), 
               np.expand_dims(ms[:,:,3], -1)]
    else:
        raise ValueError('Only WV02, GE01 and WV03_VNIR band configurations implemented') 
    
    rgb = np.concatenate(rgb, axis = 2)
    rgb = stretch(rgb)
    return rgb

def plot_image_pair(ms, pan):
    fig, axs = plt.subplots(1, 2, constrained_layout = True, figsize = (12, 12))
    axs[0].imshow(ms_to_rgb(ms))
    axs[1].imshow(pan, cmap='gray')

def plot_hist_img_pairs(ms, pan):
    ms_flat = np.ndarray.flatten(ms)
    pan_flat = np.ndarray.flatten(pan)
    fig, axs = plt.subplots(2, 2, constrained_layout = True, figsize = (10, 10))
    
    axs[0,0].hist(ms_flat)
    axs[0,0].set_title(str('ms - '+'mean:'+str(round(np.mean(ms_flat), 3)))+',  sd:'+str(round(np.std(ms_flat), 3)))
    axs[0,1].hist(pan_flat)
    axs[0,1].set_title(str('ms - '+'mean:'+str(round(np.mean(pan_flat), 3)))+',  sd:'+str(round(np.std(pan_flat), 3)))
    axs[1,0].imshow(ms_to_rgb(ms))
    axs[1,1].imshow(pan, cmap='gray')

In [None]:
def random_img_pair(ms_l, pan_l):
    assert len(ms_l) == len(pan_l)
    n = len(ms_l)
    i = random.randint(0, n)
    ms, pan = ms_l[i], pan_l[i]
    ms = decode_geotiff(ms)
    pan = decode_geotiff(pan)
    return ms, pan

ms, pan = random_img_pair(ms_l, pan_l)
plot_hist_img_pairs(ms, pan)