In [1]:
import os
import random
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from matplotlib import pyplot as plt
from superresolution import Superresolution
from utils import get_img_paths, load_image, create_mask, plot_prediction
from model import DeeplabV3Plus
from tqdm import tqdm

In [2]:
BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, "data")
PASCAL_ROOT = os.path.join(DATA_DIR, "VOCdevkit", "VOC2012")
IMGS_PATH = os.path.join(PASCAL_ROOT, "JPEGImages")

precomputed_dest_root = os.path.join(DATA_DIR, "precomputed_features")
if not os.path.exists(precomputed_dest_root):
    os.mkdir(precomputed_dest_root)

SEED = np.random.randint(0, 1000)
IMG_SIZE = (512, 512)
BATCH_SIZE = 64
BUFFER_SIZE = 1000
EPOCHS = 30
CLASSES = 21
RESHAPE_MASKS = True
NUM_AUG = 50
CLASS_ID = 8 # Cat class

In [3]:
def filter_by_class(img_paths, class_id, image_size=(512, 512)):
    """
    Given a list of image paths, return the images that contain the given class id in the respective mask

    Args:
        img_paths: List of image paths to check
        class_id: Class id used for filering
        image_size: Size of the image used to load and resize the image

    Returns: A dictionary whose keys are the image filename and values are the actual images

    """
    images_dict = {}
    for img_path in img_paths:
        image_name = os.path.splitext(os.path.basename(img_path))[0]
        mask_path = img_path.replace("JPEGImages", "SegmentationClassAug").replace("jpg", "png")
        mask = load_image(mask_path, image_size=image_size, normalize=False, is_png=True, resize_method="nearest")
        if np.any(mask == class_id):
            image = load_image(img_path, image_size=IMG_SIZE, normalize=True)
            images_dict[image_name] = image

    print(f"Valid images: {len(images_dict)} (Initial:  {len(img_paths)})")
    return images_dict

In [4]:
image_list_path = os.path.join(DATA_DIR, "augmented_file_lists", "valaug.txt")
image_paths = get_img_paths(image_list_path, IMGS_PATH)[:150]
images_dict = filter_by_class(image_paths, class_id=8)

valid_filenames = list(images_dict.keys())

model_no_upsample = DeeplabV3Plus(
    input_shape=(512, 512, 3),
    classes=21,
    OS=16,
    last_activation=None,
    load_weights=True,
    backbone="mobilenet",
    alpha=1.).build_model(final_upsample=False)

model_standard = DeeplabV3Plus(
    input_shape=(512, 512, 3),
    classes=21,
    OS=16,
    last_activation=None,
    load_weights=True,
    backbone="mobilenet",
    alpha=1.).build_model(final_upsample=True)

Valid images: 19 (Initial:  150)


# Save standard output for comparison

In [5]:
def get_prediction(model, input_image):

    prediction = model.predict(input_image[tf.newaxis, ...])
    mask = create_mask(prediction[0])

    return mask

def save_standard_output(image_dict, model, standard_out_folder, filter_class_id=None):
    standard_masks = {}
    if not os.path.exists(standard_out_folder):
        os.mkdir(standard_out_folder)

    for key in tqdm(image_dict):
        standard_mask = get_prediction(model, image_dict[key])
        if filter_class_id is not None:
            standard_mask = tf.where(standard_mask == filter_class_id, standard_mask, 0) # Set to 0 all predictions different from the given class
        tf.keras.utils.save_img(f"{standard_out_folder}/{key}.png", standard_mask, scale=False)
        standard_masks[key] = standard_mask
    return standard_masks

In [6]:
standard_out_folder = os.path.join(DATA_DIR, "standard_output")
standard_masks_dict = save_standard_output(images_dict, model_standard, standard_out_folder, filter_class_id=CLASS_ID)

100%|██████████| 19/19 [00:06<00:00,  2.98it/s]


# Precompute Augmented Output Features

In [7]:
def augment_images(batched_images, angles, shifts):

    rotated_images = tfa.image.rotate(batched_images, angles, interpolation="bilinear")
    translated_images = tfa.image.translate(rotated_images, shifts, interpolation="bilinear")

    return translated_images


def save_augmented_features(model, images_array, dest_folder, filter_class_id=None):
    if not os.path.exists(dest_folder):
        os.mkdir(dest_folder)

    predictions = model.predict(images_array, batch_size=2)

    for i, prediction in enumerate(predictions):
        mask = create_mask(prediction)
        if filter_class_id is not None:
            mask = tf.where(mask == filter_class_id, mask, 0) # Set to 0 all predictions different from the given class


        #mask_npy = mask.numpy()
        #mask_scaled = ((mask_npy - mask_npy.min()) * (1/(mask_npy.max() - mask_npy.min()) * 255)).astype('uint8')

        tf.keras.utils.save_img(f"{dest_folder}/{i}.png", mask, scale=True)

    return predictions


def precompute_augmented_features(image_filenames, dest_root_folder, model, class_id=None, num_aug=100, angle_max=0.5, shift_max=30):
    for filename in tqdm(image_filenames):
        image_path = os.path.join(IMGS_PATH, f"{filename}.jpg")
        image = load_image(image_path, image_size=(512, 512), normalize=True)
        batched_image = tf.tile(tf.expand_dims(image, axis=0), [num_aug, 1, 1, 1])  # Size [num_aug, 512, 512, 3]
        angles = np.random.uniform(-angle_max, angle_max, num_aug)
        shifts = np.random.uniform(-shift_max, shift_max, (num_aug, 2))
        # First sample is not augmented
        angles[0] = 0
        shifts[0] = np.array([0, 0])
        angles = angles.astype("float32")
        shifts = shifts.astype("float32")

        augmented_images = augment_images(batched_image, angles, shifts)

        dest_folder = os.path.join(dest_root_folder, filename)

        save_augmented_features(model, augmented_images, dest_folder=dest_folder, filter_class_id=class_id)
        np.save(os.path.join(dest_folder, f"{filename}_angles"), angles)
        np.save(os.path.join(dest_folder, f"{filename}_shifts"), shifts)

