In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from IPython import display
import time

from math import *

from torchvision.transforms.functional import rotate

import pygame

DEVICE = "cuda"

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
MULT = 4
XRES = 256
YRES = 256

pygame.init()
screen = pygame.display.set_mode((MULT*XRES,MULT*YRES), 0, 24)

In [3]:
class RadConv(nn.Module):
    def __init__(self, NI, NO, radius):
        super().__init__()
        self.conv = nn.Conv2d(NI, NO, 2*radius+1, padding=radius, padding_mode='circular', bias=None)
        
        avg_weight = torch.zeros_like(self.conv.weight.data)
        
        for theta in np.arange(0,2*pi,pi/256.0):
            avg_weight += rotate(self.conv.weight.data.detach().clone(), 180*theta/pi)/sqrt(512.0)
        
        #avg_weight /= avg_weight.sum(3).sum(2).unsqueeze(2).unsqueeze(3)
        
        self.conv.weight.data = avg_weight
    
    def forward(self, x):
        return self.conv(x)
    
class BlurFields(nn.Module):
    def __init__(self, N, inner, outer, s1, s2):
        super().__init__()
        
        MR = int(outer)+2
        self.conv = nn.Conv2d(N, N, [1+2*MR,1+2*MR], padding=[MR,MR], padding_mode='circular', bias=None)
        self.conv.weight.data *= 0
        
        yy,xx = np.meshgrid(1.0*np.arange(2*MR+1)-MR, 1.0*np.arange(2*MR+1)-MR)        
        
        r2 = np.sqrt(xx**2 + yy**2)
        r2 = torch.FloatTensor(r2).unsqueeze(0).unsqueeze(1).unsqueeze(2)
        xx = torch.FloatTensor(xx).unsqueeze(0).unsqueeze(1).unsqueeze(2)
        yy = torch.FloatTensor(yy).unsqueeze(0).unsqueeze(1).unsqueeze(2)
        
        rfilt = torch.sigmoid(s1*(r2-inner))*torch.sigmoid((outer-r2)*s2)
        
        for i in range(N):
            self.conv.weight.data[i,i] = rfilt

        self.conv.weight.data /= torch.sum(self.conv.weight.data[0,0]).detach()
    
    def forward(self, x):
        return self.conv(x)    

In [79]:
class AffinityNetwork(nn.Module):
    def __init__(self, DIM):
        super().__init__()
        
        self.linput = nn.Linear(DIM*2, 32)
        nn.init.orthogonal_(self.linput.weight, gain=3)
        
        self.layers = nn.ModuleList([nn.Linear(32,32) for i in range(3)])
        for l in self.layers:
            nn.init.orthogonal_(l.weight, gain=3)
            
        self.loutput = nn.Linear(32, DIM*DIM)
        nn.init.orthogonal_(self.loutput.weight, gain=3)
        
    def forward(self, x, field):
        XRES = x.shape[2]
        YRES = x.shape[3]
        DIM = x.shape[1]
        
        z = torch.cat([x,field],1).permute(0,2,3,1).contiguous().view(XRES*YRES, 2*DIM)
        
        z = F.elu(self.linput(z))
        for l in self.layers:
            z = F.elu(l(z))
        
        return self.loutput(z).view(1,XRES,YRES,DIM*DIM).permute(0,3,1,2).contiguous().view(DIM,DIM,XRES,YRES)
        
