#Dependencies

In [None]:
import os
from PIL import ImageOps, Image
import random
import tensorflow as tf

In [None]:
base_dir = 'PATH_TO_YOUR_DATASET' #@param
random.seed(408)

In [None]:
import tensorflow as tf
print(tf.__version__)
print(tf.config.list_physical_devices())
# Ref: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],
        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=10*1024)])
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    print('>',e)

#Image selection

In [None]:
print(">images: %d"%len(os.listdir(base_dir+'/original/')))

In [None]:
batch_size = 2#@param
epochs = 600#@param
model_name = "halftone_edsr" #@param ["halftone_edsr", "halftone_net"]
image_shape = (256,256)
input_zoom = 0.5
output_zoom = 1
load_pre_trained_model = "True" #@param ['True', 'False']
load_pre_trained_model = bool(load_pre_trained_model)
mode = 'CMYK' #BW

In [None]:
from tensorflow.python.tools import module_util
#@title Dataset training parameters
train_size = 0.7 #@param
validation_split = 2/3 #@param
k_fold = 5 #@param
example_folds = [{'train':[],'test':[]} for i in range(k_fold)]
output_path = "OUTPUT_PATH"

example_images = os.listdir(base_dir+'/original/')

random.shuffle(example_images)
example_images = example_images[:900]

size_fold = int(len(example_images)*train_size)


print(size_fold)
for fold in range(k_fold):
  example_folds[fold]['train'] += [base_dir+'/%s/'+v+'.npy' for v in example_images[:size_fold]]
  example_folds[fold]['test']  += [base_dir+'/%s/'+v+'.npy' for v in example_images[size_fold:]]
  example_images = example_images[size_fold:] + example_images[:size_fold]

  if(validation_split):
    size_val = int(validation_split*len(example_folds[fold]['test']))
    example_folds[fold]['validation'] = example_folds[fold]['test'][:size_val]
    example_folds[fold]['test'] = example_folds[fold]['test'][size_val:]


for i in range(k_fold):
  print("#%d fold"%i)
  for t in ['train', 'validation', 'test']:
    print('\t>',"%s:"%t,len(example_folds[i][t]))

In [None]:
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

In [None]:
import numpy as np

class HalftoneModel(tf.keras.Model):
    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, xs):
        if(mode == 'BW'):
          xs = tf.cast(tf.expand_dims(xs, axis=3), tf.float32)
        else:
          xs = tf.cast(xs, tf.float32)

        outputs = []
        for x in xs:
          x = tf.cast(tf.expand_dims(x, axis=0), tf.float32)
          output = self(x, training=False)
          output = 255*output
          output = tf.clip_by_value(output, 0, 255)
          output = tf.round(output)

          if(mode == 'BW'):
            output = tf.squeeze(tf.cast(output, tf.uint8), axis=0)
          else:
            output = tf.cast(output, tf.uint8)

          outputs.append(output)
        return np.array(outputs)

# Residual Block
def ResBlock(inputs, kernel = 3):

    x = layers.Conv2D(2**(9-kernel), kernel, padding="same", activation="relu")(inputs)
    x = layers.Conv2D(64, kernel, padding="same")(x)
    x = layers.Add()([inputs, x])
    return x

# Upsampling Block
def Upsampling(inputs, qt_inputs, factor=2, **kwargs):
    x = layers.Conv2D(qt_inputs* 64 * (factor ** 2), 3, padding="same", **kwargs)(inputs)
    x = tf.nn.depth_to_space(x, block_size=factor)

    return x

In [None]:
def hafltone_edsr(num_filters, num_of_residual_blocks):
    tf.keras.backend.clear_session()

    if(mode == 'CMYK'):
      input_layer = layers.Input(shape=(None, None, 4))
    else:
      input_layer = layers.Input(shape=(None, None, 1))

    x = input_layer
    x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x)

    for _ in range(num_of_residual_blocks):
        x_new = ResBlock(x_new, None)

    x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new)
    x = layers.Add()([x, x_new])

    up_layers = np.log2(output_zoom/input_zoom)

    if(int(up_layers) != up_layers):
      raise Exception("The input and output resolution must be in base 2! Current ratio: %.2f "%(output_zoom/input_zoom))

    for _ in range(int(up_layers)):
      x = Upsampling(x, 1, factor=2**up_layers)

    if(mode == 'CMYK'):
      output_layer = layers.Conv2D(4, 3, padding="same")(x)
    else:
      output_layer = layers.Conv2D(1, 3, padding="same")(x)

    model =  HalftoneModel(input_layer, output_layer)

    return model

