In [None]:
import sys
sys.path.append("..")

import random
import math
import time
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for idx, entry in enumerate(tqdm(iterable, total=total)):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry) if callable(label) else idx)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

In [None]:
class TotalCALayer(nn.Module):
    def __init__(
            self,
            birth: Union[None, torch.Tensor, Iterable[int]] = None,
            survive: Union[None, torch.Tensor, Iterable[int]] = None,
            iterations: int = 1,
            learn_kernel: bool = False,
            learn_rules: bool = False,
            wrap: bool = True,
    ):
        """
        Totalitarian Cellular Automaton as torch layer.
        
        :param birth: optional birth rule
        :param survive: optional survival rule 
        :param iterations: number of iterations
        :param learn_kernel: do train the neighbourhood kernel 
        :param learn_rules: do train the rules
        :param wrap: if True, edges wrap around
        """
        super().__init__()
        for name, value in (("birth", birth), ("survive", survive)):
            if value is None:
                value = torch.rand(9).bernoulli()
            elif not isinstance(value, torch.Tensor):
                value = torch.Tensor(value)

            if value.shape != torch.Size((9, )):
                raise ValueError(f"Expected `{name}` to have shape (9), got {value.shape}")

            setattr(self, name, nn.Parameter(value, requires_grad=learn_rules))

        self.iterations = iterations
        self.wrap = wrap
        self.kernel = nn.Parameter(torch.Tensor([[[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]]), requires_grad=learn_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ndim = x.ndim
        if ndim not in (3, 4):
            raise ValueError(f"Expected x.ndim == 3 or 4, got {x.shape}")
        if ndim == 3:
            x = x.unsqueeze(0)

        y = x
        ch = x.shape[-3]
        if ch != 1:
            y = y.view(-1, 1, *x.shape[-2:])  # (BxC)xHxW

        for i in range(self.iterations):
            y = self._ca_step(y)

        if ch != 1:
            y = y.view(x.shape)

        return y if ndim == 4 else y.squeeze(0)

    def _ca_step(self, x: torch.Tensor) -> torch.Tensor:
        if self.wrap:
            xp = torch.concat([x[..., -1, None], x, x[..., 0, None]], dim=-1)
            xp = torch.concat([xp[..., -1, None, :], xp, xp[..., 0, None, :]], dim=-2)
            neighbour_count = F.conv2d(xp, self.kernel)
        else:
            neighbour_count = F.conv2d(x, self.kernel, padding=1)

        neighbour_count = neighbour_count.long().clamp(0, 8)

        birth = torch.index_select(self.birth, 0, neighbour_count.flatten(0)).view(x.shape)
        survive = torch.index_select(self.survive, 0, neighbour_count.flatten(0)).view(x.shape)

        return birth * (x < 1.) + survive * (x >= 1.)


inp = torch.zeros(1, 3, 10, 10)
inp[..., :, 3, 0] = 1
inp[..., 0, 0, 4] = 1
inp[..., 0, 5, 5] = 1
inp[..., 0, 5, 6] = 1
inp[..., 0, 6, 6] = 1
inp[..., 0, 0, 9] = 1
print(inp)
TotalCALayer()(inp)

In [None]:
with torch.no_grad():
    ca = TotalCALayer(
        birth=  (0, 0, 0, 1, 0, 0, 0, 0, 0),
        survive=(0, 0, 1, 1, 0, 0, 0, 0, 0),
        iterations=5,
    )
    state = torch.rand(8, 3, 32, 32)
    states = []
    for i in range(20):
        for s in state:
            states.append(s)
        
        state = ca(state)
        
VF.to_pil_image(resize(make_grid(states), 2))