# [Emergent Dynamics in Neural Cellular Automata](https://arxiv.org/pdf/2404.06406)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import DTD
from torchvision import transforms
from torch.optim import Adam

## Define the Convolution Kernels
The paper uses Sobel filters for x and y directions, as well as a Laplacian filter, and an identity filter in the convolution. 

We define them as PyTorch tensors:

In [2]:
def get_conv_kernels():
  """Defines the convolution kernels used in the NCA."""
  kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
  ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
  klap = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], dtype=torch.float32)
  kid = torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype = torch.float32) # Identity filter
  return kx, ky, klap, kid

## Implement the NCA Module

Create a custom PyTorch nn.Module to encapsulate the NCA logic, making it easier to integrate into the training loop.

In [3]:
class NCA(nn.Module):
    def __init__(self, num_channels, hidden_neurons):
        super(NCA, self).__init__()
        self.num_channels = num_channels
        self.hidden_neurons = hidden_neurons

        kx, ky, klap, kid = get_conv_kernels()
        self.conv_kernels = torch.stack([kx, ky, klap, kid]).unsqueeze(1)

        # Define the MLP for the update rule
        self.mlp = nn.Sequential(
            nn.Linear(4 * num_channels, hidden_neurons),
            nn.ReLU(),
            nn.Linear(hidden_neurons, num_channels)
        )


    def perception(self, state):
        """Applies the convolution kernels to the cell state."""
        b, c, h, w = state.shape
        # Move filters to the device
        conv_kernels = self.conv_kernels.to(state.device)
        # Perform convolution
        # Note: We need to pad the edges of the state in the correct manner 
        padded_state = F.pad(state, (1,1,1,1), mode='circular')
        # Apply depthwise convolution
        conv_out = F.conv2d(padded_state, conv_kernels, padding=0, groups=c)
        # Reshape conv_out to combine all channels to run through MLP
        conv_out = conv_out.permute(0, 2, 3, 1).reshape(b, h, w, -1)
        return conv_out

    def update(self, state):
       """Applies one update rule of the NCA."""
       b, c, h, w = state.shape

       # Compute Perception step using convolution
       perception_output = self.perception(state)

       # Apply MLP to get output
       mlp_output = self.mlp(perception_output)


       # Generate a random binary mask to achieve asynchronicity
       mask = (torch.rand(b, h, w, 1, device=state.device) > 0.5).float()

       # Apply the stochastic mask and update the state
       state = state + (mlp_output.permute(0, 3, 1, 2) * mask)
       return state

    def forward(self, initial_state, steps):
        """Applies the update rule for a specified number of steps."""
        state = initial_state
        for _ in range(steps):
            state = self.update(state)
        return state

In [4]:
import os
from urllib.request import urlretrieve
import zipfile
def download_file(url, destination):
    """Downloads a file to a local directory."""
    if not os.path.exists(destination):
        print(f"Downloading {url} to {destination}")
        urlretrieve(url, destination)

def download_spynet_weights():
     """Downloads the pretrained weights for the optical flow estimation."""
     spynet_url = "https://github.com/sniklaus/pytorch-spynet/releases/download/v1.0/spynet-sintel-final.pytorch"
     destination_dir = "./"
     destination_filename = os.path.join(destination_dir, "spynet-sintel-final.pytorch")
     
     download_file(spynet_url, destination_filename)
     return destination_filename

def compute_optical_flow(image1, image2, weights_path):
    """Computes the optical flow between two images."""
    try:
        from optical_flow import SpyNet
    except:
       print("Please move spynet.py file into the root of the project, refer to https://github.com/pytorch/examples/tree/main/optical_flow for reference")
       return None
    spynet = SpyNet().cuda()

    # Load the weights
    state_dict = torch.load(weights_path)
    spynet.load_state_dict(state_dict)

    image1 = image1.unsqueeze(0).cuda()
    image2 = image2.unsqueeze(0).cuda()
    flow = spynet(image1, image2).detach().cpu()
    return flow
 
def motion_strength(flow):
    """Calculates the motion strength of the given optical flow."""
    norm = torch.sqrt(flow[:,0,:,:]**2+flow[:,1,:,:]**2)
    return torch.mean(norm)

