# Imports

In [1]:
import sys
import numpy as np
import typing as t
from functools import partial
import matplotlib.pyplot as plt

# VizUtils

In [15]:
import os
os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
from IPython.display import HTML, display, clear_output

class VideoWriter:
  def __init__(self, filename, fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))

# Utils

In [24]:
import torch
import torch.nn.functional as F
from torch.fft import fft2, ifft2, fftshift, ifftshift

def sigmoid(x):
    return 0.5 * (torch.tanh(x / 2) + 1)

def ker_f(x, a, w, b):
    return (b * torch.exp( - (x[..., None] - a)**2 / w)).sum(-1)

def bell(x, m, s):
    return torch.exp(-((x-m)/s)**2 / 2)

def growth(U, m, s):
    return bell(U, m, s)*2-1

kx = torch.tensor([
                [-1., 0., 1.],
                [-2., 0., 2.],
                [-1., 0., 1.]
])[None, None, :, :].double()

ky = torch.transpose(kx, 2, 3).double()



def sobel_x(A):
    A_bat = A[None, None, ...]
    return torch.stack([F.conv2d(A_bat[..., c], kx, padding='same')[...,None]
                    for c in range(A.shape[-1])], dim=-1)[0, 0, ...]

def sobel_y(A):
    A_bat = A[None, None, ...]
    return torch.stack([F.conv2d(A_bat[..., c], ky, padding='same')[...,None]
                    for c in range(A.shape[-1])], dim=-1)[0, 0, ...]

def sobel(A):
  #  A_3d = A[:, :, ,;None]  # add an extra dimension to make A a 3D tensor
    return torch.cat((sobel_y(A), sobel_x(A)), dim=-2)



def conn_from_matrix(mat):
    C = mat.shape[0]
    c0 = []
    c1 = [[] for _ in range(C)]
    i = 0
    for s in range(C):
        for t in range(C):
            n = mat[s, t]
            if n:
                c0 = c0 + [s]*n
                c1[t] = c1[t] + list(range(i, i+n))
            i+=n
    return c0, c1

def conn_from_lists(c0, c1, C):
    return c0, [[i == c1[i] for i in range(len(c0))] for c in range(C)]


# Reintegration Tracking

In [25]:
import torch
from torch.autograd import grad
from functools import partial
import torch.nn.functional as F

class ReintegrationTracking:

    def __init__(self, SX=256, SY=256, dt=.2, dd=5, sigma=.65, has_hidden=False, mix="softmax"):
        self.SX = SX
        self.SY = SY
        self.dt = dt
        self.dd = dd
        self.sigma = sigma
        self.has_hidden = has_hidden
        self.mix = mix

        self.apply = self._build_apply()

    def __call__(self, *args):
        return self.apply(*args)

    def _build_apply(self):

        x, y = torch.arange(self.SX), torch.arange(self.SY)
        X, Y = torch.meshgrid(x, y)
        pos = torch.stack((X, Y), dim=-1) + .5  # (SX, SY, 2)
        dxs = []
        dys = []
        dd = self.dd
        for dx in range(-dd, dd+1):
            for dy in range(-dd, dd+1):
                dxs.append(dx)
                dys.append(dy)
        dxs = torch.tensor(dxs)
        dys = torch.tensor(dys)




        def step(X, mu, dx, dy):
            Xr = torch.roll(X, (dx, dy), dims=(0, 1))
            mur = torch.roll(mu, (dx, dy), dims=(0, 1))
            dpmu = torch.min(torch.stack(
                    [torch.abs(pos[..., None] - (mur + torch.tensor([di, dj])[None, None, :, None]))
                    for di in (-self.SX, 0, self.SX) for dj in (-self.SY, 0, self.SY)]
                ), dim=0)[0]

#            dpmu = torch.absolute(pos[..., None] - mur)


            sz = 0.5 - dpmu + self.sigma
            area = torch.prod(torch.clamp(sz, 0, min(1, 2*self.sigma)) , dim=2) / (4 * self.sigma**2)
            nX = Xr * area

            return nX


        def apply(X, F):
            ma = self.dd - self.sigma  # upper bound of the flow magnitude
            mu = pos[..., None] + torch.clamp(self.dt * F, -ma, ma)  # (x, y, 2, c) : target positions (distribution centers)
            
      #      mu = torch.clamp(mu, self.sigma, self.SX-self.sigma)

            nX = torch.zeros_like(X)  # initialize nX

            for dx, dy in zip(dxs, dys):
                nX += step(X, mu, dx, dy)

            return nX




        return apply


