# Neural CA Grafting

This is a modiffied version of [Self-Organizing Textures NCA](https://distill.pub/selforg/2021/textures) that was created for the [Neural CA Grafting video tutorial](https://youtu.be/Tbe-41HowwY).

*Copyright 2021 Google LLC*

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
#@title Imports and Notebook Utilities
%tensorflow_version 2.x

import os
import io
import PIL.Image, PIL.ImageDraw
import base64
import zipfile
import json
import requests
import numpy as np
import matplotlib.pylab as pl
import glob

from IPython.display import Image, HTML, clear_output
from tqdm import tqdm_notebook, tnrange

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter


def imread(url, max_size=None, mode=None):
  if url.startswith(('http:', 'https:')):
    # wikimedia requires a user agent
    headers = {
      "User-Agent": "Requests in Colab/0.0 (https://colab.research.google.com/; no-reply@google.com) requests/0.0"
    }
    r = requests.get(url, headers=headers)
    f = io.BytesIO(r.content)
  else:
    f = url
  img = PIL.Image.open(f)
  if max_size is not None:
    img.thumbnail((max_size, max_size), PIL.Image.ANTIALIAS)
  if mode is not None:
    img = img.convert(mode)
  img = np.float32(img)/255.0
  return img

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string

def imshow(a, fmt='jpeg'):
  display(Image(data=imencode(a, fmt)))

def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

class VideoWriter:
  def __init__(self, filename='_autoplay.mp4', fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
    if self.params['filename'] == '_autoplay.mp4':
      self.show(loop=True)

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))

!nvidia-smi -L

In [None]:
import torch
import torchvision.models as models

torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [None]:
#@title VGG16-based Style Model
vgg16 = models.vgg16(pretrained=True).features

def calc_styles(imgs):
  style_layers = [1, 6, 11, 18, 25]  
  mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None]
  std = torch.tensor([0.229, 0.224, 0.225])[:,None,None]
  x = (imgs-mean) / std
  grams = []
  for i, layer in enumerate(vgg16[:max(style_layers)+1]):
    x = layer(x)
    if i in style_layers:
      h, w = x.shape[-2:]
      y = x.clone()  # workaround for pytorch in-place modification bug(?)
      gram = torch.einsum('bchw, bdhw -> bcd', y, y) / (h*w)
      grams.append(gram)
  return grams

def style_loss(grams_x, grams_y):
  loss = 0.0
  for x, y in zip(grams_x, grams_y):
    loss = loss + (x-y).square().mean()
  return loss

def to_nchw(img):
  img = torch.as_tensor(img)
  if len(img.shape) == 3:
    img = img[None,...]
  return img.permute(0, 3, 1, 2)

In [None]:
import os

#@title Minimalistic Neural CA
ident = torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]])
sobel_x = torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]])
lap = torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])

def perchannel_conv(x, filters):
  '''filters: [filter_n, h, w]'''
  b, ch, h, w = x.shape
  y = x.reshape(b*ch, 1, h, w)
  y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
  y = torch.nn.functional.conv2d(y, filters[:,None])
  return y.reshape(b, -1, h, w)

def perception(x):
  filters = torch.stack([ident, sobel_x, sobel_x.T, lap])
  return perchannel_conv(x, filters)

class CA(torch.nn.Module):
  def __init__(self, chn=12, hidden_n=96):
    super().__init__()
    self.chn = chn
    self.w1 = torch.nn.Conv2d(chn*4, hidden_n, 1)
    self.w2 = torch.nn.Conv2d(hidden_n, chn, 1, bias=False)
    self.w2.weight.data.zero_()

  def forward(self, x, update_rate=0.5):
    y = perception(x)
    y = self.w2(torch.relu(self.w1(y)))
    b, c, h, w = y.shape
    udpate_mask = (torch.rand(b, 1, h, w)+update_rate).floor()
    return x+y*udpate_mask

  def seed(self, n, sz=128):
    return torch.zeros(n, self.chn, sz, sz)

def to_rgb(x):
  return x[...,:3,:,:]+0.5

param_n = sum(p.numel() for p in CA().parameters())
print('CA param count:', param_n)
if not os.path.exists('init.pt'):
  torch.save(CA(), 'init.pt')
  print('saved init.pt')

In [None]:
#@title targets {vertical-output: true}
style_urls = {
  'dots': 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/dotted/dotted_0090.jpg',
  'chess': 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/chequered/chequered_0121.jpg',
  'bubbles': 'https://www.robots.ox.ac.uk/~vgg/data/dtd/thumbs/bubbly/bubbly_0101.jpg',
}