In [9]:
def train_nca(nca_model, dataset, epochs=6000, learning_rate=1e-3, batch_size = 1, steps_per_frame = 32, frames_per_metric = 100):
   optimizer = Adam(nca_model.parameters(), lr=learning_rate, weight_decay=1e-5)
   motion_strengths = []
   for epoch in range(epochs):
        for i, data in enumerate(dataset):
             optimizer.zero_grad()
             image = data[0].to(device)
             b, c, h, w = image.shape
             initial_state = torch.randn(b,nca_model.num_channels,h,w, device=device)
             
             final_state = nca_model(initial_state, steps_per_frame * frames_per_metric)
             
             # Motion strength is calculated using frames_per_metric sequential frames
             avg_motion_strength = 0
             for frame_idx in range(frames_per_metric - 1):
                img1 = final_state[:, :3, :, :].clone()
                img2 = nca_model(initial_state, (frame_idx + 1) * steps_per_frame)[:, :3, :, :].clone()
                
                # Normalize between 0 and 1 to be useable in the optical flow computation
                img1 = (img1 - img1.min()) / (img1.max() - img1.min())
                img2 = (img2 - img2.min()) / (img2.max() - img2.min())
                
                flow = compute_optical_flow(img1[0], img2[0], weights_path)
                if flow is not None:
                 avg_motion_strength += motion_strength(flow)

             avg_motion_strength /= (frames_per_metric - 1)

             loss = -avg_motion_strength
             loss.backward()
             optimizer.step()
        
        if epoch % 100 == 0:
         print(f"Epoch: {epoch} Loss: {loss.item()}, avg_motion_strength: {avg_motion_strength.item()}")
        motion_strengths.append(avg_motion_strength.item())
   return motion_strengths

In [13]:
# # Download spynet weights
# weights_path = download_spynet_weights()

# Download the DTD dataset, and select one of the four textures as in the paper
# We are downloading this to './data'
dataset_path = './data'

if not os.path.isdir(dataset_path):
   os.makedirs(dataset_path)
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((128, 128))])


dtd_dataset = DTD(root=dataset_path, split='train', download=True, transform=transform)

dataset_dir = os.path.join(dataset_path, 'dtd')
labels_dir = os.path.join(dataset_dir, 'labels')   
texture_idx = 0
texture_names = ['bubbly_0101', 'chequered_0121', 'interlaced_0172', 'cracked_0085']
print(f"Dtd: {dtd_dataset}")

# Import module
import os

# Assign directory
directory = r"gfg-test"

# Iterate over files in directory
for name in os.listdir("./data/dtd/dtd/labels"):
    print(f"Content of '{name}'")


dtd_indices = [idx for idx, label in enumerate(dtd_dataset.labels) if label == texture_names[texture_idx]]
subset_dataset = torch.utils.data.Subset(dtd_dataset, dtd_indices)
dataloader = torch.utils.data.DataLoader(subset_dataset, batch_size=1, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
   
C_values = range(8,128+1, 8)
D_values = range(16, 128+1, 16)

results = {}
# Loop through the values
for C in C_values:
    for D in D_values:
        print(f"Training for C: {C}, D: {D}")
        nca_model = NCA(num_channels=C, hidden_neurons=D).to(device)
        motion_strengths = train_nca(nca_model, dataloader)
        results[(C,D)] = motion_strengths

plt.figure(figsize=(12,8))
for C,D in results.keys():
    plt.plot(results[(C,D)], label = f"C: {C}, D: {D}")
plt.xlabel("Epochs")
plt.ylabel("Motion Strength")
plt.title("Motion Strength vs Training Epochs for various C & D")
plt.legend()
plt.show()
# Now, you can use the trained model for visualization and experimentation

Dtd: Dataset DTD
    Number of datapoints: 1880
    Root location: ./data
    split=train, partition=1
    StandardTransform
Transform: Compose(
               ToTensor()
               Resize(size=(128, 128), interpolation=bilinear, max_size=None, antialias=True)
           )
Content of 'labels_joint_anno.txt'
Content of 'test1.txt'
Content of 'test10.txt'
Content of 'test2.txt'
Content of 'test3.txt'
Content of 'test4.txt'
Content of 'test5.txt'
Content of 'test6.txt'
Content of 'test7.txt'
Content of 'test8.txt'
Content of 'test9.txt'
Content of 'train1.txt'
Content of 'train10.txt'
Content of 'train2.txt'
Content of 'train3.txt'
Content of 'train4.txt'
Content of 'train5.txt'
Content of 'train6.txt'
Content of 'train7.txt'
Content of 'train8.txt'
Content of 'train9.txt'
Content of 'val1.txt'
Content of 'val10.txt'
Content of 'val2.txt'
Content of 'val3.txt'
Content of 'val4.txt'
Content of 'val5.txt'
Content of 'val6.txt'
Content of 'val7.txt'
Content of 'val8.txt'
Content of 'val9

AttributeError: 'DTD' object has no attribute 'labels'