# Flow Lenia

In [26]:
import torch
from dataclasses import dataclass

@dataclass
class Params:
    """Flow Lenia update rule parameters
    """
    r: torch.Tensor
    b: torch.Tensor
    w: torch.Tensor
    a: torch.Tensor
    m: torch.Tensor
    s: torch.Tensor
    h: torch.Tensor
    R: float


class RuleSpace :

    """Rule space for Flow Lenia system

    Attributes:
        kernel_keys (TYPE): Description
        nb_k (int): number of kernels of the system
        spaces (TYPE): Description
    """

    def __init__(self, nb_k: int):
        """
        Args:
            nb_k (int): number of kernels in the update rule
        """
        self.nb_k = nb_k
        self.spaces = {
            "r" : {'low' : .2, 'high' : 1., 'mut_std' : .2, 'shape' : None},
            "b" : {'low' : .001, 'high' : 1., 'mut_std' : .2, 'shape' : (3,)},
            "w" : {'low' : .01, 'high' : .5, 'mut_std' : .2, 'shape' : (3,)},
            "a" : {'low' : .0, 'high' : 1., 'mut_std' : .2, 'shape' : (3,)},
            "m" : {'low' : .05, 'high' : .5, 'mut_std' : .2, 'shape' : None},
            "s" : {'low' : .001, 'high' : .18, 'mut_std' : .01, 'shape' : None},
            "h" : {'low' : .01, 'high' : 1., 'mut_std' : .2, 'shape' : None},
            'R' : {'low' : 2., 'high' : 25., 'mut_std' : .2, 'shape' : None},
        }

    def sample(self, key: torch.Tensor)->Params:
        """sample a random set of parameters

        Returns:
            Params: sampled parameters

        Args:
            key (torch.Tensor): random generation key
        """
        kernels = {}
        for k in 'rmsh':
            kernels[k] = torch.rand((self.nb_k,)) * (self.spaces[k]['high'] - self.spaces[k]['low']) + self.spaces[k]['low']
        for k in "awb":
            kernels[k] = torch.rand((self.nb_k, 3)) * (self.spaces[k]['high'] - self.spaces[k]['low']) + self.spaces[k]['low']
        R = torch.rand(1) * (self.spaces['R']['high'] - self.spaces['R']['low']) + self.spaces['R']['low']
        return Params(R=R.item(), **kernels)



import torch
from torch.fft import fft2, ifft2, fftshift, ifftshift
from dataclasses import dataclass

@dataclass
class CompiledParams:
    """Flow Lenia compiled parameters
    """
    fK: torch.Tensor
    m: torch.Tensor
    s: torch.Tensor
    h: torch.Tensor
    K: torch.Tensor
    nK: torch.Tensor

class KernelComputer:

    """Summary

    Attributes:
        apply (Callable): main function transforming raw params (Params) in compiled ones (CompiledParams)
        SX (int): X size
        SY (int): Y size
    """

    def __init__(self, SX: int, SY: int, nb_k: int):
        """Summary

        Args:
            SX (int): Description
            SY (int): Description
            nb_k (int): Description
        """
        self.SX = SX
        self.SY = SY
        mid = self.SX // 2

        def compute_kernels(params: Params)->CompiledParams:
            """Compute kernels and return a dic containing kernels fft

            Args:
                params (Params): raw params of the system

            Returns:
                CompiledParams: compiled params which can be used as update rule
            """
            

            Ds = [ np.linalg.norm(np.mgrid[-mid:mid, -mid:mid], axis=0) /
                    ((params.R+15) * params.r[k]) for k in range(nb_k) ]  # (x,y,k)
            



            K = torch.stack([sigmoid(-(D-1)*10) * ker_f(D, params.a[k], params.w[k], params.b[k])
                            for k, D in zip(range(nb_k), Ds)])
            
            K=K.transpose(0,2)
        


            

            nK = K / torch.sum(K, dim=(0,1), keepdims=True)
        
            fK = torch.fft.fft2(torch.fft.fftshift(nK, dim=(0,1)), dim=(0,1))

            return CompiledParams(fK=fK, m=params.m, s=params.s, h=params.h, K=K, nK=nK)
        
        self.apply = compute_kernels

    def __call__(self, params: Params):
        """callback to apply
        """
        return self.apply(params)
    
#==================================================================================================================
#==================================================FLOW LENIA======================================================
#==================================================================================================================


