In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.util.image import *
from src.util import *
from src.algo import ca1

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
all_rules = ca1.Ca1ReplaceRules(num_states=3, num_neighbours=1)
f"{len(all_rules):,}"

In [None]:
for r in tqdm(range(len(all_rules))):
    pass

Aua! This is huge, maybe comming back another time...

# reservoir computing

### create dataset

In [None]:
def to_float_rep(f: float):
    b = np.array([f]).astype(np.float32).data.tobytes()
    return [
        (b[i // 8] >> (i % 8)) & 1
        for i in range(8 * 4)
    ]

to_float_rep(23.5)

In [None]:
NUM_BITS = 9
NUM_TEST_SAMPLES = 100

data = torch.zeros((2**NUM_BITS, NUM_BITS), dtype=torch.uint8)
for i in range(data.shape[0]):
    for k in range(data.shape[1]):
        data[i, k] = (i >> k) & 1

if 0:  # copy
    targets = data    
elif 0:  # xor
    assert NUM_BITS % 2 == 0
    targets = data[:, :5] ^ data[:, 5:]
elif 0:  # square
    targets = torch.Tensor([
        [((i * i % 251) >> j) & 1 for j in range(8)]
        for i in range(data.shape[0])
    ]).to(data)
elif 0:  # sqrt
    targets = torch.Tensor([
        [(int(math.sqrt(i)) >> j) & 1 for j in range((NUM_BITS + 1) // 2)]
        for i in range(data.shape[0])
    ]).to(data)

else:  # float-rep
    targets = torch.Tensor([
        to_float_rep(float(i))
        for i in range(data.shape[0])
    ]).to(data)

test_indices = torch.randperm(data.shape[0], generator=torch.Generator().manual_seed(23))[:NUM_TEST_SAMPLES]
train_indices = torch.Tensor(sorted(set(range(data.shape[0])) - set(test_indices.tolist()))).to(torch.int64)

data_train, data_test = data[train_indices], data[test_indices] 
targets_train, targets_test = targets[train_indices], targets[test_indices]
print("input: ", data_train.shape, data_test.shape)
print("target:", targets_train.shape, targets_test.shape)
print("sample:")
for i in (0, 1, 2, 3, -3, -2, -1):
    print(f"{i:2}: {data[i].tolist()} -> {targets[i].tolist()}")

### render states for each input

In [None]:
STATE_WIDTH = 200
MIN_INPUT_DIST = 3

seed = 50
while True:
    mapping_indices = torch.randperm(STATE_WIDTH, generator=torch.Generator().manual_seed(seed))[:NUM_BITS].sort().values
    diffs = mapping_indices[1:] - mapping_indices[:-1]
    if diffs.min() >= MIN_INPUT_DIST:
        break
    seed += 1
        
print(f"seed: {seed}\nindices: {mapping_indices}")

In [None]:
def calc_ca_states(
    rule: int,
    num_states: int = 2,
    num_neighbours: int = 1,
    iterations: int = 2000,
    output_steps: int = 100,
    wrap: bool = False,
    seq_input: str = "none", 
    seq_input_stride: int = 7,
    verbose: bool = True,
):
    states = []
    for batch in tqdm(iter_batches(data, batch_size=256), disable=not verbose):
        input = torch.zeros(batch.shape[0], STATE_WIDTH, dtype=torch.uint8)

        if 1:  # direct input
            input[:, mapping_indices] = batch
            seq_input = None
        
        if seq_input == "sequential":  # seq input
            seq_input = torch.zeros(batch.shape[0], batch.shape[-1], STATE_WIDTH, dtype=torch.uint8)
            for idx in range(batch.shape[-1]):
                seq_input[:, idx, mapping_indices[idx]] = batch[:, idx]
            # display(VF.to_pil_image(resize(seq_input[1:2], 4) * 255))

        if seq_input == "repeat":  # seq input
            seq_input = torch.zeros(batch.shape[0], 20, STATE_WIDTH, dtype=torch.uint8)
            seq_input[:, :, mapping_indices] = batch.unsqueeze(1).repeat(1, 20, 1)

        state = ca1.ca1_replace_step(
            input=input,
            lookup=ca1.Ca1ReplaceRules(num_states=num_states, num_neighbours=num_neighbours).lookup(rule),
            num_states=num_states,
            num_neighbours=num_neighbours,
            iterations=iterations,
            wrap=wrap,
            seq_input=seq_input,
            seq_input_stride=seq_input_stride,
            seq_input_mode="add",
        )
        states.append(state[:, -output_steps:, :].flatten(1))
    states = torch.concat(states)

    if verbose:
        img = make_grid([
            states[i].view(1, -1, STATE_WIDTH)
            for i in (1, 23, 42, -1)
        ]).float() / states.max()
        display(VF.to_pil_image(resize(img, 5)))
    
    return states

# 30, 225, 18, 60, 181
states = calc_ca_states(22)
#states = calc_ca_states(193426, num_states=2, num_neighbours=2)
states_train, states_test = states[train_indices], states[test_indices]
states.shape

### linear readout

In [None]:
from sklearn.linear_model import Ridge
ridge = Ridge()
ridge.fit(states_train, targets_train)
for kind, states_, target_ in (
        ("train", states_train, targets_train), 
        ("test", states_test, targets_test),
):
    output_real = ridge.predict(states_)
    output = (output_real >= .5).astype(np.float32)
    num_correct = np.sum(np.sum(output == np.array(target_), axis=-1) == target_.shape[-1])
    mae = (torch.Tensor(output_real) - target_).abs().mean()
    print(f"correct {kind:5}: {num_correct:4} / {target_.shape[0]:4}  (MAE {mae:.5f})")

In [None]:
img = make_grid(torch.Tensor(ridge.coef_).view(targets.shape[-1], 1, -1, STATE_WIDTH), normalize=False, nrow=3)
img = signed_to_image(img) * 10
VF.to_pil_image(resize(img.clamp(0, 1), 5))
#VF.to_pil_image(resize(signed_to_image(make_grid(torch.Tensor(ridge.coef_).view(targets.shape[-1], 1, -1, STATE_WIDTH), normalize=False, nrow=3))*10, 5))

# run all rules

In [None]:
def run_all_rules():
    rows = []
    try:
        for params in iter_parameter_permutations({
            "iter_steps": [
                #(200, 50),
                (200, 100),
                #(500, 50),
                #(500, 50),
                #(500, 100),
                #(500, 200),
                #(500, 300),
                #(500, 400),
                #(2000, 50),
                #(2000, 100),
                #(2000, 200),
                #(2000, 500),
            ],
            "wrap": [True, False],
            "seq_input": ["none", "repeat", "sequential"],
            "seq_input_stride": [0, 1, 3, 7],
        }):
            iterations, output_steps = params["iter_steps"]
            wrap = params["wrap"]
            seq_input = params["seq_input"]
            seq_input_stride = params["seq_input_stride"]
            if seq_input == "none": 
                if seq_input_stride:
                    continue
            else:
                if not seq_input_stride:
                    continue
                
            for rule in tqdm(
                #range(len(ca1.Ca1ReplaceRules())),
                ca1.Ca1ReplaceRules().strong_rules,
                #[22],
                #[150, 105],
                desc=str(params),
                #desc=f"it={iterations}, out={output_steps}, wrap={wrap}", 
            ):        
                states = calc_ca_states(
                    rule=rule, iterations=iterations, output_steps=output_steps, wrap=wrap, seq_input=seq_input, 
                    verbose=False,
                )
                
                states_train, states_test = states[train_indices], states[test_indices]
        
                ridge = Ridge()
                ridge.fit(states_train, targets_train)
    
                row = {
                    "rule": rule, "iterations": iterations, "output_steps": output_steps, "wrap": wrap,
                    "seq_input": seq_input, "seq_input_stride": seq_input_stride,
                    "readout_var": np.std(ridge.coef_),
                }
                for kind, states_, target_ in (
                        ("train", states_train, targets_train), 
                        ("test", states_test, targets_test),
                ):
                    output_real = ridge.predict(states_)
                    output = (output_real >= .5).astype(np.float32)
                    num_correct = np.sum(np.sum(output == np.array(target_), axis=-1) == target_.shape[-1])
                    mae = (torch.Tensor(output_real) - target_).abs().mean()
                    
                    row.update({
                        f"correct_{kind}": num_correct,
                        f"mae_{kind}": float(mae),
                    })
                    #print(f"correct {kind:5}: {num_correct:4} / {target_.shape[0]:4}  (MAE {mae:.5f})")
                rows.append(row)
                
    except KeyboardInterrupt:
        pass

    return pd.DataFrame(rows).set_index(["rule", "iterations", "output_steps", "wrap", "seq_input", "seq_input_stride"])

df = run_all_rules().sort_values("correct_test", ascending=False).sort_values("correct_train", ascending=False, kind="stable")
df.head(50)

In [None]:
df.sort_values("correct_test", ascending=False).sort_values("correct_train", ascending=False, kind="stable").head(50)

In [None]:
df.sort_values("mae_train").sort_values("correct_train", ascending=False, kind="stable").head(50)#.to_csv("../experiments/logs/data/ca1-repl-copy.csv")

In [None]:
df_xor.head(50)

In [None]:
print((df_xor["correct_train"] == 924).sum())
print((df_xor["correct_test"] == 100).sum())

In [None]:
df = df.sort_values("correct_test", ascending=False).sort_values("correct_train", ascending=False, kind="stable")
#df = df_square.copy()
#df["correct_train"] = df["correct_train"].astype(np.str_)
print(df.head(50).droplevel(["iterations", "state_size"]).to_markdown(floatfmt="f", intfmt=""))

In [None]:
i1 = sorted((df_xor[df_xor["correct_train"] == 924]).index.get_level_values(0))
i2 = sorted((df_square[df_square["correct_train"] == 462]).index.get_level_values(0))
print(i1)
print(i2)

In [None]:
for r in i1:
    print(f"{r:3}: {ca1.Ca1ReplaceRules().lookup(r).tolist()}")

In [None]:
def plot_rules(rules, mark=[]):
    images = []
    for rule in rules:
        input = torch.zeros(101)
        input[input.shape[-1] // 2] = 1
        state = ca1.ca1_replace_step(
            input=input,
            lookup=ca1.Ca1ReplaceRules().lookup(rule),
            iterations=100,
            num_neighbours=1,
            wrap=True,
        )
        images.append(state.unsqueeze(0))
    labels = [str(r) + (" *" if r in mark else "") for r in rules]
    display(VF.to_pil_image(make_grid_labeled(images, labels=labels, nrow=5)))

plot_rules(i1, mark=[60, 102, 105, 150, 153, 195])
    