In [None]:
import logging.handlers
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Any, Union, Any
import logging
import time
from IPython.display import clear_output
from dataclasses import dataclass, field
from tqdm.notebook import tqdm

logging.basicConfig(
    level=logging.INFO, 
    format="%(asctime)s - %(name)s.%(levelname)s: %(message)s",
    datefmt="%d-%m-%y %H:%M:%S"
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Using device {device}")

In [None]:
class BaseVecConwaysGOL:
    # (curr_state, #neighbours) -> new_state
    rules = {
        (0, 0): 0,
        (0, 1): 0,
        (0, 2): 0,
        (0, 3): 1,
        (0, 4): 0,
        (0, 5): 0,
        (0, 6): 0,
        (0, 7): 0,
        (0, 8): 0,

        (1, 0): 0,
        (1, 1): 0,
        (1, 2): 1,
        (1, 3): 1,
        (1, 4): 0,
        (1, 5): 0,
        (1, 6): 0,
        (1, 7): 0,
        (1, 8): 0
    }
    # k=2 
    # r=1

    def __init__(self, grid: Union[np.ndarray, torch.Tensor]) -> None:
        self._orig_grid: np.ndarray = (
            grid 
            if isinstance(grid, np.ndarray) 
            else grid.detach().cpu().numpy()
        )
        self._sum_kernel: torch.Tensor = torch.tensor(
            [[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.float32, device=device,
        ).unsqueeze(0).unsqueeze(0)
        self._set_tensor_rules()
        self.reset_grid()

    def reset_grid(self) -> None:
        self.__step: int = 0
        self._pt_grid: torch.Tensor = torch.tensor(
            self._orig_grid, dtype=torch.long, device=device
        ).view(-1, 1, self._orig_grid.shape[-2], self._orig_grid.shape[-1])
    
    @property
    def grid(self) -> np.ndarray:
        return self._pt_grid.squeeze().bool().cpu().numpy()
    
    @property
    def step(self) -> int:
        return self.__step
    
    def _set_tensor_rules(self) -> None:
        self._tensor_rules = torch.zeros((2, 9), dtype=torch.long, device=device)
        for state, num_neighbours in self.rules:
            self._tensor_rules[state, num_neighbours] = self.rules[(state, num_neighbours)]
    
    def _torch_update(self) -> None:
        padded_pt_grid = F.pad(self._pt_grid, (1, 1, 1, 1), "circular")
        sum_grid = F.conv2d(padded_pt_grid.float(), self._sum_kernel, padding=0).long().squeeze()
        
        new_grid_rule_slice = [self._pt_grid.flatten(), sum_grid.flatten()]
        new_grid = self._tensor_rules[new_grid_rule_slice].view_as(self._pt_grid)
        
        self._pt_grid = new_grid
    
    def update_n(self, n: int) -> None:
        for _ in range(n):
            self.update()

    def update(self) -> None:
        self._torch_update()
        self.__step += 1


class VecConwaysGOL(BaseVecConwaysGOL):
    def __init__(self, grid: Union[np.ndarray, torch.Tensor]) -> None:
        if grid.ndim != 2:
            raise ValueError("Grid must be a 2D array")
        if grid.shape[0] < 3 or grid.shape[1] < 3:
            raise ValueError("Grid must be at least 3x3")
        if grid.shape[0] != grid.shape[1]:
            raise ValueError("Grid must be square")
        super().__init__(grid)
    
    def update_n(self, n: int, plot: bool = False, plot_every: int = 1, 
                 sleep_s: float = 0., **plot_kwargs) -> None:
        if plot:
            if sleep_s < 0:
                raise ValueError("Sleep must be non-negative")
            if plot_every < 1:
                raise ValueError("Plot every must be at least 1")
            
            for i in range(n):
                self.update()
                if (i % plot_every) == 0:
                    clear_output(wait=True)
                    self.plot(**plot_kwargs)
                    if sleep_s != 0:
                        time.sleep(sleep_s)
        else:
            for _ in range(n):
                self.update()

    def plot(self, **kwargs) -> None:
        plt.figure(figsize=kwargs.pop("figsize", (10, 10)))
        plt.title(f"Step {self.step}")
        plt.imshow(self.grid * 255, **kwargs)
        plt.show()


class BatchedConwaysGOL(BaseVecConwaysGOL):
    def __init__(self, grids: Union[np.ndarray, torch.Tensor]) -> None:
        if grids.ndim != 3:
            raise ValueError("Grid must be a 3D array")
        if grids.shape[0] < 3 or grids.shape[1] < 3:
            raise ValueError("Grids must be at least 3x3")
        if grids.shape[1] != grids.shape[2]:
            raise ValueError("Grids must be squares")
        super().__init__(grids)

    reset_grids = BaseVecConwaysGOL.reset_grid

    @property
    def grids(self) -> np.ndarray:
        return super().grid

In [None]:
class KnownMethuselahs:
    R_PENTOMINO: np.ndarray = np.array(
        [
            [0, 1, 1],
            [1, 1, 0],
            [0, 1, 0],
        ],
        dtype=int,
    )

    ACORN: np.ndarray = np.array(
        [
            [0, 0, 0, 0, 0, 0, 0],  # Row 5
            [0, 0, 0, 0, 0, 0, 0],  # Row 0
            [0, 1, 0, 0, 0, 0, 0],  # Row 1
            [0, 0, 0, 1, 0, 0, 0],  # Row 2
            [1, 1, 0, 0, 1, 1, 1],  # Row 3
            [0, 0, 0, 0, 0, 0, 0],  # Row 4
            [0, 0, 0, 0, 0, 0, 0],  # Row 6
        ],
        dtype=int,
    )

    PI_HEPTOMINO: np.ndarray = np.array(
        [
            [1, 1, 1],
            [1, 0, 1],
            [1, 0, 1],
        ],
        dtype=int,
    )

    M2513: np.ndarray = np.array(
        [
            [1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1],
            [1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0],
            [0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1],
            [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0],
            [0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1],
            [1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0],
            [0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1],
            [0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0],
            [1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1],
            [1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0],
            [1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0],
            [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
            [1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1],
            [0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1],
            [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0],
            [1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1]
        ],
        dtype=int
    )


class Grider:

    @staticmethod
    def grid(grd: np.ndarray) -> np.ndarray:
        if not isinstance(grd, np.ndarray):
            raise ValueError("grid must be a numpy array")
        if grd.ndim != 2:
            raise ValueError("grid must be a 2D array")
        if grd.shape[0] != grd.shape[1]:
            raise ValueError("grid must be a square array")
        return grd

    @staticmethod
    def get_random_grid(size: int, states: List[int] = [0, 1], 
                        probs: Union[float, List[float]] = 0.7) -> np.ndarray:
        if isinstance(probs, float):
            probs = [probs]
        if len(states) == len(probs) + 1:
            probs.append(1 - sum(probs))
        if len(states) != len(probs):
            raise ValueError("States and probs must have the same length")
        
        return np.random.choice(states, size=(size, size), p=probs).astype(np.uint8)
    
    @classmethod
    def _enter_center_to_grid(cls, grid: np.ndarray, center: np.ndarray) -> None:
        grid = cls.grid(grid)
        center = cls.grid(center)
        if center.shape[0] > grid.shape[0] or center.shape[1] > grid.shape[1]:
            raise ValueError("Center must be smaller than grid")
        
        grid[
            (grid.shape[0] - center.shape[0])//2 : (grid.shape[0] + center.shape[0])//2, 
            (grid.shape[1] - center.shape[1])//2 : (grid.shape[1] + center.shape[1])//2
        ] = center

    @classmethod
    def get_random_center_grid(cls,
            size: int, center_size: int, states: List[int] = [0, 1], 
            probs: Union[float, List[float]] = 0.7
        ) -> np.ndarray:
        grid = np.zeros((size, size), dtype=np.uint8)
        center = cls.get_random_grid(center_size, states, probs)
        cls._enter_center_to_grid(grid, center)

        return grid

    @classmethod
    def get_empty_grid_with_defined_center(cls, size: int, center: np.ndarray) -> np.ndarray:
        grid = np.zeros((size, size), dtype=np.uint8)
        cls._enter_center_to_grid(grid, center)

        return grid

In [None]:
grid = Grider.get_empty_grid_with_defined_center(200, KnownMethuselahs.R_PENTOMINO)
gol = VecConwaysGOL(grid)
gol.plot(figsize=(7, 7), cmap="gray")

In [None]:
gol.update_n(1_500, plot=True, plot_every=99, sleep_s=0., cmap="gray", figsize=(14, 14))

In [None]:
class classproperty:
    """A decorator to create class-level properties."""
    def __init__(self, fget):
        self.fget = fget

    def __get__(self, obj, owner):
        return self.fget(owner)


class MMetric:
    """Base Methuselah Metric class."""
    def __init_subclass__(cls) -> None:
        original_call = cls.__call__

        def new_call(self, *args: Any, **kwds: Any) -> Union[float, np.ndarray]:
            result = original_call(self, *args, **kwds)
            if not isinstance(result, (float, np.ndarray)):
                raise RuntimeError("Metric result must be a number or numpy array")
            if (
                (isinstance(result, float) and result < 0) or
                (isinstance(result, np.ndarray) and np.any(result < 0))
            ):
                raise RuntimeError("Metric result must be non-negative")
            
            return result
        
        cls.__call__ = new_call
    
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        raise NotImplementedError


class MMetricFuncs:

    @classproperty
    def FperI(cls) -> MMetric:
        class _FperI(MMetric):
            def __call__(self, gol: BaseVecConwaysGOL) -> Union[float, np.ndarray]:
                """F/I: Final population per Initial population"""
                grid, orig_grid = gol.grid, gol._orig_grid
                if grid.ndim == 2:
                    grid, orig_grid = grid[np.newaxis, ...], orig_grid[np.newaxis, ...]
                f = gol.grid.sum(axis=(-2, -1))
                i = gol._orig_grid.astype(bool).sum(axis=(-2, -1))
                return f / (i + 1e-7)  # Avoid division by zero. If i == 0, f == 0.
            
            def __repr__(self) -> str:
                return "F/I Metric"
            
        return _FperI()

In [None]:
class BaseMethuselahsFinder:
    __version__: str

    def __init__(self, pop_size: int, center_side_len: int = 8, 
                 grid_side_len: int = 250, methuselah_min_life: int = 100) -> None:
        self.pop_size = pop_size
        self.center_side_len = center_side_len
        self.grid_side_len = grid_side_len
        self.methuselah_min_life = methuselah_min_life
        self.methuselahs: List[np.ndarray] = []
        self.reset_finder()
        self.__pbar: tqdm

    def __repr__(self) -> str:
        return (
            f"MethuselahsFinder_v{self.__version__}(\n"
            f"    pop_size={self.pop_size}, "
            f"center_side_len={self.center_side_len}, " 
            f"grid_side_len={self.grid_side_len}, "
            f"methuselah_min_life={self.methuselah_min_life}\n"
            f"    -------------------------\n"
            f"    #found_methuselahs={len(self.methuselahs)} "
            f"finder_iter={self.iter}\n"
            f")"
        )

    def reset_finder(self) -> None:
        self.candidates = np.array([
            Grider.get_random_grid(self.center_side_len, probs=np.random.rand()) 
            for _ in range(self.pop_size)
        ])
        self.metric_values: List[np.ndarray] = []
        self.__iter: int = 0

    @property
    def iter(self) -> int:
        return self.__iter

    def find(self, iters: int, metric: MMetric, metric_ths: float) -> None:
        self.__pbar = tqdm(range(iters), initial=self.iter, total=self.iter + iters, 
                           desc="Finding Methuselahs")
        for _ in self.__pbar:
            gol = BatchedConwaysGOL(self._expand_candidates())
            gol.update_n(self.methuselah_min_life)
            self._update_candidates(gol, metric, metric_ths)
            self.__iter += 1

    def _expand_candidates(self) -> np.ndarray:
        return np.array([
            Grider.get_empty_grid_with_defined_center(
                self.grid_side_len, candidate
            ) for candidate in self.candidates
        ])
    
    def _update_candidates(self, gol: BaseVecConwaysGOL, metric: MMetric, metric_ths: float) -> None:
        metric_vals: np.ndarray = metric(gol)
        self.metric_values.append(metric_vals)
        self.__pbar.set_postfix({"mean_gen_score": round(metric_vals.mean(), 2), 
                                 "max_gen_score": round(metric_vals.max(), 2)})

        winners = metric_vals >= metric_ths
        if np.any(winners):
            self.__pbar.write(f"Found methuselah(s) with: {metric}={metric_vals[winners]}")
            self.methuselahs.extend(gol._orig_grid[winners])

        mating_probs = metric_vals / (metric_vals.sum() + 1e-7)
        mating_idxs = np.random.choice(np.arange(self.pop_size), size=self.pop_size * 2, p=mating_probs)

        new_candidates = np.zeros_like(self.candidates)
        for i in range(self.pop_size):
            parent1 = self.candidates[mating_idxs[2 * i]]
            parent2 = self.candidates[mating_idxs[2 * i + 1]]
            child = self._mate(parent1, parent2)
            new_candidates[i] = child

        self.candidates = new_candidates
    
    def plot_metric(self, figsize: Tuple[int, int] = (10, 5)) -> None:
        """plot mean and max for each generation"""
        metric_vals = np.array(self.metric_values)
        mean_vals = metric_vals.mean(axis=1)
        max_vals = metric_vals.max(axis=1)
        plt.figure(figsize=figsize)
        plt.plot(mean_vals, label="Mean")
        plt.plot(max_vals, label="Max")
        plt.xlabel("Generation")
        plt.ylabel("Metric Value")
        plt.legend()
        plt.show()


class BaseBaseMethuselahsFinder(BaseMethuselahsFinder):
    def __init_subclass__(cls) -> None:
        if not hasattr(cls, "__version__"):
            raise AttributeError("MethuselahsFinder subclasses must have a __version__ attribute")
        if not hasattr(cls, "_mate"):
            raise AttributeError("MethuselahsFinder subclasses must have a _mate method")
        super().__init_subclass__()


class MethuselahsFinderV1(BaseBaseMethuselahsFinder):
    __version__ = "1"

    def _mate(self, parent1: np.ndarray, parent2: np.ndarray) -> np.ndarray:
        cut = np.random.randint(0, parent1.size)
        child = np.zeros(parent1.size)
        child[:cut] = parent1.flatten()[:cut]
        child[cut:] = parent2.flatten()[cut:]
        
        return child.reshape(parent1.shape)


class MethuselahsFinderV2(BaseBaseMethuselahsFinder):
    __version__ = "2"

    def _mate(self, parent1: np.ndarray, parent2: np.ndarray) -> np.ndarray:
        cut = np.random.randint(0, parent1.size - 1)
        child = np.zeros(parent1.size, dtype=np.uint8)
        child[:cut] = parent1.flatten()[:cut]
        child[cut] = np.random.randint(0, 2, dtype=np.uint8)
        child[cut + 1:] = parent2.flatten()[cut + 1:]
        
        return child.reshape(parent1.shape)

In [None]:
m_finder_v1 = MethuselahsFinderV1(50, 5, 150, 500)
m_finder_v1.find(100, MMetricFuncs.FperI, metric_ths=50)

In [None]:
m_finder_v2 = MethuselahsFinderV2(50, 5, 150, 500)
m_finder_v2.find(100, MMetricFuncs.FperI, metric_ths=50)

In [None]:
m_finder_v2.find(25, MMetricFuncs.FperI, metric_ths=40)

In [None]:
m_finder_v1.plot_metric(figsize=(8, 4))

In [None]:
m_finder_v2.plot_metric(figsize=(8, 4))

In [None]:
gol = VecConwaysGOL(m_finder_v1._expand_candidates()[0])
gol.plot(figsize=(4, 4), cmap="gray")

In [None]:
gol.update_n(1000, plot=True, plot_every=47, sleep_s=0.01, cmap="gray", figsize=(6, 6))

In [None]:
np.unique(m_finder_v2.methuselahs, axis=0).shape

In [None]:
gol = VecConwaysGOL(m_finder_v2.methuselahs[-1])
gol.plot(figsize=(4, 4), cmap="gray")

In [None]:
gol.update_n(1000, plot=True, plot_every=47, sleep_s=0.01, cmap="gray", figsize=(6, 6))