@dataclass
class Config :

    """Configuration of Flow Lenia system
    """
    SX: int
    SY: int
    nb_k: int
    C: int
    c0: t.Iterable
    c1: t.Iterable
    dt: float
    dd: int = 5
    sigma: float = .65
    n: int = 2
    theta_A : float = 1.

@dataclass
class State :

    """State of the system
    """
    A: np.ndarray


import torch
from torch.fft import fft2, ifft2, fftshift, ifftshift
from torch.autograd import grad
from functools import partial
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Callable, Tuple


import torch
from torch.autograd import grad
from functools import partial
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Tuple

@dataclass
class State:
    A: torch.Tensor


class FlowLenia:

    """class building the main functions of Flow Lenia

    Attributes:
        config (FL_Config): config of the system
        kernel_computer (KernelComputer): kernel computer
        rollout_fn (Callable): rollout function
        RT (ReintegrationTracking): Description
        rule_space (RuleSpace): Rule space of the system
        step_fn (Callable): system step function
    """

    def __init__(self, config: Config):
        """
        Args:
            config (Config): config of the system
        """
        self.config = config

        self.rule_space = RuleSpace(config.nb_k)

        self.kernel_computer = KernelComputer(self.config.SX, self.config.SY, self.config.nb_k)

        self.RT = ReintegrationTracking(self.config.SX, self.config.SY, self.config.dt,
            self.config.dd, self.config.sigma)

        self.step_fn = self._build_step_fn()

        self.rollout_fn = self._build_rollout()

    def __call__(self, state: State, params: CompiledParams)->State:
        """callback to step function

        Args:
            state (State): Description
            params (CompiledParams): Description

        Returns:
            State: Description
        """
        return self.step_fn(state, params)

    def _build_step_fn(self)->torch.Callable[[State, CompiledParams], State]:
        """Build step function of the system according to config

        Returns:
            torch.Callable[[State, CompiledParams], State]: step function which outputs next state
            given a state and params
        """

        def step(state: State, params: CompiledParams)->State:
            """
            Main step

            Args:
                state (State): state of the system
                params (CompiledParams): params

            Returns:
                State: new state of the system


            """
            #---------------------------Original Lenia------------------------------------
            A = state.A  # (x,y,c)

            fA = torch.fft.fft2(A,dim=(0,1))  # (x,y,c)

            fAk = fA[:, :, self.config.c0]  # (x,y,k)


            #U = jnp.real(jnp.fft.ifft2(params.fK * fAk, axes=(0,1)))  # (x,y,k)

            U = torch.fft.ifft2(params.fK * fAk,dim=(0,1) ).real  # (x,y,k)
            
            U = growth(U, params.m, params.s) * params.h  # (x,y,k)

            U = torch.stack([ U[:, :, self.config.c1[c]].sum(axis=-1) for c in range(self.config.C) ], dim=-1)  # (x,y,c)
            
            #-------------------------------FLOW------------------------------------------
            nabla_U = sobel(U) #(x, y, 2, c)   

            nabla_A = sobel(A.sum(axis = -1, keepdims = True).double()) #(x, y, 2, 1)

            alpha = torch.clip((A[:, :, None, :]/self.config.theta_A)**self.config.n, .0, 1.)

            F = nabla_U * (1 - alpha)  - nabla_A * alpha #

            nA = self.RT(A, F)
            
            return State(A=nA)
        
        return step




    def _build_rollout(self) -> Callable[[CompiledParams, State, int], Tuple[State, State]]:
        """build rollout function

        Returns:
            Callable[[CompiledParams, State, int], Tuple[State, State]]: Description
        """


        def rollout(params: CompiledParams, init_state: State, steps: int) -> Tuple[State, State]:
            """Summary

            Args:
                params (CompiledParams): Description
                init_state (State): Description
                steps (int): Description

            Returns:
                Tuple[State, State]: Description
            """
            states = [init_state]
            for _ in range(steps):
                states.append(self.step_fn(states[-1], params))
            return states[0], states

        return rollout


        

# Lenia

# Demo

## Flow Lenia

In [39]:
#@title Initialize state and parameters

#@title Configuration
number_of_kernels = 10 #@param {type:"raw"}
nb_k = number_of_kernels
world_size = 128 #@param {type : "integer"}
SX = SY = world_size
C = 1 # @param {type : "integer"}
dt = 0.2 # @param
theta_A = 2.0 # @param
sigma = 0.6 #@param
M = torch.ones((C, C), dtype=int) * nb_k
nb_k = int(M.sum())
c0, c1 = conn_from_matrix(M)
config = Config(SX=SX, SY=SY, nb_k=nb_k, C=C, c0=c0, c1=c1,
                dt=dt, theta_A=theta_A, dd=5, sigma=sigma)
