In [2]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from math import *

import time

In [160]:
DIM = 3
CARRIERS = 9

# Rotate N x N x N x 2 vectors by angle

def rotate(x, angle): 
    ca = torch.cos(angle)
    sa = torch.sin(angle)
    
    return torch.cat([ca*x[:,:,:,0:1]+sa*x[:,:,:,1:2], -sa*x[:,:,:,0:1]+ca*x[:,:,:,1:2]], 3)

def sigmoid(x):
    return 1/(1+np.exp(-0.25 * x))

# dim 3 carries the x,y variation
def mirror(p):
    p = RES-1-p
    w1 = torch.le(p,0).long()
    p = (1-w1)*p + w1*(-p)

    p = RES-1-p
    w1 = torch.le(p,0).long()
    p = (1-w1)*p + w1*(-p)
    
    return p

class ConvObj(nn.Module):
    def __init__(self, inner, outer, s1, s2):
        super().__init__()
        
        MR = int(outer)+3
        self.conv = nn.Conv2d(DIM, 3*DIM, 1+2*MR, padding=MR, padding_mode='circular', bias=None)
        self.conv.weight.data *= 0
        
        yy,xx = np.meshgrid(np.arange(2*MR+1)-MR, np.arange(2*MR+1)-MR)
        r2 = np.sqrt(xx**2 + yy**2)
        r2 = torch.FloatTensor(r2).unsqueeze(0).unsqueeze(1)        
        xx = torch.FloatTensor(xx).unsqueeze(0).unsqueeze(1)
        yy = torch.FloatTensor(yy).unsqueeze(0).unsqueeze(1)
        
        rfilt = torch.sigmoid(s1*(r2-inner))*torch.sigmoid((outer-r2)*s2)
        
        for i in range(DIM):
            self.conv.weight.data[i,i] = rfilt
            self.conv.weight.data[i+DIM,i] = rfilt * xx / MR
            self.conv.weight.data[i+2*DIM,i] = rfilt * yy / MR
        self.conv.weight.data /= torch.sum(self.conv.weight.data[0,0]).detach()
    
    def forward(self, x, pool_factor=1):
        p = x
        if pool_factor != 1:
            p = F.avg_pool2d(p, pool_factor)
        p = self.conv(p)
        if pool_factor != 1:
            p = F.upsample(p, scale_factor=pool_factor, mode='bicubic')
        
        s = p[:,:DIM]
        v = torch.cat([p[:,DIM:2*DIM].unsqueeze(4), p[:,2*DIM:3*DIM].unsqueeze(4)], 4)
        return s, v

