In [None]:
import torch

from controllable_nca.image.emoji_dataset import EmojiDataset

In [None]:
import matplotlib.pyplot as plt
import torch
from einops import rearrange

from controllable_nca.dataset import MultiClass2DDataset
from controllable_nca.utils import load_emoji, rgb


class EmojiDataset(MultiClass2DDataset):
    # EMOJI = '🦎😀💥'
    EMOJI = "🦎😀👁🕸🎄"

    digits = [
        "0030",  # 0
        "0031",  # 1
        "0032",  # 2
        "0033",  # 3
        "0034",  # 4
        "0035",  # 5
        "0036",  # 6
        "0037",  # 7
        "0038",  # 8
        "0039",  # 9
    ]

    def __init__(self, image_size=64):
        emojis = torch.stack(
            [load_emoji(e, image_size) for e in EmojiDataset.EMOJI], dim=0
        )
        targets = torch.arange(emojis.size(0))
        super(EmojiDataset, self).__init__(emojis, targets)
        self.digits = torch.stack(
            [load_emoji(None, image_size, code=e) for e in EmojiDataset.digits], dim=0
        )

    def visualize(self, idx=0):
        self.plot_img(self.x[idx : idx + 1])

    def plot_img(self, img):
        with torch.no_grad():
            rgb_image = rgb(img, False).squeeze().detach().cpu().numpy()
        rgb_image = rearrange(rgb_image, "c w h -> w h c")
        _ = plt.imshow(rgb_image)
        plt.show()


In [None]:
dataset = EmojiDataset()

In [None]:
dataset.visualize(0)

### Import NCA

In [None]:
from controllable_nca.image.nca import ControllableImageNCA

In [None]:
nca =  ControllableImageNCA(target_shape=dataset.target_size(), living_channel_dim=3, num_hidden_channels=8)

In [None]:
nca

In [None]:
sum(dict((p.data_ptr(), p.numel()) for p in nca.parameters()).values())

### Put in Cuda

In [None]:
device = torch.device('cuda')
dataset.to(device)
nca = nca.to(device)

### Trainer

In [None]:
from controllable_nca.image.trainer import ControllableNCAImageTrainer

In [None]:
trainer = ControllableNCAImageTrainer(nca, dataset, nca_steps=[32, 48], lr=1e-3, num_damaged=0, damage_radius=3, device=device, pool_size=1024)

In [None]:
trainer.train(batch_size=5, epochs=50000)