fl = FlowLenia(config)
# roll_fn = jax.jit(fl.rollout_fn, static_argnums=(2,))  # JAX-specific jitting
roll_fn =fl.rollout_fn

import torch

seed = 21
seed=7
torch.manual_seed(seed)

params_seed = torch.tensor([seed])
state_seed = torch.tensor([seed+1])

params = fl.rule_space.sample(params_seed)

mx, my = SX//2, SY//2  # center coordinates
A0 = torch.zeros((SX, SY, C))
A0[mx-20:mx+20, my-20:my+20, :] = torch.tensor(
     np.random.rand(40, 40, 1) * 0.5 + 0.5
)
#plot a square
#A0[mx-5:mx+5, my-5:my+5, :] = 1.0
state = State(A=A0)


In [40]:
nb_k

10

In [29]:
nb_k

90

In [30]:
c_params = fl.kernel_computer(params=params)

In [31]:
c_params

CompiledParams(fK=tensor([[[ 1.0000+0.0000e+00j,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
           ...,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
           1.0000+0.0000e+00j],
         [ 0.8749+3.8687e-18j,  0.9636+1.7980e-18j,  0.8553+4.9595e-18j,
           ...,  0.8148+9.3501e-18j,  0.8836+6.1311e-18j,
           0.7604+1.2435e-17j],
         [ 0.5630+3.8736e-18j,  0.8603-2.6840e-18j,  0.5142+4.7417e-18j,
           ...,  0.4019+2.9430e-18j,  0.5921+5.3545e-18j,
           0.2718+3.0010e-18j],
         ...,
         [ 0.2136+1.9660e-17j,  0.7059+1.6157e-17j,  0.1748+2.2579e-17j,
           ...,  0.0460+1.8716e-17j,  0.2617+1.9020e-17j,
          -0.0619+2.2773e-17j],
         [ 0.5630-6.5341e-18j,  0.8603-1.3890e-18j,  0.5142-5.5907e-18j,
           ...,  0.4019-4.1494e-18j,  0.5921-6.3090e-18j,
           0.2718-2.8889e-18j],
         [ 0.8749-4.7683e-18j,  0.9636-5.2495e-18j,  0.8553-3.4638e-18j,
           ...,  0.8148-9.4201e-18j,  0.8836-6.1203e-18j,
           0.7604-1.27

In [32]:
#@title Utils
def state2img(A):
    C = A.shape[-1]
    if C == 1:
        return A[..., 0]
    if C == 2:
        return np.dstack([A[..., 0], A[..., 0], A[..., 1]])
    return A[..., :3]

In [33]:
#@title Simulate {vertical-output : true}
# Collect rollout and visualize
T = 100 #@param {type : 'integer'}
final_state, states = roll_fn(c_params, state, T)


with VideoWriter("example.mp4", 10) as vid:
    for i in range(T):
        vid.add(state2img(states[i].A))
        
        #sum of the state
     #   print(states[i].A.sum())

    vid.show(width = 360, heigth = 360)

In [7]:
SX=128
SY=128
C=1
import math
import torch

scale_init_state = 2

init_state = torch.zeros(
    SY,
    SX,
    C,
    dtype=torch.float64,
    requires_grad=False,
)

scaled_SY = SY // scale_init_state
scaled_SX = SX // scale_init_state



print("rangeX",SY // 2
    - math.ceil(scaled_SY / 2), SY // 2 + scaled_SY // 2)
print("rangeY",SX // 2
    - math.ceil(scaled_SX / 2), SX // 2 + scaled_SX // 2)

newstate= torch.ones(
    scaled_SY,
    scaled_SX,
    C,
    dtype=torch.float64,
    requires_grad=False,
)

print(init_state.shape)

rangeX 32 96
rangeY 32 96
torch.Size([128, 128, 1])


In [9]:
init_state[
    SY // 2
    - math.ceil(scaled_SY / 2) : SY // 2 + scaled_SY // 2,
    SX // 2
    - math.ceil(scaled_SX / 2) : SX // 2 + scaled_SX // 2,
    :,
] = newstate

print(init_state.shape)

torch.Size([128, 128, 1])