In [161]:
class AdvCAParams(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.coe1 = nn.Linear(3*DIM, 6)
        self.coe2 = nn.Linear(6, 6)
        
        #self.coe_diff = nn.Linear(6, DIM, bias=None)
        self.coe_adv = nn.Linear(6, DIM*(3*DIM+1), bias=None)
        #self.coe_polar = nn.Linear(6, 3*(4*DIM+1), bias=None)
        
        
    def load_params(self, eom):
        values = {}
        idx = 0
        for name, param in self.named_parameters():
            n = torch.numel(param)
            values[name] = eom[idx:idx+n].view(*param.shape)
            idx += n
        
        return values
            
class AdvCA(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.advp = AdvCAParams()
        
        yy,xx = np.meshgrid(np.arange(RES), np.arange(RES))
        zz = np.concatenate([xx[np.newaxis,np.newaxis,:,:,np.newaxis,np.newaxis],\
                             yy[np.newaxis,np.newaxis,:,:,np.newaxis,np.newaxis]], axis=4)
        
        self.pos = nn.Parameter(torch.FloatTensor(zz))
        
        self.conv1 = ConvObj(-2,3,4,4)        
        self.conv2 = ConvObj(4,7,1,1)        
        self.conv3 = ConvObj(9,17,1,1)        
                
        colors = np.array(
            [[[[1.0, 0.0, 0.0, 1.0],
               [0.0, 1.0, 0.0, 1.0],
               [0.0, 0.0, 1.0, 1.0],
              ]]])            
        
        self.carriers = torch.zeros((1,1,1,1,2,CARRIERS)).cuda()
        
        for i in range(8):
            self.carriers[:,:,:,:,0,1+i] = cos(2*pi*i/8.0)
            self.carriers[:,:,:,:,1,1+i] = sin(2*pi*i/8.0)
        
        self.colors = nn.Parameter(torch.FloatTensor(colors).permute(2,3,0,1))

        self.periodic = True
    
    def set_periodic(self, periodic):
        self.periodic = periodic
        
        if periodic:
            self.conv1.conv.padding_mode = 'circular'
            self.conv2.conv.padding_mode = 'circular'
            self.conv3.conv.padding_mode = 'circular'
        else:
            self.conv1.conv.padding_mode = 'reflect'
            self.conv2.conv.padding_mode = 'reflect'
            self.conv3.conv.padding_mode = 'reflect'
            
    def forward(self, fields, eom): 
        advals = self.advp.load_params(eom)
        
        rho = fields[:,:DIM,:,:]
        old_mag = fields[:,DIM:DIM+1,:,:]
        vmem = fields[:,DIM+1:DIM+3,:,:].permute(0,2,3,1).unsqueeze(1)
        
        # Get the neighborhoods
        s1,v1 = self.conv1(rho)
        s2,v2 = self.conv2(rho)
        s3,v3 = self.conv3(rho)
        
        s1 = s1/torch.sqrt(0.1 + torch.sum(s1**2,1).unsqueeze(1))
        s2 = s2/torch.sqrt(0.1 + torch.sum(s2**2,1).unsqueeze(1))
        s3 = s3/torch.sqrt(0.1 + torch.sum(s3**2,1).unsqueeze(1))
               
        v1 = v1/torch.sqrt(0.1 + torch.sum(v1**2,4).unsqueeze(4))
        v2 = v2/torch.sqrt(0.1 + torch.sum(v2**2,4).unsqueeze(4))
        v3 = v3/torch.sqrt(0.1 + torch.sum(v3**2,4).unsqueeze(4))
                
        sym = torch.cat([s1,s2,s3],0) # 2*DIM
        
        vec = torch.cat([v1, v2, v3, vmem], 1)
        
        z = F.elu(F.linear(sym.view(3*DIM,RES*RES).transpose(1,0), weight = advals['coe1.weight'], bias = advals['coe1.bias']))
        z = F.elu(F.linear(z, weight = advals['coe2.weight'], bias = advals['coe2.bias']))
        
        adv = F.linear(z, weight=advals['coe_adv.weight'], bias=None)
        adv = adv.transpose(1,0).view(DIM, 3*DIM+1, RES, RES, 1)
        adv = adv/(1+torch.abs(adv))
        
        """ Advection """        
        adv = torch.sum(adv * vec, 1).unsqueeze(0) # 1 x DIM x RES x RES x 2
        adv = adv.unsqueeze(5)
        
        advn = torch.sqrt(torch.sum(adv**2,4).unsqueeze(4))
        
        new_mag = torch.sum(rho*advn.view(1,DIM,RES,RES),1) # 1 x RES x RES
        gate = torch.ge(new_mag, old_mag).float()
        
        old_mag = gate*new_mag + (1-gate)*old_mag*0.99
        
        newv = gate*(adv.mean(1).view(1,RES,RES,2).permute(0,3,1,2).contiguous())
        vmem = newv + (1-gate)*vmem.view(1,RES,RES,2).permute(0,3,1,2).contiguous()
        
        #vmem = vmem/torch.sqrt(1 + torch.sum(vmem**2,3).unsqueeze(3))
        
        adv = adv + self.carriers
        
        # Particle destinations
        if self.periodic:
            p_pos = ((self.pos + adv + 0.5 + 2*RES).long())%RES
        else:
            p_pos = mirror((self.pos + adv + 0.5)).long()
        
        p_pos = p_pos.transpose(5,4).contiguous()
        p_pos = p_pos.view(DIM,CARRIERS*RES*RES,2)
        p_pos = p_pos[:,:,0]*RES + p_pos[:,:,1]
                
        mass = rho.view(DIM,RES,RES,1).expand(DIM,RES,RES,CARRIERS) / CARRIERS
        mass = mass.view(DIM,RES*RES*CARRIERS)
                
        #zs = rho + v1[:,:,:,:,1]
        zs = torch.zeros(DIM,RES*RES,device="cuda")
        zs.scatter_add_(index=p_pos.expand(DIM,RES*RES*CARRIERS), src=mass, dim=1)
        zs = zs.view(1,DIM,RES,RES)

        diff = torch.cuda.FloatTensor([0.4,0.3,0.2]).view(1,3,1,1)
        
        zs = (1-diff)* zs + \
                  diff * 0.25 * (torch.cat([zs[:,:,:,1:], zs[:,:,:,:1]], 3) + \
                  torch.cat([zs[:,:,:,-1:], zs[:,:,:,:-1]], 3) + \
                  torch.cat([zs[:,:,1:,:], zs[:,:,:1,:]], 2) + \
                  torch.cat([zs[:,:,-1:], zs[:,:,:-1]], 2))
        
        zs = torch.cat([zs[:,:DIM,:,:], old_mag, vmem],1)
                                
        basecolor = 0.1
        
        colors = (zs[0,:DIM].unsqueeze(1)*self.colors).sum(0)
        colors = colors.permute(1,2,0)
        norm = zs[0,:DIM].sum(0).unsqueeze(2)

        view = (basecolor + colors/(0.1 + norm))
        
        view = torch.clamp(256*view,0,255).byte()
        
        #zs[:,:,:,0] = 0
        
        #ybar = torch.mean(zs[:,:DIM,:,:]*self.pos[:,:,:,:,1,0])/torch.mean(zs[:,:DIM,:,:])
        speed = 0.6
        
        """
        zs = speed * torch.cat([zs[:,:,:,1:], torch.zeros_like(zs[:,:,:,-1:])], 3) + (1-speed) * zs
        zs[:,0,:RES//3,-1] += 0.05
        zs[:,1,RES//3:2*RES//3,-1] += 0.05
        zs[:,2,2*RES//3:,-1] += 0.05
        """
        
        return zs, view

In [176]:
import pygame

# Model testing
MULT = 2
RES = 512
pygame.init()
screen = pygame.display.set_mode((MULT*RES,MULT*RES), 0, 24)

In [177]:
fields = torch.relu(torch.randn(1, DIM+3, RES, RES).cuda())*0 + 0.0
fields[DIM:] *= 0

for y in range(RES):
    for x in range(RES):
        if (x-RES//2)**2 + (y-RES//2)**2 < 128*128:
            fields[0,:DIM,x,y] = 1.5
eom = 4*torch.randn(282,).cuda()

In [178]:
#torch.save(eom, open("trails.pth","wb"))

In [179]:
model = AdvCA().cuda()
print(torch.cat([p.view(-1) for p in model.advp.parameters()],0).shape)

torch.Size([282])


In [180]:
from PIL import Image

!rm /sata/frames/*.png

frame = 0

while True:
    with torch.no_grad():
        fields, view = model.forward(fields, eom)
        
    for event in pygame.event.get():
        if event.type == pygame.KEYDOWN:
            if event.key == pygame.K_n:
                eom = 4*torch.randn(282,).cuda()
                fields = torch.relu(torch.randn(1, DIM+3, RES, RES).cuda())*0 + 0.1
                fields[DIM:] *= 0

                for y in range(RES):
                    for x in range(RES):
                        if (x-RES//2)**2 + (y-RES//2)**2 < 64*64:
                            fields[0,:DIM,x,y] = 0.5
                
    
    # Field mode
    view = view.cpu().detach().numpy().astype(np.uint8)[:,:,:3]
    s_image = np.repeat(np.repeat(view, MULT, axis=0), MULT, axis=1)
    
    if frame%2 == 0:
        im = Image.fromarray(s_image.transpose(1,0,2))
        #im.save("/sata/frames/%.6d.png" % (frame//2))
    frame += 1
    s_image = pygame.pixelcopy.make_surface(s_image)

    screen.blit(s_image, (0, 0))

    pygame.display.update()   
    time.sleep(0.005)
    

KeyboardInterrupt: 