# Morphing Emojis

In [1]:
from controllable_nca.experiments.morphing_image.trainer import MorphingImageNCATrainer
from controllable_nca.experiments.morphing_image.emoji_dataset import EmojiDataset
from controllable_nca.nca import ControllableNCA

import torch

In [None]:
import matplotlib.pyplot as plt
import torch
from einops import rearrange

from controllable_nca.dataset import NCADataset
from controllable_nca.utils import load_emoji, rgb


class EmojiDataset(NCADataset):
    # EMOJI = '🦎😀💥'
    EMOJI = "🦎😀💥👁🐠🦋🐞🕸🥨🎄"
    # EMOJI = "🦎😀👁🕸🥨🎄"

    digits = [
        "0030",  # 0
        "0031",  # 1
        "0032",  # 2
        "0033",  # 3
        "0034",  # 4
        "0035",  # 5
        "0036",  # 6
        "0037",  # 7
        "0038",  # 8
        "0039",  # 9
    ]

    def __init__(self, image_size=64, thumbnail_size=32, use_one_hot: bool = False):
        emojis = torch.stack(
            [load_emoji(e, image_size, thumbnail_size) for e in EmojiDataset.EMOJI],
            dim=0,
        )
        self.emojis = emojis
        self.num_samples = len(self)
        self._target_size = self.emojis.size()[-3:]

    def num_goals(self):
        return self.emojis.size(0)

    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.emojis[idx : idx + 1].clone(), idx
        return self.emojis[idx].clone(), idx

    def __len__(self):
        return self.emojis.size(0)

    def target_size(self):
        if self._target_size is not None:
            return self._target_size
        self._target_size = self.emojis.size()[-3:]
        return self._target_size

    def to(self, device: torch.device):
        self.emojis = self.emojis.to(device)

    def visualize(self, idx=0):
        self.plot_img(self.emojis[idx : idx + 1])

    def plot_img(self, img):
        with torch.no_grad():
            rgb_image = rgb(img, False).squeeze().detach().cpu().numpy()
        rgb_image = rearrange(rgb_image, "c w h -> w h c")
        _ = plt.imshow(rgb_image)
        plt.show()


In [None]:
dataset = EmojiDataset(image_size=64, thumbnail_size=40)

In [None]:
dataset.visualize(9)

In [None]:
dataset.visualize(1)

### Make NCA

In [None]:
dataset.target_size()

In [None]:
from typing import Optional, Tuple  # noqa

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Embedding

from controllable_nca.utils import build_conv2d_net

class DeepEncoder(nn.Module):
    def __init__(self, num_embeddings: int, out_channels: int):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding = Embedding(num_embeddings, 32)
        self.encoder = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, out_channels, bias=False),
        )

    def forward(self, indices):
        embeddings = self.encoder(self.embedding(indices))
        return embeddings

NUM_HIDDEN_CHANNELS = 16

encoder = DeepEncoder(dataset.num_goals(), NUM_HIDDEN_CHANNELS)

nca =  ControllableNCA(num_goals=dataset.num_goals(), use_image_encoder=False, encoder=encoder, target_shape=dataset.target_size(), living_channel_dim=3, num_hidden_channels=NUM_HIDDEN_CHANNELS, cell_fire_rate=0.5)

In [None]:
device = torch.device('cuda')
nca = nca.to(device)
dataset.to(device)

In [None]:
trainer = MorphingImageNCATrainer(nca, dataset, nca_steps=[48, 96], lr=1e-3, num_damaged=0, damage_radius=3, device=device, pool_size=1024)

In [None]:
nca

In [None]:
# trainer.train(batch_size=8, epochs=100000)

In [None]:
nca.load("../saved_models/100k_linear_encoder.pt")

### Interactive Visualization

In [None]:
from threading import Event, Thread

import cv2
import numpy as np
import torch
from einops import rearrange
from ipycanvas import Canvas, hold_canvas  # noqa
from ipywidgets import Button, HBox, VBox

from controllable_nca.utils import create_2d_circular_mask, rgb


def to_numpy_rgb(x, use_rgb=False):
    return rearrange(
        np.squeeze(rgb(x, use_rgb).detach().cpu().numpy()), "c x y -> x y c"
    )


class MorphingImageVisualizer:
    def __init__(
        self,
        trainer,
        image_size,
        rgb: bool = False,
        canvas_scale=5,
        damage_radius: int = 5,
    ):
        self.trainer = trainer
        self.current_state = None
        self.current_goal = None

        self.image_size = image_size
        self.rgb = rgb
        self.canvas_scale = canvas_scale
        self.canvas_size = self.image_size * self.canvas_scale

        self.canvas = Canvas(width=self.canvas_size, height=self.canvas_size)
        self.canvas.on_mouse_down(self.handle_mouse_down)
        self.stopped = Event()

        self.current_goal = torch.tensor(0, device=self.trainer.device)

        self.device = self.trainer.device
        self.damage_radius = damage_radius
        self.current_state = self.trainer.nca.generate_seed(1).to(self.device)

        def button_fn(class_num):
            def start(btn):
                self.current_goal = torch.tensor(class_num, device=self.trainer.device)
                if self.stopped.isSet():
                    self.stopped.clear()
                    Thread(target=self.loop).start()

            return start

        button_list = []
        for i in range(len(self.trainer.target_dataset.EMOJI)):
            button_list.append(Button(description=self.trainer.target_dataset.EMOJI[i]))
            button_list[-1].on_click(button_fn(i))

        self.vbox = VBox(button_list)

        self.stop_btn = Button(description="Stop")

        def stop(btn):
            if not self.stopped.isSet():
                self.stopped.set()

        self.stop_btn.on_click(stop)

    def handle_mouse_down(self, xpos, ypos):
        in_x = int(xpos / self.canvas_scale)
        in_y = int(ypos / self.canvas_scale)

        mask = create_2d_circular_mask(
            self.image_size,
            self.image_size,
            (in_x, in_y),
            radius=self.damage_radius,
        )
        self.current_state[0][:, mask] *= 0.0

    def draw_image(self, rgb):
        with hold_canvas(self.canvas):
            rgb = np.squeeze(rearrange(rgb, "b c w h -> b w h c"))
            self.canvas.clear()  # Clear the old animation step
            self.canvas.put_image_data(
                cv2.resize(
                    rgb * 255.0,
                    (self.canvas_size, self.canvas_size),
                    interpolation=cv2.INTER_NEAREST,
                ),
                0,
                0,
            )

    def loop(self):
        with torch.no_grad():
            self.current_state = self.trainer.nca.generate_seed(1).to(self.device)
            while not self.stopped.wait(0.02):  # the first call is in `interval` secs
                # update_particle_locations()
                self.draw_image(self.trainer.to_rgb(self.current_state))
                self.current_state = self.trainer.nca.grow(
                    self.current_state, 1, self.current_goal
                )

    def visualize(self):
        Thread(target=self.loop).start()
        display(self.canvas, HBox([self.stop_btn, self.vbox]))  # noqa


In [None]:
viz = MorphingImageVisualizer(trainer, 64)

In [None]:
viz.visualize()