# Upscaler

In [None]:
import tensorflow as tf

import os
import math
import numpy as np
import random
import glob
import PIL

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing import image_dataset_from_directory


The pages (.png) are in `/pages`.

In [4]:
def get_id(n):
    charset = "123456789abcdef"
    return "".join([random.choice(charset) for _ in range(n)])


We extract high-quality 512×512px training images from each page and save them in `/training`.

In [5]:
output_dir = "/content/training"

tile_size = 256
stride = 256

os.makedirs(output_dir, exist_ok=True)

files = glob.glob("/content/pages/*.png")[:15]

for filename in files:
    with PIL.Image.open(filename) as input_image:
        # Compute the number of tiles
        width, height = input_image.size
        tiles_w = int((width - tile_size) / stride) + 1
        tiles_h = int((height - tile_size) / stride) + 1

        # Extract each tile and save to disk
        for i in range(tiles_w):
            for j in range(tiles_h):
                x = i * stride
                y = j * stride
                x_end = x + tile_size
                y_end = y + tile_size

                # Extract the patch from the input image
                tile = input_image.crop((x, y, x_end, y_end))

                # Save the patch as PNG in output_dir
                output_path = os.path.join(output_dir, f"{get_id(16)}.png")
                tile.save(output_path)


We create train and test sets.

In [6]:
! mkdir -p images/train
! mkdir images/test

In [7]:
import os
import random
import shutil
import glob

img_folder = "training"

train_folder = "images/train"
test_folder = "images/test"

train_ratio = 0.75
test_ratio = 0.25

files = glob.glob(img_folder + "/*.png")

random.seed(42)
random.shuffle(files)

num_train = int(len(files) * train_ratio)
num_test = int(len(files) * test_ratio)

# Move training files
for img_name in files[:num_train]:
    destination = os.path.join(train_folder, img_name.split("/")[1])
    shutil.move(img_name, destination)

# Move test files
for img_name in files[num_train:]:
    destination = os.path.join(test_folder, img_name.split("/")[1])
    shutil.move(img_name, destination)


## Training the upscaler

In [None]:
crop_size = 512
upscale_factor = 2
input_size = crop_size // upscale_factor
batch_size = 16


def process_input(input, input_size):
    # Convert to YCbCr, and resize to input_size × input_size
    input = tf.image.rgb_to_yuv(input)
    last_dimension_axis = len(input.shape) - 1
    y, u, v = tf.split(input, 3, axis=last_dimension_axis)
    return tf.image.resize(y, [input_size, input_size], method="area")


def process_target(input):
    # Convert to YCbCr, and extract the Y channel
    input = tf.image.rgb_to_yuv(input)
    last_dimension_axis = len(input.shape) - 1
    y, u, v = tf.split(input, 3, axis=last_dimension_axis)
    return y


def create_dataset(root_dir, batch_size, crop_size, validation_split=0.3, seed=42):
    scaler = keras.Sequential(
        [
            layers.experimental.preprocessing.Rescaling(
                1.0 / 255, input_shape=(None, None, 3)
            )
        ]
    )

    dataset_options = dict(
        batch_size=batch_size,
        image_size=(crop_size, crop_size),
        validation_split=validation_split,
        seed=seed,
        label_mode=None,
    )

    train_ds = keras.preprocessing.image_dataset_from_directory(
        root_dir, subset="training", **dataset_options
    )

    valid_ds = keras.preprocessing.image_dataset_from_directory(
        root_dir, subset="validation", **dataset_options
    )

    # Scale images
    train_ds = train_ds.map(lambda x: (scaler(x),))
    valid_ds = valid_ds.map(lambda x: (scaler(x),))

    train_ds = train_ds.map(lambda x: (process_input(x, input_size), process_target(x)))
    train_ds = train_ds.prefetch(buffer_size=32)

    valid_ds = valid_ds.map(lambda x: (process_input(x, input_size), process_target(x)))
    valid_ds = valid_ds.prefetch(buffer_size=32)

    return train_ds, valid_ds


train_ds, valid_ds = create_dataset("images/train", batch_size, crop_size)
test_img_paths = glob.glob(test_folder + "/*.png")


In [10]:
def get_lr(img, upscale_factor):
    # Resize image to width/upscale_factor × height/upscale_factor
    return img.resize(
        (img.size[0] // upscale_factor, img.size[1] // upscale_factor),
        PIL.Image.BICUBIC,
    )


def upscale_image(model, img):
    # Convert into YCbCr
    ycbcr = img.convert("YCbCr")

    # Split into separate channels
    y, cb, cr = ycbcr.split()

    # Upscale the Y channel
    y = img_to_array(y)
    y = y.astype("float32") / 255.0

    input = np.expand_dims(y, axis=0)
    out = model.predict(input)

    out_img_y = out[0]
    out_img_y *= 255.0

    # Reshape Y channel
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
    out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")

    # Bicubic upscaling for Cb channel
    out_img_cb = cb.resize(out_img_y.size, PIL.Image.BICUBIC)

    # Bicubic upscaling for Cr channel
    out_img_cr = cr.resize(out_img_y.size, PIL.Image.BICUBIC)

    # Combine channels into an image, convert to RGB
    out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
        "RGB"
    )

    return out_img


In [11]:
class ESPCNCallback(keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.test_img = get_lr(load_img(test_img_paths[0]), upscale_factor)

    # Store PSNR value in each epoch.
    def on_epoch_begin(self, epoch, logs=None):
        self.psnr = []

    def on_epoch_end(self, epoch, logs=None):
        print("Mean PSNR for epoch: %.2f" % (np.mean(self.psnr)))

    def on_test_batch_end(self, batch, logs=None):
        self.psnr.append(10 * math.log10(1 / logs["loss"]))


In [20]:
early_stopping_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=10)

checkpoint_filepath = "/checkpoint"

model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor="loss",
    mode="min",
    save_best_only=True,
)

conv_args = {
    "activation": "relu",
    "kernel_initializer": "Orthogonal",
    "padding": "same",
}

inputs = keras.Input(shape=(None, None, 1))
x = layers.Conv2D(64, 5, **conv_args)(inputs)
x = layers.Conv2D(64, 3, **conv_args)(x)
x = layers.Conv2D(32, 3, **conv_args)(x)
x = layers.Conv2D(upscale_factor**2, 3, **conv_args)(x)
outputs = tf.nn.depth_to_space(x, upscale_factor)

model = keras.Model(inputs=inputs, outputs=outputs)

model.summary()


Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, None, None, 1)]   0         
                                                                 
 conv2d_4 (Conv2D)           (None, None, None, 64)    1664      
                                                                 
 conv2d_5 (Conv2D)           (None, None, None, 64)    36928     
                                                                 
 conv2d_6 (Conv2D)           (None, None, None, 32)    18464     
                                                                 
 conv2d_7 (Conv2D)           (None, None, None, 4)     1156      
                                                                 
 tf.nn.depth_to_space_1 (TFO  (None, None, None, 1)    0         
 pLambda)                                                        
                                                           

In [None]:
epochs = 100

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.MeanSquaredError(),
)

history = model.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=epochs,
    callbacks=[ESPCNCallback(), early_stopping_callback, model_checkpoint_callback],
    verbose=2,
)

model.load_weights(checkpoint_filepath)
