- [From Conway to Lenia - Notebook](https://colab.research.google.com/github/OpenLenia/Lenia-Tutorial/blob/main/Tutorial_From_Conway_to_Lenia.ipynb#scrollTo=VAt144SoGZZr)

- [NCA Impl - Notebook](https://github.com/Mayukhdeb/differentiable-morphogenesis/blob/main/notebooks/basic_walkthrough.ipynb)

- [ALIFE2023: Flow-Lenia](https://www.youtube.com/watch?v=605DcOMwFLM)

- [Play animation widget](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20List.html#play-animation-widget)

In [1]:
%load_ext autoreload
%autoreload 2

In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from aesthetic_tensor import aesthetify

In [4]:
aesthetify()

In [73]:
class EncodeDecode(nn.Module):
    def __init__(self, msg_size, img_size):
        super().__init__()
        
        self.msg_size = msg_size
        self.img_size = img_size
        
        self.encoder = nn.Sequential(
            nn.Linear(msg_size, 20),
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, self.img_size * self.img_size),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.img_size * self.img_size, 20),
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, msg_size),
        )
        
        self.optim = torch.optim.Adam(self.parameters(), lr=0.01)
        
    def encode(self, msg):
        bs = msg.size(0)
        x = self.encoder(msg)
        x = x.reshape(bs, 1, self.img_size, self.img_size)
        return x
        
    def decode(self, img):
        bs = img.size(0)
        x = img.reshape(bs, self.img_size * self.img_size)
        x = self.decoder(x)
        return x
    
    @staticmethod
    def msg_loss(msg_true, msg_pred):
        return F.mse_loss(msg_true, msg_pred)

In [74]:
def sample_msg(bs, msg_size):
    return torch.rand(bs, msg_size)

In [75]:
msg_size = 16
img_size = 20
bs = 3

msg = sample_msg(bs, msg_size)
encode_decode = EncodeDecode(msg_size, img_size)
encoded_msg = encode_decode.encode(msg)
decoded_msg = encode_decode.decode(encoded_msg)
loss = EncodeDecode.msg_loss(msg, decoded_msg)