In [8]:
angle_max = 0.5  # in radians
shift_max = 30

precompute_augmented_features(valid_filenames, precomputed_dest_root, model_no_upsample, class_id=CLASS_ID, num_aug=NUM_AUG,
                              angle_max=angle_max, shift_max=shift_max)

100%|██████████| 19/19 [01:20<00:00,  4.23s/it]


# Compute Super-Resolution Output

In [11]:
def load_images(img_folder):
    images = []
    # Sort images based on their filename which is an integer indicating the augmented copy number
    image_list = sorted([name.replace(".png", "") for name in os.listdir(img_folder) if ".npy" not in name], key=int)

    for img_name in image_list:
        if ".npy" in img_name:
            continue
        image = load_image(os.path.join(img_folder, f"{img_name}.png"), normalize=False, is_png=True)
        images.append(image)

    return images


def get_precomputed_folders_path(root_dir, num_aug=100):
    valid_folders = []
    for path in os.listdir(root_dir):
        full_path = os.path.join(root_dir, path)
        if len(os.listdir(full_path)) == (num_aug + 2):
            valid_folders.append(full_path)
        else:
            print(f"Skipped folder named {path} as it is not valid")

    return valid_folders


def compute_save_final_output(superresolution_obj, image_filenames, precomputed_root_dir, output_folder, num_aug=100):

    superres_masks = {}
    losses = {}

    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    for filename in tqdm(image_filenames):
        precomputed_folder_path = os.path.join(precomputed_root_dir, filename)

        if not len(os.listdir(precomputed_folder_path)) == (num_aug + 2):
            print(f"Skipped folder named {filename} as it is not valid")
            continue

        augmented_images = tf.stack(load_images(precomputed_folder_path))
        max_values = np.max(augmented_images, axis=(1, 2, 3), keepdims=True)
        max_values[max_values == 0.] = 1.
        augmented_images = augmented_images / max_values # TODO: Testing only
        augmented_images = tf.cast(augmented_images, tf.float32)

        base_name = os.path.basename(os.path.normpath(precomputed_folder_path))
        angles = np.load(os.path.join(precomputed_folder_path, f"{base_name}_angles.npy"))
        shifts = np.load(os.path.join(precomputed_folder_path, f"{base_name}_shifts.npy"))
        target_image, loss = superresolution_obj.compute_output(augmented_images, angles, shifts)
        target_image = target_image[0]

        #target_image_npy = target_image[0].numpy()
        #target_image_scaled = ((target_image_npy - target_image_npy.min()) * (1/(target_image_npy.max() - target_image_npy.min()) * 255)).astype('uint8')

        tf.keras.utils.save_img(f"{output_folder}/{base_name}.png", target_image, scale=True)

        superres_masks[base_name] = target_image
        losses[base_name] = loss

    return superres_masks, losses

In [12]:
# super resolution parameters
learning_rate = 1e-3
lambda_eng = 0.0001 * NUM_AUG
lambda_tv = 0.002 * NUM_AUG
num_iter = 400

superresolution = Superresolution(
    lambda_tv=lambda_tv,
    lambda_eng=lambda_eng,
    num_iter=num_iter,
    num_aug=NUM_AUG,
    learning_rate=learning_rate
)

precomputed_root_dir = os.path.join(DATA_DIR, "precomputed_features")
output_folder = os.path.join(DATA_DIR, "superres_output")

superres_masks_dict, losses = compute_save_final_output(superresolution, valid_filenames, precomputed_root_dir, output_folder, num_aug=NUM_AUG)

  0%|          | 0/19 [00:00<?, ?it/s]

1/400 -- loss = 40348.0
11/400 -- loss = 39563.8125
21/400 -- loss = 38789.5078125
31/400 -- loss = 38026.50390625
41/400 -- loss = 37274.140625
51/400 -- loss = 36533.7734375
61/400 -- loss = 35806.30859375
71/400 -- loss = 35092.35546875
81/400 -- loss = 34391.73046875
91/400 -- loss = 33704.05078125
101/400 -- loss = 33029.05859375
111/400 -- loss = 32366.548828125
121/400 -- loss = 31716.228515625
131/400 -- loss = 31077.95703125
141/400 -- loss = 30451.5
151/400 -- loss = 29836.623046875
161/400 -- loss = 29233.189453125
171/400 -- loss = 28641.0078125
181/400 -- loss = 28059.87109375
191/400 -- loss = 27489.6171875
201/400 -- loss = 26930.056640625
211/400 -- loss = 26381.033203125
221/400 -- loss = 25842.412109375
231/400 -- loss = 25313.9453125
241/400 -- loss = 24795.529296875
251/400 -- loss = 24287.001953125
261/400 -- loss = 23788.201171875
271/400 -- loss = 23298.974609375
281/400 -- loss = 22819.208984375
291/400 -- loss = 22348.720703125
301/400 -- loss = 21887.373046875

  5%|▌         | 1/19 [00:17<05:10, 17.26s/it]