class AffinityCA(nn.Module):
    def __init__(self, DIM, UPDATES):
        super().__init__()
        self.UPDATES = UPDATES
        
        self.interaction = nn.Parameter(torch.rand(DIM, DIM))
        #self.scale = nn.Parameter(0.1+2*torch.rand(DIM,1,1,1))
        #self.perp = nn.Parameter(torch.randn(DIM,1,1,1))
        
        self.field_conv1 = RadConv(DIM, 16, 11)
        self.field_conv2 = RadConv(16, DIM, 11)
        
        self.aff_propagator = BlurFields(DIM, -2, 4, 4, 1)
        
        self.affinity = AffinityNetwork(DIM)
        
    def forward(self, rho, p): # rho is 1 x DIM x XRES x YRES, p is 1 x 2 x XRES x YRES
        XRES = rho.shape[2]
        YRES = rho.shape[3]
        DIM = rho.shape[1]
        
        # Field update        
        field = self.field_conv2(F.elu(self.field_conv1(rho)))
        
        # Rho update        
        
        # Grid coordinates of each cell
        yy,xx = np.meshgrid(np.arange(YRES), np.arange(XRES))
        zz = np.concatenate([xx[np.newaxis,np.newaxis,:,:],\
                             yy[np.newaxis,np.newaxis,:,:]], axis=1)
        
        pos0 = torch.FloatTensor(zz).to(DEVICE)
        pos0c = (pos0.long()[:,1] + YRES * pos0.long()[:,0]).view(1,XRES*YRES)
            
        # Calculate the affinity of destination sites
        affinity = self.affinity(rho, field).view(DIM,DIM,XRES*YRES)
        #affinity = self.aff_propagator(affinity).view(DIM,DIM,XRES*YRES)
        
        dest = torch.zeros(DIM,XRES*YRES,device=DEVICE)
        dest_p = torch.zeros(DIM,2,XRES*YRES,device=DEVICE)
        
        RES = torch.LongTensor([XRES,YRES]).to(DEVICE).view(1,2,1,1)
        affinity_src = (field.view(1,DIM,XRES*YRES) * affinity).sum(1).view(DIM,XRES*YRES)
        
        for i in range(self.UPDATES):
            # Try a random displacement
            eta = 8*torch.randn(DIM,2,XRES,YRES,device=DEVICE)
            
            pos = (pos0 + 0.5 + eta + p + RES.float()).long()%RES
            
            #pos = torch.cat([ (pos[:,0:1] + 4*XRES).long()%XRES,
            #                  (pos[:,1:2] + 4*YRES).long()%YRES ], 1)
            
            pos = pos.view(DIM,2,XRES*YRES)
            pos = pos[:,1] + YRES * pos[:,0]
            
            # Check affinity at destination
            affinity_dest = torch.gather(index = pos, input=field.view(DIM,XRES*YRES), dim=1)
            affinity_dest = (affinity_dest.unsqueeze(0) * affinity).sum(1).view(DIM,XRES*YRES)
            
            # If it's higher, keep; otherwise, discard
            gate = torch.ge(affinity_dest, affinity_src).long()
            del affinity_dest
            
            # Send material
            #pos = pos0c*(1-gate)+pos*gate
            
            eta = eta * gate.view(DIM,1,XRES,YRES)
            #etaperp = torch.cat([eta[:,1:,:,:], -eta[:,:1,:,:]], 1)
            
            #eta = etaperp
            
            #eta = self.scale * eta + self.perp * etaperp
            pos = (pos0 + 0.5 + eta + p + RES.float()).long()%RES
            
            #pos = torch.cat([ (pos[:,0:1] + 4*XRES).long()%XRES,
            #                  (pos[:,1:2] + 4*YRES).long()%YRES ], 1)
            
            pos = pos.view(DIM,2,XRES*YRES)
            pos = pos[:,1] + YRES * pos[:,0]
            
            dp = p.view(1,2,XRES*YRES) + eta.view(DIM,2,XRES*YRES)
            
            dest.scatter_(dim = 1, index = pos, src = rho.view(DIM,XRES*YRES)/self.UPDATES, reduce='add')
            dest_p.scatter_(dim = 2, index = pos.unsqueeze(1).expand(DIM,2,XRES*YRES), src = rho.view(DIM,1,XRES*YRES)*dp/self.UPDATES, reduce='add')
        
        dest_p = dest_p.sum(0) / (1e-8 + dest.unsqueeze(1).sum(0))
        dest_p = dest_p.view(1,2,XRES,YRES)
        #dest_n = torch.sqrt(1e-8 + torch.sum(dest_p**2,1)).unsqueeze(1)
        #dest_p = 0.999 * dest_p + 0.001 * 2 * dest_p/dest_n
        
        del pos0
        del pos0c
        del pos
        del gate
        del affinity_src
        
        # Blend with last frame to remove jitter, slow down the dynamics
        rho = 0.4 * rho + 0.6 * dest.view(1,DIM,XRES,YRES)
        p = 0.4 * p + 0.6 * dest_p
        return rho, p
    
