In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from skimage.color import rgba2rgb
import cv2
from IPython.display import display, clear_output
from math import *
import time

from os import makedirs, path
from copy import deepcopy

from tqdm import tqdm

import pygame

  from .autonotebook import tqdm as notebook_tqdm


pygame 2.1.2 (SDL 2.0.18, Python 3.9.15)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# brush properties
r = 5
s = 1

def LMB_make(state, r=5, s=1):
    '''
    left click to make
    r: radius of brush
    s: smoothing / sigma
    '''
    xcl, ycl = pygame.mouse.get_pos()
    xcl, ycl = int(xcl/UPSCALE), int(ycl/UPSCALE)

    # radial blur
    xm, ym = torch.meshgrid(torch.linspace(-1, 1, 2*r), torch.linspace(-1, 1, 2*r))
    rm = torch.sqrt(xm**2 + ym**2).type(torch.double)
    blur = torch.exp(-rm**2 / s**2)
    blur = torch.where(rm <= 1., blur, 0.) # circular mask

    xslice = range(xcl - r, xcl + r)
    yslice = range(ycl - r, ycl + r)
    for count_i, i in enumerate(xslice):
        for count_j, j in enumerate(yslice):
            i = i % RESX
            j = j % RESY
            state[:, 1, i, j] = state[:, 1, i, j] + 5.
    return state


def RMB_del(state, r=5, s=1):
    '''
    right click to erase
    r: radius of eraser
    s: smoothing / sigma
    '''
    xcl, ycl = pygame.mouse.get_pos()
    xcl, ycl = int(xcl/UPSCALE), int(ycl/UPSCALE)

    # radial blur
    xm, ym = torch.meshgrid(torch.linspace(-1, 1, 2*r), torch.linspace(-1, 1, 2*r))
    rm = torch.sqrt(xm**2 + ym**2).type(torch.double)
    blur = (1 - torch.exp(-rm**2 / s**2))
    blur = torch.where(rm <= 1., blur, 1.) # circular mask

    xslice = range(xcl - r, xcl + r)
    yslice = range(ycl - r, ycl + r)
    for count_i, i in enumerate(xslice):
        for count_j, j in enumerate(yslice):
            i = i % RESX
            j = j % RESY
            state[:, 1, i, j] = state[:, 1, i, j] - 0.2
    return state

def print_something(something):
    fps = f'{something:.3f}'
    fps_text = font.render(fps, 1, pygame.Color("white"))
    fps_bg = pygame.Surface((fps_text.get_height(),fps_text.get_width()))  # the size of your rect
    fps_bg.set_alpha(50)                # alpha level
    fps_bg.fill((255,255,255))           # this fills the entire surface

    fps_surf = pygame.Surface((fps_bg.get_height(), fps_bg.get_width()))
    fps_surf.blit(fps_bg, (0, 0))
    fps_surf.blit(fps_text, (0, 0))
    return fps_surf

