In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.util.image import *
from src.util import *
from src.algo import *

In [None]:
#def plot_state(state: torch.Tensor):
#    if state.shape[-1] >= 256:
#        image = state.view(1, *state.shape)
#    else:


In [None]:
#state = torch.Tensor([[1, 2, 3]])
#torch.concat([state[:, -1:], state, state[:, :1]], dim=-1)
2 ** 5

In [None]:
def ca_expand(
        input: torch.Tensor,
        kernel: torch.Tensor,
        iterations: int = 100,
        num_states: int = 2,
):
    kernel = kernel.view(1, 1, kernel.shape[-1])
    state = input.view(1, input.shape[-1])
    history = [state]
    for it in range(iterations):
        
        state = F.conv1d(state, kernel, padding=kernel.shape[-1] // 2) % num_states
        history.append(state)
    return torch.concat(history)
        
input = torch.zeros(257, dtype=torch.uint8)
input[128] = 1
#input[139] = 1
kernel = torch.Tensor([2, 4, 1]).to(torch.uint8)
state = ca_expand(input, kernel, num_states=5, iterations=1000)
px.imshow(state > 0, height=state.shape[0] * 3 + 50)

In [None]:
def plot_rules(num_states: int = 2, num_n: int = 1):
    kernel_size = 1 + num_n * 2
    kernels = set()
    states = []
    labels = []
    for i in range(num_states ** kernel_size):
        kernel = []
        for k in range(kernel_size):
            kernel.append((i // (num_states ** k)) % num_states) 

        input = torch.zeros(257, dtype=torch.uint8)
        input[128] = 1
        #input[139] = 1
        kernel = torch.Tensor(kernel).to(torch.uint8)
        state = ca_expand(input, kernel, num_states=num_states, iterations=256)
        states.append(state.unsqueeze(0) * 255 // (num_states - 1))
        labels.append(str(kernel.tolist()))
        
    display(VF.to_pil_image(
        make_grid_labeled(states, nrow=4, labels=labels)
    ))

plot_rules(num_states=2, num_n=1)

In [None]:
for i in range(1, 10):
    state = torch.ones(i)
    s = state.shape[-1]
    print(i, (3 - (s - (s // 3) * 3)) % 3)
    #state2 = F.pad(state, (1, (s - (s // 3) * 3)))
    #print(state.shape, state2.shape)

In [None]:
state = torch.linspace(1, 11, 11, dtype=torch.uint8)
l = (state.shape[-1] // 3 + 1) * 3
print(l)
state = F.pad(state, (1, 1 + (state.shape[-1]) % 3))
print(state)
for k in range(3):
#    if k == 0:        
#    else:
#        pad = (0, (state.shape[-1]+k-1) % 3)
    s = state[k:l+k].view(-1, 3)
    print(s)
    s = s * torch.Tensor([2 ** n for n in range(3)]).view(1, -1)#.expand(s.shape[-2], -1)
    print(s.sum(dim=-1))

In [None]:
def ca_expand2(
        input: torch.Tensor,
        rule: int = 90,
        iterations: int = 100,
        num_states: int = 2,
        num_n: int = 1,
        dtype: torch.dtype = torch.uint8,
):
    num_n2 = 1 + num_n * 2
    lookup = []
    for k in range(num_states ** num_n2):
        lookup.append((rule // (num_states ** k) % num_states))
    lookup = torch.Tensor(lookup).to(dtype)

    state = input.unsqueeze(0)
    history = [state]
    for it in range(iterations):
        #print(state.shape)
        index = torch.empty(state.shape[-1] + num_n*2, dtype=torch.int64)
        state2 = F.pad(state, (num_n, (state.shape[-1] + num_n + 1) % num_n2))
            
        for k in range(num_n2):
            #state2 = F.pad(state, (num_n, num_n + (1 + k) % num_n2))
            # print(state.shape, state2.shape)
            index2 = torch.concat([
                state2[:, n+k::num_n2] * (num_states ** n)
                for n in range(num_n2)
            ]).sum(dim=0)
            # print(index2)
            index[k::num_n2] = index2
        # print(lookup.shape, index.shape)
        state = torch.index_select(lookup, 0, index)
        state = state[:-2].unsqueeze(0)
        #print(state)
        history.append(state)
    return torch.concat(history)
        
input = torch.zeros(33*5, dtype=torch.uint8)
#input = torch.randint(0, 2, (33,)) 
input[input.shape[-1] // 2] = 1
input[-10] = 1
state = ca_expand2(input, iterations=400, rule=30)
px.imshow(state, height=state.shape[0] * 3 + 50)


In [None]:
246 / 3

In [None]:
def plot_rules2(num_states: int = 2, num_n: int = 1):
    kernel_size = 1 + num_n * 2
    states = []
    labels = []
    for i in range(num_states ** num_states ** kernel_size):
        input = torch.zeros(252, dtype=torch.uint8)
        input[input.shape[-1] // 2] = 1
        state = ca_expand2(input, rule=i, num_states=num_states, iterations=128)
        states.append(state.unsqueeze(0) * 255 // (num_states - 1))
        labels.append(str(i))
        
    display(VF.to_pil_image(
        make_grid_labeled(states, nrow=4, labels=labels)
    ))

plot_rules2(num_states=2, num_n=1)

# use as "reservoir"

In [None]:
SIZE = 9
data = torch.zeros((2**SIZE, SIZE), dtype=torch.uint8)
for i in range(data.shape[0]):
    for k in range(data.shape[1]):
        data[i, k] = (i >> k) & 1

test_indices = torch.randperm(data.shape[0])[:100]
train_indices = torch.Tensor(sorted(set(range(data.shape[0])) - set(test_indices.tolist()))).to(torch.int64)
data_train = data[train_indices]
data_test = data[test_indices]
data_train.shape, data_test.shape

In [None]:
states = []
for i, row in tqdm(enumerate(data)):
    input = torch.zeros(33 * 5, dtype=torch.uint8)
    x = (input.shape[0] - row.shape[0]) // 2
    input[x: x + row.shape[0]] = row
    state = ca_expand2(input, rule=30, iterations=100)
    states.append(state[90:].flatten(0).unsqueeze(0))
states = torch.concat(states)
states_train, states_test = states[train_indices], states[test_indices]
states.shape

In [None]:
from sklearn.linear_model import Ridge
ridge = Ridge()
ridge.fit(states_train, data_train)
for kind, states_, data_ in (
        ("train", states_train, data_train), 
        ("test", states_test, data_test),
):
    output = (ridge.predict(states_) >= .5).astype(np.float32)
    num_correct = np.sum(np.sum(output == np.array(data_), axis=-1) == data_.shape[-1])
    print(f"correct {kind:5}: {num_correct} / {data_.shape[0]}")