In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import glob
import os
import sys
from pathlib import Path
from PIL import Image

import tensorflow as tf
import tensorflow.keras as keras


In [2]:
IN_DIR = Path("./netin")
OUT_DIR = Path("./netout")
LOAD_SIZE = 128

BAND_DIRS = sorted(list(IN_DIR.glob("B*")))
display(BAND_DIRS)

[PosixPath('netin/B1'),
 PosixPath('netin/B2'),
 PosixPath('netin/B3'),
 PosixPath('netin/B4'),
 PosixPath('netin/B5'),
 PosixPath('netin/B7')]

In [3]:
MAX_X = 255

def read_fname(fname):
    bands = np.asarray([np.array(Image.open(band_dir / fname)) for band_dir in BAND_DIRS]) / MAX_X
    mchannel = np.dstack(bands)
    return mchannel

def rgb_transform(ds):
    return np.flip(ds[:,:,:,1:4], 3)

In [4]:
band1_paths = list(BAND_DIRS[0].glob("*.png"))

In [5]:
def readin_batch(band1pths):
    img_lst = []
    img_names = []
    for imgpth in band1pths:
        img = read_fname(imgpth.name)
        img_lst.append(img)
        img_names.append(imgpth.name)
    return img_names, np.asarray(img_lst)

In [6]:
def write_out_img(name, img_data):
    imga = tf.cast(img_data * 255, tf.uint8)
    imgencoded = tf.image.encode_png(imga)
    out_file = OUT_DIR / name
    tf.io.write_file(out_file.as_posix(), imgencoded)

In [7]:
from keras_unet.models import custom_unet

input_shape = (512, 512, 6)

model = custom_unet(
    input_shape,
    filters=40,
    use_batch_norm=True,
    dropout=0.15,  # 0.3
    dropout_change_per_layer=0.0,
    num_layers=5
)

display(model.summary())

-----------------------------------------
keras-unet init: TF version is >= 2.0.0 - using `tf.keras` instead of `Keras`
-----------------------------------------
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 6) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 40) 2160        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 40) 160         conv2d[0][0]                     
__________________________________________________________________________________________________
spatial_dropout2d (SpatialDropo

None

In [8]:
model_filename = 'segm_model_v3.h5'
model.load_weights(model_filename)

In [9]:
from keras_unet.utils import plot_imgs

for batch in np.array_split(band1_paths, LOAD_SIZE):
    names, x_pred = readin_batch(batch)
    y_pred = model.predict(x_pred)
#     plot_imgs(org_imgs=rgb_transform(x_pred), mask_imgs=y_pred, nm_img_to_plot=10, figsize=10)
    for name, pred in zip(names, y_pred):
        write_out_img(name, pred)