def plot_classification_scores(class_scores, correct_classes, width=100, height=400):
    num_samples = class_scores.shape[0]
    BAR_WIDTH = width
    BAR_HEIGHT = int( 0.8 * height / (4 * num_samples))
    BAR_SPACING = BAR_HEIGHT // 2
    BAR_Y_OFFSET = int(0.1 * height)

    CLASS_SPACING = 3*BAR_HEIGHT

    # Create a surface for drawing the graph
    graph_surface = pygame.Surface((width, height), pygame.SRCALPHA)
    graph_surface.fill((255, 255, 255, 63))  # Fill with transparent color

    for i in range(num_samples):
        score1 = class_scores[i][0]
        score2 = class_scores[i][1]
        correct_class = correct_classes[i]

        # # Normalize scores to fit in the graph
        # score1 = max(min(score1, 1), -1)  # Clamp scores to be within [-1, 1]
        # score2 = max(min(score2, 1), -1)

        # Draw the bars
        predicted = np.argmax([score1, score2])
        bar_colors = np.zeros((2, 3))
        GREEN = np.array([0, 255, 0])
        RED = np.array([255, 0, 0])

        if predicted == correct_class:
            # Highlight the correct class
            bar_colors[predicted] = GREEN
        else:
            bar_colors[predicted] = RED

        # Draw left bar
        # rect(left, top, width, height)
        bar1_top = i * (2 * BAR_HEIGHT + CLASS_SPACING) + BAR_Y_OFFSET
        bar1_width = int(score1 * BAR_WIDTH)
        pygame.draw.rect(graph_surface, bar_colors[0], (0, bar1_top, bar1_width, BAR_HEIGHT))

        # Draw right bar
        # bar2_top = (i + 1) * BAR_HEIGHT + i * CLASS_SPACING + BAR_SPACING + BAR_Y_OFFSET
        bar2_top = bar1_top + BAR_HEIGHT + BAR_SPACING
        bar2_width = int(score2 * BAR_WIDTH)
        pygame.draw.rect(graph_surface, bar_colors[1], (0, bar2_top, bar2_width, BAR_HEIGHT))

        # Add text label
        text = font.render(f'N={i + 2}', True, (0, 0, 0))
        text_rect = text.get_rect(center=(width//2, int(bar1_top-0.5*BAR_Y_OFFSET)))
        graph_surface.blit(text, text_rect)

    return graph_surface

def WHEEL_permute(cdim_order, direction):
    cdim_order = np.mod(np.add(cdim_order, direction), len(cdim_order))

    return cdim_order

def min_max(mat):
    return (mat - mat.min()) / (mat.max() - mat.min())

In [3]:
class Rule(nn.Module):
    def __init__(self,
                 CHANNELS=8,
                 FILTERS=1,
                 NET_SIZE=[16],
                 RES=50,
                 READIN_CHANNELS=1,
                 READOUT_CHANNELS=1,
                 READIN_SCALE=1,
                 READOUT_SCALE=1,
                 NUM_READOUT_HEADS=10):
        super().__init__()
        self.channels = CHANNELS
        self.filters = FILTERS
        self.net_size = NET_SIZE
        self.alpha = torch.nn.Parameter(torch.tensor([0.]))

        self.rin_channels = READIN_CHANNELS
        self.rout_channels = READOUT_CHANNELS
        self.rin_scale = READIN_SCALE
        self.rout_scale = READOUT_SCALE

        # for forward_perception
        self.ident = torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]).cuda()
        self.sobel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).cuda() / 8.0
        self.lap = torch.tensor([[1.0, 2.0, 1.0], [2.0, -12, 2.0], [1.0, 2.0, 1.0]]).cuda() / 16.0

        self.filters = [nn.Parameter(torch.randn(3, 3).cuda())
                        for i in range(FILTERS)]

        self.ws = [torch.nn.Conv2d(CHANNELS * (4 + FILTERS), NET_SIZE[0], 1)]
        self.ws += [torch.nn.Conv2d(NET_SIZE[i], NET_SIZE[i + 1], 1) for i in range(len(NET_SIZE) - 1)]
        self.ws += [torch.nn.Conv2d(NET_SIZE[-1], CHANNELS, 1)]


        # self.w1 = torch.nn.Conv2d(CHANNELS * (4 + FILTERS), HIDDEN, 1)
        # self.w1.bias.data.zero_()
        # self.w2 = torch.nn.Conv2d(HIDDEN, HIDDEN, 1)
        # self.w3 = torch.nn.Conv2d(HIDDEN, CHANNELS, 1)

        # read in layer is used to project from 1D -> CA grid via 3 input channels
        # readout layer is used to project from CA grid via 3 output channels -> num_classes (2)
        readin_res = int(READIN_SCALE*RES)
        readout_res = int(READOUT_SCALE*RES)
        self.readin = torch.nn.Linear(1, int(readin_res*readin_res*READIN_CHANNELS))
        self.readouts = [torch.nn.Linear(int(readout_res*readout_res*READOUT_CHANNELS), 2, bias=True).cuda() for i in range(NUM_READOUT_HEADS)]

        self.module_list = torch.nn.ModuleList(self.ws + self.readouts)
        self.parameter_list = torch.nn.ParameterList(self.filters)
        # self.w2.weight.data.zero_()
        ###########################################

