# Adversarial Reprogramming of Growing Neural Cellular Automata

This notebook contains code to reproduce experiments and figures regarding Growing CAs for the "Adversarial Reprogramming of Neural Cellular Automata" article.

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Common code for both experiments

In [None]:
#@title Imports and Notebook Utilities
%tensorflow_version 2.x

import os
import io
import PIL.Image, PIL.ImageDraw
import base64
import zipfile
import json
import requests
import numpy as np
import matplotlib.pylab as pl
import glob

import tensorflow as tf

from IPython.display import Image, HTML, clear_output
import tqdm

import os
os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
clear_output()

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string

def imshow(a, fmt='jpeg'):
  display(Image(data=imencode(a, fmt)))

def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

class VideoWriter:
  def __init__(self, filename, fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()

In [None]:
#@title Cellular Automata Parameters
CHANNEL_N = 16        # Number of CA state channels
TARGET_SIZE = 40
BATCH_SIZE = 8
POOL_SIZE = 1024
CELL_FIRE_RATE = 0.5

EXPERIMENT_TYPE = "Regenerating" #@param ["Persistent", "Regenerating"]
EXPERIMENT_MAP = {"Persistent":0, "Regenerating":1}
EXPERIMENT_N = EXPERIMENT_MAP[EXPERIMENT_TYPE]

DAMAGE_N = [0, 3][EXPERIMENT_N]  # Number of patterns to damage in a batch

In [None]:
#@title CA Model and Utilities
#@markdown This model doesn't have a bias in the last layer.
from tensorflow.keras.layers import Conv2D

def load_image(url, max_size=TARGET_SIZE):
  r = requests.get(url)
  img = PIL.Image.open(io.BytesIO(r.content))
  img.thumbnail((max_size, max_size), PIL.Image.ANTIALIAS)
  img = np.float32(img)/255.0
  # premultiply RGB by Alpha
  img[..., :3] *= img[..., 3:]
  return img

def to_rgba(x):
  return x[..., :4]

def to_alpha(x):
  return tf.clip_by_value(x[..., 3:4], 0.0, 1.0)

def to_rgb(x):
  # assume rgb premultiplied by alpha
  rgb, a = x[..., :3], to_alpha(x)
  return 1.0-a+rgb

def get_living_mask(x):
  alpha = x[:, :, :, 3:4]
  return tf.nn.max_pool2d(alpha, 3, [1, 1, 1, 1], 'SAME') > 0.1

def make_seed(size, n=1):
  x = np.zeros([n, size, size, CHANNEL_N], np.float32)
  x[:, size//2, size//2, 3:] = 1.0
  return x


class CAModel(tf.keras.Model):

  def __init__(self, channel_n=CHANNEL_N, fire_rate=CELL_FIRE_RATE):
    super().__init__()
    self.channel_n = channel_n
    self.fire_rate = fire_rate

    self.dmodel = tf.keras.Sequential([
          Conv2D(128, 1, activation=tf.nn.relu),
          Conv2D(self.channel_n, 1, activation=None,
              kernel_initializer=tf.zeros_initializer),
    ])

    self(tf.zeros([1, 3, 3, channel_n]))  # dummy call to build the model

  @tf.function
  def perceive(self, x, angle=0.0):
    identify = np.float32([0, 1, 0])
    identify = np.outer(identify, identify)
    dx = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0  # Sobel filter
    dy = dx.T
    c, s = tf.cos(angle), tf.sin(angle)
    kernel = tf.stack([identify, c*dx-s*dy, s*dx+c*dy], -1)[:, :, None, :]
    kernel = tf.repeat(kernel, self.channel_n, 2)
    y = tf.nn.depthwise_conv2d(x, kernel, [1, 1, 1, 1], 'SAME')
    return y

  @tf.function
  def call(self, x, fire_rate=None, angle=0.0, step_size=1.0):
    pre_life_mask = get_living_mask(x)

    y = self.perceive(x, angle)
    dx = self.dmodel(y)*step_size
    if fire_rate is None:
      fire_rate = self.fire_rate
    update_mask = tf.random.uniform(tf.shape(x[:, :, :, :1])) <= fire_rate
    x += dx * tf.cast(update_mask, tf.float32)

    post_life_mask = get_living_mask(x)
    life_mask = pre_life_mask & post_life_mask
    return x * tf.cast(life_mask, tf.float32)


CAModel().dmodel.summary()

In [None]:
#@title Train Utilities (SamplePool, Model Export, Damage)
from google.protobuf.json_format import MessageToDict
from tensorflow.python.framework import convert_to_constants

class SamplePool:
  def __init__(self, *, _parent=None, _parent_idx=None, **slots):
    self._parent = _parent
    self._parent_idx = _parent_idx
    self._slot_names = slots.keys()
    self._size = None
    for k, v in slots.items():
      if self._size is None:
        self._size = len(v)
      assert self._size == len(v)
      setattr(self, k, np.asarray(v))

  def sample(self, n):
    idx = np.random.choice(self._size, n, False)
    batch = {k: getattr(self, k)[idx] for k in self._slot_names}
    batch = SamplePool(**batch, _parent=self, _parent_idx=idx)
    return batch

  def commit(self):
    for k in self._slot_names:
      getattr(self._parent, k)[self._parent_idx] = getattr(self, k)

@tf.function
def make_circle_masks(n, h, w):
  x = tf.linspace(-1.0, 1.0, w)[None, None, :]
  y = tf.linspace(-1.0, 1.0, h)[None, :, None]
  center = tf.random.uniform([2, n, 1, 1], -0.5, 0.5)
  r = tf.random.uniform([n, 1, 1], 0.1, 0.4)
  x, y = (x-center[0])/r, (y-center[1])/r
  mask = tf.cast(x*x+y*y < 1.0, tf.float32)
  return mask

def export_model(ca, base_fn):
  ca.save_weights(base_fn)

  cf = ca.call.get_concrete_function(
      x=tf.TensorSpec([None, None, None, CHANNEL_N]),
      fire_rate=tf.constant(0.5),
      angle=tf.constant(0.0),
      step_size=tf.constant(1.0))
  cf = convert_to_constants.convert_variables_to_constants_v2(cf)
  graph_def = cf.graph.as_graph_def()
  graph_json = MessageToDict(graph_def)
  graph_json['versions'] = dict(producer='1.14', minConsumer='1.14')
  model_json = {
      'format': 'graph-model',
      'modelTopology': graph_json,
      'weightsManifest': [],
  }
  with open(base_fn+'.json', 'w') as f:
    json.dump(model_json, f)

def generate_pool_figures(pool, step_i):
  tiled_pool = tile2d(to_rgb(pool.x[:49]))
  fade = np.linspace(1.0, 0.0, 72)
  ones = np.ones(72) 
  tiled_pool[:, :72] += (-tiled_pool[:, :72] + ones[None, :, None]) * fade[None, :, None] 
  tiled_pool[:, -72:] += (-tiled_pool[:, -72:] + ones[None, :, None]) * fade[None, ::-1, None]
  tiled_pool[:72, :] += (-tiled_pool[:72, :] + ones[:, None, None]) * fade[:, None, None]
  tiled_pool[-72:, :] += (-tiled_pool[-72:, :] + ones[:, None, None]) * fade[::-1, None, None]
  imwrite('train_log/%04d_pool.jpg'%step_i, tiled_pool)

def visualize_batch(x0, x, step_i):
  vis0 = np.hstack(to_rgb(x0).numpy())
  vis1 = np.hstack(to_rgb(x).numpy())
  vis = np.vstack([vis0, vis1])
  imwrite('train_log/batches_%04d.jpg'%step_i, vis)
  print('batch (before/after):')
  imshow(vis)

def plot_loss(loss_log):
  pl.figure(figsize=(10, 4))
  pl.title('Loss history (log10)')
  pl.plot(np.log10(loss_log), '.', alpha=0.1)
  pl.show()


In [None]:
!wget -O models.zip 'https://github.com/google-research/self-organising-systems/blob/master/assets/growing_ca/models.zip?raw=true'
!unzip -oq models.zip

# available pretrained emoji are:
#EMOJI = 'ü¶éüòÄüí•üëÅüê†ü¶ãüêûüï∏ü•®üéÑ'

def get_model(emoji='ü¶é', fire_rate=0.5, use_pool=1, damage_n=3, run=0,
              prefix='models/', output='model'):
  path = prefix
  assert fire_rate in [0.5, 1.0]
  if fire_rate==0.5:
    path += 'use_sample_pool_%d damage_n_%d '%(use_pool, damage_n)
  elif fire_rate==1.0:
    path += 'fire_rate_1.0 '
  code = hex(ord(emoji))[2:].upper()
  path += 'target_emoji_%s run_index_%d/08000'%(code, run)
  assert output in ['model', 'json']
  if output == 'model':
    ca = CAModel(channel_n=16, fire_rate=fire_rate)
    ca.load_weights(path)
    return ca
  elif output == 'json':
    return open(path+'.json', 'r').read()

def get_local_model(path, output='model'):
  assert output in ['model', 'json']
  if output == 'model':
    ca = CAModel(channel_n=16)
    ca.load_weights(path)
    return ca
  elif output == 'json':
    return open(path+'.json', 'r').read()


## CA Targets

In [None]:
!wget -O growing_ca_target_images.zip 'https://github.com/google-research/self-organising-systems/blob/master/adversarial_reprogramming_ca/assets/growing_ca_target_images.zip?raw=true'
!unzip -oq "growing_ca_target_images.zip" -d "target_pics"
!ls target_pics/

In [None]:
# If you want a new picture, modify it with either:
# - https://pixlr.com/
# - https://www.piskelapp.com/
def load_image_from_file(fp, max_size=TARGET_SIZE):
  img = PIL.Image.open(fp)
  img.thumbnail((max_size, max_size), PIL.Image.ANTIALIAS)
  img = np.float32(img)/255.0
  # premultiply RGB by Alpha
  img[..., :3] *= img[..., 3:]
  return img

pic_prefix = "target_pics/"
lizard_no_tail_fp = pic_prefix + "lizard_no_tail.png"
lizard_no_leg_fp = pic_prefix + "lizard_no_leg.png"
lizard_no_head_fp = pic_prefix + "lizard_no_head.png"
lizard_no_arm_fp = pic_prefix + "lizard_no_arm.png"
lizard_red_fp = pic_prefix + "lizard_red.png"
lizard_blue_fp = pic_prefix + "lizard_blue.png"
target_map = {
    "liz_no_tail": lizard_no_tail_fp,
    "lizard_no_leg": lizard_no_leg_fp,
    "lizard_no_head": lizard_no_head_fp,
    "lizard_no_arm": lizard_no_arm_fp,
    "lizard_red": lizard_red_fp,
    "lizard_blue": lizard_blue_fp}

lizard_complete_fp = pic_prefix + "lizard_complete.png"

#butterfly_all_orange_fp = pic_prefix + "butterfly_all_orange.png"
#butterfly_left_orange_fp = pic_prefix + "butterfly_left_orange.png"


# Experiment: Virus takeover

We get a pretrained model (lizard) and insert a small number of cells running a different CA rule. This CA needs to reprogram the entire CA group to create a different lizard.

In [None]:
target_label = 'liz_no_tail' #@param ["liz_no_tail", "liz_no_leg", "liz_no_head", "lizard_no_arm", "lizard_red", "lizard_blue"]

target_img = load_image_from_file(target_map[target_label])
imshow(zoom(to_rgb(target_img), 2), fmt='png')

In [None]:
def generate_padded_xy(p, target_img, ca):
  pad_target = tf.pad(target_img, [(p, p), (p, p), (0, 0)])
  h, w = pad_target.shape[:2]
  seed = np.zeros([h, w, CHANNEL_N], np.float32)
  seed[h//2, w//2, 3:] = 1.0
  x0_seed = seed[None, ...]
  for i in tf.range(200):
    x0_seed = ca(x0_seed)
  return x0_seed, pad_target

In [None]:
#@title Initialize Training { vertical-output: true}

#@markdown Recommended: 0.1 percentage_virus for no_tail, and 0.6 for red lizard.
percentage_virus = 0.1 #@param [0.1, 0.6] {allow-input: true}

# Number of pixels used to pad the target image border.
# Use 32 for the no_tail experiment
TARGET_PADDING = 32
# YOU CAN use 16 for the red experiment
# TARGET_PADDING = 16

def loss_f(x):
  return tf.reduce_mean(tf.square(to_rgba(x)-pad_target), [-2, -3, -1])

target_emoji = 'ü¶é' # 'ü¶ã' 
ca = get_model(emoji=target_emoji)
x0_seed, pad_target = generate_padded_xy(TARGET_PADDING, target_img, ca)
h, w = pad_target.shape[:2]

print("This is the initial configuration")
imshow(zoom(to_rgb(x0_seed[0]), 2), fmt='png')


# We create a mask: you can't place viruses outside this.
print("Mask visualization.")
pad_target_mask = tf.cast(pad_target[...,3:4] >= 0.1, tf.float32)
pl.imshow(pad_target_mask[:,:,0])
pl.show()

@tf.function
def random_mask_not_out_target(bsize, perc, pad_target_mask):
  h, w = pad_target_mask.shape[0:2]
  mask = tf.cast(
    tf.random.uniform([bsize, h, w, 1]) < perc, 
    tf.float32)
  return mask * pad_target_mask

print("Example target with random mask.")
random_mask = random_mask_not_out_target(1, percentage_virus, pad_target_mask)
imshow(zoom(to_rgb(x0_seed[0] * random_mask[0]), 2), fmt='png')

adv_mask = random_mask
inv_adv_mask = 1. - adv_mask

loss_log = []

lr = 2e-3
lr_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    [2000], [lr, lr*0.1])
trainer = tf.keras.optimizers.Adam(lr_sched)

loss0 = loss_f(x0_seed).numpy()

initial_pool_elems = np.repeat(x0_seed, POOL_SIZE, 0)
initial_random_mask = random_mask_not_out_target(
    POOL_SIZE, percentage_virus, pad_target_mask).numpy()
pool = SamplePool(x=initial_pool_elems, m=initial_random_mask)

# The spooky virus
adversarial_ca = CAModel()

!mkdir -p train_log && rm -f train_log/*

In [None]:
#@title Training Loop {vertical-output: true}

@tf.function
def train_step(x, m):
  iter_n = tf.random.uniform([], 64, 96, tf.int32)
  with tf.GradientTape() as g:
    for i in tf.range(iter_n):
      x_orig = ca(x)
      x_adv = adversarial_ca(x)
      x = x_orig * (1. - m) + x_adv * m
    loss = tf.reduce_mean(loss_f(x))
  grads = g.gradient(loss, adversarial_ca.weights)
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, adversarial_ca.weights))
  return x, loss

for i in range(100000+1):
  batch = pool.sample(BATCH_SIZE)
  x0 = batch.x
  m0 = batch.m
  loss_rank = loss_f(x0).numpy().argsort()[::-1]
  x0 = x0[loss_rank]
  m0 = m0[loss_rank]
  x0[:1] = x0_seed[0]
  m0[:1] = random_mask_not_out_target(
      1, percentage_virus, pad_target_mask)[0].numpy()

  if DAMAGE_N:
    damage = 1.0-make_circle_masks(DAMAGE_N, h, w).numpy()[..., None]
    x0[-DAMAGE_N:] *= damage

  x, loss = train_step(x0, m0)

  batch.x[:] = x
  batch.m[:] = m0
  batch.commit()

  step_i = len(loss_log)
  loss_log.append(loss.numpy())
  
  if step_i%10 == 0:
    generate_pool_figures(pool, step_i)
  if step_i%100 == 0:
    clear_output()
    visualize_batch(x0, x, step_i)
    plot_loss(loss_log)
  if step_i%1000 == 0:
    export_model(adversarial_ca, 'train_log/%06d'%step_i)

  print('\r step: %d, log10(loss): %.3f'%(len(loss_log), np.log10(loss)), end='')

In [None]:
# useful code if you end up interrupting the run.
print(step_i)
export_model(adversarial_ca, 'train_log/%06d'%step_i)

In [None]:
# @title Load saved models from Github
!wget -O growing_ca_adversarial_models.zip 'https://github.com/google-research/self-organising-systems/blob/master/adversarial_reprogramming_ca/assets/growing_ca_adversarial_models.zip?raw=true'
!unzip -oq "growing_ca_adversarial_models.zip" -d "saved_adversarial"

no_tail_ca = get_local_model("saved_adversarial/no_tail_model", output='model')
red_10p_ca = get_local_model("saved_adversarial/red_model", output='model')
red_60p_ca = get_local_model("saved_adversarial/red_model_60perc", output='model')

In [None]:
#@title TensorFlow.js Demo {run:"auto", vertical-output: true}
#@markdown Select "CHECKPOINT" model to load the checkpoint created by running cells from the "Training" section of this notebook

#@markdown Shift-click to seed the pattern
import IPython.display

# These are the original models' parameters. Modify them manually
# if you want to see different behaviors.
model = "\uD83E\uDD8E 1F98E"  #[üòÄ 1F600', 'üí• 1F4A5', 'üëÅ 1F441', 'ü¶é 1F98E', 'üê† 1F420', 'ü¶ã 1F98B', 'üêû 1F41E', 'üï∏ 1F578', 'ü•® 1F968', 'üéÑ 1F384']
model_type = '3 regenerating'  #['1 naive', '2 persistent', '3 regenerating']

code = model.split(' ')[1]
emoji = chr(int(code, 16))
experiment_i = int(model_type.split()[0])-1
use_pool = (0, 1, 1)[experiment_i]
damage_n = (0, 0, 3)[experiment_i]
model_str = get_model(emoji, use_pool=use_pool, damage_n=damage_n, output='json')

adversarial_model_source = "no_tail_model"  #@param ['CHECKPOINT', 'no_tail_model', 'red_model', 'red_model_60perc' ]

if adversarial_model_source != "CHECKPOINT":
  adv_model_str = get_local_model("saved_adversarial/" + adversarial_model_source,output="json")
else:
  last_checkpoint_fn = sorted(glob.glob('train_log/*.json'))[-1]
  adv_model_str = open(last_checkpoint_fn).read()


data_js = '''
  window.GRAPH_URL = URL.createObjectURL(new Blob([`%s`], {type: 'application/json'}));
  window.ADV_GRAPH_URL = URL.createObjectURL(new Blob([`%s`], {type: 'application/json'}));
'''%(model_str, adv_model_str)

display(IPython.display.Javascript(data_js))


IPython.display.HTML('''
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.3.0/dist/tf.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/cash/4.1.2/cash.min.js"></script>

<canvas id='canvas' style="border: 1px solid black; image-rendering: pixelated;"></canvas>
<div><button type="button" id="removeadv">Remove adversaries</button></div>

<div class="boxcontainer">
<input type="checkbox" id="drawadversary" name="drawadversary">
<label for="drawadversary">Draw adversary</label><br>
</div>
<div class="boxcontainer">
<input type="checkbox" id="showadvmask" name="showadvmask">
<label for="showadvmask">Show adversary mask</label><br>
</div>

<script>
  "use strict";
  
  const sleep = (ms)=>new Promise(resolve => setTimeout(resolve, ms));
  
  const parseConsts = model_graph=>{
    const dtypes = {'DT_INT32':['int32', 'intVal', Int32Array],
                    'DT_FLOAT':['float32', 'floatVal', Float32Array]};
    
    const consts = {};
    model_graph.modelTopology.node.filter(n=>n.op=='Const').forEach((node=>{
      const v = node.attr.value.tensor;
      const [dtype, field, arrayType] = dtypes[v.dtype];
      if (!v.tensorShape.dim) {
        consts[node.name] = [tf.scalar(v[field][0], dtype)];
      } else {
        // if there is a 0-length dimension, the exported graph json lacks "size"
        const shape = v.tensorShape.dim.map(d=>(!d.size) ? 0 :parseInt(d.size));
        let arr;
        if (v.tensorContent) {
          const data = atob(v.tensorContent);
          const buf = new Uint8Array(data.length);
          for (var i=0; i<data.length; ++i) {
            buf[i] = data.charCodeAt(i);
          }
          arr = new arrayType(buf.buffer);
        } else {
          const size = shape.reduce((a, b)=>a*b);
          arr = new arrayType(size);
          if (size){
            arr.fill(v[field][0]);
          }
        }
        consts[node.name] = [tf.tensor(arr, shape, dtype)];
      }
    }));
    return consts;
  }
  
  let drawadversaryCkbx = document.getElementById("drawadversary");
  let showadvmaskCkbx = document.getElementById("showadvmask");

  const run = async ()=>{
    const r = await fetch(GRAPH_URL);
    const consts = parseConsts(await r.json());
    
    const model = await tf.loadGraphModel(GRAPH_URL);
    Object.assign(model.weights, consts);
    
    // Adversarial model now.
    const rad = await fetch(ADV_GRAPH_URL);
    const adv_consts = parseConsts(await rad.json());
    const adv_model = await tf.loadGraphModel(ADV_GRAPH_URL);

    console.log("Loaded adv model")
    Object.assign(adv_model.weights, adv_consts);
    
    let seed = new Array(16).fill(0).map((x, i)=>i<3?0:1);
    seed = tf.tensor(seed, [1, 1, 1, 16]);
    
    const D = 96;
    const initState = tf.tidy(()=>{
      const D2 = D/2;
      const a = seed.pad([[0, 0], [D2-1, D2], [D2-1, D2], [0,0]]);
      return a;
    });
    
    const state = tf.variable(initState);
    // this is where we keep track of where is which CA.
    const adv_mask = tf.variable(tf.zeros([1, D, D, 1]));
    // store this to avoid recomputations.
    const orig_mask = tf.variable(tf.ones([1, D, D, 1]));
    const [_, h, w, ch] = state.shape;
    
    const damage = (x, y, r)=>{
      tf.tidy(()=>{
        const rx = tf.range(0, w).sub(x).div(r).square().expandDims(0);
        const ry = tf.range(0, h).sub(y).div(r).square().expandDims(1);
        const mask = rx.add(ry).greater(1.0).expandDims(2);
        state.assign(state.mul(mask));
      });
    }
    
    const plantSeed = (x, y)=>{
      const x2 = w-x-seed.shape[2];
      const y2 = h-y-seed.shape[1];
      if (x<0 || x2<0 || y2<0 || y2<0)
        return;
      tf.tidy(()=>{
        const a = seed.pad([[0, 0], [y, y2], [x, x2], [0,0]]);
        state.assign(state.add(a));
      });
    }

    $('#removeadv').on('click', e=>{
        tf.tidy(()=>{
          adv_mask.assign(tf.zeros([1, D, D, 1]));
          orig_mask.assign(tf.ones([1, D, D, 1]));
        });
    });

    
    const scale = 4;
    
    const canvas = document.getElementById('canvas');
    const ctx = canvas.getContext('2d');
    canvas.width = w;
    canvas.height = h;
    canvas.style.width = `${w*scale}px`;
    canvas.style.height = `${h*scale}px`;

    const getClickPos = e=>{
        const x = Math.floor((e.pageX-e.target.offsetLeft) / scale);
        const y = Math.floor((e.pageY-e.target.offsetTop) / scale);
        return [x, y];
    }

    canvas.onmousedown = e=>{
      const [x, y] = getClickPos(e);
      //const x = Math.floor(e.clientX/scale);
      //const y = Math.floor(e.clientY/scale);
      if (drawadversaryCkbx.checked) {
        // perform surgical insertions!
        tf.tidy(()=> {
          const insertion = tf.ones([1, 1]).pad(
            [[y, h-y-1], [x, w-x-1]]).expandDims(0).expandDims(3);
          const mask = tf.tensor(1.).sub(insertion);
          adv_mask.assign(adv_mask.mul(mask).add(insertion));
          orig_mask.assign(adv_mask.sub(1).mul(-1));
        });
        return;
      }
      if (e.buttons == 1) {
        if (e.shiftKey) {
          plantSeed(x, y);  
        } else {
          damage(x, y, 8);
        }
      }
    }
    canvas.onmousemove = e=>{
      const x = Math.floor(e.clientX/scale);
      const y = Math.floor(e.clientY/scale);
      if (e.buttons == 1 && !e.shiftKey && !drawadversaryCkbx.checked) {
        damage(x, y, 8);
      }
    }

    function step() {
      tf.tidy(()=>{
        const orig_state = model.execute(
            {x:state, fire_rate:tf.tensor(0.5),
            angle:tf.tensor(0.0), step_size:tf.tensor(1.0)}, 
            ['Identity']);
        const adv_state = adv_model.execute(
            {x:state, fire_rate:tf.tensor(0.5),
            angle:tf.tensor(0.0), step_size:tf.tensor(1.0)}, 
            ['Identity']);

        state.assign(orig_state.mul(orig_mask).add(adv_state.mul(adv_mask)));
      });
    }

    const initT = new Date().getTime() / 1000;
    function render() {
      step();

      const imageData = tf.tidy(()=>{
        const rgba = state.slice([0, 0, 0, 0], [-1, -1, -1, 4]);
        const a = state.slice([0, 0, 0, 3], [-1, -1, -1, 1]);
        let img = tf.tensor(1.0).sub(a).add(rgba).mul(255);

        const advred = adv_mask.gather([0], 3) // R
                  .pad([[0, 0], [0, 0], [0, 0], [0, 2]], 0) // GB
                  .pad([[0, 0], [0, 0], [0, 0], [0, 1]], 1) // A
                  .mul(255);
        const seconds = new Date().getTime() / 1000 - initT;
        const t = tf.tensor(seconds).sin().abs();
        const onemt = tf.tensor(1.).sub(t);
        const adv_mask_period = adv_mask.mul(t);
        const img_behind_mask = img.mul(adv_mask);
        const adv_image = advred.mul(adv_mask_period).add(
          img_behind_mask.mul(onemt));
        //const orig_mask_period = orig_mask.mul(tf.tensor(seconds).sin().abs());
        //const adv_mask_period = tf.tensor(1.0).sub(orig_mask_period)
        img = img.mul(orig_mask).add(adv_image);

        const rgbaBytes = new Uint8ClampedArray(img.dataSync());
        return new ImageData(rgbaBytes, w, h);
      });
      ctx.putImageData(imageData, 0, 0);

      requestAnimationFrame(render);
    }
    render();
  }
  run();
  
</script>
''')

## Figures

In [None]:
#@title viz targets
liz_compl = load_image_from_file(lizard_complete_fp)
imshow(zoom(to_rgb(liz_compl), 2), fmt='png')

liz_notail = load_image_from_file(lizard_no_tail_fp)
liz_red = load_image_from_file(lizard_red_fp)
imshow(zoom(to_rgb(np.hstack([liz_notail, liz_red])), 2), fmt='png')

In [None]:
#@title viz current run
import PIL.ImageFont

from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

x = x0_seed
exp_percentage_virus = 0.1
new_mask = random_mask_not_out_target(1, exp_percentage_virus).numpy()
inv_new_mask = 1. - new_mask

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(10001):
    if i%20 == 0:
      # We want to show something different on step 500:
      if i == 500 or i == 1000:
        x_curr = x[0] * inv_new_mask[0]
      else:
        x_curr = x[0]
      vis = zoom(to_rgb(x_curr), 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perc virus: {}%".format(int(exp_percentage_virus * 100))
      perturbcolor = (255,0,0)# if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)
      vid.add(np.uint8(im))
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          vid.add(np.uint8(im))
    x_orig = orig_ca(x)
    x_adv = adversarial_ca(x)
    x = x_orig * inv_new_mask + x_adv * new_mask


mvp.ipython_display(fn, loop=True)

In [None]:
#@title viz no tail saved run
import PIL.ImageFont
from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

orig_ca = get_model(target_emoji)

no_tail_target_img = load_image_from_file(lizard_no_tail_fp)
p = 24
exp_x0_seed, no_tail_pad_target = generate_padded_xy(p, no_tail_target_img, orig_ca)

no_tail_pad_target_mask = tf.cast(no_tail_pad_target[...,3:4] >= 0.1, tf.float32)

exp_percentage_virus = 0.1
num_runs = 6
x =  np.repeat(exp_x0_seed, num_runs, 0)
wtile = 3
new_mask = random_mask_not_out_target(num_runs, exp_percentage_virus, 
                                      no_tail_pad_target_mask).numpy()
inv_new_mask = 1. - new_mask

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(10001):
    if i%20 == 0:
      """
      # We want to show something different on step 500:
      if i == 500 or i == 1000:
        x_curr = x[0] * inv_new_mask[0]
      else:
        x_curr = x[0]
      """
      x_curr = x
      rgb = to_rgb(x_curr)
      # Make it blinking with adversaries!
      adv_rgb_color = tf.constant([[[[1.,0., 0.]]]])
      t = 0.5 + tf.sin(i/20 / 10) / 2.
      adv_mask_period = new_mask * t
      img_behind_mask = rgb * new_mask
      adv_rgb = adv_rgb_color * adv_mask_period + img_behind_mask * (1 - t)

      rgb = rgb * inv_new_mask + adv_rgb

      if num_runs == 1:
        rgb = rgb[0]
      else:
        rgb = tile2d(rgb, w=wtile)



      vis = zoom(rgb, 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perc virus: {}%".format(int(exp_percentage_virus * 100))
      perturbcolor = (255,0,0)# if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)
      vid.add(np.uint8(im))
      """
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          vid.add(np.uint8(im))
      """
    x_orig = orig_ca(x)
    x_adv = no_tail_ca(x)
    x = x_orig * inv_new_mask + x_adv * new_mask


mvp.ipython_display(fn, loop=True)

In [None]:
#@title viz red 10% saved run
import PIL.ImageFont
from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

orig_ca = get_model(target_emoji)
adv_ca = red_10p_ca

red_target_img = load_image_from_file(lizard_red_fp)
p = 24
exp_x0_seed, red_pad_target = generate_padded_xy(p, red_target_img, orig_ca)

red_pad_target_mask = tf.cast(red_pad_target[...,3:4] >= 0.1, tf.float32)

exp_percentage_virus = 0.1
num_runs = 6
x =  np.repeat(exp_x0_seed, num_runs, 0)
wtile = 3
new_mask = random_mask_not_out_target(num_runs, exp_percentage_virus, 
                                      red_pad_target_mask).numpy()
inv_new_mask = 1. - new_mask

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(1001):
    if i%5 == 0:
      
      # We want to show something different on step 500:
      if i == 500 or i == 1000:
        x_curr = x * inv_new_mask
      else:
        x_curr = x
      #x_curr = x

      rgb = to_rgb(x_curr)
      # Make it blinking with adversaries!
      """
      adv_rgb_color = tf.constant([[[[1.,0., 0.]]]])
      t = 0.5 + tf.sin(i/20 / 10) / 2.
      adv_mask_period = new_mask * t
      img_behind_mask = rgb * new_mask
      adv_rgb = adv_rgb_color * adv_mask_period + img_behind_mask * (1 - t)

      rgb = rgb * inv_new_mask + adv_rgb
      """

      if num_runs == 1:
        rgb = rgb[0]
      else:
        rgb = tile2d(rgb)#, w=wtile)



      vis = zoom(rgb, 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perc virus: {}%".format(int(exp_percentage_virus * 100))
      perturbcolor = (255,0,0)# if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)
      vid.add(np.uint8(im))
      
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          vid.add(np.uint8(im))
      
    x_orig = orig_ca(x)
    x_adv = adv_ca(x)
    x = x_orig * inv_new_mask + x_adv * new_mask


mvp.ipython_display(fn, loop=True)

In [None]:
#@title viz red 60% saved run
import PIL.ImageFont
from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

orig_ca = get_model(target_emoji)
red_ca = adversarial_ca

red_target_img = load_image_from_file(lizard_red_fp)
p = 24
exp_x0_seed, red_pad_target = generate_padded_xy(p, red_target_img, orig_ca)

red_pad_target_mask = tf.cast(red_pad_target[...,3:4] >= 0.1, tf.float32)

exp_percentage_virus = 0.60
num_runs = 1
x =  np.repeat(exp_x0_seed, num_runs, 0)
wtile = 3
new_mask = random_mask_not_out_target(num_runs, exp_percentage_virus, 
                                      red_pad_target_mask).numpy()
inv_new_mask = 1. - new_mask

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(4001):
    if (i<300 and i%5 == 0) or i%20==0:
      
      # We want to show something different on step 500:
      if i == 500:
        x_curr = x * inv_new_mask
      else:
        x_curr = x
      #x_curr = x

      rgb = to_rgb(x_curr)
      # Make it blinking with adversaries!
      """
      adv_rgb_color = tf.constant([[[[1.,0., 0.]]]])
      t = 0.5 + tf.sin(i/20 / 10) / 2.
      adv_mask_period = new_mask * t
      img_behind_mask = rgb * new_mask
      adv_rgb = adv_rgb_color * adv_mask_period + img_behind_mask * (1 - t)

      rgb = rgb * inv_new_mask + adv_rgb
      """

      if num_runs == 1:
        rgb = rgb[0]
      else:
        rgb = tile2d(rgb, w=wtile)



      vis = zoom(rgb, 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perc virus: {}%".format(int(exp_percentage_virus * 100))
      perturbcolor = (255,0,0)# if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)
      vid.add(np.uint8(im))
       
      if i == 500:
        # add many frames to effectively pause the video.
        for _ in range(80):
          vid.add(np.uint8(im))
      
    x_orig = orig_ca(x)
    x_adv = red_ca(x)
    x = x_orig * inv_new_mask + x_adv * new_mask


mvp.ipython_display(fn, loop=True)

In [None]:
#@title Vanilla run

orig_ca = get_model(target_emoji)

seed = np.zeros([h, w, CHANNEL_N], np.float32)
seed[h//2, w//2, 3:] = 1.0
seed = seed[None, ...]
x = seed
wtile = 3
fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(201):
    rgb = to_rgb(x)

    rgb = rgb[0]

    vis = zoom(rgb, 4).clip(0, 1)
    #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

    im = np.uint8(vis*255)

    im = PIL.Image.fromarray(im)
    vid.add(np.uint8(im))
    
    x = orig_ca(x)


mvp.ipython_display(fn, loop=True)

# Experiment: State perturbations

We get a pretrained model (lizard) and train a state mutation to make it grow with missing limbs or different colors.

The mutation is applied at all states.

In [None]:
target_label = 'lizard_blue' #@param ["liz_no_tail", "liz_no_leg", "liz_no_head", "lizard_no_arm", "lizard_red", "lizard_blue"]

target_img = load_image_from_file(target_map[target_label])
imshow(zoom(to_rgb(target_img), 2), fmt='png')

In [None]:
# compose picture of all perturbations
perts_fp_list = [lizard_no_tail_fp, lizard_no_leg_fp, lizard_no_head_fp, lizard_no_arm_fp, lizard_red_fp, lizard_blue_fp]
perts_img_list = [to_rgb(load_image_from_file(pfp)) for pfp in perts_fp_list]

imshow(zoom(tile2d(perts_img_list), 4))

In [None]:
#@title Initialize Training { vertical-output: true}

p = TARGET_PADDING
pad_target = tf.pad(target_img, [(p, p), (p, p), (0, 0)])
h, w = pad_target.shape[:2]
seed = np.zeros([h, w, CHANNEL_N], np.float32)
seed[h//2, w//2, 3:] = 1.0

def loss_f(x):
  return tf.reduce_mean(tf.square(to_rgba(x)-pad_target), [-2, -3, -1])

target_emoji = 'ü¶é' # 'ü¶ã'
ca = get_model(emoji=target_emoji)

# Generate the final state to modify.
x0_seed = seed[None, ...]
for i in tf.range(200):
  x0_seed = ca(x0_seed)
imshow(zoom(to_rgb(x0_seed[0]), 2), fmt='png')

loss_log = []

lr = 2e-3
lr_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    [2000], [lr, lr*0.1])
trainer = tf.keras.optimizers.Adam(lr_sched)

loss0 = loss_f(seed).numpy()

initial_pool_elems = np.concatenate([
    np.repeat(seed[None, ...], POOL_SIZE // 2, 0),
    np.repeat(x0_seed, POOL_SIZE // 2, 0)], 0)
pool = SamplePool(x=initial_pool_elems)

# New mutation tensor. it's a 2d matrix we'll treat as symmetric: [numch, numch]
mutation_tensor = tf.Variable(tf.eye(CHANNEL_N))

!mkdir -p train_log && rm -f train_log/*

In [None]:
#@title Training Loop {vertical-output: true}

@tf.function
def train_step(x):
  iter_n = tf.random.uniform([], 64, 96, tf.int32)
  with tf.GradientTape() as g:
    m_upper = tf.linalg.band_part(mutation_tensor, 0, -1)
    mutation_symm_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(
        tf.linalg.diag_part(mutation_tensor))
    for i in tf.range(iter_n):
      # Apply mutation.
      # make sure you don't touch stuff outside.
      lm = tf.cast(get_living_mask(x), tf.float32)
      x = x @ mutation_symm_t
      # saturate mutation past abs(3)
      x = tf.clip_by_value(x, -3., 3.)
      x *= lm
      x = ca(x)
    loss = tf.reduce_mean(loss_f(x))
  grads = g.gradient(loss, [mutation_tensor])
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, [mutation_tensor]))
  return x, loss

for i in range(8000+1):
  batch = pool.sample(BATCH_SIZE)
  x0 = batch.x
  loss_rank = loss_f(x0).numpy().argsort()[::-1]
  x0 = x0[loss_rank]
  x0[:1] = seed
  x0[1:2] = x0_seed[0]
  if DAMAGE_N:
    damage = 1.0-make_circle_masks(DAMAGE_N, h, w).numpy()[..., None]
    x0[-DAMAGE_N:] *= damage

  x, loss = train_step(x0)

  batch.x[:] = x
  batch.commit()

  step_i = len(loss_log)
  loss_log.append(loss.numpy())
  
  if step_i%10 == 0:
    generate_pool_figures(pool, step_i)
  if step_i%100 == 0:
    clear_output()
    visualize_batch(x0, x, step_i)
    plot_loss(loss_log)
    export_model(ca, 'train_log/%04d'%step_i)

  print('\r step: %d, log10(loss): %.3f'%(len(loss_log), np.log10(loss)), end='')

In [None]:
#¬†Save the symmetric matrix, in case you want to use combinations of them.
# showing an example.
m_upper = tf.linalg.band_part(mutation_tensor, 0, -1)
#tail_sym_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(tf.linalg.diag_part(mutation_tensor))
#leg_sym_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(tf.linalg.diag_part(mutation_tensor))
#head_sym_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(tf.linalg.diag_part(mutation_tensor))
#arm_sym_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(tf.linalg.diag_part(mutation_tensor))
red_sym_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(tf.linalg.diag_part(mutation_tensor))
#blue_sym_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(tf.linalg.diag_part(mutation_tensor))


In [None]:
#@title visualize training run
EXP = "default"
import PIL.ImageFont

from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

x = seed[None, ...]
if EXP == "default":
  m_upper = tf.linalg.band_part(mutation_tensor, 0, -1)
  mutation_symm_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(
      tf.linalg.diag_part(mutation_tensor))
elif EXP == "diag":
  # Experiment: only use the diagonal!
  mutation_symm_t = tf.linalg.tensor_diag(
      tf.linalg.diag_part(mutation_tensor))
elif EXP == "eigvec":
  m_upper = tf.linalg.band_part(mutation_tensor, 0, -1)
  mutation_symm_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(
      tf.linalg.diag_part(mutation_tensor))

  eigval, eigvec = np.linalg.eigh(mutation_symm_t)
  eigdelta = eigval - tf.ones([CHANNEL_N])
  eigval = tf.ones([CHANNEL_N]) + eigdelta * 0.5
  mutation_symm_t = eigvec @ tf.linalg.tensor_diag(eigval) @ np.transpose(eigvec) 

perturb_step = lambda i: i < 500 or i >= 1000

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(2500):
    if i<200 or i%5 == 0:
      vis = zoom(to_rgb(x[0]), 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perturbation: {}".format("ON" if perturb_step(i) else "OFF")
      perturbcolor = (255,0,0) if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)
      vid.add(np.uint8(im))
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          vid.add(np.uint8(im))
    if perturb_step(i):
      lm = tf.cast(get_living_mask(x), tf.float32)
      x = x @ mutation_symm_t
      x = tf.clip_by_value(x, -3., 3.)
      x *= lm
    x = ca(x)

mvp.ipython_display(fn, loop=True)

In [None]:
# Create a mosaic of direction mutations!
import PIL.ImageFont

from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

# Generate 9 mutation matrices:
all_mutation_symm_t = []
all_distortions = [1.0, 0.5, 0.3, 0.1, 0.0, -0.1, -0.3, -0.5, -1.0]
for d in all_distortions:
  m_upper = tf.linalg.band_part(mutation_tensor, 0, -1)
  mutation_symm_t = m_upper + tf.transpose(m_upper) - tf.linalg.tensor_diag(
      tf.linalg.diag_part(mutation_tensor))
  
  # d*M + (1-d)*I
  mutation_symm_t = d * mutation_symm_t + (1. - d) * tf.eye(CHANNEL_N)

  all_mutation_symm_t.append(mutation_symm_t)
mutation_symm_t = np.stack(all_mutation_symm_t)
print(mutation_symm_t.shape)

x = np.repeat(seed[None, ...], 9, 0)

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(500):
    if i<200 or i%5 == 0:
      vis = zoom(tile2d(to_rgb(x), 3), 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perturbation: {}".format("ON" if perturb_step(i) else "OFF")
      perturbcolor = (255,0,0) if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)

      # Add mutation direction text.
      for idx, d in enumerate(all_distortions):
        disttext = "direction: {}".format(d)
        if d == 1.0:
          disttext += " (train config)"
        if d == 0.0:
          disttext += " (NOOP)"
        x_unit, y_unit = im.width // 3, im.height // 3
        x_displacement = x_unit // 5
        y_displacement = int(y_unit * 0.9)
        placement = (x_unit * (idx % 3) + x_displacement,
                     y_unit * (idx // 3) + y_displacement)
        draw.text(placement,disttext,(0,0,0),font=font)


      vid.add(np.uint8(im))
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          vid.add(np.uint8(im))
    if perturb_step(i):
      lm = tf.cast(get_living_mask(x), tf.float32)
      x_t = []
      for d_idx in range(len(all_distortions)):
        x_ti = x[d_idx] @ mutation_symm_t[d_idx]
        x_ti = tf.clip_by_value(x_ti, -3., 3.)
        x_t.append(x_ti)
      x = tf.stack(x_t)
      x *= lm
    x = ca(x)

mvp.ipython_display(fn, loop=True)

In [None]:
# @title Save data for external demo usage

perturbations_str = json.dumps(np.stack(
    [tail_sym_t, leg_sym_t, head_sym_t, 
     arm_sym_t, red_sym_t, blue_sym_t]).tolist())

emoji = 'ü¶é'
model_str = get_model(emoji, output='json')

data_js = '''
  window.GRAPH_URL = URL.createObjectURL(new Blob([`%s`], {type: 'application/json'}));
  window.PERTURBATIONS = %s
'''%(model_str, perturbations_str)

with open("gca_data.js", "w") as f:
  f.write(data_js)


In [None]:
# @title Download trained perturbations from Github
!wget -O growing_ca_saved_perturbations.zip 'https://github.com/google-research/self-organising-systems/blob/master/adversarial_reprogramming_ca/assets/growing_ca_saved_perturbations.zip?raw=true'
!unzip -oq "growing_ca_saved_perturbations.zip" -d "saved_perturbations"

In [None]:
# note you can also alternatively just upload the zip file to colab and unzip.
#!unzip -oq "/content/growing_ca_saved_perturbations.zip" -d "saved_perturbations"

In [None]:
# @title Load saved matrices (for demos and visualizations)

load_dir = "saved_perturbations/"

leg_sym_t = np.load(load_dir + "leg_sym_t.npy")
tail_sym_t = np.load(load_dir + "tail_sym_t.npy")
head_sym_t = np.load(load_dir + "head_sym_t.npy")
arm_sym_t = np.load(load_dir + "arm_sym_t.npy")
red_sym_t = np.load(load_dir + "red_sym_t.npy")
blue_sym_t = np.load(load_dir + "blue_sym_t.npy")

perts_mat_list = [tail_sym_t, leg_sym_t, head_sym_t, arm_sym_t, red_sym_t, blue_sym_t]


In [None]:
#@title TensorFlow.js Demo {run:"auto", vertical-output: true}
#@markdown Select "CHECKPOINT" model to load the checkpoint created by running cells from the "Training" section of this notebook
import IPython.display

emoji = 'ü¶é'

perturbations_str = json.dumps(np.stack(
    [tail_sym_t, leg_sym_t, head_sym_t, 
     arm_sym_t, red_sym_t, blue_sym_t]).tolist())

#@markdown Shift-click to seed the pattern

model_str = get_model(emoji, output='json')

data_js = '''
  window.GRAPH_URL = URL.createObjectURL(new Blob([`%s`], {type: 'application/json'}));
  window.PERTURBATIONS = %s
'''%(model_str, perturbations_str)

display(IPython.display.Javascript(data_js))


IPython.display.HTML('''
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.3.0/dist/tf.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/cash/4.1.2/cash.min.js"></script>

<canvas id='canvas' style="border: 1px solid black; image-rendering: pixelated;"></canvas>
<div><button type="button" id="reset">Reset</button></div>
<div class="slidecontainer">
    tailPerturbation:
    <input type="range" min="-1.0" max="1.0" value="0.0" class="slider" 
      step="0.05" id="tailSlider">
    <span id='tailPerturbation'>0.0</span>
</div>
<div class="slidecontainer">
    legPerturbation:
    <input type="range" min="-1.0" max="1.0" value="0.0" class="slider" 
      step="0.05" id="legSlider">
    <span id='legPerturbation'>0.0</span>
</div>
<div class="slidecontainer">
    headPerturbation:
    <input type="range" min="-1.0" max="1.0" value="0.0" class="slider" 
      step="0.05" id="headSlider">
    <span id='headPerturbation'>0.0</span>
</div>
<div class="slidecontainer">
    armPerturbation:
    <input type="range" min="-1.0" max="1.0" value="0.0" class="slider" 
      step="0.05" id="armSlider">
    <span id='armPerturbation'>0.0</span>
</div>
<div class="slidecontainer">
    redPerturbation:
    <input type="range" min="-1.0" max="1.0" value="0.0" class="slider" 
      step="0.05" id="redSlider">
    <span id='redPerturbation'>0.0</span>
</div>
<div class="slidecontainer">
    bluePerturbation:
    <input type="range" min="-1.0" max="1.0" value="0.0" class="slider" 
      step="0.05" id="blueSlider">
    <span id='bluePerturbation'>0.0</span>
</div>
<input type="checkbox" id="forcesum1" name="forcesum1">
<label for="forcesum1"> Force sum leq 1</label><br>
<script>
  "use strict";
  
  const sleep = (ms)=>new Promise(resolve => setTimeout(resolve, ms));
  
  const parseConsts = model_graph=>{
    const dtypes = {'DT_INT32':['int32', 'intVal', Int32Array],
                    'DT_FLOAT':['float32', 'floatVal', Float32Array]};
    
    const consts = {};
    model_graph.modelTopology.node.filter(n=>n.op=='Const').forEach((node=>{
      const v = node.attr.value.tensor;
      const [dtype, field, arrayType] = dtypes[v.dtype];
      if (!v.tensorShape.dim) {
        consts[node.name] = [tf.scalar(v[field][0], dtype)];
      } else {
        const shape = v.tensorShape.dim.map(d=>parseInt(d.size));
        let arr;
        if (v.tensorContent) {
          const data = atob(v.tensorContent);
          const buf = new Uint8Array(data.length);
          for (var i=0; i<data.length; ++i) {
            buf[i] = data.charCodeAt(i);
          }
          arr = new arrayType(buf.buffer);
        } else {
          const size = shape.reduce((a, b)=>a*b);
          arr = new arrayType(size);
          arr.fill(v[field][0]);
        }
        consts[node.name] = [tf.tensor(arr, shape, dtype)];
      }
    }));
    return consts;
  }

  let kTail = 0.0;
  let kLeg = 0.0;
  let kHead = 0.0;
  let kArm = 0.0;
  let kRed = 0.0;
  let kBlue = 0.0;

  let forcesum1Ckbx = document.getElementById("forcesum1");
  let tailSlider = document.getElementById("tailSlider");
  let legSlider = document.getElementById("legSlider");
  let headSlider = document.getElementById("headSlider");
  let armSlider = document.getElementById("armSlider");
  let redSlider = document.getElementById("redSlider");
  let blueSlider = document.getElementById("blueSlider");

  $('#tailSlider').on('input', e=>{
      updateK("tail", parseFloat(e.target.value));
  });
  $('#legSlider').on('input', e=>{
      updateK("leg", parseFloat(e.target.value));
  });
  $('#headSlider').on('input', e=>{
      updateK("head", parseFloat(e.target.value));
  });
  $('#armSlider').on('input', e=>{
      updateK("arm", parseFloat(e.target.value));
  });
  $('#redSlider').on('input', e=>{
      updateK("red", parseFloat(e.target.value));
  });
  $('#blueSlider').on('input', e=>{
      updateK("blue", parseFloat(e.target.value));
  });

  const updateKUnchecked = (kid, v) => {
    if (kid == "tail"){
      kTail = v;
      $('#tailPerturbation').text(kTail);
    } else if (kid == "leg"){
      kLeg = v;
      $('#legPerturbation').text(kLeg);
    } else if (kid == "head"){
      kHead = v;
      $('#headPerturbation').text(kHead);
    } else if (kid == "arm"){
      kArm = v;
      $('#armPerturbation').text(kArm);
    } else if (kid == "red"){
      kRed = v;
      $('#redPerturbation').text(kRed);
    } else if (kid == "blue"){
      kBlue = v;
      $('#bluePerturbation').text(kBlue);
    } else {
      console.log("ERROR!");
    }
  }

  const updateK = (kid, v) => {
      if (forcesum1Ckbx.checked == false) {
        updateKUnchecked(kid, v);
      } else {
        // You cannot go over 1.
        let vAbs = Math.abs(v);
        const vSign = Math.sign(v);
        const kTailAbs = Math.abs(kTail);
        const kTailSign = Math.sign(kTail);
        const kLegAbs = Math.abs(kLeg);
        const kLegSign = Math.sign(kLeg);
        const kHeadAbs = Math.abs(kHead);
        const kHeadSign = Math.sign(kHead);
        const kArmAbs = Math.abs(kArm);
        const kArmSign = Math.sign(kArm);
        const kRedAbs = Math.abs(kRed);
        const kRedSign = Math.sign(kRed);
        const kBlueAbs = Math.abs(kBlue);
        const kBlueSign = Math.sign(kBlue);

        let kCurrAbs;
        if (kid == "tail"){
          kCurrAbs = kTailAbs;
        } else if (kid == "leg"){
          kCurrAbs = kLegAbs;
        } else if (kid == "head"){
          kCurrAbs = kHeadAbs;
        } else if (kid == "arm"){
          kCurrAbs = kArmAbs;
        } else if (kid == "red"){
          kCurrAbs = kRedAbs;
        } else if (kid == "blue"){
          kCurrAbs = kBlueAbs;
        }
        let totK = vAbs + kTailAbs + kLegAbs + kHeadAbs + kArmAbs
              + kRedAbs + kBlueAbs - kCurrAbs;
        if (totK <= 1.0) {
          // No problem here, do just like you did before.
          updateKUnchecked(kid, v);
        } else {
          // Prevent v from going over 1.
          if (vAbs > 1.0) {
            vAbs = 1.0;
            totK = vAbs + kTailAbs + kLegAbs + kHeadAbs + kArmAbs
              + kRedAbs + kBlueAbs - kCurrAbs;
          }
          // Subtract the excess from the rest.
          const excess = totK - 1.0;

          const tailContrib = kid == "tail" ? 0.0 : kTailAbs;
          const legContrib = kid == "leg" ? 0.0 : kLegAbs;
          const headContrib = kid == "head" ? 0.0 : kHeadAbs;
          const armContrib = kid == "arm" ? 0.0 : kArmAbs;
          const redContrib = kid == "red" ? 0.0 : kRedAbs;
          const blueContrib = kid == "blue" ? 0.0 : kBlueAbs;
          const totContrib = tailContrib + legContrib + headContrib +
              armContrib + redContrib + blueContrib;

          let tailDecr = 0.0;
          let legDecr = 0.0;
          let headDecr = 0.0;
          let armDecr = 0.0;
          let redDecr = 0.0;
          let blueDecr = 0.0;
          if (totContrib > 1e-6) {
            tailDecr = tailContrib / totContrib * excess;
            legDecr = legContrib / totContrib * excess;
            headDecr = headContrib / totContrib * excess;
            armDecr = armContrib / totContrib * excess;
            redDecr = redContrib / totContrib * excess;
            blueDecr = blueContrib / totContrib * excess;
          }

          kTail = kid == "tail" ? vAbs * vSign : kTailSign * (kTailAbs - tailDecr);
          kLeg = kid == "leg" ? vAbs * vSign : kLegSign * (kLegAbs - legDecr);
          kHead = kid == "head" ? vAbs * vSign : kHeadSign * (kHeadAbs - headDecr);
          kArm = kid == "arm" ? vAbs * vSign : kArmSign * (kArmAbs - armDecr);
          kRed = kid == "red" ? vAbs * vSign : kRedSign * (kRedAbs - redDecr);
          kBlue = kid == "blue" ? vAbs * vSign : kBlueSign * (kBlueAbs - blueDecr);
          $('#tailPerturbation').text(kTail);
          tailSlider.value = kTail;
          $('#legPerturbation').text(kLeg);
          legSlider.value = kLeg;
          $('#headPerturbation').text(kHead);
          headSlider.value = kHead;
          $('#armPerturbation').text(kArm);
          armSlider.value = kArm;
          $('#redPerturbation').text(kRed);
          redSlider.value = kRed;
          $('#bluePerturbation').text(kBlue);
          blueSlider.value = kBlue;
        }
      }
      updatePerturbation();
  }


  let perturbationMatrix = tf.eye(16);

  const perturbations = tf.tensor(PERTURBATIONS);
  const tailPertM = perturbations.gather([0]).squeeze();
  const legPertM = perturbations.gather([1]).squeeze();
  const headPertM = perturbations.gather([2]).squeeze();
  const armPertM = perturbations.gather([3]).squeeze();
  const redPertM = perturbations.gather([4]).squeeze();
  const bluePertM = perturbations.gather([5]).squeeze();

  const I = tf.eye(16);
  const updatePerturbation = () => {
      let kI = 1.0 - kTail - kLeg - kHead - kArm - kRed - kBlue;
      perturbationMatrix = I.mul(kI).
        add(tailPertM.mul(kTail)).
        add(legPertM.mul(kLeg)).
        add(headPertM.mul(kHead)).
        add(armPertM.mul(kArm)).
        add(redPertM.mul(kRed)).
        add(bluePertM.mul(kBlue));
  }

  const run = async ()=>{
    const r = await fetch(GRAPH_URL);
    const consts = parseConsts(await r.json());
    
    const model = await tf.loadGraphModel(GRAPH_URL);
    Object.assign(model.weights, consts);

    
    let seed = new Array(16).fill(0).map((x, i)=>i<3?0:1);
    seed = tf.tensor(seed, [1, 1, 1, 16]);
    
    const D = 96;
    const initState = tf.tidy(()=>{
      const D2 = D/2;
      const a = seed.pad([[0, 0], [D2-1, D2], [D2-1, D2], [0,0]]);
      return a;
    });
    
    const state = tf.variable(initState);
    const [_, h, w, ch] = state.shape;
    

    $('#reset').on('click', e=>{
        tf.tidy(()=>{
          state.assign(initState);
        });
    });


    const damage = (x, y, r)=>{
      tf.tidy(()=>{
        const rx = tf.range(0, w).sub(x).div(r).square().expandDims(0);
        const ry = tf.range(0, h).sub(y).div(r).square().expandDims(1);
        const mask = rx.add(ry).greater(1.0).expandDims(2);
        state.assign(state.mul(mask));
      });
    }
    
    const plantSeed = (x, y)=>{
      const x2 = w-x-seed.shape[2];
      const y2 = h-y-seed.shape[1];
      if (x<0 || x2<0 || y2<0 || y2<0)
        return;
      tf.tidy(()=>{
        const a = seed.pad([[0, 0], [y, y2], [x, x2], [0,0]]);
        state.assign(state.add(a));
      });
    }
    
    const scale = 4;
    
    const canvas = document.getElementById('canvas');
    const ctx = canvas.getContext('2d');
    canvas.width = w;
    canvas.height = h;
    canvas.style.width = `${w*scale}px`;
    canvas.style.height = `${h*scale}px`;
    
    canvas.onmousedown = e=>{
      const x = Math.floor(e.clientX/scale);
        const y = Math.floor(e.clientY/scale);
        if (e.buttons == 1) {
          if (e.shiftKey) {
            plantSeed(x, y);  
          } else {
            damage(x, y, 8);
          }
        }
    }
    canvas.onmousemove = e=>{
      const x = Math.floor(e.clientX/scale);
      const y = Math.floor(e.clientY/scale);
      if (e.buttons == 1 && !e.shiftKey) {
        damage(x, y, 8);
      }
    }

    function step() {
      tf.tidy(()=>{
        let new_state = model.execute(
            {x:state, fire_rate:tf.tensor(0.5),
            angle:tf.tensor(0.0), step_size:tf.tensor(1.0)}, ['Identity']);
        new_state = new_state.reshape([-1, 16]);
        new_state = new_state.matMul(perturbationMatrix).reshape([1, D, D, 16]);
        new_state = new_state.clipByValue(-3., +3.);
        state.assign(new_state);
      });
    }

    function render() {
      step();

      const imageData = tf.tidy(()=>{
        const rgba = state.slice([0, 0, 0, 0], [-1, -1, -1, 4]);
        const a = state.slice([0, 0, 0, 3], [-1, -1, -1, 1]);
        const img = tf.tensor(1.0).sub(a).add(rgba).mul(255);
        const rgbaBytes = new Uint8ClampedArray(img.dataSync());
        return new ImageData(rgbaBytes, w, h);
      });
      ctx.putImageData(imageData, 0, 0);

      requestAnimationFrame(render);
    }
    render();
  }
  run();
  
</script>
''')

## Make videos for article

In [None]:
#@title visualize runs
import PIL.ImageFont

from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)


perturb_step = lambda i: i < 500 or i >= 1000

runs_per_perturb = []
for sym_t in perts_mat_list:
  x = seed[None, ...]

  run_frames = []
  run_texts = []
  for i in tqdm.trange(1500):
    if i<200 or i%5 == 0:
      vis = zoom(to_rgb(x[0]), 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)
      run_frames.append(im)

      # Save text for rendering.
      steptext = "Step: {}".format(i)
      perturbtext = "Perturbation: {}".format("ON" if perturb_step(i) else "OFF")
      perturbcolor = (255,0,0) if perturb_step(i) else (0,255,0)
      run_texts.append((steptext, perturbtext,perturbcolor))
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          run_frames.append(im)
          run_texts.append((steptext, perturbtext,perturbcolor))
    if perturb_step(i):
      lm = tf.cast(get_living_mask(x), tf.float32)
      x = x @ sym_t
      x = tf.clip_by_value(x, -3., 3.)
      x *= lm
    x = ca(x)

  runs_per_perturb.append(run_frames)

# generate text for each frame with a skeleton like the one above:


num_frames = len(runs_per_perturb[0])

fn = 'mosaic_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(num_frames):
    frame_images = [run_frames[i] for run_frames in runs_per_perturb]
    im = tile2d(frame_images)
  
    im = PIL.Image.fromarray(im)
    draw = PIL.ImageDraw.Draw(im)

    # compute text:
    steptext, perturbtext, perturbcolor = run_texts[i]
    # draw.text((x, y),"Sample Text",(r,g,b))
    draw.text((0, 0),steptext,(0,0,0),font=font)
    draw.text((0, 20),perturbtext,perturbcolor,font=font)
    vid.add(np.uint8(im))

mvp.ipython_display(fn, loop=True)

In [None]:
#@title interpolating directions
import PIL.ImageFont

from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

perturb_step = lambda i: True

# Generate 5 mutation matrices:
pert_1 = tail_sym_t
pert_2 = leg_sym_t
all_mutation_symm_t = []
all_distortions = [1.0, 0.75, 0.5, 0.25, 0.0]
for d in all_distortions:
  
  mutation_symm_t = d * pert_1 + (1. - d) * pert_2


  all_mutation_symm_t.append(mutation_symm_t)
mutation_symm_t = np.stack(all_mutation_symm_t)
print(mutation_symm_t.shape)

x = np.repeat(seed[None, ...], 5, 0)

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(1500):
    if i<200 or i%5 == 0:
      vis = zoom(tile2d(to_rgb(x), 5), 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perturbation: {}".format("ON" if perturb_step(i) else "OFF")
      perturbcolor = (255,0,0) if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)

      # Add mutation direction text.
      for idx, d in enumerate(all_distortions):
        disttext = "tail k: {} | leg k: {}".format(d, 1.0 - d)
        x_unit, y_unit = im.width // 5, im.height // 1
        x_displacement = x_unit // 5
        y_displacement = int(y_unit * 0.9)
        placement = (x_unit * (idx % 5) + x_displacement,
                     y_unit * (idx // 5) + y_displacement)
        draw.text(placement,disttext,(0,0,0),font=font)


      vid.add(np.uint8(im))
    if perturb_step(i):
      lm = tf.cast(get_living_mask(x), tf.float32)
      # Uncertain if there is a faster way:
      x_t = []
      for d_idx in range(len(all_distortions)):
        x_ti = x[d_idx] @ mutation_symm_t[d_idx]
        x_ti = tf.clip_by_value(x_ti, -3., 3.)
        x_t.append(x_ti)
      x = tf.stack(x_t)
      x *= lm
    x = ca(x)

mvp.ipython_display(fn, loop=True)

In [None]:
# @title combinations of mutations
import PIL.ImageFont

from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

x = seed[None, ...]

mutation_symm_t = tail_sym_t + leg_sym_t - tf.eye(16)

perturb_step = lambda i: i < 500 or i >= 1000

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(1500):
    if i<200 or i%5 == 0:
      vis = zoom(to_rgb(x[0]), 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perturbation: {}".format("ON" if perturb_step(i) else "OFF")
      perturbcolor = (255,0,0) if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)
      vid.add(np.uint8(im))
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          vid.add(np.uint8(im))
    if perturb_step(i):
      lm = tf.cast(get_living_mask(x), tf.float32)
      x = x @ mutation_symm_t
      x = tf.clip_by_value(x, -3., 3.)
      x *= lm
    x = ca(x)

mvp.ipython_display(fn, loop=True)

In [None]:
# @title Create a mosaic of direction mutations!
import PIL.ImageFont

from matplotlib import font_manager as fm
font_fn = fm.findfont(fm.FontProperties())
font = PIL.ImageFont.truetype(font_fn, 16)

# Generate 9 mutation matrices:
all_mutation_symm_t = []
all_distortions = [1.0, 0.5, 0.3, 0.1, 0.0, -0.1, -0.3, -0.5, -1.0]
for d in all_distortions:
  
  # Luca's faster way: d*M + (1-d)*I
  mutation_symm_t = d * tail_sym_t + (1. - d) * tf.eye(CHANNEL_N)

  """
  eigval, eigvec = np.linalg.eigh(mutation_symm_t)
  eigdelta = eigval - tf.ones([CHANNEL_N])
  eigval = tf.ones([CHANNEL_N]) + eigdelta * d
  mutation_symm_t = eigvec @ tf.linalg.tensor_diag(eigval) @ np.transpose(eigvec)
  """

  all_mutation_symm_t.append(mutation_symm_t)
mutation_symm_t = np.stack(all_mutation_symm_t)
print(mutation_symm_t.shape)

x = np.repeat(seed[None, ...], 9, 0)

fn = 'example_run.mp4'
with VideoWriter(fn) as vid:
  for i in tqdm.trange(500):
    if i<200 or i%5 == 0:
      vis = zoom(tile2d(to_rgb(x), 3), 4).clip(0, 1)
      #vis = np.concatenate((vis, np.ones((164, vis.shape[1], 3))), axis=0) 

      im = np.uint8(vis*255)

      im = PIL.Image.fromarray(im)
      
      draw = PIL.ImageDraw.Draw(im)

      # compute text:
      steptext = "Step: {}".format(i)
      perturbtext = "Perturbation: {}".format("ON" if perturb_step(i) else "OFF")
      perturbcolor = (255,0,0) if perturb_step(i) else (0,255,0)
      # draw.text((x, y),"Sample Text",(r,g,b))
      draw.text((0, 0),steptext,(0,0,0),font=font)
      draw.text((0, 20),perturbtext,perturbcolor,font=font)

      # Add mutation direction text.
      for idx, d in enumerate(all_distortions):
        disttext = "direction: {}".format(d)
        if d == 1.0:
          disttext += " (train config)"
        if d == 0.0:
          disttext += " (NOOP)"
        x_unit, y_unit = im.width // 3, im.height // 3
        x_displacement = x_unit // 5
        y_displacement = int(y_unit * 0.9)
        placement = (x_unit * (idx % 3) + x_displacement,
                     y_unit * (idx // 3) + y_displacement)
        draw.text(placement,disttext,(0,0,0),font=font)


      vid.add(np.uint8(im))
      if i == 500 or i == 1000:
        # add many frames to effectively pause the video.
        for _ in range(50):
          vid.add(np.uint8(im))
    if perturb_step(i):
      lm = tf.cast(get_living_mask(x), tf.float32)
      # Uncertain if there is a faster way:
      x_t = []
      for d_idx in range(len(all_distortions)):
        x_ti = x[d_idx] @ mutation_symm_t[d_idx]
        x_ti = tf.clip_by_value(x_ti, -3., 3.)
        x_t.append(x_ti)
      x = tf.stack(x_t)
      x *= lm
    x = ca(x)

mvp.ipython_display(fn, loop=True)