In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
! pip install segmentation_models

In [None]:
import os
import tensorflow as tf
import numpy as np
from PIL import Image
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score  
from tensorflow.keras.preprocessing.image import load_img
from math import ceil
import threading
from time import sleep

In [None]:
custom_objects = {'binary_crossentropy_plus_jaccard_loss':bce_jaccard_loss, 'iou_score':iou_score}

In [None]:
# Paths for input files and output
out_path = '/content/drive/MyDrive/CS507/unet/predictions/segmentation'
in_path = '/content/drive/MyDrive/CS507/unet/data/'
# classify_model_path = './content/drive/MyDrive/CS507/unet/CNN_classification_0.828.h5'
unet_model_path = '/content/drive/MyDrive/CS507/unet/resnet34_B8A_B11_B12_82.h5'

# Bands used for prediction (must be in same order as in training)
# classify_bands = ['B3', 'B4', 'B8', 'B8A', 'B11', 'B12']
unet_bands = ['B8A', 'B11', 'B12']

# Batch size for prediction (Must be multiple of 81 as each image split to 81 patches)
batch_size = 81*2

In [None]:
if not os.path.exists(out_path):
    os.mkdir(out_path)

In [None]:
# def classify_predict(img_arr):
#     classify_prediction = classify_model.predict(img_arr)
#     return classify_prediction.argmax(axis=1).reshape(9,9)

In [None]:
# classify_model = tf.keras.models.load_model(classify_model_path)
unet_model = tf.keras.models.load_model(unet_model_path, custom_objects, compile=True)

In [None]:
# List of files as (scene_name, scene_num)
f_list = []
for file in os.listdir(in_path + 'atm_penetration/'):
    fsplit = file.split('_')
    scene_name = '_'.join(fsplit[:4])
    scene_num = int(fsplit[-1][:-4])
    f_list.append((scene_name, scene_num))
    
f_list.sort()
len(f_list)

In [None]:
# Function to divide image to num*num patches
def crop_img(img, num, save_prefix='', ret_numpy=False):
    size = img.size
    crop_sizes = []
    for i in range(num):
        for j in range(num):
            x1 = i * size[0]//num
            x2 = x1 + size[0]//num
            y1 = j * size[1]//num
            y2 = y1 + size[1]//num
            crop_sizes.append((x1, y1, x2, y2))

    imgs = []
    for i, s in enumerate(crop_sizes):
        cropped = img.crop(s)
        if ret_numpy:
            imgs.append( np.array(cropped))
        else:
            cropped.save(save_prefix + f'{i}.jpg', format='JPEG', quality=80)
    if ret_numpy:
        return np.stack(imgs, axis=0)

# Return band image data of given scene name and number as a numpy array
def get_bands_arr(path, scene_name, scene_num, bands, crop_num, img_size):
    arr = []
    for b in bands:
        file_path = f'{path}{scene_name}_{b}_{scene_num}.jpg'
        img = Image.open(file_path).resize(img_size)
        arr.append(crop_img(img, crop_num, ret_numpy=True))
    return np.stack(arr, axis=3)

In [None]:
class Sequence_generator(tf.keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, f_list, bands, batch_size, img_size):
        self.batch_size = batch_size
        self.img_size = img_size
        self.f_list = f_list
        self.bands = bands

    def __len__(self):
        return ceil(len(self.f_list) * 81 / self.batch_size)

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        num = self.batch_size // 81
        i = idx * num
        batch_files = self.f_list[i : i + num]
        x = np.zeros((len(batch_files)*81,) + (self.img_size, self.img_size,)+
                     (len(self.bands),), dtype="float32")

        for j, file in enumerate(batch_files):
            scene_name, scene_num = file
            np_arr = get_bands_arr(
                in_path+'bands/' ,scene_name, scene_num, unet_bands, 9, (self.img_size*9, self.img_size*9))
            x[j*81:(j+1)*81] = np_arr

        return x


In [None]:
predict_gen = Sequence_generator(f_list[:5], unet_bands, batch_size, 256)
predict_gen[0].shape

In [None]:
# Save the results from unet predictions to PNG image
def save_files(files, unet_preds, c):
  for j in range(len(files)):
    out_image = Image.new('RGB', (2700, 2700), 'black')
    for k in range(81):
      paste_img(unet_preds[(j*81)+k], out_image, k, (300,300))
    scene_id, num = files[j]
    out_image.save(f'{out_path}/{scene_id}_{num}.png')
    c+=1
  print(f'\nsaved: {c}')

# Paste image patch (from numpy array) on output image at appropriate position
def paste_img(nparr, out_image, i, resize):
    other = nparr == 0
    burned = nparr == 1
    vegetation = nparr == 2
    unknown = nparr == 3

    pred_arr = np.zeros(nparr.shape + (3,), dtype='uint8')
    pred_arr[vegetation] = [0, 255, 0]
    pred_arr[burned] = [255, 0, 0]
    pred_arr[unknown] = [255, 255, 255]

    temp_img = Image.fromarray(pred_arr)

    box = (i // 9 * 300, i % 9 * 300)
    out_image.paste(temp_img.resize(resize), box)

In [None]:
div = 30
c = 0
threads = []
for i in range(ceil(len(f_list)/div)):
    # Divide f_list into batches of size 'div'
    print(f'batch: {i}')
    files = f_list[i*div: (i+1)*div]

    # Unet prediction on files
    predict_gen = Sequence_generator(files, unet_bands, batch_size, 256)
    predictions = unet_model.predict(
        predict_gen,
        use_multiprocessing=True,
        workers=4,
        batch_size=batch_size
    )
    unet_preds = predictions.argmax(
        axis=3).astype('uint8')

    # Thread to save predictions as images
    t = threading.Thread(target=save_files, args=(
        files[:], unet_preds.copy(), c))
    t.start()
    threads.append(t)
    c += div

    # Limit maximum running threads to prevent RAM usage
    while True:
        active_threads = 0
        for t in threads:
            if t.is_alive():
                active_threads += 1
            else:
                threads.remove(t)
        if active_threads < 3:
            break
        else:
            sleep(1)
    # np.savez(f'./cropped/{i}.npz', unet_preds)/

for t in threads:
    t.join()


In [None]:
# Save outputs in nomalized form
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
files = os.listdir(out_path)

err_files = []

def normalize(file):
    in_file = out_path + '/' + file
    img = Image.open(in_file)
    mask = np.zeros(img.size)
    data = np.array(img)

    try:
        red, green, blue = data.T
    except:
        err_files.append(file)
        return

    mask[(red.T > 245) & (green.T < 10) & (blue.T < 10)] = 1     # red
    mask[(red.T < 10) & (green.T > 245) & (blue.T < 10)] = 2     # green
    mask[(red.T > 245) & (green.T > 245) & (blue.T > 245)] = 3   # white
    outimg = Image.fromarray(mask).convert('L')
    out_file = in_file.replace('segmentation', 'normalized')
    outimg.save(out_file)

with ThreadPoolExecutor(max_workers=16) as executor:
    results = list(tqdm(executor.map(normalize, files), total=len(files)))

In [None]:
err_files