# Image Segmentation


In [None]:
pip install tensorflow-addons


In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.models as models
import tensorflow.keras.backend as K
import tensorflow.keras.layers as layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, CSVLogger
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.losses import CategoricalCrossentropy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import requests
import zipfile
import pickle
import time
import random
import shutil
import glob
import tarfile
from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from functools import partial

In [None]:
available_gpus = tf.config.experimental.list_physical_devices('GPU')
if available_gpus:
    try:
        for gpu in available_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except Exception as e:
        print("Не вдалося налаштувати пам'ять для GPU:", e)
def set_random_seed(seed_value):
    try:
        np.random.seed(seed_value)
    except Exception as e:
        print("Numpy не вдалося імпортувати.", e)
    try:
        tf.random.set_seed(seed_value)
    except Exception as e:
        print("TensorFlow не вдалося імпортувати.", e)
    try:
        random.seed(seed_value)
    except Exception as e:
        print("Модуль random не вдалося імпортувати", e)

random_seed = 4321
set_random_seed(random_seed)
print("Поточна версія TensorFlow: {}".format(tf.__version__))


In [None]:
if not os.path.exists(os.path.join('data','VOCtrainval_11-May-2012.tar')):
    url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
    r = requests.get(url)
    if not os.path.exists('data'):
        os.mkdir('data')
    with open(os.path.join('data','VOCtrainval_11-May-2012.tar'), 'wb') as f:
        f.write(r.content)
else:
    print("Файл вже існує.")

archive_path = '/content/data/VOCtrainval_11-May-2012.tar'
extracted_dir = '/content/data/VOCtrainval_11-May-2012'
if os.path.exists(archive_path):
    if not os.path.exists(extracted_dir):
        with tarfile.open(archive_path, 'r') as tar:
            tar.extractall(extracted_dir)
        print("Архів успішно розархівовано.")
    else:
        print("Розархівована директорія вже існує.")
else:
    print("Файл архіву не знайдено.")


##Завантаження даних
##Аугментація зображення

