In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from IPython import display
import time

from math import *

from torchvision.transforms.functional import rotate

import pygame

DEVICE = "cuda"

In [None]:
MULT = 1
XRES = 256
YRES = 256

pygame.init()
screen = pygame.display.set_mode((MULT*XRES,MULT*YRES), 0, 24)

In [None]:
class RadConv(nn.Module):
    def __init__(self, NI, radius):
        super().__init__()
        self.conv = nn.Conv2d(NI, NI, 2*radius+1, padding=radius, padding_mode='circular', bias=None)
        
        avg_weight = torch.zeros_like(self.conv.weight.data)
        
        for theta in np.arange(0,2*pi,pi/256.0):
            avg_weight += rotate(self.conv.weight.data.detach().clone(), 180*theta/pi)/128.0
        
        avg_weight /= avg_weight.sum(3).sum(2).unsqueeze(2).unsqueeze(3)
        
        self.conv.weight.data = avg_weight
        for i in range(NI):
            for j in range(i):
                self.conv.weight.data[i,j] = 0
                self.conv.weight.data[j,i] = 0
    
    def forward(self, x):
        return self.conv(x)
    
class BlurFields(nn.Module):
    def __init__(self, N, inner, outer, s1, s2):
        super().__init__()
        
        MR = int(outer)+2
        self.conv = nn.Conv2d(N, N, [1+2*MR,1+2*MR], padding=[MR,MR], padding_mode='circular', bias=None)
        self.conv.weight.data *= 0
        
        yy,xx = np.meshgrid(1.0*np.arange(2*MR+1)-MR, 1.0*np.arange(2*MR+1)-MR)        
        
        r2 = np.sqrt(xx**2 + yy**2)
        r2 = torch.FloatTensor(r2).unsqueeze(0).unsqueeze(1).unsqueeze(2)
        xx = torch.FloatTensor(xx).unsqueeze(0).unsqueeze(1).unsqueeze(2)
        yy = torch.FloatTensor(yy).unsqueeze(0).unsqueeze(1).unsqueeze(2)
        
        rfilt = torch.sigmoid(s1*(r2-inner))*torch.sigmoid((outer-r2)*s2)
        
        for i in range(N):
            self.conv.weight.data[i,i] = rfilt

        self.conv.weight.data /= torch.sum(self.conv.weight.data[0,0]).detach()
    
    def forward(self, x):
        return self.conv(x)    

In [None]:
class AffinityNetwork(nn.Module):
    def __init__(self, DIM):
        super().__init__()
        
        self.linput = nn.Linear(2*DIM, 16)
        nn.init.orthogonal_(self.linput.weight, gain=3)
        
        self.layers = nn.ModuleList([nn.Linear(16,16) for i in range(3)])
        for l in self.layers:
            nn.init.orthogonal_(l.weight, gain=3)
            
        self.loutput = nn.Linear(16, DIM)
        nn.init.orthogonal_(self.loutput.weight, gain=3)
        
    def forward(self, rho, x):
        XRES = x.shape[2]
        YRES = x.shape[3]
        DIM = x.shape[1]
        
        z = torch.cat([rho,x],1)
        z = z.permute(0,2,3,1).contiguous().view(XRES*YRES, 2*DIM)
        
        z = F.elu(self.linput(z))
        for l in self.layers:
            z = F.elu(l(z))
        
        return self.loutput(z).view(1,XRES,YRES,DIM).permute(0,3,1,2).contiguous()
        
