# Extract patches

In [None]:
from time import time
import tensorflow as tf

class RC:
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    overwrite = True
    seed = 21392

class DC:
    path = '/tf/datasets/best-artworks-of-all-time'
    images = path + '/images/images'
    patches = path + '/images/patches'
    info = path + '/artists.csv'

    image_size = (299, 299)
    patch_strides = (50, 50)
    patches_count = 20

class Config:
    run = RC
    data = DC

## Setup

In [None]:
from math import ceil
import os, shutil, pathlib, numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
from tensorflow.keras import Model, Sequential, Input
from tensorflow.keras.layers import (Conv2D, Dense, Dropout, BatchNormalization,
                                     Activation, Lambda)

In [None]:
def plot(y, titles=None, rows=1, i0=0):
    for i, image in enumerate(y):
        if image is None:
            plt.subplot(rows, ceil(len(y) / rows), i0+i+1)
            plt.axis('off')
            continue

        t = titles[i] if titles else None
        plt.subplot(rows, ceil(len(y) / rows), i0+i+1, title=t)
        plt.imshow(image)
        plt.axis('off')

In [None]:
sns.set()

## Dataset

In [None]:
def get_label(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    one_hot = parts[-2] == class_names
    return tf.argmax(one_hot)

def decode_img(img):
    # convert the compressed string to a 3D uint8 tensor
    i = tf.image.decode_jpeg(img, channels=3)
    i = tf.expand_dims(i, 0)
    i = tf.image.extract_patches(i,
                                 sizes=(1, *Config.data.image_size, 1),
                                 strides=(1, *Config.data.patch_strides, 1),
                                 rates=(1, 1, 1, 1),
                                 padding='VALID')
    return tf.reshape(i, (-1, 299, 299, 3))

def process_path(file_path):
    label = get_label(file_path)
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return file_path, img, label

In [None]:
if os.path.exists(Config.data.patches) and not Config.run.overwrite:
    raise ValueError('Patches were already extracted, and overwrite is set to False. '
                     'This procedure will not continue, as it might damage the '
                     'reproducibility of the current experiments.')

In [None]:
try: shutil.rmtree(Config.data.patches)
except: ...

data_dir = pathlib.Path(Config.data.images)
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)

for f in list_ds.take(5):
    print(f.numpy())

class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]))    
images_ds = list_ds.map(process_path, num_parallel_calls=Config.run.AUTOTUNE)

In [None]:
for c in class_names.tolist():
    os.makedirs(os.path.join(Config.data.patches, c), exist_ok=True)

In [None]:
for filename, image, label in images_ds:
    filename, image, label = filename.numpy(), image.numpy(), label.numpy()
    filename = str(filename)[2:-1]  # remove b''
    
    if not len(image):
        print(f'Cannot extract from {filename}. Tensor: {image}')
        continue
    
    choices = np.random.choice(len(image), size=Config.data.patches_count)
    image = image[choices]
    label_name = str(class_names[label])
    
    name, ext = os.path.splitext(os.path.basename(filename))
    
    for ix, i in enumerate(image):
        n = os.path.join(Config.data.patches, str(label_name), f'{name}_{ix}{ext}')
        tf.keras.preprocessing.image.save_img(
            n,
            i,
            scale=False)
        