400/400 -- loss = 17784.142578125
1/400 -- loss = 17677.0
11/400 -- loss = 17350.291015625
21/400 -- loss = 17028.458984375
31/400 -- loss = 16714.6015625
41/400 -- loss = 16409.578125
51/400 -- loss = 16114.033203125
61/400 -- loss = 15828.2939453125
71/400 -- loss = 15552.0830078125
81/400 -- loss = 15285.091796875
91/400 -- loss = 15025.703125
101/400 -- loss = 14774.8505859375
111/400 -- loss = 14533.0078125
121/400 -- loss = 14299.8046875
131/400 -- loss = 14074.919921875
141/400 -- loss = 13858.0
151/400 -- loss = 13648.7353515625
161/400 -- loss = 13446.96484375
171/400 -- loss = 13252.4033203125
181/400 -- loss = 13064.81640625
191/400 -- loss = 12883.9716796875
201/400 -- loss = 12709.7060546875
211/400 -- loss = 12541.7373046875
221/400 -- loss = 12379.91015625
231/400 -- loss = 12224.0224609375
241/400 -- loss = 12073.8408203125
251/400 -- loss = 11929.197265625
261/400 -- loss = 11789.953125
271/400 -- loss = 11655.8701171875
281/400 -- loss = 11526.8017578125
291/400 -- lo

 11%|█         | 2/19 [00:32<04:37, 16.31s/it]

400/400 -- loss = 10318.373046875
1/400 -- loss = 43483.0
11/400 -- loss = 42633.42578125
21/400 -- loss = 41793.8046875
31/400 -- loss = 40965.671875
41/400 -- loss = 40149.1328125
51/400 -- loss = 39344.15234375
61/400 -- loss = 38550.15625
71/400 -- loss = 37768.74609375
81/400 -- loss = 37000.171875
91/400 -- loss = 36244.31640625
101/400 -- loss = 35500.9921875
111/400 -- loss = 34769.890625
121/400 -- loss = 34050.90625
131/400 -- loss = 33343.91796875
141/400 -- loss = 32648.716796875
151/400 -- loss = 31965.150390625
161/400 -- loss = 31293.14453125
171/400 -- loss = 30632.548828125
181/400 -- loss = 29983.17578125
191/400 -- loss = 29344.9765625
201/400 -- loss = 28717.763671875
211/400 -- loss = 28101.423828125
221/400 -- loss = 27495.83984375
231/400 -- loss = 26900.8359375
241/400 -- loss = 26316.369140625
251/400 -- loss = 25742.24609375
261/400 -- loss = 25178.34765625
271/400 -- loss = 24624.57421875
281/400 -- loss = 24080.796875
291/400 -- loss = 23546.83984375
301/400

 16%|█▌        | 3/19 [00:49<04:19, 16.25s/it]

400/400 -- loss = 18333.333984375
1/400 -- loss = 13948.0
11/400 -- loss = 13677.1181640625
21/400 -- loss = 13409.369140625
31/400 -- loss = 13145.73046875
41/400 -- loss = 12886.505859375
51/400 -- loss = 12630.7041015625
61/400 -- loss = 12377.8876953125
71/400 -- loss = 12128.1875
81/400 -- loss = 11881.7255859375
91/400 -- loss = 11639.8173828125
101/400 -- loss = 11402.3046875
111/400 -- loss = 11169.0673828125
121/400 -- loss = 10939.9560546875
131/400 -- loss = 10714.8134765625
141/400 -- loss = 10493.6220703125
151/400 -- loss = 10276.2314453125
161/400 -- loss = 10062.6640625
171/400 -- loss = 9852.806640625
181/400 -- loss = 9646.6220703125
191/400 -- loss = 9444.08203125
201/400 -- loss = 9245.1083984375
211/400 -- loss = 9049.7412109375
221/400 -- loss = 8857.802734375
231/400 -- loss = 8669.3798828125
241/400 -- loss = 8484.2763671875
251/400 -- loss = 8302.5654296875
261/400 -- loss = 8124.14208984375
271/400 -- loss = 7949.01123046875
281/400 -- loss = 7777.11083984375


 21%|██        | 4/19 [01:05<04:05, 16.35s/it]

400/400 -- loss = 5963.8193359375
1/400 -- loss = 40192.0
11/400 -- loss = 39406.67578125
21/400 -- loss = 38629.95703125
31/400 -- loss = 37863.671875
41/400 -- loss = 37107.84375
51/400 -- loss = 36363.64453125
61/400 -- loss = 35630.54296875
71/400 -- loss = 34908.859375
81/400 -- loss = 34199.24609375
91/400 -- loss = 33501.6875
101/400 -- loss = 32815.9453125
111/400 -- loss = 32141.7109375
121/400 -- loss = 31478.888671875
131/400 -- loss = 30827.306640625
141/400 -- loss = 30186.76953125
151/400 -- loss = 29557.216796875
161/400 -- loss = 28938.427734375
171/400 -- loss = 28330.326171875
181/400 -- loss = 27732.734375
191/400 -- loss = 27145.537109375
201/400 -- loss = 26568.609375
211/400 -- loss = 26001.83203125
221/400 -- loss = 25445.056640625
231/400 -- loss = 24898.189453125
241/400 -- loss = 24361.056640625
251/400 -- loss = 23833.568359375
261/400 -- loss = 23315.572265625
271/400 -- loss = 22807.0
281/400 -- loss = 22307.66796875
291/400 -- loss = 21817.515625
301/400 -

 26%|██▋       | 5/19 [01:21<03:47, 16.23s/it]

