In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing import image_dataset_from_directory

2023-05-30 09:24:43.736072: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
saved_model = '/Users/balazsmorvay/Downloads/ESPCN/ESPCN_model'

In [3]:
def get_model(upscale_factor=3, channels=1):
    conv_args = {
        "activation": "relu",
        "kernel_initializer": "Orthogonal",
        "padding": "same",
    }
    # run["conv_args"] = conv_args
    inputs = keras.Input(shape=(None, None, channels))
    x = layers.Conv2D(64, 5, **conv_args)(inputs)
    x = layers.Conv2D(64, 3, **conv_args)(x)
    x = layers.Conv2D(32, 3, **conv_args)(x)
    x = layers.Conv2D(channels * (upscale_factor ** 2), 3, **conv_args)(x)
    outputs = tf.nn.depth_to_space(x, upscale_factor)

    return keras.Model(inputs, outputs)

In [4]:
model = get_model(upscale_factor=4, channels=1)
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, None, None, 1)]   0         
                                                                 
 conv2d (Conv2D)             (None, None, None, 64)    1664      
                                                                 
 conv2d_1 (Conv2D)           (None, None, None, 64)    36928     
                                                                 
 conv2d_2 (Conv2D)           (None, None, None, 32)    18464     
                                                                 
 conv2d_3 (Conv2D)           (None, None, None, 16)    4624      
                                                                 
 tf.nn.depth_to_space (TFOpL  (None, None, None, 1)    0         
 ambda)                                                          
                                                             

In [5]:
model.load_weights(saved_model)

2023-05-30 09:25:13.127560: W tensorflow/core/util/tensor_slice_reader.cc:97] Could not open /Users/balazsmorvay/Downloads/ESPCN/ESPCN_model: FAILED_PRECONDITION: /Users/balazsmorvay/Downloads/ESPCN/ESPCN_model; Is a directory: perhaps your file is in a different file format and you need to use a different restore operator?


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7fc538750d60>

In [6]:
import coremltools as ct

mlmodel = ct.convert(
                    model,
                    source='tensorflow',
                    convert_to='mlprogram',
                    inputs=[ct.TensorType(shape=(1, 64, 64, 1))])


2023-05-30 09:25:30.509140: I tensorflow/core/grappler/devices.cc:75] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support)
2023-05-30 09:25:30.583021: I tensorflow/core/grappler/devices.cc:75] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support)
Running TensorFlow Graph Passes: 100%|██████████| 6/6 [00:00<00:00, 33.23 passes/s]
Converting TF Frontend ==> MIL Ops: 100%|██████████| 23/23 [00:00<00:00, 1248.50 ops/s]
Running MIL frontend_tensorflow2 pipeline: 100%|██████████| 7/7 [00:00<00:00, 4356.10 passes/s]
Running MIL default pipeline: 100%|██████████| 57/57 [00:00<00:00, 625.04 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 10/10 [00:00<00:00, 4545.19 passes/s]


In [None]:
mlmodel.save('ESPCN.mlpackage')

In [10]:
import numpy as np
import PIL
from PIL import Image
import sys
from fdf256dataset import FDF256Dataset

Could not load pyspng. Defaulting to pillow image backend.


In [12]:
def get_lowres_image(img, upscale_factor):
    """Return low-resolution image to use as model input."""
    return img.resize(
        (img.size[0] // upscale_factor, img.size[1] // upscale_factor),
        PIL.Image.BICUBIC,
    )

In [26]:
def upscale_image(model, img):
    """Predict the result based on input image and restore the image as RGB."""
    ycbcr = img.convert("YCbCr")
    y, cb, cr = ycbcr.split()
    y = img_to_array(y)
    y = y.astype("float32") / 255.0

    input = np.expand_dims(y, axis=0)
    out = model.predict({'input_1': input})

    print(out['Identity'].shape)
    out_img_y = out['Identity'][0]
    out_img_y *= 255.0

    # Restore the image in RGB color space.
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
    out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")
    out_img_cb = cb.resize(out_img_y.size, PIL.Image.BICUBIC)
    out_img_cr = cr.resize(out_img_y.size, PIL.Image.BICUBIC)
    out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
        "RGB"
    )
    return out_img

In [28]:
# Run the entire dataset through the network
from tqdm import tqdm
import torchvision.transforms as T
import torch
from einops import rearrange
from matplotlib import pyplot as plt
import os

crop_size = 256
upscale_factor = 4
input_size = crop_size // upscale_factor

images = []

impaths = '/Users/balazsmorvay/Downloads/val_images'

for f in os.listdir(impaths):
    if f == '.DS_Store':
        continue
    p = os.path.join(impaths, f)
    image = Image.open(p)
    images.append(image)

for index, img in tqdm(enumerate(images)):
    lowres_input = get_lowres_image(img, upscale_factor)
    prediction = upscale_image(mlmodel, lowres_input)
    img_array = img_to_array(prediction)
    img_array = img_array.astype("float32") / 255.0
    plt.imsave(f'{index}.png', img_array)
    

3it [00:01,  2.69it/s]

(1, 256, 256, 1)
(1, 256, 256, 1)
(1, 256, 256, 1)