class Renderer(nn.Module):
    def __init__(self, DIM):
        super().__init__()
        u = np.arange(DIM)/(DIM)
        u = u.reshape((DIM,1,1,1))
        
        w = 0.25
        v = 0.5
        color = np.clip(np.concatenate([w+v*np.cos(2*pi*u), w+v*np.cos(2*pi*u+2*pi/3), w+v*np.cos(2*pi*u+4*pi/3)], axis=3), 0, 1)
        
        self.colors = torch.FloatTensor(color).cpu()
                
    def forward(self, rho):
        view = torch.zeros(rho.shape[2], rho.shape[3], 3).cpu() #to(DEVICE)
        
        rhosum = torch.sum(rho,1).squeeze(0).unsqueeze(2).cpu()
        urho = ( (rho.cpu()/(1e-8+rhosum.view(1,1,XRES,YRES))).squeeze(0).unsqueeze(3) * self.colors).sum(0) # XRES x YRES x 3
        
        urho = urho * torch.clamp(rhosum,0,1)
        
        # Reduce saturation
        urho = 0.9*urho + 0.1*torch.mean(urho,2).unsqueeze(2)
        
        view = torch.clamp(view + 0.8 - urho,0,1)
        
        return view    

In [76]:
DIM = 12

ca = AffinityCA(DIM, 10).to(DEVICE)
render = Renderer(DIM).to(DEVICE)

In [77]:
field = torch.zeros(1,2,XRES,YRES).to(DEVICE)
rho = torch.rand(1,DIM,XRES,YRES).to(DEVICE)*0.01

for j in range(1):
    for i in range(DIM):
        x = np.random.randint(XRES-32)+16
        y = np.random.randint(YRES-32)+16

        for yy in range(y-16,y+17):
            for xx in range(x-16,x+17):
                if ((xx-x)**2+(yy-y)**2<=16*16):
                    rho[0,i,xx,yy] += 2

In [78]:
from PIL import Image
frame = 0

#!rm /sata/frames/*.png

with torch.no_grad():
    while True:
        rho, field = ca.forward(rho, field)

        for event in pygame.event.get():
            if event.type == pygame.KEYUP:            
                if event.key == pygame.K_r:
                    ca = AffinityCA(DIM, 10).to(DEVICE)
                
        frame += 1
        
        if frame%1 == 0:
            #log_rho = 0.1 * (-log(1e-3) + torch.log(1e-3 + rho)).cpu().detach().numpy()[0].transpose(1,2,0)
            #im = np.clip(256*log_rho,0,255).astype(np.uint8)       
            im = render.forward(rho).cpu().detach().numpy()
            im = np.clip(255*im, 0, 255).astype(np.uint8)
            
            xim = Image.fromarray(im)
            #xim.save("/sata/frames/%.8d.png" % (frame//5))
            
            im = im.repeat(MULT, axis=0)
            im = im.repeat(MULT, axis=1)

            s_image = pygame.pixelcopy.make_surface(im)

            screen.blit(s_image, (0, 0))

            pygame.display.update()   
            time.sleep(0.005)

KeyboardInterrupt: 

In [75]:
field[:,0].mean()

tensor(-0.0706, device='cuda:0')

In [82]:
field[:,1]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')