In [None]:
def hafltone_net(num_filters, num_of_residual_blocks):
    tf.keras.backend.clear_session()

    if(mode == 'CMYK'):
      input_layer = layers.Input(shape=(None, None, 4))
    else:
      input_layer = layers.Input(shape=(None, None, 1))

    x = input_layer
    x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x)

    for _ in range(num_of_residual_blocks):
        x_new = ResBlock(x_new, None)

    x = layers.Conv2D(num_filters, 3, padding="same")(x_new)

    up_layers = np.log2(output_zoom/input_zoom)

    if(int(up_layers) != up_layers):
      raise Exception("The input and output resolution must be in base 2! Current ratio: %.2f "%(output_zoom/input_zoom))

    for _ in range(int(up_layers)):
      x = Upsampling(x, 2, factor=2**up_layers)

    if(mode == 'CMYK'):
      output_layer = layers.Conv2D(4, 3, padding="same")(x)
    else:
      output_layer = layers.Conv2D(1, 3, padding="same")(x)

    model =  HalftoneModel(input_layer, output_layer)

    return model

##Model types

In [None]:
model_by_name = {
    'halftone_net': "()",
    'halftone_edsr' : "(num_filters=64, num_of_residual_blocks=16)"
}

##Model Info

In halftone photography, the term amplitude modulation is used to refer to a halftoning technique (the conventional form of halftone screening) in which the sizes of the halftone dots are varied according to whether they correspond to shadows (large dots), middle tones (medium-sized dots), or highlights (small dots). An alternate means of halftone screening is known as stochastic screening, or FM screening. See Halftone and Stochastic Screening.
<a src="http://printwiki.org/Amplitude_Modulation#:~:text=In%20halftone%20photography%2C%20the%20term,highlights%20(small%20dots).">source</a>

In [None]:
model = eval(model_name+model_by_name[model_name])
model.summary()

In [None]:
from tensorflow import keras
keras.utils.plot_model(model, show_shapes=True)

#Prepare the dataset for the created model.

In [None]:
model_shape = model.output.shape[1:3][::-1]
if(model != 'BW'):
  model_shape+=model.output.shape[3]
print(model_shape)

In [None]:
#@title DataGenerator
import random

import numpy as np
from tensorflow import keras

