In [143]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/Users/Heysoos/Documents/Pycharm Projects/Dissertation/03_adaptiveCA')

import numpy as np
import matplotlib.pyplot as plt

import torch
from tqdm.auto import tqdm
import cv2
import time

from models.MNNCA import CA, totalistic
import pygame
from src.utils import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Init

In [144]:
CHANNELS=4 # number of channels in grid
FILTERS=4 # number of filters per channel
NET_SIZE=[32, 32, 32] # number of hidden neurons per layer

# RADIUS=11
RADIUS=[5] * 4

ca = CA(CHANNELS=CHANNELS, FILTERS=FILTERS, NET_SIZE=NET_SIZE, RADIUS=RADIUS).cuda() 

In [145]:
# if it's not the slackermanz implementation
if not isinstance(RADIUS, list):
    kernels = torch.cat([k for k in ca.rule.kernels], dim=0).cpu().detach().numpy()

    num_plot_kernels = CHANNELS
    fig, axes = plt.subplots(FILTERS, num_plot_kernels, figsize=(CHANNELS, 1.3*FILTERS))
    
    for i in range(FILTERS):
        if FILTERS > 1:
            for j in range(num_plot_kernels):
                kplot = kernels[i, j, :, :]
                kmax = np.max(np.abs(kplot))
                axes[i, j].imshow(kplot, vmin=-kmax, vmax=kmax)
                axes[i, j].axis('off')
        else:
            for j in range(num_plot_kernels):
                kplot = kernels[i, j, :, :]
                kmax = np.max(np.abs(kplot))
                axes[j].imshow(kplot, vmin=-kmax, vmax=kmax)
                axes[j].axis('off')

In [146]:
# brush properties
r = 20
s = 1

In [158]:
# resolution of grid
RESX=256
RESY=256

# pygame stuff
######################################
pygame.init()
size = RESX, RESY
# screen = pygame.display.set_mode(size)
screen = pygame.Surface(size)
UPSCALE = 2
RESXup, RESYup = int(RESX*UPSCALE), int(RESY*UPSCALE)
upscaled_screen = pygame.display.set_mode([RESXup, RESYup])

running = True
time_ticking = True
LMB_trigger = False
RMB_trigger = False
WHEEL_trigger = False
brush_toggle = False
cdim_order = np.arange(0, CHANNELS)

clock = pygame.time.Clock()
font = pygame.font.SysFont("Noto Sans", 12)
######################################

