<a href="https://colab.research.google.com/github/ekgren/neural_ca/blob/main/neural_ca.ipynb" target="_parent">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Neural CA

In [None]:
!git clone -l -s git://github.com/ekgren/neural_ca.git neural_ca
%cd neural_ca

In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import time
import torch.nn.functional as F
import torch.optim as optim
import random
import torchvision
from torchvision import datasets, models, transforms

import numpy as np

import matplotlib.pylab as pl
import glob
from IPython.display import Image, HTML, clear_output
import tqdm

from src.tools import *

In [None]:
class Net(nn.Module):
    def __init__(self, in_dim=16, image_dim=64):
        super(Net, self).__init__()
        self.in_dim = in_dim
        self.image_dim = image_dim
        self.conv = torch.nn.Conv2d(in_channels=1,
                                    out_channels=2,
                                    kernel_size=3,
                                    padding=(1, 1),
                                    groups=1,
                                    padding_mode='circular',
                                    bias=False)
        self.ff1 = nn.Linear(3*self.in_dim, 3*4*self.in_dim, bias=False)
        self.ff2 = nn.Linear(3*4*self.in_dim, self.in_dim, bias=False)
        self.ff2.weight.data.fill_(0.0)

    def forward(self, state_grid):
        state_grid = state_grid
        grad = self.conv(state_grid)
        perception_grid = torch.cat([state_grid, grad], dim=1)
        perception_vectors = perception_grid.view(3*self.in_dim, self.image_dim*self.image_dim).transpose(0, 1)
        update = self.ff2(torch.relu(self.ff1(perception_vectors)))
        update = update.transpose(0, 1).view(self.in_dim, 1, self.image_dim, self.image_dim)
        return update

In [None]:
in_dim = 16
im_dim = 224
model_ft = models.resnet18(pretrained=True).eval().cuda()
net = Net(in_dim=in_dim, image_dim=im_dim).train().cuda()

a = torch.nn.Parameter(torch.rand(in_dim, 1, im_dim, im_dim, requires_grad=True).cuda())
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

loss = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam([a] + list(net.parameters()), lr=1e-4)
plt.imshow(a.detach().transpose(0, 2).transpose(1, 3)[:, :, :3, 0].cpu())

In [None]:
for i in range(10):
    a = torch.nn.Parameter(torch.rand(in_dim, 1, im_dim, im_dim, requires_grad=True).cuda())
    c = torch.zeros(in_dim, 1, im_dim, im_dim).cuda()
    for j in range(4):
        opt.zero_grad()
        a = a.detach()
        c = c.detach()
        for k in range(random.randint(32, 96)):
            c = net(a)
            a = a + c
            a = torch.cat([a[:3, :, :, :].clamp(0, 1), a[3:, :, :, :]], dim=0)
        #d = normalize(a.transpose(0, 1)[0, :3, :, :])
        #d = model_ft(d.unsqueeze(0))
        #l = loss(d, y_true)
        d = model_ft(a.transpose(0, 1)[:, :3, :, :])
        l = d.norm()
        l.backward()
        opt.step()

with VideoWriter(fps=24) as vid:
    for i in range(1):
        a = torch.nn.Parameter(torch.rand(in_dim, 1, im_dim, im_dim, requires_grad=True).cuda())
        c = torch.zeros(in_dim, 1, 64, 64).cuda()
        for j in range(4):
            opt.zero_grad()
            a = a.detach()
            c = c.detach()
            for k in range(random.randint(32, 96)):
                c = net(a)
                a = a + c
                a = torch.cat([a[:3, :, :, :].clamp(0, 1), a[3:, :, :, :]], dim=0)
                if k % 2 == 0:
                    plot = a.detach().transpose(0, 2).transpose(1, 3)[:, :, :3, 0].cpu()
                    #vid.add(zoom(plot))
                    vid.add(plot)
            #d = normalize(a.transpose(0, 1)[0, :3, :, :])
            #d = model_ft(d.unsqueeze(0).clamp(0, 1))
            #l = loss(d, y_true)
            d = model_ft(a.transpose(0, 1)[:, :3, :, :])
            l = d.norm()
            l.backward()
            opt.step()