# Distilling Stable Diffusion to Neural Cellular Automata

This code is partly adapted from these two files:  
https://github.com/google-research/self-organising-systems/blob/master/notebooks/%CE%BCNCA_pytorch.ipynb  
https://github.com/ashawkey/stable-dreamfusion/blob/main/sd.py

In [2]:
from pathlib import Path
from datetime import datetime
from contextlib import nullcontext
import random
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
# adapted from https://github.com/ashawkey/stable-dreamfusion/blob/main/sd.py
from stable_diffusion_sds_share import StableDiffusion

In [3]:
import os
os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from mediapy import VideoWriter as VW
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
from tqdm import tnrange

In [4]:
device = "cuda" #"cpu"

ctx = nullcontext() # torch.amp.autocast(device_type=device, dtype=torch.bfloat16) # 

In [5]:
def show(t):
    plt.imshow(t.permute(1,2,0).cpu().numpy())

In [6]:
sd_sds = StableDiffusion(device)

[INFO] loading stable diffusion...
[INFO] loaded stable diffusion!


In [7]:
# μNCA 
# adapted from https://github.com/google-research/self-organising-systems/blob/master/notebooks/%CE%BCNCA_pytorch.ipynb

side = torch.tensor([[0.0, 0.0,0.0], [2.0,-2.0,0.0], [0.0, 0.0,0.0]]) 
sobel_x = torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]) 
lap = torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])

# filters = [lap]*2+[sobel_x]*1+[sobel_x.T]*1  # 68 params
# filters = [lap]*2+[sobel_x]*2+[sobel_x.T]*2 # 150 params
# filters = [lap]*4+[sobel_x]*2+[sobel_x.T]*2  # 264 params
#filters = [lap]*4+[sobel_x]*4+[sobel_x.T]*4  # 588 params
filters = [lap]*8+[sobel_x]*4+[sobel_x.T]*8  # 1620 params

filters = torch.stack(filters)[:,None]
filters = filters.to(device)
CHN = len(filters)

