In [None]:
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import tensorflow as tf
import tensorflow_addons as tfa
from model import DeeplabV3Plus
from superresolution_scripts.superresolution import Superresolution
from superresolution_scripts.optimizer import Optimizer
from utils import *
from superresolution_scripts.superres_utils import list_precomputed_data_paths, load_SR_data, normalize_coefficients, threshold_image, get_img_paths, filter_images_by_class
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
SEED = 1234

np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
IMG_SIZE = (512, 512)
FEATURE_SIZE = (128, 128)
NUM_AUG = 100
CLASS_ID = 16
NUM_SAMPLES = 5

MODE = "argmax"
MODEL_BACKBONE = "xception"
USE_VALIDATION = True
SAVE_SLICE_OUTPUT = False

DATA_DIR = os.path.join(os.getcwd(), "data")
PASCAL_ROOT = os.path.join(DATA_DIR, "dataset_root", "VOCdevkit", "VOC2012")
IMGS_PATH = os.path.join(PASCAL_ROOT, "JPEGImages")

SUPERRES_ROOT = os.path.join(DATA_DIR, "superres_root")
AUGMENTED_COPIES_ROOT = os.path.join(SUPERRES_ROOT, "augmented_copies")
PRECOMPUTED_OUTPUT_DIR = os.path.join(
    AUGMENTED_COPIES_ROOT, f"{MODEL_BACKBONE}_{MODE}_{NUM_AUG}{'_validation' if USE_VALIDATION else ''}")
STANDARD_OUTPUT_ROOT = os.path.join(SUPERRES_ROOT, "standard_output")
STANDARD_OUTPUT_DIR = os.path.join(
    STANDARD_OUTPUT_ROOT, f"{MODEL_BACKBONE}_{CLASS_ID}{'_validation' if USE_VALIDATION else ''}")
SUPERRES_OUTPUT_DIR = os.path.join(
    SUPERRES_ROOT, f"superres_output_{MODE}{'_validation' if USE_VALIDATION else ''}")

In [None]:
TEST_FOLDER = os.path.join(os.getcwd(), "test_folder")
test_image_path = os.path.join(IMGS_PATH, "2007_000528.jpg")
# test_image_path = os.path.join(IMGS_PATH, "2010_004327.jpg")


# image_list_path = os.path.join(DATA_DIR, "augmented_file_lists",
#                                 f"{'valaug' if USE_VALIDATION else 'trainaug'}.txt")
# image_paths = get_img_paths(
#     image_list_path, IMGS_PATH, is_png=False, sort=True)
# test_image_path = random.choice(image_paths)
# print(test_image_path)
image = load_image(test_image_path, image_size=IMG_SIZE, normalize=True)

In [None]:
angle_max = 0.15
shift_max = 80
num_aug = 100

In [None]:
batched_images = tf.tile(tf.expand_dims(image, axis=0), [num_aug, 1, 1, 1])
# angles = np.random.uniform(-angle_max, angle_max, num_aug)

angles = np.random.uniform(-angle_max, angle_max, num_aug)
shifts = np.random.uniform(-shift_max, shift_max, (num_aug, 2))
# First sample is not augmented
angles[0] = 0
shifts[0] = np.array([0, 0])
angles = angles.astype("float32")
shifts = shifts.astype("float32")

rotated_images = tfa.image.rotate(batched_images, angles, interpolation="bilinear")
augmented_images = tfa.image.translate(rotated_images, shifts, interpolation="bilinear")

In [None]:
for i, image in enumerate(augmented_images):
    tf.keras.utils.save_img(f"{TEST_FOLDER}/augmented_images/{i}_class.png", image, scale=True)

In [None]:
model = DeeplabV3Plus(
        input_shape=(512, 512, 3),
        classes=21,
        OS=16,
        last_activation=None,
        load_weights=True,
        backbone=MODEL_BACKBONE,
        alpha=1.).build_model(final_upsample=True)

predictions = model.predict(augmented_images)

In [None]:
segmentation_masks = np.array([create_mask(mask).numpy() for mask in predictions])
augmented_images = np.array([image.numpy() for image in augmented_images])

