In [None]:
import numpy as np
import pathlib
import rasterio
import rasterio.plot
import matplotlib.pyplot as plt
#import PIL

PATH_MS = 'data/noise-test/ms'
PATH_PAN = 'data/noise-test/pan'

In [None]:
def decode_geotiff(image_path):
    image_path = pathlib.Path(image_path)
    with rasterio.open(image_path) as src:
        img = src.read()
    img = rasterio.plot.reshape_as_image(img) # from channels first to channels last
    return img

def stretch(image, individual_bands = True):
    image_out = np.empty(image.shape)
    if individual_bands:
        for i in range(image.shape[2]):
            image_out[:,:,i] = (image[:,:,i] - np.min(image[:,:,i])) / (np.max(image[:,:,i]) - np.min(image[:,:,i]))
    else:
        image_out = (image - np.min(image)) / (np.max(image) - np.min(image))
    return image_out

def ms_to_rgb(ms, sensor = 'WV02'):
    if sensor == 'WV02':
        rgb = [np.expand_dims(ms[:,:,4], -1), 
               np.expand_dims(ms[:,:,2], -1), 
               np.expand_dims(ms[:,:,1], -1)]
    elif sensor == 'GE01':
        rgb = [np.expand_dims(ms[:,:,2], -1), 
               np.expand_dims(ms[:,:,1], -1), 
               np.expand_dims(ms[:,:,0], -1)]
    elif sensor == 'WV03_VNIR':
        rgb = [np.expand_dims(ms[:,:,1], -1), 
               np.expand_dims(ms[:,:,2], -1), 
               np.expand_dims(ms[:,:,3], -1)]
    else:
        raise ValueError('Only WV02, GE01 and WV03_VNIR band configurations implemented') 
    
    rgb = np.concatenate(rgb, axis = 2)
    rgb = stretch(rgb)
    return rgb

In [None]:
img_sea = decode_geotiff(str(PATH_MS + '/00009.tif'))
img_sea_rgb = ms_to_rgb(img_sea)
img_build = decode_geotiff(str(PATH_MS + '/00010.tif'))
img_build_rgb = ms_to_rgb(img_build)

img = decode_geotiff(str(PATH_PAN + '/00011.tif'))
plt.imshow(img, cmap = 'gray')

In [None]:
plt.imshow(img_sea_rgb)

In [None]:
plt.imshow(img_build_rgb)

In [None]:
print(img_build.shape)
img_build_fl = np.ndarray.flatten(img_build)
print(img_build_fl.shape)
print(img_sea.shape)
img_sea_fl = np.ndarray.flatten(img_sea)
print(img_sea_fl.shape)

In [None]:
plt.hist(img_build_fl)

In [None]:
plt.hist(img_sea_fl)

In [None]:
img_sea_fl