class CA(nn.Module):
    def __init__(self,
                 CHANNELS=8,
                 FILTERS=1,
                 NET_SIZE=[16],
                 RES=50,
                 READIN_CHANNELS=1,
                 READOUT_CHANNELS=1,
                 READIN_SCALE=1,
                 READOUT_SCALE=1,
                 NUM_READOUT_HEADS=10):
        super().__init__()
        self.channels = CHANNELS
        self.filters = FILTERS
        self.net_size = NET_SIZE
        self.res = RES


        self.rule = Rule(CHANNELS, FILTERS, NET_SIZE, RES, READIN_CHANNELS, READOUT_CHANNELS, READIN_SCALE, READOUT_SCALE, NUM_READOUT_HEADS)

    def initGrid(self, BS):
        grid = torch.cuda.FloatTensor(2 * np.random.rand(BS, self.channels, self.res, self.res) - 1)
        # first channel is input channel
        # grid[:, -1, ...] *= 0.
        return grid * 0.

    def seed(self, RES, n):
        seed = torch.randn(n, self.channels, RES, RES)
        return seed

    def perchannel_conv(self, x, filters):
        '''filters: [filter_n, h, w]'''
        b, ch, h, w = x.shape
        y = x.reshape(b * ch, 1, h, w)
        y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
        y = torch.nn.functional.conv2d(y, filters[:, None])
        return y.reshape(b, -1, h, w)

    def perception(self, x):
        filters = [self.rule.ident, self.rule.sobel_x, self.rule.sobel_x.T, self.rule.lap]
        filters = filters + self.rule.filters
        return self.perchannel_conv(x, torch.stack(filters))

    def get_living_mask(self, x, alive_thres=0):
        alpha_channel = x[:, 0:1, :, :]
        R = 1
        d = 2*R+1
        alpha_channel = F.pad(alpha_channel, (R, R, R, R), mode='circular')

        alive_mask = F.max_pool2d(alpha_channel, kernel_size=d, stride=1, padding=R).abs() > alive_thres
        alive_mask = alive_mask[:, :, R:-R, R:-R]

        return alive_mask

    def forward(self, x, dt=1, update_rate=1.):
        b, ch, h, w = x.shape
        pre_mask = self.get_living_mask(x)
        y = self.perception(x)

        for layer_i in self.rule.ws[:-1]:
            y = F.leaky_relu(layer_i(y))
        y = self.rule.ws[-1](y)
        # y = self.rule.w3(y)


        update_mask = (torch.rand(b, 1, h, w) + update_rate).floor().cuda()
        y = dt * y * update_mask
        y = y * pre_mask
        # y = dt * y
        # keep the first channel empty for inputs
        # y[:, -1, ...] *= 0.
        # res = torch.clamp(x + y, 0, 1)
        alpha = torch.sigmoid(self.rule.alpha)
        res = (1 - alpha) * x + alpha * y
        post_mask = self.get_living_mask(res)
        res = res*post_mask
        # res = F.leaky_relu(x + y)
        return res

# Training

Let's train the model now in the N-Parity task. We will project the 1-D timeseries in 3 of the 8 channel dimensions of all the cells in the grid (size=1 x RESxRESx3). We will then read out from 3 different channels from the entire grid to get a classification (size=RESxRESx3 x 1).

In [4]:
def generate_binary_sequence(M):
    return (torch.rand(M) < 0.5) * 2. - 1.

def make_batch_Nbit_pair_parity(Ns, M, bs):
    with torch.no_grad():
        sequences = [generate_binary_sequence(M).unsqueeze(-1) for i in range(bs)]
        labels = [torch.stack([get_parity(s, N) for s in sequences]) for N in Ns]

    return torch.stack(sequences), labels

def get_parity(vec, N):
    return  (((vec + 1)/2)[-N:].sum() % 2).long()