400/400 -- loss = 17035.833984375
1/400 -- loss = 38300.0
11/400 -- loss = 37552.55859375
21/400 -- loss = 36813.68359375
31/400 -- loss = 36083.83984375
41/400 -- loss = 35363.6875
51/400 -- loss = 34654.46875
61/400 -- loss = 33956.6953125
71/400 -- loss = 33270.35546875
81/400 -- loss = 32595.283203125
91/400 -- loss = 31931.455078125
101/400 -- loss = 31278.8828125
111/400 -- loss = 30637.443359375
121/400 -- loss = 30006.98828125
131/400 -- loss = 29387.373046875
141/400 -- loss = 28778.45703125
151/400 -- loss = 28180.044921875
161/400 -- loss = 27592.037109375
171/400 -- loss = 27014.3203125
181/400 -- loss = 26446.765625
191/400 -- loss = 25889.23828125
201/400 -- loss = 25341.583984375
211/400 -- loss = 24803.75390625
221/400 -- loss = 24275.560546875
231/400 -- loss = 23756.912109375
241/400 -- loss = 23247.669921875
251/400 -- loss = 22747.75390625
261/400 -- loss = 22257.01171875
271/400 -- loss = 21775.32421875
281/400 -- loss = 21302.609375
291/400 -- loss = 20838.7304687

 32%|███▏      | 6/19 [01:37<03:29, 16.13s/it]

400/400 -- loss = 16323.8994140625
1/400 -- loss = 81021.0
11/400 -- loss = 79434.4765625
21/400 -- loss = 77865.4375
31/400 -- loss = 76317.640625
41/400 -- loss = 74791.46875
51/400 -- loss = 73288.5390625
61/400 -- loss = 71809.6328125
71/400 -- loss = 70354.828125
81/400 -- loss = 68924.015625
91/400 -- loss = 67516.9453125
101/400 -- loss = 66133.3125
111/400 -- loss = 64772.83984375
121/400 -- loss = 63435.16796875
131/400 -- loss = 62120.1328125
141/400 -- loss = 60827.42578125
151/400 -- loss = 59556.78125
161/400 -- loss = 58307.94921875
171/400 -- loss = 57080.640625
181/400 -- loss = 55874.60546875
191/400 -- loss = 54689.63671875
201/400 -- loss = 53525.4140625
211/400 -- loss = 52381.703125
221/400 -- loss = 51258.30859375
231/400 -- loss = 50154.90234375
241/400 -- loss = 49071.28125
251/400 -- loss = 48007.17578125
261/400 -- loss = 46962.3359375
271/400 -- loss = 45936.58984375
281/400 -- loss = 44929.6015625
291/400 -- loss = 43941.203125
301/400 -- loss = 42971.113281

 37%|███▋      | 7/19 [01:53<03:12, 16.05s/it]

400/400 -- loss = 34306.4140625
1/400 -- loss = 35585.0
11/400 -- loss = 34902.06640625
21/400 -- loss = 34224.30078125
31/400 -- loss = 33554.2578125
41/400 -- loss = 32896.4921875
51/400 -- loss = 32251.2890625
61/400 -- loss = 31617.974609375
71/400 -- loss = 30997.333984375
81/400 -- loss = 30388.994140625
91/400 -- loss = 29792.697265625
101/400 -- loss = 29208.15234375
111/400 -- loss = 28635.181640625
121/400 -- loss = 28073.529296875
131/400 -- loss = 27522.986328125
141/400 -- loss = 26983.40234375
151/400 -- loss = 26454.556640625
161/400 -- loss = 25936.212890625
171/400 -- loss = 25428.267578125
181/400 -- loss = 24930.51953125
191/400 -- loss = 24442.771484375
201/400 -- loss = 23964.84765625
211/400 -- loss = 23496.623046875
221/400 -- loss = 23037.916015625
231/400 -- loss = 22588.564453125
241/400 -- loss = 22148.431640625
251/400 -- loss = 21717.3671875
261/400 -- loss = 21295.181640625
271/400 -- loss = 20881.76171875
281/400 -- loss = 20476.95703125
291/400 -- loss =

 42%|████▏     | 8/19 [02:09<02:55, 15.98s/it]

400/400 -- loss = 16272.185546875
1/400 -- loss = 24505.0
11/400 -- loss = 24029.681640625
21/400 -- loss = 23558.556640625
31/400 -- loss = 23093.861328125
41/400 -- loss = 22635.80078125
51/400 -- loss = 22184.384765625
61/400 -- loss = 21738.974609375
71/400 -- loss = 21300.6328125
81/400 -- loss = 20870.18359375
91/400 -- loss = 20447.375
101/400 -- loss = 20031.90234375
111/400 -- loss = 19623.646484375
121/400 -- loss = 19222.498046875
131/400 -- loss = 18828.26953125
141/400 -- loss = 18440.931640625
151/400 -- loss = 18060.326171875
161/400 -- loss = 17686.423828125
171/400 -- loss = 17319.107421875
181/400 -- loss = 16958.275390625
191/400 -- loss = 16603.830078125
201/400 -- loss = 16255.7041015625
211/400 -- loss = 15913.8369140625
221/400 -- loss = 15578.11328125
231/400 -- loss = 15248.4794921875
241/400 -- loss = 14924.8125
251/400 -- loss = 14607.06640625
261/400 -- loss = 14295.16015625
271/400 -- loss = 13989.0146484375
281/400 -- loss = 13688.5595703125
291/400 -- los

 47%|████▋     | 9/19 [02:24<02:37, 15.70s/it]