In [None]:
def get_subset_filenames(orig_dir, seg_dir, subset_dir, subset):
    if subset.startswith('train'):
        ser = pd.read_csv(os.path.join(subset_dir, "train.txt"), header=None).squeeze().tolist()
    elif subset.startswith('val') or subset.startswith('test'):
        random.seed(random_seed)
        ser = pd.read_csv(os.path.join(subset_dir, "val.txt"), header=None).squeeze().tolist()
        random.shuffle(ser)
        if subset.startswith('val'):
            ser = ser[:len(ser)//2]
        else:
            ser = ser[len(ser)//2:]
    else:
        raise NotImplementedError("Subset={} не розпізнана".format(subset))

    orig_filenames = [os.path.join(orig_dir, f'{f}.jpg') for f in ser]
    seg_filenames = [os.path.join(seg_dir, f'{f}.png') for f in ser]

    for o, s in zip(orig_filenames, seg_filenames):
        yield o, s

def image_loader(image):
    img = np.array(Image.open(image))
    img[img == 255] = 0
    return img

def adjust_shape(x, y, target_size):
    x.set_shape((target_size[0], target_size[1], 3))
    y.set_shape((target_size[0], target_size[1], 1))
    return x, y

def generate_tf_dataset(
    subset_filename_gen_func, batch_size, epochs,
    input_size=(256, 256),  resize_to_before_crop=None,
    augmentation=False
):

    filename_ds = tf.data.Dataset.from_generator(
        subset_filename_gen_func, output_types=(tf.string, tf.string)
    )

    image_ds = filename_ds.map(lambda x, y: (
        tf.image.decode_jpeg(tf.io.read_file(x)),
        tf.numpy_function(image_loader, [y], [tf.uint8])
    )).cache()
    image_ds = image_ds.map(lambda x, y: (tf.cast(x, 'float32')/255.0, y))

    def crop_or_resize_randomly(x, y):
        rand = tf.random.uniform([], 0.0, 1.0)
        def random_crop(x, y):
            x = tf.image.resize(x, resize_to_before_crop, method='bilinear')
            y = tf.cast(tf.image.resize(tf.transpose(y, [1, 2, 0]), resize_to_before_crop, method='nearest'), 'float32')

            offset_h = tf.random.uniform([], 0, x.shape[0]-input_size[0], dtype='int32')
            offset_w = tf.random.uniform([], 0, x.shape[1]-input_size[1], dtype='int32')

            x = tf.image.crop_to_bounding_box(x, offset_h, offset_w, input_size[0], input_size[1])
            y = tf.image.crop_to_bounding_box(y, offset_h, offset_w, input_size[0], input_size[1])
            return x, y

        def resize(x, y):
            x = tf.image.resize(x, input_size, method='bilinear')
            y = tf.cast(tf.image.resize(tf.transpose(y, [1, 2, 0]), input_size, method='nearest'), 'float32')
            return x, y

        if augmentation and (input_size[0] < resize_to_before_crop[0] or input_size[1] < resize_to_before_crop[1]):
            x, y = tf.cond(
                rand < 0.5,
                lambda: random_crop(x, y),
                lambda: resize(x, y)
            )
        else:
            x, y = resize(x, y)

        return x, y

    image_ds = image_ds.map(lambda x, y: crop_or_resize_randomly(x, y))
    image_ds = image_ds.map(lambda x, y: adjust_shape(x, y, target_size=input_size))

    if augmentation:
        image_ds = image_ds.map(lambda x, y: (tf.image.random_hue(x, 0.1), y))
        image_ds = image_ds.map(lambda x, y: (tf.image.random_brightness(x, 0.1), y))
        image_ds = image_ds.map(lambda x, y: (tf.image.random_contrast(x, 0.8, 1.2), y))

    image_ds = image_ds.batch(batch_size).repeat(epochs)
    image_ds = image_ds.prefetch(tf.data.experimental.AUTOTUNE)
    image_ds = image_ds.map(lambda x, y: (x, tf.squeeze(y)))
    return image_ds

orig_dir = os.path.join('data', 'VOCtrainval_11-May-2012', 'VOCdevkit', 'VOC2012', 'JPEGImages')
seg_dir = os.path.join('data', 'VOCtrainval_11-May-2012', 'VOCdevkit', 'VOC2012', 'SegmentationClass')
subset_dir = os.path.join('data', 'VOCtrainval_11-May-2012', 'VOCdevkit', 'VOC2012', 'ImageSets', "Segmentation")


partial_subset_fn = partial(get_subset_filenames, orig_dir=orig_dir, seg_dir=seg_dir, subset_dir=subset_dir)
train_subset_fn = partial(partial_subset_fn, subset='train')
val_subset_fn = partial(partial_subset_fn, subset='val')
test_subset_fn = partial(partial_subset_fn, subset='test')

tr_image_ds = generate_tf_dataset(train_subset_fn, 2, 1)
val_image_ds = generate_tf_dataset(val_subset_fn, 2, 1)

tr_image_ds_no_aug = generate_tf_dataset(train_subset_fn, batch_size=2, epochs=1, augmentation=False)
tr_image_ds_aug = generate_tf_dataset(train_subset_fn, batch_size=2, epochs=1, augmentation=True, resize_to_before_crop=(300, 300))

orig_images, orig_targets = next(iter(tr_image_ds_no_aug.take(-2)))
aug_images, aug_targets = next(iter(tr_image_ds_aug.take(-2)))

def display_images(original, augmented):
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for i in range(2):
        axes[i, 0].imshow(original[i])
        axes[i, 0].set_title('Оригінальне зображення')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(augmented[i])
        axes[i, 1].set_title('Аугментоване зображення')
        axes[i, 1].axis('off')
    plt.show()

display_images(orig_images.numpy()[-2:],  aug_images.numpy()[-2:])




##Перетворення у RGB зображення
##Вивід результату завантажених даних





In [None]:
def rgb_image_from_pallette(image):
    pallette = annot_image.getpalette()
    pallette = np.array(pallette).reshape(-1,3)
    if isinstance(image, PngImageFile):
        h, w = image.height, image.width
        image = np.array(image).reshape(-1)
    elif isinstance(image, np.ndarray):
        h, w = image.shape[0], image.shape[1]
        image = image.reshape(-1)
    rgb_image = np.zeros(shape=(image.shape[0],3))
    rgb_image[(image != 0),:] = pallette[image[(image != 0)], :]
    rgb_image = rgb_image.reshape(h, w, 3)
    return rgb_image

orig_image_path = os.path.join('data', 'VOCtrainval_11-May-2012', 'VOCdevkit', 'VOC2012', 'JPEGImages', '2011_002200.jpg')
annot_image_path = os.path.join('data', 'VOCtrainval_11-May-2012', 'VOCdevkit', 'VOC2012', 'SegmentationClass', '2011_002200.png')
orig_image = Image.open(orig_image_path)
annot_image = Image.open(annot_image_path)
tr_image_ds = generate_tf_dataset(
    train_subset_fn, 1, 1, augmentation=True, resize_to_before_crop=(384,384))

n=10
def plot_data(image_ds, n):
    plt.subplots(n//2, 4, figsize=(12,12))
    for i, (img, y_true) in enumerate(tr_image_ds.take(n)):
        y_true = y_true.numpy().astype('int')
        y_rgb_true = rgb_image_from_pallette(y_true)

        plt.subplot(n//2,4,i*2+1)
        plt.imshow((img[0,:, :, :].numpy()*255.0).astype('uint8'))
        plt.axis('off')
        plt.subplot(n//2,4, i*2+2)
        plt.imshow(y_rgb_true.astype('uint8'))
        plt.axis('off')

plot_data(tr_image_ds, n)

##Завантажує зображення та їх сегментаційні мітки
##Візуалізація вхідних даних


In [None]:
train_fn_gen = get_subset_filenames(orig_dir, seg_dir, subset_dir, 'training')
visual_images = {}

object_classes = {
    0: "Background", 1: "Aeroplane", 2 : "Bicycle", 3: "Bird", 4: "Boat", 5: "Bottle",
    6: "Bus", 7: "Car", 8: "Cat", 9: "Chair", 10: "Cow", 11: "Dining table",
    12: "Dog", 13: "Horse", 14: "Motorbike", 15: "Person", 16: "Potted plant", 17: "Sheep",
    18: "Sofa", 19: "Train", 20: "TV/Monitor", 255: "Boundaries / Unknown object"
}

for input_path, target_path in train_fn_gen:
    input_image = np.array(Image.open(input_path))
    target_image = np.array(Image.open(target_path))
    major_class = np.max(target_image[(target_image != 0) & (target_image != 255)])

    visual_images[major_class] = (input_image, target_image)
    if len(visual_images) >= 20:
        break

plt.subplots(5, 8, figsize=(16, 12))
for i, (k, v) in enumerate(sorted(visual_images.items(), key=lambda x: x[0])):
    input_image, target_image = v
    target_image = rgb_image_from_pallette(target_image)
    plt.subplot(5, 8, (i * 2) + 1)
    plt.imshow(input_image)
    plt.title("{}".format(object_classes[i + 1]))
    plt.axis('off')
    plt.subplot(5, 8, (i * 2) + 2)
    plt.imshow(target_image)
    plt.axis('off')


## Побудова моделі
*   ResNet50
*   ASPP



In [8]:
K.clear_session()
num_classes = 21

def create_deeplabv3(target_size):
    def build_block_level3(inp, filters, kernel_size, rate, block_id, convlayer_id, activation=True):
        conv_out = layers.Conv2D(
            filters, kernel_size, dilation_rate=rate, padding='same', name='conv5_block{}_{}_conv'.format(block_id, convlayer_id)
        )(inp)
        bn_out = layers.BatchNormalization(
            name='conv5_block{}_{}_bn'.format(block_id, convlayer_id)
        )(conv_out)

        if activation:
            relu_out = layers.Activation(
                'relu', name='conv5_block{}_{}_relu'.format(block_id, convlayer_id)
            )(bn_out)
            return relu_out
        else:
            return bn_out

    def build_block_level2(inp, rate, block_id):
        block_1_out = build_block_level3(inp, 512, (1,1), rate, block_id, 1)
        block_2_out = build_block_level3(block_1_out, 512, (3,3), rate, block_id, 2)
        block_3_out = build_block_level3(block_2_out, 2048, (1,1), rate, block_id, 3, activation=False)
        return block_3_out

    def build_block_level1(inp, rate):
        block0_out = build_block_level3(inp, 2048, (1,1), 1, block_id=1, convlayer_id=0, activation=False)

        block1_out = build_block_level2(inp, 2, block_id=1)
        block1_add = layers.Add(name='conv5_block{}_add'.format(1))([block0_out, block1_out])
        block1_relu = layers.Activation('relu', name='conv5_block{}_relu'.format(1))(block1_add)

        block2_out = build_block_level2(block1_relu, 2, block_id=2)
        block2_add = layers.Add(name='conv5_block{}_add'.format(2))([block1_add, block2_out])
        block2_relu = layers.Activation('relu', name='conv5_block{}_relu'.format(2))(block2_add)

        block3_out = build_block_level2(block2_relu, 2, block_id=3)
        block3_add = layers.Add(name='conv5_block{}_add'.format(3))([block2_add, block3_out])
        block3_relu = layers.Activation('relu', name='conv5_block{}_relu'.format(3))(block3_add)

        return block3_relu

    def atrous_spatial_pyramid_pooling(inp):
        atrous_1_conv = build_block_level3(inp, 256, (1,1), 1, '_aspp_a', 1, activation='relu')
        atrous_2_conv = build_block_level3(inp, 256, (3,3), 6, '_aspp_a', 2, activation='relu')
        atrous_3_conv = build_block_level3(inp, 256, (3,3), 12, '_aspp_a', 3, activation='relu')
        atrous_4_conv = build_block_level3(inp, 256, (3,3), 18, '_aspp_a', 4, activation='relu')

        global_pooling = layers.Lambda(lambda x: K.mean(x, axis=[1,2], keepdims=True))(inp)
        global_conv = build_block_level3(global_pooling, 256, (1,1), 1, '_aspp_b', 1, activation='relu')
        global_up = layers.UpSampling2D((24,24), interpolation='bilinear')(global_pooling)
        aspp_output = layers.Concatenate()([atrous_1_conv, atrous_2_conv, atrous_3_conv, atrous_4_conv, global_up])

        return aspp_output

    inp = layers.Input(shape=target_size + (3,))
    resnet50_base = tf.keras.applications.ResNet50(
        include_top=False, input_tensor=inp, pooling=None
    )
    for layer in resnet50_base.layers:
        if layer.name == "conv5_block1_1_conv":
            break
        out = layer.output

    resnet50_model = models.Model(resnet50_base.input, out)
    resnet_block_output = build_block_level1(resnet50_model.output, 2)

    aspp_output = atrous_spatial_pyramid_pooling(resnet_block_output)
    final_output = layers.Conv2D(num_classes, (1,1), padding='same')(aspp_output)
    final_output = layers.UpSampling2D((16,16), interpolation='bilinear')(final_output)

    deeplabv3_model = models.Model(resnet50_model.input, final_output)

    weights_dict = {}
    for layer_name in ["conv5_block1_0_conv", "conv5_block1_0_bn",
                       "conv5_block1_1_conv", "conv5_block1_1_bn",
                       "conv5_block1_2_conv", "conv5_block1_2_bn",
                       "conv5_block1_3_conv", "conv5_block1_3_bn"]:
        weights_dict[layer_name] = resnet50_base.get_layer(layer_name).get_weights()

    return deeplabv3_model, weights_dict


##Визначення функцій втрат

In [9]:
num_classes = 21
def get_label_weights(y_true, y_pred):
    ignore_label = 255

    mask = tf.not_equal(y_true, ignore_label)
    y_true = tf.where(mask, y_true, tf.zeros_like(y_true))

    weights = tf.reduce_sum(tf.one_hot(y_true, num_classes), axis=[1, 2])  # [b, classes]
    tot = tf.reduce_sum(weights, axis=-1, keepdims=True)

    weights = (tot - weights) / tot  # [b, classes]
    y_true = tf.reshape(y_true, [-1, y_pred.shape[1] * y_pred.shape[2]])  # [b, -1]

    y_weights = tf.gather(weights, y_true, batch_dims=1)
    y_weights = tf.reshape(y_weights, [-1])
    return y_weights
def dice_loss_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        smooth = 1.
        ignore_label = 255
        y_true = tf.cast(y_true, 'int32')
        y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])

        mask = tf.not_equal(y_true, ignore_label)
        y_true = tf.where(mask, y_true, tf.zeros_like(y_true))
        y_weights = tf.reshape(get_label_weights(y_true, y_pred), [-1, 1])
        y_pred = tf.nn.softmax(y_pred)
        y_true_unwrap = tf.reshape(y_true, [-1])
        y_true_unwrap = tf.cast(tf.one_hot(y_true_unwrap, num_classes), 'float32')
        y_pred_unwrap = tf.reshape(y_pred, [-1, num_classes])

        intersection = tf.reduce_sum(y_true_unwrap * y_pred_unwrap * y_weights)
        union = tf.reduce_sum((y_true_unwrap + y_pred_unwrap) * y_weights)
        score = (2. * intersection + smooth) / (union + smooth)
        loss = 1 - score
        return loss
    return loss_fn

def ce_weighted_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        ignore_mask = tf.cast(y_true != 255, tf.float32)
        y_true = tf.cast(y_true, tf.float32)
        y_true_masked = y_true * ignore_mask
        y_pred_masked = y_pred * ignore_mask[..., tf.newaxis]

        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=tf.cast(y_true_masked, tf.int32),
                logits=y_pred_masked
            )
        )
        return loss
    return loss_fn

def ce_dice_loss_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        loss = ce_weighted_from_logits(num_classes)(tf.cast(y_true, 'int32'), y_pred) + \
               dice_loss_from_logits(num_classes)(y_true, y_pred)
        return loss
    return loss_fn

##Реалізація метрики оцінок

In [10]:
class PixelAccuracyMetric(tf.keras.metrics.Accuracy):

  def __init__(self, num_classes, name='pixel_accuracy', **kwargs):
    super(PixelAccuracyMetric, self).__init__(name=name, **kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):

    y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])
    y_true = tf.reshape(y_true, [-1])
    y_pred = tf.reshape(tf.argmax(y_pred, axis=-1),[-1])

    valid_mask = tf.reshape((y_true <= num_classes - 1), [-1])

    y_true = tf.boolean_mask(y_true, valid_mask)
    y_pred = tf.boolean_mask(y_pred, valid_mask)
    super(PixelAccuracyMetric, self).update_state(y_true, y_pred)

