In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import PIL.Image
import IPython.display
import io
import statistics as stats
from matplotlib import pyplot as plt
from copy import copy

# Springer: search for books on tensors or tensors in machine learning

class CellularAutomata(object):
    def __init__(self, img_size: int, channel_count: int):
        self.img_size = img_size
        self.channel_count = channel_count
        self.zerofill()

    def imagefill(self, image_path):
        """ Fills the world with image data from the disk. """
        img = PIL.Image.open(image_path).convert("RGB")
        self.zerofill()
        color_arr = np.float32(img) / 255.0
        self.world[:, :, :3] = color_arr

    def zerofill(self):
        """ Fills the world with zeros. """
        self.world = np.zeros((self.img_size, self.img_size, self.channel_count), dtype=np.float32)

    def onefill(self):
        """ Fills the world with ones. """
        self.world = np.ones((self.img_size, self.img_size, self.channel_count), dtype=np.float32)
        
    def pointfill(self):
        """ Fills the world with zeros except for a single point. """
        self.world = np.zeros((self.img_size, self.img_size, self.channel_count), dtype=np.float32)
        self.world[self.img_size // 2, self.img_size // 2] = np.ones((self.channel_count,))
        
    def to_image(self, scale=1):
        # Slice off all the non-color (hidden channels):
        arr = self.world[:, :, :3]
        rgb_array = np.uint8(arr * 255.0)

        # Scale the first two dimensions of the image by the given scale.
        for dim in range(2):
            rgb_array = np.repeat(rgb_array, scale, dim)

        out = io.BytesIO()
        PIL.Image.fromarray(rgb_array).save(out, 'png')
        return IPython.display.Image(data=out.getvalue())

    def display(self, scale=1):
        IPython.display.display(self.to_image(scale))
        
    def convolve(self, f):
        # Convolve over the image by taking 3x3 slices from it and running 
        # the model on each slice to get a matrix of deltas.
        result = np.zeros((self.img_size, self.img_size, self.channel_count), dtype=np.float32)
        
        for x in range(0, self.img_size - 2):
            for y in range(0, self.img_size - 2):
                neighborhood = self.world[x : x+3, y : y+3, :]
                result[x+1, y+1] = f(neighborhood)

    def tick(self, model: keras.Model, num_steps: int):
        """ Run the cellular automata for a tick from some Keras model. """
        input_tensor = np.empty((1, 3, 3, self.channel_count))
        def f(neighborhood):
            input_tensor[0] = neighborhood
            return model(input_tensor)
        
        for i in range(num_steps):
            delta_world = self.convolve(f)

    def train(self, num_steps: int, model: keras.Model, target, kwargs={}):
        """ Train the model by running this cellular automata on it and comparing the final state to a target state. """
        self.tick(model, num_steps)
        
        x = []
        def f(neighborhood):
            x.append(neighborhood)
        self.convolve(f)
        
        x_tensor = np.empty((len(x), 3, 3, self.channel_count))
        x_tensor[:] = x[:]
        
        y = []
        def f(neighborhood):
            y.append(neighborhood)
        target.convolve(f)
        
        y_tensor = np.empty((len(x), 3, 3, self.channel_count))
        y_tensor[:] = y[:]

        return model.fit(x=x_tensor, y=y_tensor, **kwargs)

#layerconfigs = [
#    ConvolutionLayer(units=256, kernel_size=3),
#]
#model = create_cnn(layerconfigs, img_size=16, channel_count=256)
#model.summary()

channel_count = 8
model = keras.models.Sequential()
model.add(tf.keras.Input(shape=(3, 3, channel_count)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(16))
model.add(keras.layers.Dense(channel_count))
def loss(state, target):
    return tf.reduce_mean(tf.square(state - target))
model.compile(optimizer='adam', loss=loss)
model.summary()

target_ca = CellularAutomata(img_size=16, channel_count=channel_count)
target_ca.imagefill("earth16.png")
target_ca.display(scale=3)

ca = copy(target_ca)
ca.pointfill()
ca.display(scale=3)

losses = []
while True:
    for i in range(20):
        ca.pointfill()
        hist = ca.train(num_steps=30, model=model, target=target_ca, kwargs=dict(epochs=1, verbose=0))
        mean_loss = stats.mean(hist.history['loss'])
        losses.append(mean_loss)

    ca.pointfill()
    ca.tick(num_steps=30, model=model)
    ca.display(scale=3)
    
    plt.plot(losses)
    plt.yscale('log')
    plt.grid()
    plt.show()