class CA(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = torch.nn.Parameter(torch.randn(CHN, 4*CHN+1, 1, 1)*1e-3)

    def to_rgb(self, x):
        return x[...,:3,:,:]+0.5

    def forward(self, x, update_rate=1.0):
        y = torch.nn.functional.pad(x, [1, 1, 1, 1], 'circular')
        y = torch.nn.functional.conv2d(y, filters, groups=y.shape[1])
        y = torch.cat([x, y], 1)
        y = y = torch.cat([y, y.abs()], 1)
        w, b = self.w[:,:-1], self.w[:,-1,0,0]
        y = torch.nn.functional.conv2d(y, w, b)
        if update_rate<1.0:
            y *= (torch.rand(*y.shape, device=device)+update_rate).floor()
        #return x+y, y
        return y

print('param count:', sum(p.numel() for p in CA().parameters()))

param count: 1620


In [8]:

def make_vid(out_path, ca, iters, x, y):
    step_n = 1
    with VW(out_path, shape=(x,y)) as w, torch.no_grad():
        s = m_state()["state"]
        for k in range(iters):
            img = ca.to_rgb(s)[0].permute(1,2,0).cpu().numpy()
            #print(img.mean())
            w.add_image(img)
            #print(norm_loss(x, 1))
            for i in range(step_n):
                s = s + ca(s, update_rate=0.5)


In [None]:
# todo
# 1. increase model capacity
# 2. (done) use pool to reinit from for long term stability
# 3. use simple ca model?
# 4. make memory usage less?
# 5. is quality affected by using fp/bf16?
# 6. increase resolution

def make_state(batch, channels, x, y):
    #return {"state": torch.rand(batch_size, channels, x, y, device=device)-0.5, "steps": 0}
    return {"state": torch.zeros(batch_size, channels, x, y, device=device), "steps": 0}

def norm_loss(cur_img, scale):
    #above = F.relu( cur_img - 1 ) ** 2
    #below = F.relu( -cur_img ) ** 2
    #return scale * (above.mean() + below.mean())
    return scale * ((cur_img - 0.5) ** 2).mean() # for testing for stable loss

batch_size = 1
pool_size = 32
reset_prob = 0.05
ca_steps = 20 #10
channels = 20 #12

iterations = 4000
norm_loss_scale = 10
sd_loss_scale = 0.005

dim = 128 # 256

seed = 0
random.seed(seed)
torch.manual_seed(seed)
ca =  CA().to(device) #  CAModel().to(device)
text_embeds = sd_sds.get_text_embeds(["raindrops on glass"]*batch_size, [""]*batch_size) 
opt = torch.optim.Adam(ca.parameters(), 1e-4)

m_state = lambda : make_state(batch_size, channels, dim, dim)
pool = [m_state() for _ in range(pool_size)]

run_dir = Path("output/{:%Y_%m_%d_%H_%M_%S}".format(datetime.now()))
run_dir.mkdir(exist_ok=True, parents=True)

for i in tnrange(iterations):
    
    track_grad_mag = 0
    track_norm_loss = 0
    loss_steps = 0
    for p in range(pool_size):
        # gradually increase the number of steps a CA is likely 
        # to run so gradients don't explode
        if random.uniform(0, 1) < reset_prob or i < pool[p]["steps"]:
            pool[p] = m_state()
        else:
            pool[p]["state"] = pool[p]["state"].detach()
        state = pool[p]["state"]
        
        for step in range(ca_steps):
            state = state + ca(state, update_rate=0.5)
            
        pixels = ca.to_rgb(state)
        
        with ctx:
            sd_loss, grad_mag = sd_sds.train_step(
                text_embeds, pixels, guidance_scale=200, override_t=None
            )
        
        sd_loss = sd_loss_scale * sd_loss
        loss = sd_loss
        
        # MSE loss for going outside of pixel range
        total_norm_loss = norm_loss(pixels, norm_loss_scale)
        loss += total_norm_loss

        track_norm_loss += total_norm_loss.item()
        track_grad_mag += sd_loss_scale * grad_mag.item()
        loss_steps += 1

        loss.backward()
        pool[p]["state"] = state
        pool[p]["steps"] += ca_steps
        
    print(f"i{i} step{step} grad_mag: {track_grad_mag / loss_steps:06f} " + 
          f"norm loss: {track_norm_loss / loss_steps:06f} " + 
          f"oldest state: {max([s['steps'] for s in pool])}")
    opt.step()
    opt.zero_grad()
    
    # save a video every 10 iterations (this can add up quickly)
    if i % 10 == 0:
        out_dir = run_dir / "states" 
        out_dir.mkdir(exist_ok=True, parents=True)
        path = out_dir / f"run-{i:06d}.mp4"
        make_vid(path, ca, max([p["steps"] for p in pool]), dim, dim)


  0%|          | 0/4000 [00:00<?, ?it/s]

i0 step19 grad_mag: 0.006064 norm loss: 0.000339 oldest state: 20
i1 step19 grad_mag: 0.005820 norm loss: 0.000453 oldest state: 20
i2 step19 grad_mag: 0.008385 norm loss: 0.000594 oldest state: 20
i3 step19 grad_mag: 0.007729 norm loss: 0.000776 oldest state: 20
i4 step19 grad_mag: 0.008035 norm loss: 0.001006 oldest state: 20
i5 step19 grad_mag: 0.009237 norm loss: 0.001299 oldest state: 20
i6 step19 grad_mag: 0.007377 norm loss: 0.001664 oldest state: 20
i7 step19 grad_mag: 0.008850 norm loss: 0.002107 oldest state: 20
i8 step19 grad_mag: 0.006770 norm loss: 0.002651 oldest state: 20
i9 step19 grad_mag: 0.006065 norm loss: 0.003306 oldest state: 20
i10 step19 grad_mag: 0.007045 norm loss: 0.004085 oldest state: 20
i11 step19 grad_mag: 0.003510 norm loss: 0.004993 oldest state: 20
i12 step19 grad_mag: 0.004306 norm loss: 0.005978 oldest state: 20
i13 step19 grad_mag: 0.004028 norm loss: 0.007128 oldest state: 20
i14 step19 grad_mag: 0.004170 norm loss: 0.008475 oldest state: 20
i15 s

In [None]:
show(ca.to_rgb(state)[0].detach())

In [None]:
ca.to_rgb(state)[0].mean()

In [None]:
class VideoWriter:
    def __init__(self, filename="_autoplay.mp4", fps=30.0, **kw):
        self.writer = None
        self.params = dict(filename=filename, fps=fps, **kw)

    def add(self, img):
        img = np.asarray(img)
        if self.writer is None:
            h, w = img.shape[:2]
            self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
        if img.dtype in [np.float32, np.float64]:
            img = np.uint8(img.clip(0, 1) * 255)
        if len(img.shape) == 2:
            img = np.repeat(img[..., None], 3, -1)
        self.writer.write_frame(img)

    def close(self):
        if self.writer:
            self.writer.close()

    def __enter__(self):
        return self

    def __exit__(self, *kw):
        self.close()
        if self.params["filename"] == "_autoplay.mp4":
            self.show()

    def show(self, **kw):
        self.close()
        fn = self.params["filename"]
        display(mvp.ipython_display(fn, **kw))


class LoopWriter(VideoWriter):
    def __init__(self, *a, fade_len=1.0, **kw):
        super().__init__(*a, **kw)
        self._intro = []
        self._outro = []
        self.fade_len = int(fade_len * self.params["fps"])

    def add(self, img):
        if len(self._intro) < self.fade_len:
            self._intro.append(img)
            return
        self._outro.append(img)
        if len(self._outro) > self.fade_len:
            super().add(self._outro.pop(0))

    def close(self):
        for t in np.linspace(0, 1, len(self._intro)):
            img = self._intro.pop(0) * t + self._outro.pop(0) * (1.0 - t)
            super().add(img)
        super().close()


In [27]:
step_n = 1

with LoopWriter('final_ca.mp4', fade_len=0.0) as vid, torch.no_grad():
    x = make_state(1, 20, 512, 512)["state"] #m_state()["state"]
    for k in tnrange(200, leave=False):
        img = ca.to_rgb(x)[0].detach().permute(1,2,0).cpu().numpy()
        img = np.dot(img, [0.2989, 0.5870, 0.1140])
        vid.add(img)
        #print(norm_loss(x, 1))
        for i in range(step_n):
            x = x + ca(x, update_rate=0.5)
vid.show()

  0%|          | 0/200 [00:00<?, ?it/s]