In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Any, Union
import logging
import time
from IPython.display import clear_output
from dataclasses import dataclass, field


logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s.%(levelname)s: %(message)s")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Using device {device}")

In [None]:
class VecConwaysGOL:
    # (curr_state, #neighbours) -> new_state
    rules = {
        (0, 0): 0,
        (0, 1): 0,
        (0, 2): 0,
        (0, 3): 1,
        (0, 4): 0,
        (0, 5): 0,
        (0, 6): 0,
        (0, 7): 0,
        (0, 8): 0,

        (1, 0): 0,
        (1, 1): 0,
        (1, 2): 1,
        (1, 3): 1,
        (1, 4): 0,
        (1, 5): 0,
        (1, 6): 0,
        (1, 7): 0,
        (1, 8): 0
    }
    num_states: int = 2
    # r=1 This is hard-coded.

    def __init__(self, grid: Union[np.ndarray, torch.Tensor]) -> None:
        if grid.ndim != 2:
            raise ValueError("Grid must be a 2D array")
        if grid.shape[0] < 3 or grid.shape[1] < 3:
            raise ValueError("Grid must be at least 3x3")
        if grid.shape[0] != grid.shape[1]:
            raise ValueError("Grid must be square")
        
        self._pt_grid: torch.Tensor = torch.tensor(
            grid, dtype=torch.long, device=device, requires_grad=False
        ) 
        self._sum_kernel: torch.Tensor = torch.tensor(
            [[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.float32, device=device, 
            requires_grad=False
        ).unsqueeze(0).unsqueeze(0)
        self._set_tensor_rules()
        self.__step: int = 0

    @property
    def grid(self) -> np.ndarray:
        return self._pt_grid.cpu().numpy() * 256 / (self.num_states - 1)
    
    @property
    def step(self) -> int:
        return self.__step
    
    def _set_tensor_rules(self) -> None:
        self._tensor_rules = torch.zeros((2, 9), dtype=torch.long, device=device, requires_grad=False)
        for state, num_neighbours in self.rules:
            self._tensor_rules[state, num_neighbours] = self.rules[(state, num_neighbours)]
    
    def _torch_update(self) -> None:
        padded_pt_grid = F.pad(self._pt_grid.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), "circular")
        sum_grid = F.conv2d(padded_pt_grid.float(), self._sum_kernel, padding=0).long().squeeze()
        
        new_grid_rule_slice = [self._pt_grid.flatten(), sum_grid.flatten()]
        new_grid = self._tensor_rules[new_grid_rule_slice].view(self._pt_grid.shape)
        
        self._pt_grid = new_grid
    
    def update_n(self, n: int, plot: bool = False, sleep_s: float = 0., **plot_kwargs) -> None:
        if sleep_s < 0:
            raise ValueError("Sleep must be non-negative")
        if plot and sleep_s == 0:
            raise ValueError("If plotting, sleep must be positive")
        
        if plot:
            _figsize = plot_kwargs.pop("figsize", (10, 10))
            for _ in range(n):
                self.update()
                clear_output(wait=True)
                plt.figure(figsize=_figsize)
                plt.imshow(self.grid, **plot_kwargs)
                plt.show()
                time.sleep(sleep_s)
        else:
            for _ in range(n):
                self.update()

    def update(self) -> None:
        self._torch_update()
        self.__step += 1

    def plot(self, plt_show: bool = True, **kwargs) -> None:
        plt.figure(figsize=kwargs.pop("figsize", (10, 10)))
        plt.imshow(self.grid, **kwargs)
        if plt_show:
            plt.show()

In [None]:
class KnownMethuselahs:
    R_PENTOMINO: np.ndarray = np.array([[0, 1, 1], [1, 1, 0], [0, 1, 0]])


class Grider:

    @staticmethod
    def grid(grd: np.ndarray) -> np.ndarray:
        if not isinstance(grd, np.ndarray):
            raise ValueError("grid must be a numpy array")
        if grd.ndim != 2:
            raise ValueError("grid must be a 2D array")
        if grd.shape[0] != grd.shape[1]:
            raise ValueError("grid must be a square array")
        return grd

    @staticmethod
    def get_random_grid(size: int, states: List[int] = [0, 1], 
                        probs: Union[float, List[float]] = 0.7) -> np.ndarray:
        if isinstance(probs, float):
            probs = [probs]
        if len(states) == len(probs) - 1:
            probs.append(1 - sum(probs))
        if len(states) != len(probs):
            raise ValueError("States and probs must have the same length")
        
        return np.random.choice(states, size=(size, size), p=probs).astype(np.uint8)
    
    @classmethod
    def _enter_center_to_grid(cls, grid: np.ndarray, center: np.ndarray) -> None:
        grid = cls.grid(grid)
        center = cls.grid(center)
        if center.shape[0] > grid.shape[0] or center.shape[1] > grid.shape[1]:
            raise ValueError("Center must be smaller than grid")
        
        grid[
            (grid.shape[0] - center.shape[0])//2 : (grid.shape[0] + center.shape[0])//2, 
            (grid.shape[1] - center.shape[1])//2 : (grid.shape[1] + center.shape[1])//2
        ] = center

    @classmethod
    
    def get_random_center_grid(cls,
            size: int, center_size: int, states: List[int] = [0, 1], 
            probs: Union[float, List[float]] = 0.7
        ) -> np.ndarray:
        grid = np.zeros((size, size), dtype=np.uint8)
        center = cls.get_random_grid(center_size, states, probs)
        cls._enter_center_to_grid(grid, center)

        return grid

    @classmethod
    def get_empty_grid_with_defined_center(cls, size: int, center: np.ndarray) -> np.ndarray:
        grid = np.zeros((size, size), dtype=np.uint8)
        cls._enter_center_to_grid(grid, center)

        return grid

In [None]:
grid = Grider.get_empty_grid_with_defined_center(100, KnownMethuselahs.R_PENTOMINO)
gol = VecConwaysGOL(grid)
gol.plot(figsize=(7, 7), cmap="gray")

In [None]:
grid = Grider.get_empty_grid_with_defined_center(250, KnownMethuselahs.R_PENTOMINO)
gol = VecConwaysGOL(grid)
gol.plot(figsize=(7, 7), cmap="gray")

In [None]:
gol.update_n(1100, plot=False, figsize=(7, 7), cmap="gray", sleep_s=0.02)

In [None]:
gol.plot(figsize=(7, 7), cmap="gray")