In [None]:
import os
import gc
import random
import numpy as np
import seaborn as sns
from tqdm import tqdm
import tensorflow as tf
from model import DeeplabV3Plus
import tensorflow_addons as tfa
from matplotlib import pyplot as plt
from superresolution_scripts.optimizer import Optimizer
from superresolution_scripts.superresolution import Superresolution
from utils import *
from superresolution_scripts.superres_utils import *

SEED = 1234

np.random.seed(SEED)
tf.random.set_seed(SEED)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [None]:
IMG_SIZE = (512, 512)
FEATURE_SIZE = (128, 128)
NUM_AUG = 100
CLASS_ID = 8
NUM_SAMPLES = 500
MODE = "argmax"
MODEL_BACKBONE = "xception"
USE_VALIDATION = False
SAVE_SLICE_OUTPUT = False
SAVE_FINAL_SR_OUTPUT = True
TH_FACTOR = 0.2
ANGLE_MAX = 0.15
SHIFT_MAX = 80
BATCH_SIZE = 16

In [None]:
DATA_DIR = os.path.join(os.getcwd(), "data")
PASCAL_ROOT = os.path.join(DATA_DIR, "dataset_root", "VOCdevkit", "VOC2012")
IMGS_PATH = os.path.join(PASCAL_ROOT, "JPEGImages")

TEST_FOLDER_ROOT = os.path.join(os.getcwd(), "thesis_images_folder")
AUGMENTED_IMAGE_FOLDER = os.path.join(TEST_FOLDER_ROOT, f"augmented_images_{MODE}_{CLASS_ID}")

STANDARD_OUTPUT_ROOT = os.path.join(TEST_FOLDER_ROOT, "standard_output", "standard_output")
STANDARD_OUTPUT_DIR = os.path.join(
    STANDARD_OUTPUT_ROOT, f"{MODEL_BACKBONE}_{CLASS_ID}{'_validation' if USE_VALIDATION else ''}")


SUPERRES_ROOT = os.path.join(TEST_FOLDER_ROOT, "superres_root")
SUPERRES_OUTPUT_DIR = os.path.join(
    SUPERRES_ROOT, f"superres_output{'_validation' if USE_VALIDATION else ''}")

if not os.path.exists(AUGMENTED_IMAGE_FOLDER):
    os.makedirs(AUGMENTED_IMAGE_FOLDER)
if not os.path.exists(SUPERRES_OUTPUT_DIR):
    os.makedirs(SUPERRES_OUTPUT_DIR)

In [None]:
filename = "2007_000528"
image_path = os.path.join(IMGS_PATH, f"{filename}.jpg")

In [None]:
def create_augmented_copies(image, num_aug, angle_max, shift_max):
    batched_images = tf.tile(tf.expand_dims(image, axis=0), [
                             num_aug, 1, 1, 1])  # Size [num_aug, 512, 512, 3]
    angles = np.random.uniform(-angle_max, angle_max, num_aug)
    shifts = np.random.uniform(-shift_max, shift_max, (num_aug, 2))
    # First sample is not augmented
    angles[0] = 0
    shifts[0] = np.array([0, 0])
    angles = angles.astype("float32")
    shifts = shifts.astype("float32")

    rotated_images = tfa.image.rotate(
        batched_images, angles, interpolation="bilinear")
    translated_images = tfa.image.translate(
        rotated_images, shifts, interpolation="bilinear")

    return translated_images, angles, shifts