class DataGenerator(tf.keras.utils.Sequence):
  def slice_images(self,xs,ys):
    ret_x, ret_y = [],[]
    x = 1-np.load(xs)/255
    y = 1-np.load(ys)/255
    if(mode == 'BW'):
      x = np.expand_dims(x[:,:,0],-1)
      y = np.expand_dims(y[:,:,0],-1)
    i = 0

    x1,y1 = random.randint(0,x.shape[0]-image_shape[0]-1), random.randint(0,x.shape[1]-image_shape[1]-1)
    x2,y2 = x1+image_shape[0],y1+image_shape[1]
    v = np.mean(y[x1:x2,y1:y2])
    if(True or v >= 0.02):
        ret_x.append(x[x1:x2,y1:y2:,])
        ret_y.append(y[x1*2:2*x2,y1*2:2*y2:,])
        i += 1
    return ret_x[0], ret_y[0]

  def __init__(self, dataset_images, oversampling = 1, shuffle=True, batch_size = 3, name = "", only_class = None, slice_once = False):

    self.slice_once = slice_once
    if(name == 'train'):
      self.dataset_images = ([v.replace("/%s/",'/halfsize/') for v in dataset_images*oversampling], [v.replace("/%s/",'/real/') for v in dataset_images*oversampling])
    else:
      self.dataset_images = ([v.replace("/%s/",'/halfsize/') for v in dataset_images], [v.replace("/%s/",'/real/') for v in dataset_images])

    if(name == 'train'):
      self.batch_size = batch_size
    else:
      self.batch_size = 1

    self.name = name
    self.data_aug = ['train']
    self.shuffle = shuffle
    self.on_epoch_end()


  def __len__(self):
      return self.steps_per_epoch

  def __getitem__(self, index):

    if(index >= np.floor(self.steps_per_epoch)):
      indexes = self.indexes[index*self.batch_size:]
    else:
      indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

    list_IDs_temp = [(self.dataset_images[0][k],self.dataset_images[1][k]) for k in indexes]
    return self.__data_generation(list_IDs_temp)

  def data_augmentation(self,a,b):
    pass

  def on_epoch_end(self):
    if(self.batch_size == 1):
      self.steps_per_epoch = len(self.dataset_images[0])
    else:
      self.steps_per_epoch = (len(self.dataset_images[0]) // self.batch_size) +1 if(self.batch_size > 1) else 1
    self.indexes = np.arange(len(self.dataset_images[0]))
    if self.shuffle == True:
      np.random.shuffle(self.indexes)

  def _datagen(self, list_IDs_temp):
    if(self.name == 'train'):
      if(mode == 'BW'):
        X = np.empty((len(list_IDs_temp), image_shape[1], image_shape[0],1), dtype = np.float32)
        Y = np.empty((len(list_IDs_temp), 2*image_shape[1], 2*image_shape[0],1), dtype = np.float32)
      else:
        X = np.empty((len(list_IDs_temp), image_shape[1], image_shape[0],4), dtype = np.float32)
        Y = np.empty((len(list_IDs_temp), 2*image_shape[1], 2*image_shape[0],4), dtype = np.float32)

      for i,d in enumerate(list_IDs_temp):
        (x,y) = self.slice_images(*d)
        X[i,],Y[i,] = x,y
    else:
      X,Y = None,None

      for i,d in enumerate(list_IDs_temp):
        x,y = 1-np.load(d[0])/255, 1-np.load(d[1])/255
        if(X is None):
          if(mode == 'BW'):
            X = np.empty((len(list_IDs_temp), (x.shape[0]//2)*2 -1 , (x.shape[1]//2)*2 - 1,1), dtype = np.float32)
            Y = np.empty((len(list_IDs_temp), X.shape[1]*2, X.shape[2]*2, 1), dtype = np.float32)
          else:
            X = np.empty((len(list_IDs_temp), (x.shape[0]//2)*2 -1 , (x.shape[1]//2)*2 - 1,4), dtype = np.float32)
            Y = np.empty((len(list_IDs_temp), X.shape[1]*2, X.shape[2]*2, 4), dtype = np.float32)
        if(mode == 'BW'):
          x = np.expand_dims(x[:,:,0],-1)
          y = np.expand_dims(y[:,:,0],-1)
        X[i,] = x[:X.shape[1], : X.shape[2]:, :]
        Y[i,] = y[:Y.shape[1], : Y.shape[2]:, :]
    return X,Y

  def __data_generation(self, list_IDs_temp):
    return self._datagen(list_IDs_temp)


In [None]:
all_datasets = []
for fold in example_folds:
    all_datasets.append({})
    for key in fold:
        all_datasets[-1][key] = DataGenerator(fold[key], batch_size = batch_size, name=key)

In [None]:
import tensorflow.keras.backend as K
def PSNR(im1, im2):
  max_pixel = 1.0
  return (10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(im2 - im1)))))

In [None]:
from skimage.metrics import structural_similarity as ssim
def SSIM(im1, im2):
  return tf.image.ssim(im1, im2, 1)

#Main Routine

In [None]:
base_resultados = output_path+'/'
nome_pasta = '/'+model_name+'/'
export_folder_name = base_resultados+nome_pasta
try:
    os.mkdir(export_folder_name)
except FileExistsError:
    pass
export_folder_name

In [None]:
loss_checkpoint_cb = keras.callbacks.ModelCheckpoint(
    export_folder_name+model_name+'.h5', monitor='val_loss', mode='min', verbose=1
)

In [None]:
optim_edsr = tf.keras.optimizers.Adam(
    learning_rate=tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=[5000], values=[5e-4, 5e-5]
    )
)
model.compile(optimizer=optim_edsr, loss="mse", metrics=[tf.keras.losses.MeanAbsoluteError()])

In [None]:
def test_model(datasets, model):
  m_ss, m_mae, m_dbs = [], [], []
  img_idx = 0
  for k in ['test', 'validation']:
    for v in datasets[k]:
      if(mode == 'CMYK2BW' and len(v[0].shape) > 2):
          result = np.array([model.predict_step(v[0][:,:,:,c]).squeeze() for c in range(v[0].shape[-1])])
          result = np.expand_dims(np.moveaxis(result, 0, -1), axis=0)
      else:
          result = model.predict_step(v[0])
      for img,res in zip(v[1], result):
        img = img.squeeze()
        res = res.squeeze().astype(np.float32)/255
        v = np.ma.masked_invalid(PSNR(img, res).numpy())
        m_dbs.append(v.mean())
        m_mae.append(abs(((img-res))).mean())
        if(mode == 'BW'):
          m_ss.append(ssim(img,res, data_range = 1.0))
        else:
          m_ss.append(ssim(img,res, data_range = 1.0, channel_axis = 2))
        img_idx += 1
  a , b , c = np.mean(m_mae), np.mean(m_dbs), np.mean(m_ss)

  return "All images mean: MAE: %.8f, PSNR: %.8fdB, SSIM: %.8f"%(a,b,c)

In [None]:
import pickle
# ===========================================================
count = 1
for idx_data, datasets in enumerate(all_datasets):
  model = eval(model_name+model_by_name[model_name])

  optim_edsr = tf.keras.optimizers.Adam(
    learning_rate=tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[5000], values=[5e-4, 1e-5])
  )

  model.compile(optimizer=optim_edsr, loss="mse", metrics=[tf.keras.losses.MeanAbsoluteError(), tf.keras.losses.MeanSquaredError()])
  if(not load_pre_trained_model):

    history = model.fit(
        datasets['train'],
        validation_data=datasets['validation'],
        epochs=epochs,
        shuffle=True,
        verbose=1,
    )

    model.save_weights(export_folder_name+model_name+'_%d.h5'%idx_data)
    with open(export_folder_name+'dataset_%d.pck'%idx_data,'wb+') as f:
      pickle.dump(datasets,f)

    import matplotlib.pyplot as plt
    import json

    try:
      with open(export_folder_name+'history_%d.json'%idx_data) as f:
        hist = json.load(f)
      for key in hist:
        hist[key] += history.history[key]
      with open(export_folder_name+'history_%d.json'%idx_data, 'w+') as f:
        json.dump(hist,f)
    except FileNotFoundError:
      with open(export_folder_name+'history_%d.json'%idx_data, 'w+') as f:
        json.dump(history.history,f)
      hist = history.history

    with open(export_folder_name+'history_%d.json'%idx_data, 'w+') as f:
      json.dump(history.history,f)
    hist = history.history

    metrics_val = {}

    metrics = []
    tem_val = False
    for metric in hist:
      if (('val_'+metric in hist)):
        tem_val = True
        metrics.append(metric)
    if(not tem_val):
      metrics = [metric for metric in hist]

    fig, ax = plt.subplots(nrows = len(metrics), ncols= 1, figsize=(10, 5*len(metrics)))
    ax = ax.ravel()

    i = 0
    for metric in (metrics):
      ax[i].plot(hist[metric])
      v = "val_" + metric
      if(v in hist):
        ax[i].plot(hist[v])
      ax[i].set_title("Model {}".format(metric))
      ax[i].set_xlabel("epochs")
      ax[i].set_ylabel(metric)
      ax[i].legend(["train", "val"])
      i += 1
    plt.savefig(export_folder_name+'/histórico_%d.png'%idx_data)

  else:
    model.load_weights(export_folder_name+model_name+'_%d.h5'%idx_data)


  relatorio = test_model(datasets,model)
  print(relatorio)

  from skimage.metrics import structural_similarity as ssim
  # ===========================================================
  if(mode == 'BW'):
    c_mode = 'L'
  elif(mode == 'CMYK'):
    c_mode = 'CMYK'
  # ===========================================================
  buffer_tests = ""
  # ===========================================================
  for t_dataset in ['validation','test']:
    for v in datasets[t_dataset]:
      if(mode == 'CMYK2BW' and len(v[0].shape) > 2):
          result = np.array([model.predict_step(v[0][:,:,:,c]).squeeze() for c in range(v[0].shape[-1])])
          result = np.expand_dims(np.moveaxis(result, 0, -1), axis=0)
      else:
          result = model.predict_step(v[0])
      # ===========================================================
      for i,img in enumerate(v[0]):
        name = export_folder_name+'/%d_input.png'%(i+(count-1))
        img = 1-img.squeeze()
        Image.fromarray(np.array(255*img, dtype=np.uint8), mode = c_mode).convert('RGB').save(name)
      # ===========================================================
      for i,img in enumerate(v[1]):
        name = export_folder_name+'/%d_ideal.png'%(i+(count-1))
        img = 1-img.squeeze()
        Image.fromarray(np.array(255*img, dtype=np.uint8), mode = c_mode).convert('RGB').save(name)
      # ===========================================================
      for i,(img,res_i) in enumerate(zip(v[1], result)):
        res = res_i.squeeze().astype(np.float32)/255
        img = img.squeeze()
        v = np.ma.masked_invalid(PSNR(img, res).numpy())
        dbs = (v.mean())
        mae = abs(((img-res))).mean()
        if(mode == 'BW'):
          ss = ssim(img,res, data_range = 1.0)
        else:
          ss = ssim(img,res, data_range = 1.0, channel_axis = 2)
        buffer_tests += "Imagem: %d | MAE: %.5f | PSNR: %.8f dBs | SSIM: %.8f\n"%((i+(count-1)), mae,dbs, ss)
      # ===========================================================
      for i,img in enumerate(result):
        img = 255-img.squeeze()
        name = export_folder_name+"%d_resultado.png"%(i+(count-1))
        Image.fromarray(np.array(img.squeeze(), dtype=np.uint8), mode = c_mode).convert('RGB').save(name)
      count += datasets[t_dataset].batch_size


    with open(export_folder_name+'/relatorio_%d.txt'%idx_data, 'w+') as f:
      f.write(buffer_tests+relatorio+'\n')