class MeanAccuracyMetric(tf.keras.metrics.Mean):
  def __init__(self, num_classes, name='mean_accuracy', **kwargs):
    super(MeanAccuracyMetric, self).__init__(name=name, **kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):
    smooth = 1
    y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])
    y_true = tf.reshape(y_true, [-1])
    y_pred = tf.reshape(tf.argmax(y_pred, axis=-1),[-1])

    valid_mask = tf.reshape((y_true <= num_classes - 1), [-1])

    y_true = tf.boolean_mask(y_true, valid_mask)
    y_pred = tf.boolean_mask(y_pred, valid_mask)

    conf_matrix = tf.cast(tf.math.confusion_matrix(y_true, y_pred, num_classes=num_classes), 'float32')
    true_pos = tf.linalg.diag_part(conf_matrix)
    mean_accuracy = tf.reduce_mean(
        (true_pos + smooth)/(tf.reduce_sum(conf_matrix, axis=1) + smooth)
    )
    super(MeanAccuracyMetric, self).update_state(mean_accuracy)

class MeanIoUMetric(tf.keras.metrics.MeanIoU):
  def __init__(self, num_classes, name='mean_iou', **kwargs):
    super(MeanIoUMetric, self).__init__(num_classes=num_classes, name=name, **kwargs)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])
    y_true = tf.reshape(y_true, [-1])

    y_pred = tf.nn.softmax(y_pred)
    y_pred = tf.reshape(tf.argmax(y_pred, axis=-1),[-1])

    valid_mask = tf.reshape((y_true <= num_classes - 1), [-1])

    y_true = tf.boolean_mask(y_true, valid_mask)
    y_pred = tf.boolean_mask(y_pred, valid_mask)
    super(MeanIoUMetric, self).update_state(y_true, y_pred)