In [None]:
def compute_augmented_features(images_paths, model, dest_folder, filter_class_id, mode="slice", num_aug=100,
                               angle_max=0.15, shift_max=80, save_output=False, image_size=(512, 512)):

    for image_path in tqdm(images_paths):
        image_name = os.path.splitext(os.path.basename(image_path))[0]

        # Load image
        image = load_image(image_path, image_size=image_size, normalize=True)
        # Create augmented copies
        augmented_copies, angles, shifts = create_augmented_copies(image, num_aug=num_aug, angle_max=angle_max,
                                                                   shift_max=shift_max)

        # Create destination folder
        output_folder = os.path.join(dest_folder, mode, image_name)
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        class_masks = []
        max_masks = []

        predictions = model.predict(augmented_copies, batch_size=BATCH_SIZE)
        # Used to clear memory as it appears that there is a memory leak with something related to model.predict
        _ = gc.collect()

        for i, prediction in enumerate(predictions):
            if mode == "slice":
                # Get the slice corresponding to the class id
                class_mask = tf.gather(
                    prediction, filter_class_id, axis=-1)[..., tf.newaxis]

                # Get all the other slices and compute the max pixel-wise
                gather_indexes = np.delete(
                    np.arange(0, tf.shape(prediction)[-1], step=1), filter_class_id)
                max_mask = tf.reduce_max(
                    tf.gather(prediction, gather_indexes, axis=-1), axis=-1)[..., tf.newaxis]

                max_masks.append(max_mask)

            elif mode == "slice_var":
                # Get the slice corresponding to the class id
                class_mask = tf.gather(
                    prediction, filter_class_id, axis=-1)[..., tf.newaxis]

                global_max = tf.reduce_max(prediction)
                global_min = tf.reduce_min(prediction)

                class_mask = min_max_normalization(class_mask.numpy(), new_min=0.0, new_max=1.0, global_min=global_min,
                                                   global_max=global_max)

            else:
                class_mask = create_mask(prediction)
                # Set to 0 all predictions different from the given class
                class_mask = tf.where(
                    class_mask == filter_class_id, class_mask, 0)
                # Necessary for super-resolution operations
                class_mask = tf.cast(class_mask, tf.float32)
                class_mask = class_mask.numpy()

            class_masks.append(class_mask)

            if save_output == 0:
                tf.keras.utils.save_img(
                    f"{output_folder}/{i}_class.png", class_mask, scale=True)
                if mode == "slice":
                    tf.keras.utils.save_img(
                        f"{output_folder}/{i}_max.png", max_mask, scale=True)
    
    return class_masks, max_masks, angles, shifts

In [None]:
model = DeeplabV3Plus(
    input_shape=(512, 512, 3),
    classes=21,
    OS=16,
    last_activation=None,
    load_weights=True,
    backbone=MODEL_BACKBONE,
    alpha=1.).build_model(final_upsample=False)

In [None]:
class_masks, max_masks, angles, shifts = compute_augmented_features([image_path], model, mode=MODE, dest_folder=AUGMENTED_IMAGE_FOLDER, filter_class_id=CLASS_ID,
                            num_aug=NUM_AUG, angle_max=ANGLE_MAX, shift_max=SHIFT_MAX, save_output=True, image_size=IMG_SIZE)


In [None]:
coeff_dict = {
    "lambda_df": 1.0,
    "lambda_tv": 0.3,
    "lambda_L2": 0.7,
    "lambda_L1": 0,
}

optimizer_obj = Optimizer(optimizer="adam", learning_rate=1e-3, amsgrad=True, lr_scheduler=True, decay_steps=60, decay_rate=30)

superresolution_obj = Superresolution(**coeff_dict, num_iter=300, num_aug=100, optimizer=optimizer_obj, feature_size=FEATURE_SIZE)

In [None]:
# true_mask_path = os.path.join(
#     PASCAL_ROOT, "SegmentationClassAug", f"{filename}.png")
# true_mask = load_image(true_mask_path, image_size=IMG_SIZE, normalize=False,
#                         is_png=True, resize_method="nearest")

# standard_mask_path = os.path.join(
#     STANDARD_OUTPUT_DIR, f"{filename}.png")
# standard_mask = load_image(standard_mask_path, image_size=IMG_SIZE, normalize=False, is_png=True,
#                             resize_method="nearest")

max_masks = None

target_augmented_SR = compute_SR(superresolution_obj, class_masks, angles, shifts, filename, max_masks=max_masks, SR_type="aug", save_final_output=SAVE_FINAL_SR_OUTPUT,
                                    save_intermediate_output=SAVE_SLICE_OUTPUT, class_id=CLASS_ID, dest_folder=SUPERRES_OUTPUT_DIR, th_factor=TH_FACTOR)

target_max_SR = compute_SR(superresolution_obj, class_masks, angles, shifts, filename, max_masks=max_masks, SR_type="max", save_final_output=SAVE_FINAL_SR_OUTPUT,
                            save_intermediate_output=SAVE_SLICE_OUTPUT, class_id=CLASS_ID, dest_folder=SUPERRES_OUTPUT_DIR, th_factor=TH_FACTOR)

target_mean_SR = compute_SR(superresolution_obj, class_masks, angles, shifts, filename, max_masks=max_masks, SR_type="mean", save_final_output=SAVE_FINAL_SR_OUTPUT,
                            save_intermediate_output=SAVE_SLICE_OUTPUT, class_id=CLASS_ID, dest_folder=SUPERRES_OUTPUT_DIR, th_factor=TH_FACTOR)