In [None]:
import sys
sys.path.append("..")

import datetime
import random
import math
import time
import json
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display, HTML, Audio
import plotly
plotly.io.templates.default = "plotly_dark"
import plotly.express as px
import pandas as pd

from src.datasets import *
from src.util.image import *
from src.util import *
from src.util.files import *
from src.util.embedding import *
from src.algo import *
from src.models.encoder import *
from src.models.decoder import *
from src.models.util import *
from src.models.reservoir import *

# reproducing https://arxiv.org/pdf/2309.06815.pdf

## make small MNIST dataset

In [None]:
@torch.no_grad()
def get_train_test_data(
        num_train_samples_per_class: int = 660,
        num_test_samples_per_class: int = 330,
        seed: int = 23,
):
    from torchvision.datasets import MNIST
    org_mnist = MNIST("~/prog/data/datasets")

    train_images = []
    train_targets = []
    test_images = []
    test_targets = []
    
    train_counts = {}
    test_counts = {}

    for image, label in zip(org_mnist.data, org_mnist.targets):
        label = int(label)
        
        target = [0] * 10
        target[label] = 1
        target = torch.Tensor(target).unsqueeze(0)
        
        #image = VF.hflip(image.permute(1, 0))
        image = VF.crop(image, 3, 6, 22, 16)
        image = image.unsqueeze(0).float() / 255.
        
        if train_counts.get(label, 0) < num_train_samples_per_class:
            train_counts[label] = train_counts.get(label, 0) + 1
            train_images.append(image)
            train_targets.append(target)
            
        elif test_counts.get(label, 0) < num_test_samples_per_class:
            test_counts[label] = test_counts.get(label, 0) + 1
            test_images.append(image)
            test_targets.append(target)
    
    gen = torch.Generator().manual_seed(seed)
    train_permute = torch.randperm(len(train_images), generator=gen)
    test_permute = torch.randperm(len(test_images), generator=gen)
    
    return (
        torch.concat(train_images)[train_permute],
        torch.concat(train_targets)[train_permute],
        torch.concat(test_images)[test_permute],
        torch.concat(test_targets)[test_permute],
    )

train_images, train_targets, test_images, test_targets = get_train_test_data()
display(VF.to_pil_image(make_grid(train_images[:20, None, :, :], nrow=20)))
display(VF.to_pil_image(make_grid(test_images[:20, None, :, :], nrow=20)))
train_images.shape, train_targets.shape, test_images.shape, test_targets.shape

In [None]:
def expand_time(series, steps: int = 3):
    return torch.repeat_interleave(series, steps, dim=1)

display(VF.to_pil_image(make_grid(expand_time(train_images)[:20, None, :, :], nrow=20)))

## baseline with linear regression only

In [None]:
import matplotlib.pyplot

def dump_error(label: str, targets, prediction, ret: bool = False):
    #error_l1 = (targets - prediction).abs().mean()
    error_l1 = F.l1_loss(prediction, targets)
    error_l2 = F.mse_loss(prediction, targets)
    #error_l2 = (((targets - prediction) ** 2).sum(dim=-1)).sqrt().mean() / targets.shape[-1]
    accuracy = (targets.argmax(dim=-1) == prediction.argmax(dim=-1)).float().mean()
    if ret:
        return float(error_l1), float(error_l2), float(accuracy)
    print(f"{label} error l1={error_l1:.3f}, l2={error_l2:.3f}, accuracy={accuracy:.3f}")
    # print( == prediction.argmax(dim=-1))
    x = prediction.argmax(dim=-1)
    y = targets.argmax(dim=-1)
    fig = matplotlib.pyplot.figure(figsize=(2, 2))
    hist = np.histogram2d(x.numpy(), y.numpy())[0]
    matplotlib.pyplot.imshow(np.power(hist / hist.max(), .5))
    #display(px.imshow(hist, title="confusion matrix"))

from sklearn.linear_model import Ridge

ridge = Ridge()
ridge.fit(train_images.flatten(1).numpy(), train_targets.numpy())

dump_error("train", train_targets, torch.Tensor(ridge.predict(train_images.flatten(1).numpy())))
dump_error("test ", test_targets, torch.Tensor(ridge.predict(test_images.flatten(1).numpy())))

## reservoir / echo state network

In [None]:
def _act(x):
    return torch.sin(x * 5.)

esn = ReservoirReadout(
    Reservoir(
        16, 100, activation="sigmoid", rec_prob=.1, rec_std=1.5, leak_rate=.1, input_prob=0.5, input_std=1.5,
        #16, 1000, activation="sigmoid", rec_prob=0.1, rec_std=1.0, leak_rate=.5, input_prob=0.5, input_std=1.5,
    ),
)

def _trans_images(images):
    #images = expand_time(images, 3)
    #images = F.pad(images, (0, 0, 0, 20))
    return images

def _trans_state(state, reshape: bool = True):
    # state = state[:, state.shape[1] // 2:, :]
    # state = state[:, :, -100:]
    state = torch.sin(state * 6.)
    if reshape:
        state = state.view(state.shape[0], -1)
    return state

train_state = esn.run_reservoir(_trans_images(train_images))
print(f"trans_state.shape = {train_state.shape}")