def pad_to(mat, shape_to):
    shape = mat.shape
    # shape diff
    shd = [shape_to[0]-shape[2], shape_to[1]-shape[3]]
    pad = [shd[0]//2, shd[0]//2, shd[1]//2, shd[1]//2]
    return F.pad(mat, pad, mode='constant')

In [5]:
criterion = nn.CrossEntropyLoss()

In [6]:
CHANNELS=3
FILTERS=0
NET_SIZE=[16]
NUM_READOUT_HEADS=100

READIN_CHANNELS = 1
READIN_SCALE = 1
READOUT_CHANNELS = 1
READOUT_SCALE = 1

RES = 20
BATCH_SIZE = 128

device = 'cuda'
ca = CA(CHANNELS=CHANNELS,
        FILTERS=FILTERS,
        NET_SIZE=NET_SIZE,
        RES=RES,
        READIN_CHANNELS=READIN_CHANNELS,
        READOUT_CHANNELS=READOUT_CHANNELS,
        READIN_SCALE=READIN_SCALE,
        READOUT_SCALE=READOUT_SCALE,
        NUM_READOUT_HEADS=NUM_READOUT_HEADS,
        ).to(device)

In [7]:
total_numel = 0
for n, p in ca.named_parameters():
    print(f'{n:<20} {p.numel():>5}')
    total_numel += p.numel()

print(f'{"Total # parameters:":<20} {total_numel:>5}')

rule.alpha               1
rule.readin.weight     400
rule.readin.bias       400
rule.module_list.0.weight   192
rule.module_list.0.bias    16
rule.module_list.1.weight    48
rule.module_list.1.bias     3
rule.module_list.2.weight   800
rule.module_list.2.bias     2
rule.module_list.3.weight   800
rule.module_list.3.bias     2
rule.module_list.4.weight   800
rule.module_list.4.bias     2
rule.module_list.5.weight   800
rule.module_list.5.bias     2
rule.module_list.6.weight   800
rule.module_list.6.bias     2
rule.module_list.7.weight   800
rule.module_list.7.bias     2
rule.module_list.8.weight   800
rule.module_list.8.bias     2
rule.module_list.9.weight   800
rule.module_list.9.bias     2
rule.module_list.10.weight   800
rule.module_list.10.bias     2
rule.module_list.11.weight   800
rule.module_list.11.bias     2
rule.module_list.12.weight   800
rule.module_list.12.bias     2
rule.module_list.13.weight   800
rule.module_list.13.bias     2
rule.module_list.14.weight   800
rule.modul

In [8]:
def forward_pass(ca, sequences, num_readouts=1):
    ridx = np.random.choice(POOL.shape[0], BATCH_SIZE)
    state = POOL[ridx, ...].cuda()

    # for t in range(warmup_time):
    #     readin_patch = ca.rule.readin(torch.zeros(BATCH_SIZE, 1, 1).cuda()).reshape(BATCH_SIZE, READIN_CHANNELS, int(READIN_SCALE*RES), int(READIN_SCALE*RES))
    #     readin_patch = pad_to(readin_patch, (RES, RES))
    #     state[:, -READIN_CHANNELS:, ...] = state[:, -READIN_CHANNELS:, ...] + readin_patch
    #     state = ca(state)

    readin_res = int(READIN_SCALE*RES)
    for t in range(timesteps):
        readin_patch = ca.rule.readin(sequences[:, [t], 0]).reshape(BATCH_SIZE, READIN_CHANNELS, readin_res, readin_res)
        readin_patch = pad_to(readin_patch, (RES, RES))

        readin_mask = pad_to(-torch.ones(1, READIN_CHANNELS, RES//2, RES//2), (RES, RES)).cuda() + 1
        readin_patch = readin_patch * readin_mask
        state[:, -READIN_CHANNELS:, ...] = state[:, -READIN_CHANNELS:, ...] + readin_patch
        state = ca(state)

    if np.random.rand() > 0.5:
        POOL[ridx] = state.detach().cpu()
    else:
        POOL[ridx] = ca.seed(RES, BATCH_SIZE)

    readout_radius = int(RES*READOUT_SCALE*0.5)
    readout_patch = state[:, :READOUT_CHANNELS, RES//2 - readout_radius:RES//2 + readout_radius, RES // 2 - readout_radius:RES // 2 + readout_radius]

    # readout_mask_res = readin_res * 2
    # mask = torch.ones((1, 1, RES, RES)).cuda() - pad_to(torch.ones((1, 1, readout_mask_res, readout_mask_res)).cuda(), (RES, RES))
    # readout_patch = (readout_patch * mask).reshape(BATCH_SIZE, -1)
    readout_patch = (readout_patch).reshape(BATCH_SIZE, -1)
    # readout_patch = F.max_pool2d(readout_patch, 2).reshape(BATCH_SIZE, -1)
    outs = [l_r(readout_patch) for l_r in ca.rule.readouts[:num_readouts]]


    return outs


In [9]:
num_epochs = 20
num_training_steps = 250
warmup_time = 0
optim = torch.optim.Adam(ca.parameters(), lr=1e-2)
POOL_SIZE = 1000
POOL = ca.seed(RES, POOL_SIZE)

# optim = torch.optim.SGD(ca.parameters(), lr=1e-3, momentum=0.9, nesterov=True)

# task details
Ns = [2]
# Ns = list(np.arange(2, 11))
num_extra_heads = 1
k_factor = 1
min_T = 10 + (Ns[-1]) * k_factor
max_T = 10 + min_T + 3*Ns[-1] * k_factor

loss_hist = []
accuracies =[]
print('Training started...')
for i_epoch in tqdm(range(num_epochs)):
    for i in range(num_training_steps):
        optim.zero_grad()

        timesteps = np.random.randint(min_T, max_T)
        # np.random.randint(timesteps//5, timesteps//3)
        # thinking_time = np.random.randint(10, 30)

        sequences, labels = make_batch_Nbit_pair_parity(Ns, timesteps, BATCH_SIZE)
        sequences = sequences.to(device)
        sequences = sequences.repeat_interleave(k_factor, dim=1)
        timesteps = sequences.shape[1]
        labels = [l.to(device) for l in labels]

        outs = forward_pass(ca=ca, sequences=sequences, num_readouts=len(Ns))

        # Backward and optimize
        loss = 0.
        for N_i in range(len(Ns)):
            loss += criterion(outs[N_i], labels[N_i])
        loss_hist.append(loss.item())
        loss.backward()
        nn.utils.clip_grad_norm_(ca.parameters(), max_norm=2.0, norm_type=2)  # gradient clipping
        optim.step()

        # Test and measure accuracy
        correct_N = np.zeros_like(Ns)
        total = 0
        if (i + 1) % 25 == 0:
            with torch.no_grad():
                timesteps = np.random.randint(min_T, max_T)
                # warmup_time = 50
                # warmup_time = np.random.randint(timesteps//3, timesteps)
                # thinking_time = np.random.randint(10, 30)

                sequences, labels = make_batch_Nbit_pair_parity(Ns, timesteps, BATCH_SIZE)
                sequences = sequences.to(device)
                sequences = sequences.repeat_interleave(k_factor, dim=1)
                timesteps = sequences.shape[1]
                labels = [l.to(device) for l in labels]

                outs = forward_pass(ca=ca, sequences=sequences, num_readouts=len(Ns))

                for N_i in range(len(Ns)):
                    predicted = torch.max(outs[N_i], 1)[1]

                    correct_N[N_i] += (predicted == labels[N_i]).sum()
                    total += labels[N_i].size(0)

            accuracy = 100 * correct_N / float(total) * len(Ns)
            accuracies.append(accuracy)

            print(f'Epoch: {i_epoch+1}/{num_epochs}, Step: {i+1}/{num_training_steps}, '
                  f'Loss: {loss_hist[-1]:.4f}, Accuracy: {np.mean(accuracy):.2f}')
            print('({N}, accuracy):\n' + ''.join([f'({Ns[i]}, {accuracy[i]:.4f})\n' for i in range(len(Ns))]), flush=True)


            if np.mean(accuracy) > 98:
                if accuracy[-1] > 98:
                    if len(Ns) == NUM_READOUT_HEADS:
                        break
                    print(f'Solved N = {Ns[-1]}, starting N = {Ns[-1]} + {num_extra_heads}')
                    Ns += [Ns[-1] + i for i in range(1, num_extra_heads+1)]
                    min_T = 10 + (Ns[-1]) * k_factor
                    max_T = 10 + min_T + 3*Ns[-1] * k_factor


Training started...


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 1/20, Step: 25/250, Loss: 0.5920, Accuracy: 84.38
({N}, accuracy):
(2, 84.3750)

Epoch: 1/20, Step: 50/250, Loss: 0.0012, Accuracy: 100.00
({N}, accuracy):
(2, 100.0000)

Solved N = 2, starting N = 2 + 1
Epoch: 1/20, Step: 75/250, Loss: 0.0790, Accuracy: 100.00
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)

Solved N = 3, starting N = 3 + 1
Epoch: 1/20, Step: 100/250, Loss: 0.0540, Accuracy: 100.00
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)

Solved N = 4, starting N = 4 + 1
Epoch: 1/20, Step: 125/250, Loss: 0.0917, Accuracy: 100.00
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)

Solved N = 5, starting N = 5 + 1
Epoch: 1/20, Step: 150/250, Loss: 0.1080, Accuracy: 99.06
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 95.3125)