400/400 -- loss = 10522.796875
1/400 -- loss = 15275.0
11/400 -- loss = 14989.8828125
21/400 -- loss = 14707.8310546875
31/400 -- loss = 14431.5654296875
41/400 -- loss = 14162.01953125
51/400 -- loss = 13899.4287109375
61/400 -- loss = 13642.23046875
71/400 -- loss = 13390.5517578125
81/400 -- loss = 13145.55078125
91/400 -- loss = 12907.3203125
101/400 -- loss = 12675.634765625
111/400 -- loss = 12450.3037109375
121/400 -- loss = 12231.0703125
131/400 -- loss = 12017.7294921875
141/400 -- loss = 11810.1591796875
151/400 -- loss = 11608.1904296875
161/400 -- loss = 11411.662109375
171/400 -- loss = 11220.4658203125
181/400 -- loss = 11034.439453125
191/400 -- loss = 10853.4794921875
201/400 -- loss = 10677.466796875
211/400 -- loss = 10506.2841796875
221/400 -- loss = 10339.76171875
231/400 -- loss = 10177.876953125
241/400 -- loss = 10020.435546875
251/400 -- loss = 9867.384765625
261/400 -- loss = 9718.603515625
271/400 -- loss = 9573.9912109375
281/400 -- loss = 9433.447265625
291/

 53%|█████▎    | 10/19 [02:39<02:20, 15.57s/it]

400/400 -- loss = 8040.50732421875
1/400 -- loss = 33248.0
11/400 -- loss = 32601.6953125
21/400 -- loss = 31962.57421875
31/400 -- loss = 31330.5625
41/400 -- loss = 30707.296875
51/400 -- loss = 30093.484375
61/400 -- loss = 29489.810546875
71/400 -- loss = 28896.681640625
81/400 -- loss = 28314.072265625
91/400 -- loss = 27741.67578125
101/400 -- loss = 27179.318359375
111/400 -- loss = 26626.888671875
121/400 -- loss = 26084.1640625
131/400 -- loss = 25551.01953125
141/400 -- loss = 25027.32421875
151/400 -- loss = 24512.990234375
161/400 -- loss = 24007.861328125
171/400 -- loss = 23511.794921875
181/400 -- loss = 23024.685546875
191/400 -- loss = 22546.416015625
201/400 -- loss = 22076.83984375
211/400 -- loss = 21615.8671875
221/400 -- loss = 21163.400390625
231/400 -- loss = 20719.28515625
241/400 -- loss = 20283.427734375
251/400 -- loss = 19855.693359375
261/400 -- loss = 19436.009765625
271/400 -- loss = 19024.216796875
281/400 -- loss = 18620.283203125
291/400 -- loss = 182

 58%|█████▊    | 11/19 [02:54<02:03, 15.48s/it]

400/400 -- loss = 14375.1494140625
1/400 -- loss = 7937.0
11/400 -- loss = 7786.2841796875
21/400 -- loss = 7638.3623046875
31/400 -- loss = 7493.52734375
41/400 -- loss = 7351.6611328125
51/400 -- loss = 7212.98583984375
61/400 -- loss = 7077.28466796875
71/400 -- loss = 6945.25537109375
81/400 -- loss = 6816.7578125
91/400 -- loss = 6690.86083984375
101/400 -- loss = 6567.2666015625
111/400 -- loss = 6447.2568359375
121/400 -- loss = 6330.8994140625
131/400 -- loss = 6217.88671875
141/400 -- loss = 6108.10546875
151/400 -- loss = 6001.45556640625
161/400 -- loss = 5897.8173828125
171/400 -- loss = 5797.11962890625
181/400 -- loss = 5699.322265625
191/400 -- loss = 5604.28369140625
201/400 -- loss = 5512.021484375
211/400 -- loss = 5422.41455078125
221/400 -- loss = 5335.443359375
231/400 -- loss = 5250.99267578125
241/400 -- loss = 5169.03369140625
251/400 -- loss = 5089.49853515625
261/400 -- loss = 5012.3447265625
271/400 -- loss = 4937.46630859375
281/400 -- loss = 4864.869140625


 63%|██████▎   | 12/19 [03:10<01:47, 15.42s/it]

400/400 -- loss = 4154.5087890625
1/400 -- loss = 45289.0
11/400 -- loss = 44403.2890625
21/400 -- loss = 43527.87890625
31/400 -- loss = 42664.109375
41/400 -- loss = 41812.48046875
51/400 -- loss = 40973.28515625
61/400 -- loss = 40147.33984375
71/400 -- loss = 39333.45703125
81/400 -- loss = 38532.5390625
91/400 -- loss = 37745.01171875
101/400 -- loss = 36970.62109375
111/400 -- loss = 36209.1015625
121/400 -- loss = 35460.30078125
131/400 -- loss = 34723.9921875
141/400 -- loss = 34000.02734375
151/400 -- loss = 33288.27734375
161/400 -- loss = 32588.611328125
171/400 -- loss = 31900.8515625
181/400 -- loss = 31224.828125
191/400 -- loss = 30560.50390625
201/400 -- loss = 29907.625
211/400 -- loss = 29266.138671875
221/400 -- loss = 28635.87109375
231/400 -- loss = 28016.6796875
241/400 -- loss = 27408.46484375
251/400 -- loss = 26811.052734375
261/400 -- loss = 26224.341796875
271/400 -- loss = 25648.171875
281/400 -- loss = 25082.46875
291/400 -- loss = 24527.017578125
301/400 -

 68%|██████▊   | 13/19 [03:25<01:32, 15.37s/it]

