In [1]:
from pathlib import Path
from datetime import datetime
import random
import torch
from torch import nn
import torch.nn.functional as F
from stable_diffusion_sds import StableDiffusion
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import reduce
from PIL import Image
import numpy as np

  jax.tree_util.register_keypaths(


In [2]:
import os
os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
from tqdm import tqdm_notebook, tnrange

In [3]:
#!pip install kornia
import kornia


In [4]:
device = "cuda"#"cpu"

In [5]:
def show(t):
    plt.imshow(t.permute(1,2,0).cpu().numpy())

In [6]:
sd_sds = StableDiffusion(device)

[INFO] loading stable diffusion...
[INFO] loaded stable diffusion!


In [7]:
# μNCA 
# from https://github.com/google-research/self-organising-systems/blob/master/notebooks/%CE%BCNCA_pytorch.ipynb

side = torch.tensor([[0.0, 0.0,0.0], [2.0,-2.0,0.0], [0.0, 0.0,0.0]]) 
sobel_x = torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]) 
lap = torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])

# filters = [lap]*2+[sobel_x]*1+[sobel_x.T]*1  # 68 params
# filters = [lap]*2+[sobel_x]*2+[sobel_x.T]*2 # 150 params
# filters = [lap]*4+[sobel_x]*2+[sobel_x.T]*2  # 264 params
filters = [lap]*4+[sobel_x]*4+[sobel_x.T]*4  # 588 params

filters = torch.stack(filters)[:,None]
filters = filters.to(device)
CHN = len(filters)