# CENTER SEED
seed = torch.cuda.FloatTensor(np.zeros((CHANNELS, RESX, RESY))).unsqueeze(0)
seed[:, 3:, RESX//2, RESY//2] = 1
state = seed.clone()

update_rate = 1.
ticker = 0.
sink = False
export_imgs = False
imgs = []

with torch.no_grad():
    while running:              
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
                
            '''
            Keyboard shortcuts:
            - [ and ] to change update_rate
            - t to toggle totalistic rule
            - g to toggle growth kernel
            - s to toggle sink
            - e to toggle export_imgs for video
            - o to save current state as png
            - p to pause
            - r to reset
            - scroll wheel to permute through channel dims
            - LMB/RMB to make/delete
            '''
                
            if event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 1:
                    LMB_trigger = True
                if event.button == 3:
                    RMB_trigger = True
            if event.type == pygame.MOUSEBUTTONUP:
                if event.button == 1:
                    LMB_trigger = False
                if event.button == 3:
                    RMB_trigger = False
                    
            if event.type == pygame.MOUSEWHEEL:
                WHEEL_trigger = True
                direction = event.y
                
            if event.type== pygame.KEYDOWN and event.key == pygame.K_LEFTBRACKET:
                update_rate += -0.5
            if event.type== pygame.KEYDOWN and event.key == pygame.K_RIGHTBRACKET:
                update_rate += 0.5
            
            if event.type == pygame.MOUSEBUTTONUP and event.button == 2:
                # scroll through channel dims
                cdim_order = np.arange(0, state.shape[1])
                    
            if event.type == pygame.KEYDOWN and event.key == pygame.K_SPACE:
                # pick another random CA
                ca = CA(CHANNELS=CHANNELS, FILTERS=FILTERS, NET_SIZE=NET_SIZE, RADIUS=RADIUS).cuda() 
                
            if event.type== pygame.KEYDOWN and event.key == pygame.K_t:
                ca.rule.totalistic = not ca.rule.totalistic
            if event.type== pygame.KEYDOWN and event.key == pygame.K_b:
                brush_toggle = not brush_toggle
            if event.type== pygame.KEYDOWN and event.key == pygame.K_g:
                ca.rule.use_growth_kernel = not ca.rule.use_growth_kernel
            if event.type== pygame.KEYDOWN and event.key == pygame.K_s:
                sink = not sink
            if event.type== pygame.KEYDOWN and event.key == pygame.K_e:
                export_imgs = not export_imgs
            if event.type== pygame.KEYDOWN and event.key == pygame.K_o:
                # save current state as png
                timestr = time.strftime("%Y%m%d-%H%M%S")
                pygame.image.save(upscaled_screen, f"../figures/state_{timestr}.png")
            if event.type== pygame.KEYDOWN and event.key == pygame.K_p:
                # pause/toggle time
                time_ticking = not time_ticking
                
            if event.type== pygame.KEYDOWN and event.key == pygame.K_r:
                # start from seed 
                state = seed.clone()
                

        mouse_pos = pygame.mouse.get_pos()
        if LMB_trigger:
            state = click(state, rmb=False, r=r, s=s, upscale=UPSCALE, brush_toggle=brush_toggle)
            state.clamp(0, 1)
        if RMB_trigger:
            state = click(state, rmb=True, r=r, s=s, upscale=UPSCALE, brush_toggle=brush_toggle)
            state.clamp(0, 1)
        
        
        # scroll to permute through which channels get visualized
        if WHEEL_trigger:
            #cdim_order = WHEEL_permute(cdim_order, direction, CHANNELS)
            WHEEL_trigger = False
            
            update_rate = WHEEL_param(update_rate, direction, 1./3.)
            
        nx = state.cpu()[:, cdim_order].numpy()[0, 0:3, :, :].transpose(1, 2, 0)
#         nx = np.clip(nx, 0, 1)*255
        nx = min_max(nx) * 255

        if time_ticking:
#             state = ca.forward_perception(state, dt=1)
#             state = ca.forward_masked(state)
#             state = ca.forward(state, update_rate= update_rate)
            state = ca.forward_slacker(state)
            ticker += 1
    
        if sink and time_ticking:
            xv, yv = torch.linspace(-1, 1, RESX), torch.linspace(-RESY/RESX, RESY/RESX, RESY)
            X, Y = torch.meshgrid(xv, yv)
            amp = 0.00
            w = 30
            R = torch.sqrt(X**2 + Y**2)
            state = state - R.cuda()/3

        if export_imgs and time_ticking:
            # imgs.append(nx) # export img
            imgs.append(state) # export state

        pygame.surfarray.blit_array(screen, nx)        
        frame = pygame.transform.scale(screen, (RESXup, RESYup))
        upscaled_screen.blit(frame, frame.get_rect())
        upscaled_screen.blit(update_fps(clock, font), (10,0))
        upscaled_screen.blit(show_param_info(update_rate, 'update_rate'), (RESXup - 100,0))
        pygame.display.flip()
        clock.tick(15)
        
pygame.quit()


In [154]:
def overflowloss(state):
    return (state - state.clamp(-1, 1)).abs().sum()

In [155]:
xxx

NameError: name 'xxx' is not defined

In [None]:
CHANNELS=4
FILTERS=4
NET_SIZE=[32, 32]
RESX, RESY = 32, 32
RADIUS=2
BS=128

In [None]:
ca = CA(CHANNELS=CHANNELS, FILTERS=FILTERS, NET_SIZE=NET_SIZE, RADIUS=RADIUS).cuda() 

In [None]:
def seed(BS):
    seed = torch.cuda.FloatTensor(np.zeros((BS, CHANNELS, RESX, RESY)))
    seed[:, 3:, RESX//2, RESY//2] = 1
    
    return seed

POOL = torch.randn((500, CHANNELS, RESX, RESY))

In [None]:
# code to train the models to not explode:
num_epochs = 200
training_steps = 1000

lr = 1e-3
# optim = torch.optim.SGD(ca.parameters(), lr=lr)
optim = torch.optim.Adam(ca.parameters(), lr=lr)

loss_hist = []
for epoch in range(num_epochs):
    optim.zero_grad()
    
    forward_steps = np.random.randint(10, 25)
    
    pool_idx = np.random.randint(len(POOL) - BS)
    state = POOL[pool_idx:pool_idx+BS].cuda()
    state[0] = seed(1)

    for t in range(forward_steps):
        state = ca.forward(state, update_rate=1.)
        
    loss = overflowloss(state)
    loss_hist.append( loss.item() )
    loss.backward()
    
    POOL[pool_idx:pool_idx+BS] = state.detach().cpu()
    
    # save gradient info
    grads = []
    for n, p in ca.named_parameters():
        if p.requires_grad:
            if p.grad is not None:
                grads.append(p.grad.reshape(-1).cpu().data.numpy())
    grads = np.concatenate(grads)
    
    if epoch % 1 == 0:
        print(f'Epoch: {epoch}/{num_epochs},'
              f'Loss: {loss.item():.3f},'
              f'M_Activity: {state.detach().abs().mean():.3f},'
              f'|Grads*LR|: {np.mean(np.abs(grads)) * lr:.4f}')

    torch.nn.utils.clip_grad_norm_(ca.parameters(), 2.0)
    # xx
    optim.step()
        


In [None]:
del loss

In [None]:
for n, p in ca.named_parameters():
    if p.requires_grad:
        if p.grad is not None:
            print(f'Name: {n}, <|Grad|>: {p.grad.data.abs().mean():.4f}')
        else:
            print(f'Name: {n}, <|Grad|>: None!')

In [None]:
plt.plot(loss_hist, '.'); plt.yscale('log')

In [None]:
for n, p in ca.named_parameters():
    print(n)

In [None]:
ca.rule.kernels[0].shape