##Підготовка до навчання моделі

In [None]:
batch_size = 8
epochs = 25
def get_steps_per_epoch(n_data, batch_size):
    if n_data % batch_size == 0:
        return n_data // batch_size
    else:
        return (n_data // batch_size) + 1

train_filenames = pd.read_csv(os.path.join(subset_dir, "train.txt"), index_col=None, header=None).squeeze().tolist()
val_filenames = pd.read_csv(os.path.join(subset_dir, "val.txt"), index_col=None, header=None).squeeze().tolist()

n_train = get_steps_per_epoch(len(train_filenames), batch_size)
n_valid = get_steps_per_epoch(len(val_filenames) // 2, batch_size)

input_size = (384, 384)
tr_image_ds = generate_tf_dataset(
    train_subset_fn, batch_size, epochs,
    input_size=input_size, resize_to_before_crop=(444, 444),
    augmentation=True
)
val_image_ds = generate_tf_dataset(
    val_subset_fn, batch_size, epochs,
    input_size=input_size,
)
test_image_ds = generate_tf_dataset(
    test_subset_fn, batch_size, 1,
    input_size=input_size,
)
deeplabv3, w_dict = create_deeplabv3(input_size)

optimizer = tf.keras.optimizers.Adam(lr=0.0001)
deeplabv3.compile(
    loss=ce_dice_loss_from_logits(num_classes),
    optimizer=optimizer,
    metrics=[
        PixelAccuracyMetric(num_classes),
        MeanIoUMetric(num_classes),
        MeanAccuracyMetric(num_classes)
    ])

for k, w in w_dict.items():
    deeplabv3.get_layer(k).set_weights(w)
deeplabv3.summary()



##Навчання моделі

In [None]:
if not os.path.exists('eval'):
    os.mkdir('eval')

if not os.path.exists('models'):
    os.mkdir('models')

csv_logger = tf.keras.callbacks.CSVLogger(os.path.join('eval','2_deeplab_v3.log'))
monitor_metric = 'val_loss'
mode = 'min' if 'loss' in monitor_metric else 'max'
print("Using metric={} and mode={} for EarlyStopping".format(monitor_metric, mode))

lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor=monitor_metric, factor=0.1, patience=3, mode=mode, min_lr=1e-8
)
t1 = time.time()

deeplabv3.fit(
    tr_image_ds, steps_per_epoch=n_train,
    validation_data=val_image_ds, validation_steps=n_valid,
    epochs=epochs, callbacks=[lr_callback, csv_logger]
)
t2 = time.time()
print("It took {} seconds to complete the training".format(t2-t1))

##Зберігання моделі Deeplabv3

In [None]:

tf.keras.models.save_model(deeplabv3, os.path.join('models', 'deeplabv3.h5'))
deeplabv3 = tf.keras.models.load_model(os.path.join('models', 'deeplabv3.h5'), compile=False)
optimizer = tf.keras.optimizers.Adam(lr=0.0001)
deeplabv3.compile(
    loss=ce_dice_loss_from_logits(num_classes),
    optimizer=optimizer,
    metrics=[
        MeanIoUMetric(num_classes),
        MeanAccuracyMetric(num_classes),
        PixelAccuracyMetric(num_classes)
    ])
deeplabv3.evaluate(test_image_ds, steps=n_valid)

## Вивід результатуючих зоборажень

In [None]:
def generate_results_plot(image_ds, n, model):
    plt.subplots(n//2, 4, figsize=(32,32))

    i = 0
    for img, y_true in image_ds.skip(220).take(n):
        img_pred = model.predict(img, verbose=0)

        y_pred = np.argmax(img_pred[0,:,:,:], axis=-1)
        y_true = y_true.numpy().astype('int')

        y_rgb_pred = rgb_image_from_pallette(y_pred)
        y_rgb_true = rgb_image_from_pallette(y_true)

        row = i // 2
        col_off = (i % 2) * 2

        plt.subplot(n//2, 4, row*4+col_off+1)
        plt.imshow((img[0, :, :, :].numpy() * 255.0).astype('uint8'))
        plt.axis('off')
        if i < 2:
            plt.title('Оригінальне зображення', fontsize=18)

        plt.subplot(n//2, 4, row*4+col_off+2)
        plt.imshow(y_rgb_pred.astype('uint8'))
        plt.axis('off')
        if i < 2:
            plt.title('Результатуюче зображення', fontsize=18)
        i += 1
    plt.show()

test_image_ds = generate_tf_dataset(
    test_subset_fn, 1, 1,
    input_size=(384,384))
n=8
generate_results_plot(test_image_ds, n, deeplabv3)