Epoch: 1/20, Step: 175/250, Loss: 0.0323, Accuracy: 100.00
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 100.0000)

Solved N = 6, starting N = 6 + 1
Epoch: 1/

  5%|▌         | 1/20 [00:26<08:19, 26.29s/it]

Epoch: 2/20, Step: 25/250, Loss: 0.0602, Accuracy: 99.67
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 100.0000)
(7, 99.2188)
(8, 98.4375)

Solved N = 8, starting N = 8 + 1
Epoch: 2/20, Step: 50/250, Loss: 0.2088, Accuracy: 99.12
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 100.0000)
(7, 100.0000)
(8, 100.0000)
(9, 92.9688)

Epoch: 2/20, Step: 75/250, Loss: 0.2794, Accuracy: 99.80
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 100.0000)
(7, 100.0000)
(8, 100.0000)
(9, 98.4375)

Solved N = 9, starting N = 9 + 1
Epoch: 2/20, Step: 100/250, Loss: 0.2351, Accuracy: 99.31
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 99.2188)
(7, 100.0000)
(8, 100.0000)
(9, 98.4375)
(10, 96.0938)

Epoch: 2/20, Step: 125/250, Loss: 0.1236, Accuracy: 99.91
({N}, accuracy):
(2, 100.0000)
(3, 100.0000)
(4, 100.0000)
(5, 100.0000)
(6, 100.0000)
(7, 100.0000)
(8, 100.0000)
(9, 99.2188)
(10, 

 10%|█         | 2/20 [01:11<10:40, 35.57s/it]

Solved N = 12, starting N = 12 + 1




KeyboardInterrupt



Test for longer timescales

In [None]:
correct_N = np.zeros_like(Ns)
total = 0
with torch.no_grad():
    timesteps = np.random.randint(1000, 2000)
    # warmup_time = 50
    # warmup_time = np.random.randint(timesteps//3, timesteps)
    # thinking_time = np.random.randint(10, 30)

    sequences, labels = make_batch_Nbit_pair_parity(Ns, timesteps, BATCH_SIZE)
    sequences = sequences.to(device)
    sequences = sequences.repeat_interleave(k_factor, dim=1)
    timesteps = sequences.shape[1]
    labels = [l.to(device) for l in labels]

    outs = forward_pass(ca=ca, sequences=sequences, num_readouts=len(Ns))

    for N_i in range(len(Ns)):
        predicted = torch.max(outs[N_i], 1)[1]

        correct_N[N_i] += (predicted == labels[N_i]).sum()
        total += labels[N_i].size(0)

accuracy = 100 * correct_N / float(total) * len(Ns)
accuracies.append(accuracy)

print(f'Epoch: {i_epoch+1}/{num_epochs}, Step: {i+1}/{num_training_steps}, '
      f'Loss: {loss_hist[-1]:.4f}, Accuracy: {np.mean(accuracy):.2f}')
print('({N}, accuracy):\n' + ''.join([f'({Ns[i]}, {accuracy[i]:.4f})\n' for i in range(len(Ns))]), flush=True)

# Visualize in PyGame

In [None]:
# RES = 200
# ca.res = RES

In [11]:
# pygame stuff
######################################
RESX, RESY = RES, RES
state = ca.initGrid(BS=1)

pygame.init()
size = RESX, RESY

win = pygame.display.set_mode((RESX, RESY))

screen = pygame.Surface(size)
UPSCALE = 20
RESXup, RESYup = int(RESX*UPSCALE), int(RESY*UPSCALE)
upscaled_screen = pygame.display.set_mode([RESXup, RESYup])
FPS_init = 250
FPS = int(1*FPS_init)


running = True
time_ticking = True
LMB_trigger = False
RMB_trigger = False
WHEEL_trigger = False
cdim_order = np.arange(0, state.shape[1])

do_task = False
thinking_time = 0
task_ticker = 0
t = 0
max_readout = 10
correct = []

clock = pygame.time.Clock()
font_h = pygame.font.SysFont("Noto Sans", 24)
font = pygame.font.SysFont("Noto Sans", 12)
def update_fps(clock, font):
    fps = str(int(clock.get_fps()))
    fps_text = font.render(fps, 1, pygame.Color("white"))
    fps_bg = pygame.Surface((fps_text.get_height(),fps_text.get_width()))  # the size of your rect
    fps_bg.set_alpha(50)                # alpha level
    fps_bg.fill((255,255,255))           # this fills the entire surface

    fps_surf = pygame.Surface((fps_bg.get_height(), fps_bg.get_width()))
    fps_surf.blit(fps_bg, (0, 0))
    fps_surf.blit(fps_text, (0, 0))
    return fps_surf
######################################


update_rate = 1.
ticker = 0.

export_imgs = False
imgs = []

with torch.no_grad():
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False

            if event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 1:
                    LMB_trigger = True
                if event.button == 3:
                    RMB_trigger = True
            if event.type == pygame.MOUSEBUTTONUP:
                if event.button == 1:
                    LMB_trigger = False
                if event.button == 3:
                    RMB_trigger = False

            if event.type == pygame.MOUSEWHEEL:
                WHEEL_trigger = True
                direction = -event.y

            if event.type == pygame.MOUSEBUTTONUP and event.button == 2:
                # scroll through channel dims
                cdim_order = np.arange(0, state.shape[1])
            if event.type == pygame.KEYDOWN and event.key == pygame.K_e:
                export_imgs = not export_imgs
            if event.type == pygame.KEYDOWN and event.key == pygame.K_p:
                # pause/toggle time
                time_ticking = not time_ticking
            if event.type == pygame.KEYDOWN and event.key == pygame.K_s:
                # start a task
                do_task = not do_task
                if not do_task:
                    FPS = 120
                task_ticker = 0
                timesteps = 1_000

                sequence = generate_binary_sequence(timesteps)
                sequence = sequence.to(device)
                labels = [torch.stack([get_parity(sequence[ii - N:ii], N) for ii in range(N, len(sequence))]) for N in Ns]
                labels = [l.to(device) for l in labels]
                sequence = sequence.repeat_interleave(k_factor)
                labels = [l.repeat_interleave(k_factor) for l in labels]
                timesteps = len(sequence)

                correct_N = []

                # state = ca.initGrid(BS=1)

                ridx = np.random.choice(POOL.shape[0])
                state = POOL[[ridx], ...].cuda()


            if event.type == pygame.KEYDOWN and event.key == pygame.K_r:
                # start from seed
                state = ca.initGrid(BS=1)

        mouse_pos = pygame.mouse.get_pos()
        if LMB_trigger:
            state = LMB_make(state, r=r, s=s)
        if RMB_trigger:
            state = RMB_del(state, r=r, s=s)

        if WHEEL_trigger:
            cdim_order = WHEEL_permute(cdim_order, direction)
            WHEEL_trigger = False

        # nx = state[0, cdim_order[0], :, :].cpu().numpy()
        nx = min_max(state[0, cdim_order[0:3], :, :].cpu().numpy().transpose(1, 2, 0))
        # nx = min_max(state[0, :, :, :].mean(dim=0).cpu().numpy())
        nx = nx * 255.

        if time_ticking:

            if do_task:
                if task_ticker < (timesteps + warmup_time) and task_ticker < timesteps - Ns[-1]*k_factor - 1:
                    if task_ticker > warmup_time:
                        t = task_ticker - warmup_time

                        readin_res = int(READIN_SCALE*RES)
                        readin_patch = ca.rule.readin(sequence[t].unsqueeze(0)).reshape(1, READIN_CHANNELS, readin_res, readin_res)
                        readin_patch = pad_to(readin_patch, (RES, RES))

                        readin_mask = pad_to(-torch.ones(1, READIN_CHANNELS, RES//2, RES//2), (RES, RES)).cuda() + 1
                        readin_patch = readin_patch * readin_mask
                        state[:, -READIN_CHANNELS:, ...] = state[:, -READIN_CHANNELS:, ...] + readin_patch


                        # readin_res = int(READIN_SCALE*RES)
                        # readin_patch = ca.rule.readin(sequence[t].unsqueeze(0)).reshape(1, READIN_CHANNELS, readin_res, readin_res)
                        # readin_patch = pad_to(readin_patch, (RES, RES))
                        # state[:, -READIN_CHANNELS:, ...] = state[:, -READIN_CHANNELS:, ...] + readin_patch
                    task_ticker += 1
                else:
                    do_task = False
                    t = 0

            state = ca.forward(state)
            ticker += 1
            # if do_task:
                # readout_patch = state[...,center-r:center+r+1, center-r:center+r+1].reshape(BATCH_SIZE, -1)
                # out = ca.rule.readout(readout_patch)

            if export_imgs:
                imgs.append(nx)

        pygame.surfarray.blit_array(screen, nx)
        frame = pygame.transform.scale(screen, (RESXup, RESYup))

        upscaled_screen.blit(frame, frame.get_rect())
        upscaled_screen.blit(update_fps(clock, font), (10,0))
        if do_task:
            minimum_wait_time = t - Ns[-1]*k_factor - thinking_time + 1

            if minimum_wait_time >= 0 and minimum_wait_time % k_factor == 0: # only do readout on the end of the time-dilation to be the same as training
                FPS=10
                # readout_patch = state[:, :READOUT_CHANNELS, ...]
                # mask = torch.ones((1, 1, RES, RES)).cuda() - pad_to(torch.ones((1, 1, readin_res, readin_res)).cuda(), (RES, RES))
                # readout_patch = (readout_patch * mask).reshape(1, -1)

                readout_radius = int(RES*READOUT_SCALE*0.5)
                readout_patch = state[:, :READOUT_CHANNELS, RES//2 - readout_radius:RES//2 + readout_radius, RES // 2 - readout_radius:RES // 2 + readout_radius]
                readout_patch = (readout_patch).reshape(1, -1)

                outs = [l_r(readout_patch) for l_r in ca.rule.readouts[:len(Ns)]]

                # each task N has labels if different length that are indexed differently.
                t_label_N = [t - N*k_factor - thinking_time + 1 for N in Ns]
                # used for the histograms
                label_t = torch.stack([l[t_label_N[il]].cpu() for il, l in enumerate(labels[:max_readout])]).numpy()

                correct_N = []
                for N_i in range(len(Ns)):
                    predicted = torch.max(outs[N_i], 1)[1][0]
                    correct_N.append((predicted == labels[N_i][t_label_N[N_i]]).sum().cpu().numpy())

                if len(correct) < 1000:
                    correct.append(correct_N)
                else:
                    correct = [correct_N] + correct[1:]

            if minimum_wait_time >= 0:
                # plot the histograms
                graph_surf = plot_classification_scores(
                    F.softmax(
                        torch.stack(outs)[:max_readout, ...].squeeze(1).cpu(), dim=1
                    ).numpy(), label_t, width=RESXup//6, height=int(len(label_t) * RESYup//10))
                upscaled_screen.blit(graph_surf, (0,30))
                upscaled_screen.blit(print_something(100 * np.mean(correct)), (10,20))

        pygame.display.flip()
        clock.tick(FPS)

pygame.quit()


  return (mat - mat.min()) / (mat.max() - mat.min())


In [None]:
plt.hist(state.reshape(-1).cpu().numpy(), 100); plt.yscale('log')