400/400 -- loss = 19105.4375
1/400 -- loss = 15711.0
11/400 -- loss = 15406.310546875
21/400 -- loss = 15105.2412109375
31/400 -- loss = 14809.1630859375
41/400 -- loss = 14517.5703125
51/400 -- loss = 14229.0966796875
61/400 -- loss = 13944.015625
71/400 -- loss = 13662.837890625
81/400 -- loss = 13386.6708984375
91/400 -- loss = 13115.646484375
101/400 -- loss = 12849.548828125
111/400 -- loss = 12588.216796875
121/400 -- loss = 12331.48046875
131/400 -- loss = 12079.298828125
141/400 -- loss = 11831.5390625
151/400 -- loss = 11588.16015625
161/400 -- loss = 11349.1123046875
171/400 -- loss = 11114.30859375
181/400 -- loss = 10883.685546875
191/400 -- loss = 10657.2001953125
201/400 -- loss = 10434.783203125
211/400 -- loss = 10216.4345703125
221/400 -- loss = 10001.9912109375
231/400 -- loss = 9791.505859375
241/400 -- loss = 9584.8779296875
251/400 -- loss = 9382.0625
261/400 -- loss = 9182.978515625
271/400 -- loss = 8987.6572265625
281/400 -- loss = 8795.9326171875
291/400 -- los

 74%|███████▎  | 14/19 [03:40<01:16, 15.36s/it]

400/400 -- loss = 6777.83447265625
1/400 -- loss = 158574.0
11/400 -- loss = 155506.9375
21/400 -- loss = 152474.625
31/400 -- loss = 149481.671875
41/400 -- loss = 146531.6875
51/400 -- loss = 143626.09375
61/400 -- loss = 140765.296875
71/400 -- loss = 137949.25
81/400 -- loss = 135177.71875
91/400 -- loss = 132450.34375
101/400 -- loss = 129766.734375
111/400 -- loss = 127126.453125
121/400 -- loss = 124529.09375
131/400 -- loss = 121974.1953125
141/400 -- loss = 119461.296875
151/400 -- loss = 116989.96875
161/400 -- loss = 114559.7734375
171/400 -- loss = 112170.2578125
181/400 -- loss = 109820.984375
191/400 -- loss = 107511.4921875
201/400 -- loss = 105241.3515625
211/400 -- loss = 103010.1484375
221/400 -- loss = 100817.390625
231/400 -- loss = 98662.6875
241/400 -- loss = 96545.5703125
251/400 -- loss = 94465.625
261/400 -- loss = 92422.4140625
271/400 -- loss = 90415.4921875
281/400 -- loss = 88444.4296875
291/400 -- loss = 86508.828125
301/400 -- loss = 84608.2421875
311/400

 79%|███████▉  | 15/19 [03:55<01:01, 15.30s/it]

400/400 -- loss = 67592.03125
1/400 -- loss = 174838.0
11/400 -- loss = 171481.15625
21/400 -- loss = 168162.234375
31/400 -- loss = 164886.109375
41/400 -- loss = 161656.546875
51/400 -- loss = 158474.953125
61/400 -- loss = 155341.921875
71/400 -- loss = 152257.484375
81/400 -- loss = 149221.390625
91/400 -- loss = 146233.28125
101/400 -- loss = 143292.75
111/400 -- loss = 140399.359375
121/400 -- loss = 137552.625
131/400 -- loss = 134752.046875
141/400 -- loss = 131997.234375
151/400 -- loss = 129287.65625
161/400 -- loss = 126622.890625
171/400 -- loss = 124002.4296875
181/400 -- loss = 121425.8125
191/400 -- loss = 118892.5703125
201/400 -- loss = 116402.2421875
211/400 -- loss = 113954.328125
221/400 -- loss = 111548.390625
231/400 -- loss = 109183.96875
241/400 -- loss = 106860.5703125
251/400 -- loss = 104577.7421875
261/400 -- loss = 102335.015625
271/400 -- loss = 100131.9453125
281/400 -- loss = 97968.0625
291/400 -- loss = 95842.921875
301/400 -- loss = 93756.03125
311/400

 84%|████████▍ | 16/19 [04:11<00:45, 15.29s/it]

400/400 -- loss = 75064.3984375
1/400 -- loss = 70780.0
11/400 -- loss = 69391.8984375
21/400 -- loss = 68020.4375
31/400 -- loss = 66667.1875
41/400 -- loss = 65331.9375
51/400 -- loss = 64015.84375
61/400 -- loss = 62719.85546875
71/400 -- loss = 61444.25
81/400 -- loss = 60188.9609375
91/400 -- loss = 58953.87109375
101/400 -- loss = 57738.76171875
111/400 -- loss = 56543.34765625
121/400 -- loss = 55367.41796875
131/400 -- loss = 54210.78125
141/400 -- loss = 53073.16796875
151/400 -- loss = 51954.37890625
161/400 -- loss = 50854.2734375
171/400 -- loss = 49772.578125
181/400 -- loss = 48709.12890625
191/400 -- loss = 47663.70703125
201/400 -- loss = 46636.11328125
211/400 -- loss = 45626.15234375
221/400 -- loss = 44633.60546875
231/400 -- loss = 43658.30078125
241/400 -- loss = 42700.0390625
251/400 -- loss = 41758.62109375
261/400 -- loss = 40833.8203125
271/400 -- loss = 39925.45703125
281/400 -- loss = 39033.35546875
291/400 -- loss = 38157.3046875
301/400 -- loss = 37297.1132

 89%|████████▉ | 17/19 [04:26<00:30, 15.24s/it]