imgs = [imread(url, max_size=128) for url in style_urls.values()]
imshow(np.hstack(imgs))

In [None]:
#@title Target image {vertical-output: true}
target_name = 'bubbles'
style_img = imread(style_urls[target_name], max_size=128)
with torch.no_grad():
  target_style = calc_styles(to_nchw(style_img))
imshow(style_img)

In [None]:
#@title Loading pretrained models
!wget -nc https://github.com/google-research/self-organising-systems/raw/master/assets/grafting_nca.zip && unzip grafting_nca.zip

In [None]:
#@title setup training
parent = 'dots' # replace this with 'init' to train from scratch
model_name = parent+'_'+target_name
ca = torch.load(parent+'.pt')
opt = torch.optim.Adam(ca.parameters(), 1e-3)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [2000], 0.3)
loss_log = []
with torch.no_grad():
  pool = ca.seed(256)

In [None]:
#@title training loop {vertical-output: true}
for i in range(4000):
  with torch.no_grad():
    batch_idx = np.random.choice(len(pool), 4, replace=False)
    x = pool[batch_idx]
    if i%8 == 0:
      x[:1] = ca.seed(1)
  step_n = np.random.randint(32, 96)
  for k in range(step_n):
    x = ca(x)
  imgs = to_rgb(x)
  styles = calc_styles(imgs)
  overflow_loss = (x-x.clamp(-1.0, 1.0)).abs().sum()
  loss = style_loss(styles, target_style)+overflow_loss
  with torch.no_grad():
    loss.backward()
    for p in ca.parameters():
      p.grad /= (p.grad.norm()+1e-8)   # normalize gradients 
    opt.step()
    opt.zero_grad()
    lr_sched.step()
    pool[batch_idx] = x                # update pool
    
    loss_log.append(loss.item())
    if i%32==0:
      clear_output(True)
      pl.plot(loss_log, '.', alpha=0.1)
      pl.yscale('log')
      pl.ylim(np.min(loss_log), loss_log[0])
      pl.show()
      imgs = to_rgb(x).permute([0, 2, 3, 1]).cpu()
      imshow(np.hstack(imgs))
      torch.save(ca, model_name+'.pt')
    if i%10 == 0:
      print('\rstep_n:', len(loss_log),
        ' loss:', loss.item(), 
        ' lr:', lr_sched.get_lr()[0], end='')


In [None]:
#@title NCA video {vertical-output: true}
def show_ca(ca):
  with VideoWriter() as vid, torch.no_grad():
    x = ca.seed(1, 256)
    for k in tnrange(300, leave=False):
      step_n = min(2**(k//30), 16)
      for i in range(step_n):
        x[:] = ca(x)
      img = to_rgb(x[0]).permute(1, 2, 0).cpu()
      vid.add(zoom(img, 2))

show_ca(torch.load('dots.pt'))
show_ca(torch.load('chess.pt'))

In [None]:
W = 256
with torch.no_grad():
  r = torch.linspace(-1, 1, W)**2
  r = (r+r[:,None]).sqrt()
  mask = ((0.6-r)*8.0).sigmoid()
  pl.contourf(mask.cpu())
  pl.colorbar()
  pl.axis('equal')

In [None]:
ca1 = torch.load('dots_chess.pt')
ca2 = torch.load('dots_bubbles.pt')
with VideoWriter() as vid, torch.no_grad():
  x = torch.zeros([1, ca1.chn, W, W])
  for i in tnrange(600):
    for k in range(8):
      x1, x2 = ca1(x), ca2(x)
      x = x1 + (x2-x1)*mask
    img = to_rgb(x[0]).permute(1, 2, 0).cpu()
    vid.add(zoom(img, 2))

    

In [None]:
def show_pair(fn1, fn2, W=512):
  ca1 = torch.load(fn1)
  ca2 = torch.load(fn2)
  with VideoWriter() as vid, torch.no_grad():
    x = torch.zeros([1, ca1.chn, 128, W])
    mask = 0.5-0.5*torch.linspace(0, 2.0*np.pi, W).cos()
    for i in tnrange(300, leave=False):
      for k in range(8):
        x1, x2 = ca1(x), ca2(x)
        x = x1 + (x2-x1)*mask
      img = to_rgb(x[0]).permute(1, 2, 0).cpu()
      vid.add(zoom(img, 2))

show_pair('chess.pt', 'dots_bubbles.pt')
show_pair('dots_chess.pt', 'dots_bubbles.pt')

In [None]:
show_pair('dots.pt', 'chess.pt')
show_pair('dots.pt', 'dots_chess.pt')

In [None]:
show_pair('dots.pt', 'dots_bubbles.pt')