ridge = Ridge()
ridge.fit(_trans_state(train_state).numpy(), train_targets.numpy())

test_state = esn.run_reservoir(_trans_images(test_images))

dump_error("train", train_targets, torch.Tensor(ridge.predict(_trans_state(train_state).numpy())))
dump_error("test ", test_targets, torch.Tensor(ridge.predict(_trans_state(test_state).numpy())))

display(px.imshow(_trans_state(train_state, reshape=False)[0].T, aspect=False, title="example reservoir state"))

# train neural linear layer

In [None]:
NUM_STEPS = 30_000

model = nn.Linear(math.prod(train_state.shape[-2:]), 10)
with torch.no_grad():
    model.weight[:] = torch.Tensor(ridge.coef_)
    model.bias[:] = torch.Tensor(ridge.intercept_)

# model = nn.Sequential(nn.Linear(math.prod(train_state.shape[-2:]), 300), nn.Linear(300, 10))

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
#optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, NUM_STEPS)

model.cuda()
train_state_trans = _trans_state(train_state).cuda()
train_targets_cuda = train_targets.cuda()
losses = []
try:
    for epoch in tqdm(range(NUM_STEPS)):
        prediction = model(
            train_state_trans 
            #+ 0.01 * torch.randn_like(train_state_trans)
        )
        loss = F.mse_loss(prediction, train_targets_cuda)

        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        losses.append(float(loss))
        if epoch % 500 == 0:
            print(min(losses[-500:]), max(losses[-500:]))
                  
except KeyboardInterrupt:
    pass

model.cpu()

display(px.line(losses, title="training loss"))

with torch.no_grad():
    train_prediction = model(_trans_state(train_state))
    test_prediction = model(_trans_state(test_state))
    
dump_error("train", train_targets, train_prediction)
dump_error("test ", test_targets, test_prediction)

In [None]:
display(px.line(pd.DataFrame({
    "ridge_bias": ridge.intercept_.flatten(),
    "nn_bias": model.bias.detach().flatten(0),
})))
px.line(pd.DataFrame({
    "ridge_weight": ridge.coef_.flatten(),
    "nn_weight": model.weight.detach().flatten(0),
    "diff": torch.Tensor(ridge.coef_.flatten()) - model.weight.detach().flatten(0),
}))

In [None]:
def run_experiment(
    activation,
    rec_prob,
    rec_std,
    leak_rate,
):
    esn = ReservoirReadout(
        Reservoir(16, 1000, activation=activation, rec_prob=rec_prob, rec_std=rec_std, leak_rate=leak_rate),
    )

    train_state = esn.run_reservoir(expand_time(train_images, 1))

    ridge = Ridge()
    ridge.fit(train_state.view(train_state.shape[0], -1).numpy(), train_targets.numpy())

    test_state = esn.run_reservoir(expand_time(test_images, 1))
    
    train_errors = dump_error("train", train_targets, torch.Tensor(ridge.predict(train_state.view(train_state.shape[0], -1).numpy())), ret=True)
    test_errors = dump_error("test ", test_targets, torch.Tensor(ridge.predict(test_state.view(test_state.shape[0], -1).numpy())), ret=True)
    return {
        "train_error_l1": train_errors[0],
        "train_error_l2": train_errors[1],
        "train_accuracy": train_errors[2],
        "test_error_l1": test_errors[0],
        "test_error_l2": test_errors[1],
        "test_accuracy": test_errors[2],
    }

matrix = {
    "activation": ["sigmoid", "tanh"],
    "rec_prob": [0.1, .5, 1.],
    "rec_std": [.5, 1., 1.5, 2.],
    "leak_rate": [0.1, .5, .9, (.1, .9)],
}
rows = []
try:
    for params in tqdm(list(iter_parameter_permutations(matrix))):
        for _ in range(5):
            results = run_experiment(**params)
            rows.append({
                "id": " ".join(f"{key}={value}" for key, value in params.items()),
                #**{f"param_{key}": val for key, val in params.items()},
                **results,
            })
except KeyboardInterrupt:
    pass

df = pd.DataFrame(rows)
df = df.groupby("id").mean(numeric_only=True).sort_values("test_accuracy", ascending=False)
df

In [None]:
print(df.to_markdown())

In [None]:
df2 = df.groupby("id").mean(numeric_only=True).sort_values("test_accuracy", ascending=False)
df2#df2["param_activation"] = df["param_activation"]#.#.reset_index().drop("id", axis=1).

In [None]:
for klass in range(3):
    idx = int(torch.argwhere(train_targets.argmax(dim=-1) == klass)[0][0])
    h = train_state.shape[-1]
    
    images = [
        train_images[idx].unsqueeze(0),
        train_state[idx].T.unsqueeze(0),
    ]
    # print([i.shape for i in images])
    display(VF.to_pil_image(torch.concat([
        VF.resize(img, (h, int(h / img.shape[-2] * img.shape[-1])), VF.InterpolationMode.NEAREST, antialias=False)
        for img in images
    ], dim=-1)))
#ridge.predict(train_state[:1].numpy())

In [None]:
activation_to_callable("gelu")