In [None]:
import sys
sys.path.append("..")

import datetime
import random
import math
import time
import json
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display, HTML, Audio
import plotly
plotly.io.templates.default = "plotly_dark"
import plotly.express as px
import pandas as pd

from src.datasets import *
from src.util.image import *
from src.util import *
from src.util.files import *
from src.util.embedding import *
from src.algo import *
from src.models.encoder import *
from src.models.decoder import *
from src.models.util import *
from src.util.text_encoder import TextEncoder

In [None]:
class Reservoir(nn.Module):
    def __init__(
        self,
        num_inputs: int,
        num_cells: int,
        leak_rate: Union[float, Tuple[float, float]] = .5,
        rec_std: float = 1.,
        rec_prob: float = .5,
        input_std: float = 1.,
        input_prob: float = .5,
        activation: Union[str, Callable] = "tanh",
    ):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_cells = self.num_outputs = num_cells
        if isinstance(leak_rate, (int, float)):
            self.leak_rate = leak_rate
        else:
            self.leak_rate = nn.Parameter(torch.rand(self.num_cells) * (leak_rate[1] - leak_rate[0]) + leak_rate[0])
            
        self.activation = activation_to_callable(activation)
        self.bias_recurrent = nn.Parameter(torch.randn(self.num_cells))
        self.weight_recurrent = nn.Parameter(
            torch.randn(self.num_cells, self.num_cells) * (torch.rand(self.num_cells, self.num_cells) < rec_prob) * rec_std
        )
        self.weight_input = nn.Parameter(
            torch.randn(self.num_inputs, self.num_cells) * (torch.rand(self.num_inputs, self.num_cells) < input_prob) * input_std
        )
        
    def forward(self, state: torch.Tensor, input: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert state.ndim == 2, state.shape
        assert state.shape[-1] == self.num_cells, state.shape        
        
        rec_state = (state + self.bias_recurrent) @ self.weight_recurrent
        rec_state = self.activation(rec_state)
        
        if input is not None:
            assert input.ndim == 2, input.shape
            assert input.shape[-1] == self.num_inputs, input.shape
            
            in_state = input @ self.weight_input
            rec_state = rec_state + self.activation(in_state)
            
        next_state = state * (1. - self.leak_rate) + rec_state * self.leak_rate
        return next_state
    
    def run(self, input: torch.Tensor, state: Optional[torch.Tensor] = None) -> torch.Tensor:
        if state is None:
            state = torch.zeros(input.shape[0], self.num_cells)
        states = []
        for i in range(input.shape[-2]):
            state = self(state, input[:, i])
            states.append(state)
        
        return torch.concat([s.unsqueeze(1) for s in states], dim=1)
    
res = Reservoir(2, 40, rec_std=.5, leak_rate=(.1, .2), activation="sin")
with torch.no_grad():
    state = res.run(torch.ones(3, 1000, 2))
    
print(state.shape)
px.imshow(state[0].T, aspect=False)

In [None]:
with torch.no_grad():
    series = torch.ones(1, 44100, res.num_inputs)
    state = res.run(series)

for i in range(min(10, state.shape[-1])):
    display(Audio(state[0, :, i], rate=44100))


In [None]:
with torch.no_grad():
    series = torch.zeros(1, 300, 2)
    series[:, 10:20] = -1
    series[:, 110:120] = -1
    state = res.run(series)
px.imshow(state[0].T, aspect=False)

# compare speed with reservoirpy

In [None]:
NUM_INPUTS = 2
NUM_CELLS = 1000
NUM_BATCHES = 30
NUM_TIMESTEPS = 1000

start_time = time.time()
res = Reservior(NUM_INPUTS, NUM_CELLS, rec_std=.5, leak_rate=(.1, .5), act="sin")
with torch.no_grad():
    state = res.run(torch.ones(NUM_BATCHES, NUM_TIMESTEPS, NUM_INPUTS))
print(f"{time.time() - start_time:,.3f}sec ")

In [None]:
import reservoirpy
from reservoirpy import nodes
reservoirpy.verbosity(0)

start_time = time.time()
ens = nodes.Input() >> nodes.Reservoir(NUM_CELLS)
ens.run(np.ones((NUM_BATCHES, NUM_TIMESTEPS, NUM_INPUTS)))
print(f"{time.time() - start_time:,.3f}sec ")

In [None]:
from sklearn.linear_model import Ridge

class ReservoirReadout:
    def __init__(
            self,
            reservoir: nn.Module,
            verbose: bool = True,
    ):
        assert hasattr(reservoir, "num_inputs"), "reservoir needs `num_inputs` attribute"
        assert hasattr(reservoir, "num_outputs"), "reservoir needs `num_outputs` attribute"
        self.reservoir = reservoir
        self.verbose = verbose
        self.ridge = None

    
    @torch.no_grad()
    def run_reservoir(
            self, 
            input: Optional[torch.Tensor] = None, 
            state: Optional[torch.Tensor] = None,
            steps: Optional[int] = None,
    ) -> torch.Tensor:
        if input is not None:
            assert input.ndim == 3, f"Expecting `input` of shape (B, T, N), got {input.shape}"
            assert input.shape[-1] == self.reservoir.num_inputs, \
                f"Expecting final dimension of `input` to match `reservior.num_inputs` {self.reservoir.num_inputs}, got {input.shape}"
            batch_size = input.shape[0]
            if steps is None:
                steps = input.shape[1]
        else:
            batch_size = 1
            
        if state is not None:
            assert state.ndim == 2, f"Expecting `state` of shape (B, C), got {state.shape}"
            assert state.shape[-1] == self.reservoir.num_cells, \
                f"Expecting last `state` dimension to match reservoir size {self.reservoir.num_cells}, got {state.shape}"
            assert state.shape[0] == batch_size, \
                f"Expecting first `state` dimension to match batch size {batch_size}, got {state.shape}"
        else:
            state = torch.zeros(batch_size, self.reservoir.num_cells)
        
        if steps is None:
            steps = 1
            
        states = []
        for i in tqdm(range(steps), desc="running reservoir", disable=not self.verbose):
            state = self.reservoir(state, input[:, i] if input is not None and i < input.shape[1] else None)
            states.append(state)
        
        return torch.concat([s.unsqueeze(1) for s in states], dim=1)
    
    @torch.no_grad()
    def fit(self, input: torch.Tensor, target: torch.Tensor, alpha: float = 1.) -> Tuple[float, float]:
        assert input.ndim == 3, f"Expecting input of shape (B, T, N), got {input.shape}"
        assert input.shape[-1] == self.reservoir.num_inputs, \
                f"Expecting final dimension of `input` to match `reservior.num_inputs` {self.reservoir.num_inputs}, got {input.shape}"
        assert target.ndim == 3, f"Expecting target of shape (B, T, N), got {target.shape}"
        assert input.shape[:2] == target.shape[:2], \
            f"Expecting first 2 dimensions of `target` to be equal to `input` {input.shape[:2]}, got {target.shape[:2]}"
        
        batch_size, time_steps = input.shape[:2]
        
        state = self.run_reservoir(input=input)
        
        state = state.reshape(batch_size * time_steps, -1)
        target = target.reshape(batch_size * time_steps, -1)
        
        if self.verbose:
            print("fitting output...")
        self.ridge = Ridge(alpha=alpha)
        self.ridge.fit(state.numpy(), target.numpy())
        
        prediction = torch.Tensor(self.ridge.predict(state))

        error_l1 = (target - prediction).abs().mean()
        error_l2 = torch.sqrt(((target - prediction) ** 2).sum())
        return float(error_l1), float(error_l2)

    @torch.no_grad()
    def predict(
            self, 
            input: Optional[torch.Tensor] = None, 
            state: Optional[torch.Tensor] = None,
            steps: Optional[int] = None,
    ) -> torch.Tensor:
        assert self.ridge is not None, "Must call `fit` before `predict`"
        
        state = self.run_reservoir(input=input, state=state, steps=steps)
  
        batch_size, time_steps = state.shape[:2]
        
        state = state.reshape(batch_size * time_steps, -1)
        
        prediction = torch.Tensor(self.ridge.predict(state))
        
        return prediction.view(batch_size, time_steps, -1)

    @torch.no_grad()
    def generate(
            self, 
            steps: int,
            input: Optional[torch.Tensor] = None, 
            state: Optional[torch.Tensor] = None,
            adjust_prediction: Optional[Callable] = None,
            lookahead: int = 1,
    ) -> torch.Tensor:
        state = self.run_reservoir(input=input, state=state)
  
        batch_size, time_steps = state.shape[:2]
        
        prediction = torch.Tensor(self.ridge.predict(state.reshape(batch_size * time_steps, -1)))
        prediction = prediction.view(batch_size, time_steps, -1)
        
        state_slice = state[:, -lookahead]
        # input_slice = prediction[:, -1, :]
        future_predictions = [prediction]
        input_slices = [prediction[:, -i, :] for i in range(lookahead-1, -1, -1)]
        if adjust_prediction is not None:
            input_slices = [adjust_prediction(s) for s in input_slices]
            
        for i in tqdm(range(steps), desc="generating", disable=not self.verbose):
            state_slice = self.reservoir(state_slice, input_slices.pop(0)) 
            predicted_slice = torch.Tensor(self.ridge.predict(state_slice))
            if adjust_prediction is not None:
                predicted_slice = adjust_prediction(predicted_slice)
            future_predictions.append(predicted_slice[:, None, :])
            input_slices.append(predicted_slice)
        
        return torch.concat(future_predictions, dim=-2)
    
esn = ReservoirReadout(Reservior(1, 1000))
t = torch.linspace(0, 8*torch.pi, 400)
curve = torch.sin(t) #+ torch.sin(t * 3) + torch.sin(t * 7)
lookahead = 10
input = curve[:-lookahead].view(1, -1, 1)
target = curve[lookahead:].view(1, -1, 1)
# display(px.imshow(esn.run_reservoir(input=inp)[0].T))
print(esn.fit(input, target))
prediction = esn.predict(input)
display(px.line(pd.DataFrame({
    "target": target[0, :, 0],
    "prediction": prediction[0, :, 0],
    "error": (target - prediction)[0, :, 0],
})))
prediction = esn.generate(1000, input, lookahead=lookahead)
display(px.line(prediction[0, :1000, 0], title="generate"))

In [None]:
prediction = esn.generate(50_000, input, lookahead=lookahead)
display(Audio(prediction[0, :, 0], rate=44100))
display(px.line(prediction[0, 20000:23000, 0], title="generate"))

# class prediction

In [None]:
def encode_sequence(sequence, num_classes: int):
    output = torch.zeros((len(sequence), num_classes))
    for i, s in enumerate(sequence):
        output[i, s] = 1.
    return output

CHAR_MAPPING = {}
for c in " ,.:/[]|&;0123456789":
    CHAR_MAPPING[c] = len(CHAR_MAPPING)
for c in range(ord('a'), ord('z') + 1):
    CHAR_MAPPING[chr(c)] = len(CHAR_MAPPING)
CODE_MAPPING = {
    v: k 
    for k, v in CHAR_MAPPING.items()
}
def encode_text(text: str) -> np.ndarray:
    classes = [
        CHAR_MAPPING[c]
        for c in text.lower()
        if c in CHAR_MAPPING
    ]
    return encode_sequence(classes, num_classes=len(CHAR_MAPPING))

def decode_text(code: torch.Tensor) -> str:
    text = []
    for classes in code:
        c = torch.argmax(classes)
        text.append(CODE_MAPPING.get(int(c), "?"))
    return "".join(text)
    
#encode_sequence([0, 1, 2, 3, 2, 1, 0], 4)
code = encode_text("abc defz")
#print(code)
decode_text(code)
#CHAR_MAPPING

In [None]:
import torchtext.datasets
ds = torchtext.datasets.EnWik9()
wiki_texts = []
for line, _ in zip(ds, range(1000)):
    if len(line) > 300:
        wiki_texts.append(line[:3000])
# display(wiki_texts[:10])
display([decode_text(encode_text(t))[:100] for t in wiki_texts[:10]])

In [None]:
@torch.no_grad()
def predict_text(*text, max_length: int = 700, num_cells: int = 1000, alpha: float = 1., act="sin", lookahead: int = 1):
    input_series = [
        encode_text(t)[None, :max_length, :]
        for t in text
    ]
    max_len = max(code.shape[1] for code in input_series)
    input_series = torch.concat([
        F.pad(s, (0, 0, 0, max_len - s.shape[1]))
        for s in input_series
    ])
    target_series = input_series[:, lookahead:,  :]
    input_series =  input_series[:, :-lookahead, :]
    
    esn = ReservoirReadout(
        Reservoir(
            input_series.shape[-1], num_cells, 
            leak_rate=(.1, .9),
            activation=act,
        )
    )
    
    error_l1, error_l2 = esn.fit(input_series, target_series, alpha=alpha)
        
    print(f"error l1 {error_l1:.3f}, l2 {error_l2:.3f}")
    
    #display(px.imshow(target_series[0].T, aspect=False, title="targets"))
    #display(px.imshow(prediction[0].T, aspect=False, title=f"prediction"))
    #display(px.imshow(-(target_series[0] - prediction[0]).T, aspect=False, title=f"error l1 {error_l1:.3f}, l2 {error_l2:.3f}"))
    
    def _adjust_prediction(p):
        p_mask = p.max(dim=-1, keepdim=True)[0] == p
        new_p = torch.zeros_like(p)
        new_p[p_mask] = 1
        return new_p
    
    input_len = input_series.shape[1] // 2
    generated = esn.generate(
        300, input_series[:, :input_len, :],
        adjust_prediction=_adjust_prediction,
        lookahead=lookahead,
    )
    
    for gen in generated:
        text = decode_text(gen)
        print(repr(text[:input_len]), "REPR:", repr(text[input_len:]))
    #print(generated.shape)
        
        
predict_text(
    *wiki_texts[:1],
    #"a simple text to learn to predict the next character. this is usually done with recurrent networks"
    #" which are kind of hard to train. in reservoir computing we only train the readout module"
    #" and let the reservoir rnn simply do its magic without intereferring",
    #"a second text that has nothing to do with the first except that it uses the same characters"
    #" and the same language.",
    alpha=1.,
    num_cells=1000,
    max_length=200,
    act="tanh",
    lookahead=1,
)