In [None]:
import tensorflow as tf
import numpy as np
from IPython import display
from pathlib import Path
import png
import datetime
import math
import random
import itertools
import zlib

In [None]:
r = png.Reader('pigw.png')
(width, height, rows_gen, info) = r.read()
print(width, height, info)
rows = [row for row in rows_gen]


In [None]:
harmonics = 250
backward = (1, 2, 3, 4, 5, 7)

dropout = tf.keras.layers.Dropout(0.03)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(2 + 2 * 2 * harmonics + 3 * len(backward),)),
  tf.keras.layers.Dense(300, activation='relu'),
  tf.keras.layers.Dense(100, activation='relu'),
  tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='Adam',
              loss=tf.keras.losses.MeanSquaredError(),
              )

tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, show_dtype=True, show_layer_names=True, expand_nested=True, show_layer_activations=True, show_trainable=True)
model.summary()
display.Image(f'model.png')

In [None]:

def get_harmonics_vect(x, total):
    xx = x / total * math.pi
    vect = [xx, math.sin(xx), math.cos(xx)]
    for i in range(2, 2 * harmonics, 2):
        vect.append(math.sin(xx * i))
        vect.append(math.cos(xx * i))
    return np.array(vect, np.float32)

def normalize_color(color):
    return color * 0.003125 + 0.1

def denormalize_color(color):
    return (color - 0.1) * 320

class MySequence(tf.keras.utils.Sequence):

    def __init__(self, batch_size, randomize, with_border):
        self.batch_size = batch_size
        self.x_dim = [get_harmonics_vect(x, width) for x in range(width)]
        self.y_dim = [get_harmonics_vect(y, height) for y in range(height)]
        self.map = list(range(width * height))
        self.with_border = with_border
        self.randomize = randomize
        if randomize:
            random.shuffle(self.map)

    def on_epoch_end(self):
        if self.randomize:
            random.shuffle(self.map)

    def __len__(self):
        return (width * height + self.batch_size - 1) // self.batch_size

    def get_input(self, x, y):
        right = np.array(list(itertools.repeat(0, len(backward))), np.float32)
        up = np.array(list(itertools.repeat(0, len(backward))), np.float32)
        diag = np.array(list(itertools.repeat(0, len(backward))), np.float32)
        for i in range(len(backward)):
            xx = x - backward[i]
            yy = y - backward[i]
            if xx >= 0: right[i] = normalize_color(rows[y][xx])
            if yy >= 0: up[i] = normalize_color(rows[yy][x])
            if (xx >= 0) and (yy >= 0): diag[i] = normalize_color(rows[yy][xx])
        return np.concatenate((self.x_dim[x], self.y_dim[y], right, up, diag), axis=0)

    def __getitem__(self, index):
        pixels = range(index * self.batch_size, min(width * height, (index + 1) * self.batch_size))
        pixels = list(map(lambda i: self.map[i], pixels))
        if not self.with_border:
            pixels = list(filter(lambda i: (i // width > 7) and (i % width > 7), pixels))
        return (
                np.array([self.get_input(i % width, i // width) for i in pixels], np.float32),
                np.array([normalize_color(rows[i // width][i % width]) for i in pixels], np.float32)
               )

In [None]:
# tensorboard --logdir logs/fit

train_data = MySequence(1000, True, False)

#model.optimizer.learning_rate = 0.001
#dropout.rate = 0.99

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

model.fit(x=train_data,
          epochs=30,
          callbacks=[tensorboard_callback]
          )


In [None]:
train_data = MySequence(14000, False, True)

outputs = []
for i in range(len(train_data)):
    x_data, _ = train_data[i]
    r = model.predict(x_data)
    outputs.extend(r)

outputs = np.concatenate(outputs, axis=0)


In [None]:

w = png.Writer(width, height, greyscale=True, bitdepth=8)
with open('out.png', 'wb') as fd:
    w.write(fd, (denormalize_color(outputs[i:i+width]).astype(np.uint8) for i in range(0, width * height, width)))

mul = 10

bytes_diff = bytearray()
img_diff = bytearray()
for i in range(width * height):
    x = i % width
    y = i // width
    col = round(denormalize_color(float(outputs[i])))
    col = min(255, max(0, col))
    img_diff.append(min(255, max(0, mul * (col - rows[y][x]) + 128)))
    diff = (col - rows[y][x]) & 0xFF
    bytes_diff.append(diff)

#print(all_diff)
print(len(zlib.compress(bytes_diff, level=9))) # wbits=-15

w = png.Writer(width, height, greyscale=True, bitdepth=8)
with open('out-diff.png', 'wb') as fd:
    w.write(fd, [img_diff[i:i+width] for i in range(0, width * height, width)])

display.Image(f'out-diff.png')


In [None]:
display.Image(f'lena-gr.png')