[![Open In Wolfram Cloud](https://raw.githubusercontent.com/gvarnavi/generative-art-iap/master/PR/wolfram-badge.svg)](https://www.wolframcloud.com/obj/gvarnavi/Published/03_differentiable-cellular-automata.nb)

# Differentiable CAs (NNs)

So far, we've mostly looked at discrete state CA, e.g. our 'binary' elementary CAs - which could either be "alive" or "dead".  

I recently stumbled upon this [excellent blog post](https://distill.pub/2020/growing-ca/) about a differentiable model of morphogenesis using neural networks and cellular automata, and figured it would be a great demonstration for the class!

In order to allow for our neural networks to minimize the gradient loss function - we need to allow for continuum state CAs, e.g. taking any real value between 0-1.

## Problem Statement

I recommend reading the blog post above after class, but essentially we're trying to teach a neural network to learn the update rules for a continuous state cellular automaton on a 3x3 Moore neighborhood in order to grow and reproduce a pattern. 

While the blog above uses emojis, we'll query the Wolfram Knowledgebase and use first-generation Pokemon characters instead.
At the end, we'll end up with cellular automata models of the sort:

![charizard](https://raw.githubusercontent.com/gvarnavi/generative-art-iap/master/01.27-Thursday/gifs/charizard.png)

## Target Images

Like we said - we'll use the Generation I pokemon as our target images. We set some parameters, and query the Knowledgebase for standardized images:

In [None]:
numChannels = 16;
prepaddedWidth = prepaddedHeight = 40;
padding = 8;
width = height = prepaddedWidth + 2 padding;

In [None]:
standardize[img_Image] := Block[{width, height}, ImagePad[
   ImageResize[ImagePad[img,
     {width, height} = 
      ImageDimensions[img]; {Table[Round[1/2 Ramp[height - width]], 
       2], Table[Round[1/2 Ramp[width - height]], 2]}, 
     "Fixed"], {prepaddedWidth, prepaddedHeight}], padding]]

In [None]:
generationOnePokemonClass = FilteredEntityClass["Pokemon", EntityFunction[p, p["Generation"] == Entity["PokemonGeneration", "GenerationI"]]];
EntityValue[generationOnePokemonClass = ComplementedEntityClass[generationOnePokemonClass, {Entity["Pokemon", "Pokedex0133:PartnerEevee"], Entity["Pokemon", "Pokedex0025:PartnerPikachu"]}], "EntityCount"]
generationOnePokemonNames = EntityList[generationOnePokemonClass];
generationOnePokemonImgs = standardize /@ EntityValue[generationOnePokemonClass, EntityProperty["Pokemon", "Image"]];
ImageCollage[generationOnePokemonImgs, Background -> None]

## Cell State

- We'll represent each cell (or pixel) by 16 channels
  - The first four are 'visible' and represent the RGBa values
  - Layers 5-16 are 'hidden' layers and allow the neural net to learn our self-assembly or growth pattern
- The Opacity a channel further encodes our cell's state
  - If a > 0.1, the cell is considered 'mature'
  - If a < 0.1, but at-least one of the cell's 3x3 neighbors is 'mature', the cell is considered 'growing'
  - If none of the cell's 3x3 neighbors are 'mature', the cell is considered 'dead'
  
Note the pokemon images we have use the same scheme (i.e. use a transparent background or a channel for 'whitespace'). For visualization purposes we can code this using our Moore neighborhood function:

In [None]:
  Moore[func__, lat_] :=   
     MapThread[func, 
       Map[RotateRight[lat, #] &, 
         {{0, 0}, {1, 0}, {0, -1}, {-1, 0}, 
           {0, 1}, {1,  -1}, {-1, -1}, {-1, 1}, 
           {1, 1}}], 2];

In [None]:
cellStateVisualization[self_, east_, south_, west_, north_, 
  southEast_, southWest_, northWest_, northEast_] := Which[
  self > 0.1, 2,
  Or @@ Thread[{east, south, west, north, southEast, southWest, 
      northWest, northEast} > 0.1], 1,
  True, 0
  ]
  
ImageCollage[
 ArrayPlot[
    Moore[cellStateVisualization, Map[Last, ImageData[#], {2}]], 
    Frame -> False] & /@ RandomSample[generationOnePokemonImgs, 12], 
 Background -> None, ImageSize -> 600]

## Neural Network (Single Update Step)

We follow the prescription from the blog above quite closely.
A single update step for our cellular automaton can be summarized graphically by:

![nn](https://raw.githubusercontent.com/gvarnavi/generative-art-iap/master/01.27-Thursday/gifs/neural-network-CAs.png)

which consists of the following steps:

1. Perception
  - We use three fixed kernels to allow the cell  to perceive its 'local' environment
  - In particular we use Sobel kernels to encode the gradient
  - and concatenate them with the cell's identity to give a 16*3 = 48 dimensional vector (for each cell)

In [None]:
kernel["sobel-x"]={{-1,0,1},{-2,0,2},{-1,0,1}};
kernel["sobel-y"]=Transpose[kernel["sobel-x"]];
kernel["identity"]=BoxMatrix[0,{3,3}];

conv["sobel-x"]=ConvolutionLayer["Weights"->Table[KroneckerDelta[i,j]Reverse[kernel["sobel-x"],{1,2}]+ConstantArray[0,{3,3}],{i,numChannels},{j,numChannels}],"Biases"->None,PaddingSize->1,"Interleaving"->True,LearningRateMultipliers->None];
conv["sobel-y"]=ConvolutionLayer["Weights"->Table[KroneckerDelta[i,j]Reverse[kernel["sobel-y"],{1,2}]+ConstantArray[0,{3,3}],{i,numChannels},{j,numChannels}],"Biases"->None,PaddingSize->1,"Interleaving"->True,LearningRateMultipliers->None];

In [None]:
net["perception"] = 
 NetGraph[<|"sobel-x" -> conv["sobel-x"], 
   "sobel-y" -> conv["sobel-y"], 
   "catenate" -> CatenateLayer[3]|>, {NetPort["Input"] -> "sobel-x", 
   NetPort["Input"] -> 
    "sobel-y", {NetPort["Input"], "sobel-x", "sobel-y"} -> 
    "catenate" -> NetPort["percepted"]}, 
  "Input" -> {width, height, numChannels}]

In [None]:
net["perception"][RandomReal[{0, 1}, {56, 56, 16}]] // Dimensions

Note this part is fully-initialized (i.e. has no trainable parameters). Also note that we use Interleaving for convenience (and keep the channels as the last tensor dimension)

 2. Update Rule
  - For each cell, we then apply a dense feed-forward network to go from this perception vector back to a state vector
  - In particular we use a 128 Dense layer, followed by a ReLU activation layer, and a final 16 Dense layer
  - We then use  NetMapOperator twice to apply this at each cell

In [None]:
net["update"] = 
 NetMapOperator[
  NetMapOperator[
   NetChain[{LinearLayer[numChannels 8], Ramp, 
     LinearLayer[numChannels, "Weights" -> 0]}, 
    "Input" -> numChannels 3]]]

Note we've initialized the weights of the last dense layer to 0. This is to ensure 'do-nothing' initial behavior and keep the gradients low. Also note this is the only layer with trainable parameters.

In [None]:
NetInitialize[net["update"]][net["perception"][RandomReal[{0, 1}, {56, 56, 16}]]]//Dimensions

3. Dropout layer (per cell)
 - Next, we apply a per cell Dropout layer to simulate the lack of a global clock in self-organizing systems (see blog post for more details)

In [None]:
net["cell-dropout"] = 
 NetChain[{NetArrayLayer[
    "Array" -> ConstantArray[1/2, {width, height}], 
    LearningRateMultipliers -> 0], DropoutLayer[], 
   ReplicateLayer[numChannels, 3]}]

In [None]:
Tally[Flatten[net["cell-dropout"][]]]

In [None]:
Tally[Flatten[net["cell-dropout"][NetEvaluationMode -> "Train"]]]

Note we used a constant array of 1/2, since 
DropoutLayer sets the input elements to zero with probability p during training, multiplying the remainder by 1/(1-p), 
and we used the default  p=1/2

4. Living cell mask
  - And finally, we apply a pre-update and post-update living cell mask
    - where we've defined 'living' as either 'mature' or 'growing'

In [None]:
net["living"] = 
 FunctionLayer[
  PartLayer[{All, All, 1}][
     PoolingLayer[3, "PaddingSize" -> 1, 
       "Input" -> {width, height, 1}, Interleaving -> True][
      PartLayer[{All, All, 4 ;; 4}][#]]] > 0.1 &]

In [None]:
Image[net["living"][ImageData[generationOnePokemonImgs[[25]]]]]

Putting it all together, we have our single update cell net:

In [None]:
net["single-update"] = 
 NetGraph[<|"perception" -> net["perception"], 
   "update" -> net["update"], "cell-dropout" -> net["cell-dropout"], 
   "dot-plus" -> FunctionLayer[Apply[#1 #2 + #3 &]], 
   "pre-living" -> net["living"], "post-living" -> net["living"], 
   "times" -> FunctionLayer[Apply[#1 #2 #3 &]]|>,
  {NetPort["Input"] -> "pre-living", 
   NetPort["Input"] -> 
    "perception" -> "update", {"cell-dropout", "update", 
     NetPort["Input"]} -> 
    "dot-plus" -> "post-living", {"pre-living", "post-living", 
     "dot-plus"} -> "times" -> NetPort["Output"]}, 
  "Input" -> {width, height, numChannels}, 
  "Output" -> {width, height, numChannels}]

## Nested Network

We now wish to nest this network, using the same set of trainable parameters.
 NetNestOperator seems like a perfect fit:

In [None]:
net["nested"] = NetChain[{NetNestOperator[net["single-update"], 64], PartLayer[{All, All, ;; 4}]}]

Note that after we iterate n times, we drop the hidden layers (to allow direct loss comparison with the visible channels of the target image)

### Training

We could now train this starting from a single seed against a single target Image:

In [None]:
singleSeed = Normal@SparseArray[Table[{width/2, height/2, i} -> 1, {i, 4, numChannels}], {width,height, numChannels}];
targetPokemon = ImageData[generationOnePokemonImgs[[6]]];

(*
NetTrain[net["nested"],<|"Input"\[Rule]{singleSeed},"Target"\[Rule]
{targetPokemon}|>,MaxTrainingRounds\[Rule]10,RandomSeeding\[Rule]1996]
*)

### Parallel Nesting

However, while this will likely learn the target image well - the fixed number of iterations will mean it'll have very little predictive power beyond the 64 number of iterations. I.e. the growth pattern will not be stable.  

The authors of the blog post deal with this in two ways:
- Using a Pool/Batch technique in their second 'persistent' experiment
- Nesting for a random number of iterations between 64 and 96 iterations
  - This should at-least ensure the system is stable for a limited number of iterations

We'll use a (much more expensive) variant of the second technique since I couldn't figure out how to get NetNestOperator take a random number efficiently. We'll nest three nets in parallel for 50, 64, and 78 iterations each and assign the maximum deviation as the current loss function.

We must ensure our parallel nested nets will share trainable parameters:

In [None]:
net["single-shared"] = NetInsertSharedArrays[net["single-update"]];
net["nested-parallel"] = 
 NetGraph[Association[
   Join[("nest-" <> ToString[#]) -> 
       NetGraph[<|
         "nested" -> NetNestOperator[net["single-shared"], #], 
         "part" -> PartLayer[{All, All, ;; 4}], 
         "loss" -> MeanSquaredLossLayer[]|>, {"nested" -> 
          "part" -> "loss"}] & /@ 
     Subdivide[50, 78, 2], {"max" -> 
      ThreadingLayer[Max]}]], {{"nest-50", "nest-64", "nest-78"} -> 
    "max" -> NetPort["Loss"]}]

We train a net for each target image (~1hr per net on  my GPU) and obtain:

In [None]:
(*
NetTrain[net["nested-parallel"],<|"Input"\[Rule]{singleSeed},"Target"\[Rule]{targetPokemon}|>,MaxTrainingRounds\[Rule]25000,RandomSeeding\[Rule]1996,TargetDevice\[Rule]"GPU"]
*)

![animation](https://raw.githubusercontent.com/gvarnavi/generative-art-iap/master/01.27-Thursday/gifs/differentiable-CAs.gif)

As we can see - while some growth patterns are stable (like Charizard, Venusaur, and Arbok), others like Rattata, Caterpie, and Spearow are not