class CA(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = torch.nn.Parameter(torch.randn(CHN, 4*CHN+1, 1, 1)*1e-3)

    def to_rgb(self, x):
        return x[...,:3,:,:]+0.5

    def forward(self, x, update_rate=1.0):
        y = torch.nn.functional.pad(x, [1, 1, 1, 1], 'circular')
        y = torch.nn.functional.conv2d(y, filters, groups=y.shape[1])
        y = torch.cat([x, y], 1)
        y = y = torch.cat([y, y.abs()], 1)
        #y = torch.cat([y.relu(), -(-y).relu()], 1)
        w, b = self.w[:,:-1], self.w[:,-1,0,0]
        y = torch.nn.functional.conv2d(y, w, b)
        if update_rate<1.0:
            y *= (torch.rand(*y.shape)+update_rate).floor()
        #return x+y, y
        return y
    
def seed_f(n, sz=128):
    return torch.rand(n, CHN, sz, sz)-0.5

ca = CA().to(device)
print('param count:', sum(p.numel() for p in ca.parameters()))

param count: 588


In [8]:
# regular nca, implementation from: https://github.com/PWhiddy/Growing-Neural-Cellular-Automata-Pytorch/blob/master/CA_Particles_V3/ca_particles/ca_model.py
# from 
class CAModel(nn.Module):
    
    def __init__(self, env_d, hidden_d, device):
        super(CAModel, self).__init__()
        self.env_d = env_d
        self.conv1 = nn.Conv2d(env_d*3, hidden_d, 1).to(device)
        self.conv2 = nn.Conv2d(hidden_d, env_d, 1).to(device)
        nn.init.zeros_(self.conv2.weight)
        nn.init.zeros_(self.conv2.bias)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        return self.conv2(x)

    def to_rgb(self, x):
        return x[...,:3,:,:]+0.5

In [9]:
def np2pil(a):
    if a.dtype in [np.float32, np.float64]:
        a = np.uint8(np.clip(a, 0, 1)*255)
    return Image.fromarray(a)

In [10]:
#print(torch.cuda.memory_summary())

In [11]:
%%script false --no-raise-error
# dummy img training (random t)
init_img = torch.rand(1,3,512,512, device=device, requires_grad=True)
init_img.mean()
text_embeds = sd_sds.get_text_embeds(["raindrops on glass"]*batch_size, [""]*batch_size) 
opt_t = torch.optim.SGD([init_img], 0.2) #torch.optim.Adam([init_img], 0.1)
for i in tqdm(range(950)):
    # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
    override_t = None#torch.tensor([975-i], dtype=torch.long, device=device)
    loss, grad_mag = sd_sds.train_step(text_embeds, init_img, guidance_scale=50, override_t=override_t)
    loss = loss
    loss.backward()
    opt_t.step()
    opt_t.zero_grad()
result = init_img[0].detach()
show(result)

In [12]:
torch.clip(torch.tensor([0.4]), 0.5, 1)

tensor([0.5000])

In [None]:
# todo
# 1. increase model capacity
# 2. use pool to reinit from for long term stability
# 3. use simple ca model?
# 4. make memory usage less?
# 5. increase resolution

def make_state(batch, channels, x, y):
    #return {"state": torch.rand(batch_size, channels, x, y, device=device)-0.5, "steps": 0}
    s = torch.zeros(batch_size, channels, x, y, device=device)-0.5
    s[:, :, x//2, y//2] = 0.5
    return {"state": s, "steps": 0}

def norm_loss(cur_img, scale):
    #above = F.relu( cur_img - 1 ) ** 2
    #below = F.relu( -cur_img ) ** 2
    #return scale * (above.mean() + below.mean())
    return scale * ((cur_img - 0.5) ** 2).mean() # for testing for stable loss

torch.manual_seed(120)
ca =  CA().to(device) #  CAModel().to(device)
batch_size = 1
pool_size = 32
reset_prob = 0.05
m_state = lambda : make_state(batch_size, 12, 128, 128)
pool = [m_state() for _ in range(pool_size)]
# gorgeous rainforrest brush |  water ripples from above | "red hot lava" | ripples moving out from a drop of water in clear blue, viewed from above
text_embeds = sd_sds.get_text_embeds(["raindrops on glass"]*batch_size, [""]*batch_size) 
opt = torch.optim.Adam(ca.parameters(), 1e-3)
ca_steps = 20
iterations = 6000
norm_loss_scale = 10
sd_loss_scale = 0.001 #0.00005
run_dir = Path("output/{:%Y_%m_%d_%H_%M_%S}".format(datetime.now()))
run_dir.mkdir(exist_ok=True, parents=True)
#res = torch.zeros(3, 32*5,32*5)
for i in tnrange(iterations):
    
    track_grad_mag = 0
    track_norm_loss = 0
    loss_steps = 0
    for p in range(pool_size):
        if random.uniform(0, 1) < reset_prob or i < pool[p]["steps"]:
            pool[p] = m_state()
        else:
            pool[p]["state"] = pool[p]["state"].detach()
        state = pool[p]["state"]
        
        for step in range(ca_steps):
            state = state + ca(state)
            
        sd_loss, grad_mag = sd_sds.train_step(text_embeds, ca.to_rgb(state), guidance_scale=200, override_t=None)
        time_sd_loss_scale = min(1, 0.005*i) ** 2
        sd_loss = time_sd_loss_scale * sd_loss_scale * sd_loss

        # similar to MSE loss for going outside of range
        # https://www.desmos.com/calculator/74bzpyt0ho
        cur_img = ca.to_rgb(state)
        total_norm_loss = norm_loss(cur_img, norm_loss_scale)

        track_norm_loss += total_norm_loss.item()
        track_grad_mag += time_sd_loss_scale * sd_loss_scale * grad_mag.item()
        loss_steps += 1

        loss = sd_loss
        loss += total_norm_loss

        loss.backward()
        pool[p]["state"] = state
        pool[p]["steps"] += ca_steps
        
    print(f"i{i} step{step} grad_mag: {track_grad_mag / loss_steps:06f} norm loss: {track_norm_loss / loss_steps:06f} oldest state: {max([s['steps'] for s in pool])}")
    opt.step()
    opt.zero_grad()
    
    if i % 10 == 0:
        #small = reduce(result, "c (h a) (w b) -> c h w", "mean", a=16, b=16)
        out_dir = run_dir / "states" 
        out_dir.mkdir(exist_ok=True, parents=True)
        path = out_dir / f"state-{i:06d}.jpg"
        np2pil(ca.to_rgb(pool[0]["state"])[0].detach().permute(1,2,0).cpu().numpy()).save(path)


  0%|          | 0/5000 [00:00<?, ?it/s]

i0 step19 grad_mag: 0.000000 norm loss: 2.188672 oldest state: 20
i1 step19 grad_mag: 0.000273 norm loss: 0.390217 oldest state: 20
i2 step19 grad_mag: 0.001500 norm loss: 0.074955 oldest state: 20
i3 step19 grad_mag: 0.003390 norm loss: 0.564858 oldest state: 20
i4 step19 grad_mag: 0.005474 norm loss: 0.442759 oldest state: 20
i5 step19 grad_mag: 0.005056 norm loss: 0.159118 oldest state: 20
i6 step19 grad_mag: 0.008946 norm loss: 0.020747 oldest state: 20
i7 step19 grad_mag: 0.013269 norm loss: 0.007818 oldest state: 20
i8 step19 grad_mag: 0.017806 norm loss: 0.043868 oldest state: 20
i9 step19 grad_mag: 0.020567 norm loss: 0.083728 oldest state: 20
i10 step19 grad_mag: 0.038897 norm loss: 0.108526 oldest state: 20
i11 step19 grad_mag: 0.033834 norm loss: 0.114437 oldest state: 20
i12 step19 grad_mag: 0.039163 norm loss: 0.104395 oldest state: 20
i13 step19 grad_mag: 0.043790 norm loss: 0.084338 oldest state: 20
i14 step19 grad_mag: 0.070230 norm loss: 0.060393 oldest state: 20
i15 s

In [None]:
show(ca.to_rgb(state)[0].detach())

In [None]:
ca.to_rgb(state)[0].mean()

In [48]:
#del ca, state, loss, opt, new_loss

In [47]:
#del sd_sds

In [12]:
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            if torch.numel(obj) > 1024*1024*32: 
                print(type(obj), obj.size())
                #print(obj)
            
    except:
        pass

<class 'torch.nn.parameter.Parameter'> torch.Size([49408, 1024])
<class 'torch.Tensor'> torch.Size([256, 12, 128, 128])
<class 'torch.Tensor'> torch.Size([1, 128, 513, 513])
<class 'torch.Tensor'> torch.Size([1, 128, 513, 513])
<class 'torch.Tensor'> torch.Size([1, 128, 513, 513])




In [46]:
#torch.cuda.empty_cache()

In [21]:
torch.cuda.memory_allocated()

18846907392

In [16]:
def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

class VideoWriter:
  def __init__(self, filename='_autoplay.mp4', fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
    if self.params['filename'] == '_autoplay.mp4':
      self.show()

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))

class LoopWriter(VideoWriter):
  def __init__(self, *a, fade_len=1.0, **kw):
    super().__init__(*a, **kw)
    self._intro = []
    self._outro = []
    self.fade_len = int(fade_len*self.params['fps'])

  def add(self, img):
    if len(self._intro) < self.fade_len:
      self._intro.append(img)
      return
    self._outro.append(img)
    if len(self._outro) > self.fade_len:
      super().add(self._outro.pop(0))
  
  def close(self):
    for t in np.linspace(0, 1, len(self._intro)):
      img = self._intro.pop(0)*t + self._outro.pop(0)*(1.0-t)
      super().add(img)
    super().close()

In [17]:
step_n = 1
with torch.no_grad():
    with LoopWriter('final_ca.mp4', fade_len=0.0) as vid, torch.no_grad():
        x = m_state()["state"]
        #x[:] = 0.0
        #x[:,:,100, 100] = 1.0
        for k in tnrange(600, leave=False):
            img = ca.to_rgb(x)[0].detach().permute(1,2,0).cpu().numpy()
            #print(img.mean())
            vid.add(zoom(img, 2))
            #print(norm_loss(x, 1))
            #step_n = int(min(2**(k/30), 8))
            for i in range(step_n):
                x = x + ca(x)
    vid.show()

  0%|          | 0/600 [00:00<?, ?it/s]

In [43]:
img.mean()

14934.515

In [25]:
norm_loss(x, 1)

tensor(3.4056e+16, device='cuda:0')

In [43]:
del ca, state, loss, opt

NameError: name 'ca' is not defined

In [40]:
x = torch.randn(1,12,128,128, device=device)
for i in range(100):
    x = x + ca(x)