In [None]:
plt.imshow(tf.image.resize(segmentation_masks[0], IMG_SIZE), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
from matplotlib.colors import LinearSegmentedColormap, BoundaryNorm
from matplotlib.colorbar import ColorbarBase

cmap = plt.cm.jet
cmaplist = [cmap(i) for i in range(cmap.N)]
cmaplist[0] = (0, 0, 0, 1.0)
# create the new map
cmap = LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
# define the bins and normalize
bounds = np.linspace(0, 20, 21)
norm = BoundaryNorm(bounds, cmap.N)

In [None]:
fig = plt.figure(figsize=(15,8))
# fig = plt.figure()

for i in range(len(augmented_images)):
        plt.subplot(2, 4, i + 1)
        plt.imshow(augmented_images[i])
        plt.axis('off')

for k in range(len(segmentation_masks)):
        plt.subplot(2, 4, 5 + k)
        plt.imshow(augmented_images[k])
        plt.imshow(segmentation_masks[k], cmap=cmap, norm=norm, alpha=1, interpolation="nearest")
        plt.axis("off")

# ax2 = fig.add_axes([0.95, 0.1, 0.03, 0.8])
# cb = ColorbarBase(ax2, cmap=cmap, norm=norm,
#     spacing='proportional', ticks=bounds, boundaries=bounds, format='%1i')

plt.tight_layout()
plt.show()

In [None]:
filename_16 = "2008_002379"
original_image_path_16 = os.path.join(IMGS_PATH, f"{filename_16}.jpg")
gt_image_path_16 = os.path.join(PASCAL_ROOT, "SegmentationClassAug", f"{filename_16}.png")
standard_image_path_16 = os.path.join(STANDARD_OUTPUT_DIR, f"{filename_16}.png")
superres_image_path_16 = os.path.join(SUPERRES_OUTPUT_DIR, "aug_SR", f"{filename_16}_aug_SR.png")

filename_4 = "2008_001260"
original_image_path_4 = os.path.join(IMGS_PATH, f"{filename_4}.jpg")
gt_image_path_4 = os.path.join(PASCAL_ROOT, "SegmentationClassAug", f"{filename_4}.png")
standard_image_path_4 = os.path.join("/home/nicoloa97/DeepLabV3Plus-Augmented-SuperResolution/data/superres_root/standard_output/xception_4_validation", f"{filename_4}.png")
superres_image_path_4 = os.path.join(SUPERRES_OUTPUT_DIR, "aug_SR", f"{filename_4}_aug_SR.png")

original_image_16 = load_image(original_image_path_16, image_size=IMG_SIZE, normalize=True).numpy()
gt_image_16 = load_image(gt_image_path_16, image_size=IMG_SIZE, normalize=False, is_png=True, resize_method="nearest").numpy()
gt_image_16[gt_image_16 != CLASS_ID] = 0.0
standard_image_16 = load_image(standard_image_path_16, image_size=IMG_SIZE, normalize=False).numpy()
superres_image_16 = load_image(superres_image_path_16, image_size=IMG_SIZE, normalize=False).numpy()
superres_image_16[superres_image_16 == 255] = CLASS_ID

original_image_4 = load_image(original_image_path_4, image_size=IMG_SIZE, normalize=True).numpy()
gt_image_4 = load_image(gt_image_path_4, image_size=IMG_SIZE, normalize=False, is_png=True, resize_method="nearest").numpy()
gt_image_4[gt_image_4 != 4] = 0.0
standard_image_4 = load_image(standard_image_path_4, image_size=IMG_SIZE, normalize=False).numpy()
superres_image_4 = load_image(superres_image_path_4, image_size=IMG_SIZE, normalize=False).numpy()
superres_image_4[superres_image_4 == 255] = CLASS_ID

In [None]:
fig = plt.figure(figsize=(15,8))
plt.subplot(2, 4, 1)
plt.imshow(original_image_16)
plt.axis("off")
plt.subplot(2, 4, 2)
plt.imshow(gt_image_16, cmap="gray")
plt.axis("off")
plt.subplot(2, 4, 3)
plt.imshow(standard_image_16)
plt.axis("off")
plt.subplot(2, 4, 4)
plt.imshow(superres_image_16)
plt.axis("off")

plt.subplot(2, 4, 5)
plt.imshow(original_image_4)
plt.axis("off")
plt.subplot(2, 4, 6)
plt.imshow(gt_image_4, cmap="gray")
plt.axis("off")
plt.subplot(2, 4, 7)
plt.imshow(standard_image_4)
plt.axis("off")
plt.subplot(2, 4, 8)
plt.imshow(superres_image_4)
plt.axis("off")

# plt.imshow(segmentation_masks[k], cmap=cmap, norm=norm, alpha=1, interpolation="nearest")
plt.tight_layout()