400/400 -- loss = 29596.275390625
1/400 -- loss = 32409.0
11/400 -- loss = 31774.984375
21/400 -- loss = 31148.701171875
31/400 -- loss = 30530.177734375
41/400 -- loss = 29920.939453125
51/400 -- loss = 29321.287109375
61/400 -- loss = 28730.58984375
71/400 -- loss = 28148.91796875
81/400 -- loss = 27576.05078125
91/400 -- loss = 27012.7265625
101/400 -- loss = 26459.115234375
111/400 -- loss = 25915.041015625
121/400 -- loss = 25380.30078125
131/400 -- loss = 24854.681640625
141/400 -- loss = 24338.087890625
151/400 -- loss = 23830.412109375
161/400 -- loss = 23331.546875
171/400 -- loss = 22841.337890625
181/400 -- loss = 22359.71484375
191/400 -- loss = 21886.53125
201/400 -- loss = 21421.689453125
211/400 -- loss = 20965.1171875
221/400 -- loss = 20516.705078125
231/400 -- loss = 20076.326171875
241/400 -- loss = 19643.892578125
251/400 -- loss = 19219.296875
261/400 -- loss = 18802.45703125
271/400 -- loss = 18393.271484375
281/400 -- loss = 17991.60546875
291/400 -- loss = 17597

 95%|█████████▍| 18/19 [04:41<00:15, 15.30s/it]

400/400 -- loss = 13756.6396484375
1/400 -- loss = 7108.0
11/400 -- loss = 6988.078125
21/400 -- loss = 6869.2529296875
31/400 -- loss = 6753.07177734375
41/400 -- loss = 6641.2666015625
51/400 -- loss = 6534.30419921875
61/400 -- loss = 6432.03515625
71/400 -- loss = 6333.37548828125
81/400 -- loss = 6238.728515625
91/400 -- loss = 6148.185546875
101/400 -- loss = 6061.544921875
111/400 -- loss = 5978.66552734375
121/400 -- loss = 5899.48388671875
131/400 -- loss = 5823.98388671875
141/400 -- loss = 5751.904296875
151/400 -- loss = 5683.04248046875
161/400 -- loss = 5617.17431640625
171/400 -- loss = 5554.21923828125
181/400 -- loss = 5494.00390625
191/400 -- loss = 5436.4013671875
201/400 -- loss = 5381.263671875
211/400 -- loss = 5328.50048828125
221/400 -- loss = 5278.01123046875
231/400 -- loss = 5229.65966796875
241/400 -- loss = 5183.36865234375
251/400 -- loss = 5139.0302734375
261/400 -- loss = 5096.5654296875
271/400 -- loss = 5055.88720703125
281/400 -- loss = 5016.896484375

100%|██████████| 19/19 [04:59<00:00, 15.76s/it]

400/400 -- loss = 4659.61669921875





# Evaluation

In [13]:
def Mean_IOU(y_true, y_pred):
    nb_classes = 21  # TODO: set this as a parameter
    ious = []
    for i in range(0, nb_classes):  # exclude last label (void)
        y_true_squeeze = tf.squeeze(y_true)
        y_pred_squeeze = tf.squeeze(y_pred)
        true_labels = tf.equal(y_true_squeeze, i)
        pred_labels = tf.equal(y_pred_squeeze, i)
        inter = tf.cast(true_labels & pred_labels, tf.int32)
        union = tf.cast(true_labels | pred_labels, tf.int32)

        iou = tf.reduce_sum(inter) / tf.reduce_sum(union)
        # returns average IoU of the same objects
        ious.append(iou)

    ious = tf.stack(ious)
    legal_labels = ~tf.math.is_nan(ious)
    ious = tf.gather(ious, indices=tf.where(legal_labels))
    return tf.reduce_mean(ious)


def custom_IOU(y_true, y_pred, class_id):
    y_true_squeeze = tf.squeeze(y_true)
    y_pred_squeeze = tf.squeeze(y_pred)
    classes = [0, class_id] # Only check in background and given class

    y_true_squeeze = tf.where(y_true_squeeze != class_id, 0, y_true_squeeze)

    ious = []
    for i in classes:
        true_labels = tf.equal(y_true_squeeze, i)
        pred_labels = tf.equal(y_pred_squeeze, i)
        inter = tf.cast(true_labels & pred_labels, tf.int32)
        union = tf.cast(true_labels | pred_labels, tf.int32)

        iou = tf.reduce_sum(inter) / tf.reduce_sum(union)
        ious.append(iou)

    ious = tf.stack(ious)
    legal_labels = ~tf.math.is_nan(ious)
    ious = tf.gather(ious, indices=tf.where(legal_labels))
    return tf.reduce_mean(ious)


def evaluate_IOU(true_mask, standard_mask, superres_mask, img_size=(512, 512)):
    true_mask = tf.reshape(true_mask, (img_size[0] * img_size[1], 1))
    standard_mask = tf.reshape(standard_mask, (img_size[0] * img_size[1], 1))
    superres_mask = tf.reshape(superres_mask, (img_size[0] * img_size[1], 1))

    # standard_IOU = Mean_IOU(true_mask, standard_mask)
    # superres_IOU = Mean_IOU(true_mask, superres_mask)

    standard_IOU = custom_IOU(true_mask, standard_mask, class_id=CLASS_ID)
    superres_IOU = custom_IOU(true_mask, superres_mask, class_id=CLASS_ID)

    return standard_IOU.numpy(), superres_IOU.numpy()