class AffinityCA(nn.Module):
    def __init__(self, DIM, UPDATES):
        super().__init__()
        self.UPDATES = UPDATES
        
        self.interaction = nn.Parameter(torch.rand(DIM, DIM))
        
        self.propagator = BlurFields(DIM, -2, 2, 4, 1)
        self.aff_propagator = BlurFields(DIM, -2, 2, 4, 1)
        
        self.affinity = AffinityNetwork(DIM)
        
    def forward(self, rho, field): # rho, field are 1 x DIM x XRES x YRES
        XRES = rho.shape[2]
        YRES = rho.shape[3]
        DIM = rho.shape[1]
        
        # Field update
        """
        for i in range(20):
            field = (1-sigma) * field + 0.1 * rho
            field = F.upsample_bilinear(F.avg_pool2d(field, 2), scale_factor = 2)
            field = self.propagator(field)
        
        mfield,_ = torch.min(field.view(DIM,XRES*YRES),1)
        field = field - mfield.view(1,DIM,1,1)
        """
        #for i in range(10):
        #    field = self.propagator(field)
        
        field = field + rho - field * torch.clip(torch.matmul(rho.view(DIM,XRES*YRES).permute(1,0), self.interaction).permute(1,0).view(1,DIM,XRES,YRES),0,0.5)
        
        for scale in range(2,0,-1):
            z = F.avg_pool2d(field, 2**scale)
            z = self.propagator(z)
            z = F.upsample_bilinear(z, scale_factor = 2**scale)
            field = 0.5*field + 0.5*z
        
        mfield,_ = torch.min(field.view(DIM,XRES*YRES),1)
        field = field - mfield.view(1,DIM,1,1)
        
        #field = field - field.mean(3).mean(2).unsqueeze(2).unsqueeze(3)
        #field = field/(1e-8 + torch.std(field.view(1,DIM,XRES*YRES).std(2).unsqueeze(2).unsqueeze(3)))

        # Rho update        
        
        # Grid coordinates of each cell
        yy,xx = np.meshgrid(np.arange(YRES), np.arange(XRES))
        zz = np.concatenate([xx[np.newaxis,np.newaxis,:,:],\
                             yy[np.newaxis,np.newaxis,:,:]], axis=1)
        
        pos0 = torch.FloatTensor(zz).to(DEVICE)
        pos0c = (pos0.long()[:,1] + YRES * pos0.long()[:,0]).view(1,XRES*YRES)
            
        # Calculate the affinity of destination sites
        affinity = self.affinity(0*rho, field)
        affinity = self.aff_propagator(affinity)
        affinity -= 0.01 * torch.sum(rho**2,1).unsqueeze(1)
        
        dest = torch.zeros(DIM,XRES*YRES,device=DEVICE)
        
        RES = torch.LongTensor([XRES,YRES]).to(DEVICE).view(1,2,1,1)
        for i in range(self.UPDATES):
            # Try a random displacement
            #eta = 2*torch.randn(DIM,2,XRES,YRES,device=DEVICE)
            pos = (pos0 + 0.5 + 2*torch.randn(DIM,2,XRES,YRES,device=DEVICE) + RES.float()).long()%RES
            
            #pos = torch.cat([ (pos[:,0:1] + 4*XRES).long()%XRES,
            #                  (pos[:,1:2] + 4*YRES).long()%YRES ], 1)
            
            pos = pos.view(DIM,2,XRES*YRES)
            pos = pos[:,1] + YRES * pos[:,0]
            
            # Check affinity at destination
            affinity_dest = torch.gather(index = pos, input=affinity.view(DIM,XRES*YRES), dim=1)
            
            # If it's higher, keep; otherwise, discard
            gate = torch.ge(affinity_dest, affinity.view(DIM,XRES*YRES)).long()
            del affinity_dest
            
            # Send material
            pos = pos0c*(1-gate)+pos*gate
            
            #eta = eta * gate.view(DIM,1,XRES,YRES)
            #pos = (pos0 + 0.5 + eta + RES.float()).long()%RES
            
            #pos = torch.cat([ (pos[:,0:1] + 4*XRES).long()%XRES,
            #                  (pos[:,1:2] + 4*YRES).long()%YRES ], 1)
            
            #pos = pos.view(DIM,2,XRES*YRES)
            #pos = pos[:,1] + YRES * pos[:,0]
            
            dest.scatter_(dim = 1, index = pos, src = rho.view(DIM,XRES*YRES)/self.UPDATES, reduce='add')
        
        del pos0
        del pos0c
        del pos
        del gate
        
        # Blend with last frame to remove jitter, slow down the dynamics
        rho = 0.5 * rho + 0.5 * dest.view(1,DIM,XRES,YRES)
        return rho, field
    
class Renderer(nn.Module):
    def __init__(self, DIM):
        super().__init__()
        u = np.arange(DIM)/(DIM)
        u = u.reshape((DIM,1,1,1))
        
        w = 0.25
        v = 0.5
        color = np.clip(np.concatenate([w+v*np.cos(2*pi*u), w+v*np.cos(2*pi*u+2*pi/3), w+v*np.cos(2*pi*u+4*pi/3)], axis=3), 0, 1)
        
        self.colors = torch.FloatTensor(color).cpu()
                
    def forward(self, rho):
        view = torch.zeros(rho.shape[2], rho.shape[3], 3).cpu() #to(DEVICE)
        
        rhosum = torch.sum(rho,1).squeeze(0).unsqueeze(2).cpu()
        urho = ( (rho.cpu()/(1e-8+rhosum.view(1,1,XRES,YRES))).squeeze(0).unsqueeze(3) * self.colors).sum(0) # XRES x YRES x 3
        
        urho = urho * torch.clamp(rhosum,0,1)
        
        # Reduce saturation
        urho = 0.9*urho + 0.1*torch.mean(urho,2).unsqueeze(2)
        
        view = torch.clamp(view + 0.8 - urho,0,1)
        
        return view    

In [None]:
DIM = 12

ca = AffinityCA(DIM, 10).to(DEVICE)
render = Renderer(DIM).to(DEVICE)

In [None]:
field = torch.zeros(1,DIM,XRES,YRES).to(DEVICE)
rho = torch.rand(1,DIM,XRES,YRES).to(DEVICE)*0.05

for i in range(DIM):
    x = np.random.randint(XRES)
    y = np.random.randint(YRES)
    
    for yy in range(YRES):
        for xx in range(XRES):
            if ((xx-x)**2+(yy-y)**2<=16*16):
                rho[0,i,xx,yy] += 1

In [None]:
frame = 0

with torch.no_grad():
    while True:
        rho, field = ca.forward(rho, field)

        for event in pygame.event.get():
            if event.type == pygame.KEYUP:            
                if event.key == pygame.K_r:
                    ca = AffinityCA(DIM, 10).to(DEVICE)
                
        frame += 1
        
        if frame%1 == 0:
            #log_rho = 0.1 * (-log(1e-3) + torch.log(1e-3 + rho)).cpu().detach().numpy()[0].transpose(1,2,0)
            #im = np.clip(256*log_rho,0,255).astype(np.uint8)       
            im = render.forward(rho).cpu().detach().numpy()
            im = np.clip(255*im, 0, 255).astype(np.uint8)
        
            im = im.repeat(MULT, axis=0)
            im = im.repeat(MULT, axis=1)

            s_image = pygame.pixelcopy.make_surface(im)

            screen.blit(s_image, (0, 0))

            pygame.display.update()   
            time.sleep(0.005)