In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.util.image import *
from src.util import *
from src.models.util import *
from src.algo import ca1

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
np.outer([1, 2, 1], [-1, 0, 1])

In [None]:
class NeuralCA(nn.Module):

    def __init__(
        self,
        channels: int = 16,
        channels_hidden: int = 128,
        default_shape: Optional[Tuple[int, int]] = None,
        activation: Union[str, Callable] = "tanh",
    ):
        assert channels >= 4, channels
        
        super().__init__()
        
        self.channels = channels
        self.default_shape = default_shape
        self.activation = activation_to_callable(activation)

        sobel = torch.Tensor([
            [-1,  0,  1],
            [-2,  0,  2],
            [-1,  0,  1]
        ]) / 2.
        self.sobel1 = nn.Parameter(
            sobel.view(1, 1, 3, 3).repeat(channels, 1, 1, 1),
            #torch.randn(channels, channels, 3, 3)
            requires_grad=False,
        )
        self.sobel2 = nn.Parameter(
            sobel.T.view(1, 1, 3, 3).repeat(channels, 1, 1, 1),
            #torch.randn(channels, channels, 3, 3)
            requires_grad=False,
        )
        self.lin1 = nn.Linear(channels * 3, channels_hidden)
        self.lin2 = nn.Linear(channels_hidden, channels)

    def forward(self, state: torch.Tensor, iterations: int = 32):
        for it in range(iterations):
            
            s1 = F.conv2d(state, self.sobel1, padding=1, groups=self.channels, )
            s2 = F.conv2d(state, self.sobel1, padding=1, groups=self.channels)
            s = torch.concat([state, s1, s2], dim=-3)
            
            s = self.lin1(s.transpose(1, 3))
            s = F.relu6(s)
            s = self.lin2(s).transpose(1, 3)
            
            next_state = state + s#self.activation(s)
            
            #mask = ~(next_state[..., 4:5, :, :] > .1).expand(-1, self.channels, -1, -1)
            #next_state[mask] = next_state[mask] * .1

            state = next_state
            
        return state
            
    def generate(self, state: Optional[torch.Tensor] = None, shape: Optional[Tuple[int, int]] = None, **kwargs):
        if state is not None:
            shape = state.shape[-2:]
        if shape is None:
            shape = self.default_shape
        if shape is None:
            raise ValueError("Need to define `shape` or `default_shape` in constructor")
        if state is None:
            state = torch.zeros(1, model.channels, *shape).to(self.sobel1)
            state[..., :, shape[-2] // 2, shape[-1] // 2] = 1
        image = torch.tanh(self(state, **kwargs)) * .5 + .5
        image = image[:, :3, :, :] * image[:, 3:4, :, :]
        return image
        
model = NeuralCA()
print(f"params: {num_module_parameters(model, trainable=True):,} / {num_module_parameters(model):,}")
#image = model(image, iterations=32)
image = model.generate(shape=(100, 100), iterations=20)
VF.to_pil_image(resize(image[0, :3, :, :], 5))

# training

In [None]:
def train_image(
        model: nn.Module, 
        image: str,
        shape: Tuple[int, int] = (64, 64),
        epochs: int = 1000,
        device: str = "auto",
):
    target_image = VF.to_tensor(PIL.Image.open(Path(image).expanduser()))[:3]
    assert target_image.shape[-3] >= 3
    target_image = image_resize_crop(target_image, shape) 
    display(VF.to_pil_image(target_image))
    
    device = to_torch_device(device)
    model.to(device)
    target_image = target_image.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=.0001)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, .999)

    print(f"params: {num_module_parameters(model):,}")

    seed_image = torch.zeros(1, model.channels, *shape).to(device)
    seed_image[..., :, seed_image.shape[-2] // 2, seed_image.shape[-1] // 2] = 1

    history = {"loss": [], "lr": []}
    images = []
    try:
        for epoch in tqdm(range(epochs)):
    
            image = model.generate(
                seed_image,
                iterations=random.randint(32, 96),
            )
            #image = torch.tanh(model(seed_image)) * .5 + .5
            #image = image[:, :3, :, :] * image[:, 3:4, :, :]
            
            loss = F.mse_loss(image[:, :3], target_image.unsqueeze(0).expand(image.shape[0], -1, -1, -1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
    
            history["loss"].append(float(loss))
            lr = scheduler.get_last_lr()
            if isinstance(lr, (list, tuple)):
                lr = lr[0]
            history["lr"].append(lr)

            if epoch % 5 == 0:
                images.append(image[0])

            if len(images) >= 10:
                display(VF.to_pil_image(make_grid(images, nrow=len(images))))
                images.clear()
            
    except KeyboardInterrupt:
        pass
    
    display(px.line(pd.DataFrame(history)))
    
model = NeuralCA()
train_image(
    model,
    #"~/Pictures/kali.png",
    "~/Pictures/matt.png",
)

In [None]:
with torch.no_grad():
    display(VF.to_pil_image(resize(
        model.generate((64, 64))[0]
    , 4)))