def compare_results(image_dict, standard_dict, superres_dict, image_size=(512, 512)):
    standard_IOUs = []
    superres_IOUs = []

    for key in image_dict:
        true_mask_path = os.path.join(DATA_DIR, "VOCdevkit/VOC2012/SegmentationClassAug", f"{key}.png")
        true_mask = load_image(true_mask_path, image_size=image_size, normalize=False,
                               is_png=True, resize_method="nearest")

        standard_mask = standard_dict[key]
        superres_image = superres_dict[key]

        standard_IOU, superres_IOU = evaluate_IOU(true_mask, standard_mask, superres_image, img_size=image_size)
        standard_IOUs.append(standard_IOU)
        superres_IOUs.append(superres_IOU)
        print(f"IOUs for image {key} - Standard: {str(standard_IOU)}, Superres: {str(superres_IOU)}")

    return standard_IOUs, superres_IOUs

In [14]:
superres_masks_dict_th = {}

for key in superres_masks_dict:
    sample_th = tf.cast(tf.reduce_max(superres_masks_dict[key]), tf.float32) * 0.15
    th_mask = tf.where(superres_masks_dict[key] > sample_th, CLASS_ID, 0)
    superres_masks_dict_th[key] = th_mask

standard_IOUs, superres_IOUs = compare_results(images_dict, standard_masks_dict, superres_masks_dict_th, image_size=IMG_SIZE)

IOUs for image 2008_002152 - Standard: 0.9085044608441453, Superres: 0.8785154176785852
IOUs for image 2007_008815 - Standard: 0.8569716559605662, Superres: 0.8741334572472753
IOUs for image 2008_000345 - Standard: 0.9379868804976746, Superres: 0.9282654059331144
IOUs for image 2010_005421 - Standard: 0.8882681156496353, Superres: 0.8795964839128017
IOUs for image 2010_002531 - Standard: 0.9380899481960936, Superres: 0.9183538439434562
IOUs for image 2010_001351 - Standard: 0.9413018640887326, Superres: 0.9238459305879041
IOUs for image 2010_002025 - Standard: 0.9285347152054473, Superres: 0.9152976890914408
IOUs for image 2011_000661 - Standard: 0.8515296926500249, Superres: 0.7888941627223187
IOUs for image 2009_000080 - Standard: 0.9089036926950496, Superres: 0.8707403819896253
IOUs for image 2010_000724 - Standard: 0.7795417381062342, Superres: 0.7452689934024626
IOUs for image 2007_009346 - Standard: 0.918779741187651, Superres: 0.9062308822417484
IOUs for image 2010_001913 - Stan

In [15]:
np.mean(standard_IOUs)

0.8809881397904697

In [16]:
np.mean(superres_IOUs)

0.8583120755924544

# Tests

In [None]:
def plot_standard_superres(input_image, standard_mask, superres_mask):
    plt.figure(figsize=(18, 18))

    plt.subplot(1, 3, 1)
    plt.title("Input Image")
    plt.imshow(tf.keras.preprocessing.image.array_to_img(input_image))
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title("Sandard predicted Mask")
    plt.imshow(tf.keras.preprocessing.image.array_to_img(input_image))
    plt.imshow(tf.keras.preprocessing.image.array_to_img(standard_mask), alpha=0.5)
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title("Superresolution Mask")
    plt.imshow(tf.keras.preprocessing.image.array_to_img(input_image))
    plt.imshow(tf.keras.preprocessing.image.array_to_img(superres_mask), alpha=0.5)
    plt.axis('off')

    plt.show()


def plot_histogram(image):
    plt.figure(figsize=(18, 18))
    vals = image.flatten()
    b, bins, patches = plt.hist(vals, 255)
    plt.show()


def print_labels(masks):
    title = ["Standard Labels: ", "Superres Labels: "]
    for i in range(2):
        values, count = np.unique(masks[i], return_counts=True)
        print(title[i] + str(dict(zip(values, count))))

In [None]:
sample_key = random.choice(valid_filenames)
sample_image = images_dict[sample_key]
sample_standard = standard_masks_dict[sample_key]
sample_superres = superres_masks_dict_th[sample_key]

true_mask_path = os.path.join(DATA_DIR, "VOCdevkit/VOC2012/SegmentationClassAug", f"{sample_key}.png")
true_mask = load_image(true_mask_path, image_size=IMG_SIZE, normalize=False, is_png=True, resize_method="nearest")

plot_prediction([sample_image, true_mask, sample_standard], only_prediction=False, show_overlay=True)
print_labels([true_mask, sample_standard])

In [None]:
plot_standard_superres(sample_image, sample_standard, sample_superres)

In [None]:
superres_numpy = sample_superres
plot_histogram(superres_numpy)

In [None]:
sample_th = np.max(superres_numpy) * 0.15
th_mask = tf.where(sample_superres > sample_th, sample_superres, 0).numpy()
th_mask_class = tf.where(sample_superres > sample_th, CLASS_ID, 0).numpy()

In [None]:
plot_histogram(th_mask)

In [None]:
plot_standard_superres(sample_image, sample_standard, sample_superres)
plot_standard_superres(sample_image, sample_standard, th_mask_class)

In [None]:
evaluate_IOU(true_mask, sample_standard, sample_superres)

In [None]:
np.unique(true_mask)