In [None]:
import numpy as np
import matplotlib as mplt
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.optimize import curve_fit
from scipy.integrate import cumulative_trapezoid
from scipy.interpolate import RegularGridInterpolator
import pandas as pd
import random
import math
import copy
from typing import Annotated, Any, Callable
from pydantic import BaseModel, Field, WithJsonSchema
import pydantic
import ipywidgets as widgets
import random



: 

In [None]:
%load_ext autoreload
%autoreload 2

import sys
for p in ['../src']:
    if p not in sys.path:
        sys.path.append(p)
        
import spikeml as sml
from spikeml.utils.nb_util import xdisplay, Markup
from spikeml.core.signal import signal_dc, signal_pulse, encode1_onehot, encode_onehot, signal_ranges, mean_per_input
from spikeml.core.ngram import build_ngram, ngram_find, ngram_msample, print_ngrams

from spikeml.plot.plot_util import plot_hist, plot_data, plot_lidata, plot_input, plot_xt, plot_mt, plot_spikes, imshow_matrix, imshow_nmatrix
from spikeml.core.params import Params, NNParams, ConnectorParams, SpikeParams, SSensorParams, SNNParams, SSNNParams

from spikeml.core.params import Params, NNParams, ConnectorParams, SpikeParams, SSensorParams, SNNParams, SSNNParams
from spikeml.ui.ipywidgets_ui import ui

from spikeml.core.matrix import matrix_split, normalize_matrix, _mult, cmask, cmask2, matrix_init, matrix_init2
from spikeml.core.monitor import Monitor

from spikeml.core.spikes import pspike, spike
from spikeml.core.base import Component, Module, Fan, Composite, Chain
from spikeml.core.snn import Layer, SimpleLayer, SNN, SSNN, SSensor, DSSNN, Connector, LinearConnector, RateConnector, LIConnector, LIConnector2
from spikeml.core.snn import make_snn_chain, make_ssnn_chain, chain_validate
from spikeml.core.snn import ssnn_apply_update
from spikeml.core.feedback import compute_error, xcompute_error, compute_sg
from spikeml.core.run import run, nrun, Context

from spikeml.core.optimize import make_params_spec, vec2dic, setattrs, params_search, set_all_attrs

from spikeml.core.snn_monitor import SSensorMonitor, SNNMonitor, SSNNMonitor, ConnectorMonitor, LIConnectorMonitor
from spikeml.core.snn_viewer import  SSensorMonitorViewer, SNNMonitorViewer, SSNNMonitorViewer, ConnectorMonitorViewer, LIConnectorMonitorViewer, ErrorMonitorViewer


In [None]:
SEED=37
random.seed(SEED)
np.random.seed(SEED)

np.set_printoptions(edgeitems=3, infstr='inf', linewidth=120, nanstr='nan', precision=4, suppress=False, threshold=1000, formatter=None)



# Display

In [None]:
def test_xdisplay():
    A = np.array([[1/3, 2, 3], [3, 4,3],[1/3, 2, 3], [3, 4,3]])
    B = np.array([[5, 6, 5], [7, 8, 5],[5, 6, 5], [7, 8, 5]])

    xdisplay(Markup('t=0', np.zeros(5), np.ones(5)), Markup('A', A), Markup('B', B),  Markup('AB', (A,B)), np.zeros(5), [1,2,3], (1,2,3), "abc", 3, 1/3)

test_xdisplay()

# NGrams

In [None]:
from spikeml.core.ngram import build_ngram, ngram_find, ngram_msample, print_ngrams

def test_ngram():
    NSYM=2
    NGRAM=3
    ngrams = build_ngram(nsym=NSYM, n=NGRAM, sd=1)
    #print(ngrams)
    print_ngrams(ngrams, flat=True)
    print_ngrams(ngrams, flat=False)

    a = ngram_find(np.array([2,0]), ngrams)
    print(a)

    SLEN=10
    NS=6
    ss = ngram_msample(ngrams, nsym=NSYM, n=SLEN, m=NS) 
    print(ss)


test_ngram()

# Signal

In [None]:
from spikeml.core.signal import signal_dc, signal_pulse, encode1_onehot, encode_onehot, signal_ranges, mean_per_input

def test_signal():
    ss = signal_dc(2, T=10, s=0, value=2)
    print(ss.tolist())
    ss = signal_pulse(2, T=5, L=2, s=[0,-1], value=1)
    print(ss.tolist())
    ss = signal_pulse(3, T=5, L=2, s=[0,1,2,-1], value=1)
    print(ss.tolist())

def test_signal2():
    s0 = np.array([1, 1, 0])
    ss = signal_pulse(3, T=3, L=1, s=[s0,-1,2,-1], value=1)
    print(ss.tolist())

def test_encode_onehot():
    encode1_onehot(0, 5)
    ss = signal_pulse(3, T=5, L=2, s=[0,1,2,-1], value=1)
    ess = encode_onehot(ss, NSYM)
    print(ess[0])
    
test_signal()
test_signal2()
#test_encode_onehot()


In [None]:

def test_signal_unique():
    ss = signal_pulse(2, T=5, L=2, s=[0,1,-1], value=1)
    print(ss.tolist())
    plot_xt(ss)
    print(signal_changes(ss))
    u = signal_unique(ss)
    print(u)
    ranges = signal_ranges(ss, ref=u, E=0)
    #print(ranges)
    for i,s in enumerate(u):
        print(i, ':', u[i], ranges[i])
        
def test_mean_per_input():
    ref, size, means = mean_per_input(monitor.err, nn.monitor.sx)
    for i in range(0, ref.shape[0]):
        print(f'{i}: {ref[i]} (#{size[i]}); Err: {means[i]:.4f}')

test_signal_unique()
#test_mean_per_input()


# Plots

In [None]:
def test_plot_data():
    data0 = [t**1.5 for t in range(0,10)]
    data1 = [t**2 for t in range(0,10)]
    data = {'data0': data0, 'data1': data1 }
    ss = signal_pulse(2, T=3, L=1, s=[0,1,-1], value=1)
    plot_data(data, title='yy', callback=lambda ax: plot_input(ss,ax=ax))
    data = [np.array([1*t,2*t])**1.5 for t in range(0,10)]
    plot_data(data, title='yy2', label='yy')#['y1','y2']
    plot_data(np.array(data), title='yy2.2', label='yy')#['y1','y2']

def test_plot_lidata():
    t = np.linspace(0, 100, num=100)
    x = np.abs(np.sin(t)) + np.random.normal(loc=0, scale=.2, size=t.shape[0])
    k_x = .8
    plot_lidata(x, k_x)
            
def test_imshow_matrix():
    m = np.random.normal(loc=0, scale=1, size=(5,5))
    imshow_matrix(m, title='M')

def test_imshow_nmatrix():
    data = [ np.random.normal(loc=0, scale=1, size=(5,5)) for i in range(0,16) ]
    imshow_nmatrix(data, title='M', tstep=3)
    imshow_nmatrix(data, title='M', tk=10)
    imshow_nmatrix(data, title='M', tk=10, ncols=10)

test_plot_data()
test_plot_lidata()
test_imshow_matrix()
test_imshow_nmatrix()

# Params

# UI

In [None]:

from spikeml.ui.params import Params, NNParams, ConnectorParams, SpikeParams, SSensorParams, SNNParams, SSNNParams

def test_ui():
    params = SSNNParams()
    print(params.fmt())
    fields = params.__class__.model_fields
    for key, field in fields.items():
        print(key, field, field.annotation, field.annotation==float, field.default, field.metadata)
        
    ui(params, callback_start = lambda params: print('START', params), callback_pause = lambda params: print('PAUSE', params), callback_stop = lambda: print('STOP', params), callback_change=lambda name,value,w: print('CHANGE:', name, value))
    #ui(params, callback_start = None, callback_pause = None, callback_stop = None, callback_change= None)


# Matrix utils

In [None]:
from spikeml.core.matrix import matrix_init, normalize_matrix, cmask, cmask2

def test_matrix_init():
    params = ConnectorParams(size=3)
    M = matrix_init(params)
    print('M:', M)

def test_normalize_matrix():
    M = np.array([[2.5 , 1], [1 , .5 ]])
    print(M)
    print(normalize_matrix(M, c_in=2, c_out=2, strict=False))
    print('-'*10)
    M = np.array([[2.5 , -1], [1 , .5 ]])
    print(M)
    print(normalize_matrix(M, c_in=2, c_out=2, strict=False))
    print('-'*10)
    print('-'*10)
    M = np.array([[2.5 , -2.5], [1 , -2 ]])
    print(M)
    print(normalize_matrix(M, c_in=2, c_out=2, strict=False))
    print('-'*10)
    M = np.array([[1 , 0], [.2 , 0 ]])
    print(M)
    print(normalize_matrix(M, c_in=0, c_out=1, strict=False))
    print('-'*10)

def test_cmask():
    M = np.array([[.7,0], [.3, 0]])
    M = np.array([[.8,.2], [.3, 0]])
    #print('M:', M)
    c_in,c_out=1,1
    d,d_in,d_out=cmask(M, c_in, c_out)
    xdisplay(M, d_out, d_in, d)
    M = np.array([[.8,.2], [.3, 0]])
    d,d_in,d_out=cmask2(M, c_in, c_out)
    xdisplay(M, d_out, d_in, d)

test_matrix_init()
#test_normalize_matrix()
#test_cmask()
#test_matrix_init2()


# Spikes

In [None]:
from spikeml.core.params import SSensorParams

from spikeml.core.spikes import pspike, spike

def test_spike():
    params = SSensorParams()
    print(params)
    s = np.linspace(params.vmin,params.vmax,num=5)
    ss = s[..., np.newaxis]
    data = []
    for t in range(0,100):
        sz = spike(s, params)
        print(t, s, sz)
        data.append(sz)
        
    plot_spikes(data, title='z', name=None, callback=lambda ax: plot_input(ss,ax=ax))

    

test_spike()

# Feedback

In [None]:
from spikeml.core.feedback import compute_error, compute_sg


def test_compute_error():
    params = SSNNParams()
    print(params.fmt())
    params = SSNNParams(g=1, e_err=5, pmax=1, e_z=2)
    print(params.fmt())
    data = {'sx': [], 'y': [], 'err': [], 'sg': [], 's': [], 'sm': [], 'ps': [], 'zs': []}
    def _err(s, y):
        err = compute_error(s, y)
        sg = compute_sg(err, params)
        s =  np.clip(sx + params.g*y, params.vmin, params.vmax).round(2)
        sm = np.clip(s*sg, params.vmin, params.vmax).round(2)
        ps = pspike(sm, params).round(2)
        zs = spike(sm, params)
        data['sx'].append(str(sx))
        data['y'].append(str(y))
        data['err'].append(err)
        data['sg'].append(sg)
        data['s'].append(str(s))
        data['sm'].append(str(sm))
        data['ps'].append(str(ps))
        data['zs'].append(str(zs))
        print(f'sx: {sx} y: {y}', f'-> err: {err:.2f}', f'; sg: {sg:.2f}', f'=> s: {s}', f'=> sm: {sm}', f'=> ps: {ps} ; zs: {zs}')
    for A in [.1, .3, .5]:
        sx = np.array([A,0.0])
        _err(sx, np.array([0.0,0.0]))
        _err(sx, sx)
        _err(sx, np.array([1.0,0]))
        _err(sx, np.array([0,1.0]))
        _err(sx, np.array([.5,0]))
        _err(sx, np.array([0,.5]))
        _err(sx, np.array([.5,.5]))

    df = pd.DataFrame(data)
    return df

def test_compute_error2():
    params = SSNNParams()
    data = {'f': [], 'R': [], 's': [], 'y': [], 'err': []}
    def _add(f, s, y, R, err, debug=True):
        data['f'].append(f)
        data['s'].append(s)
        data['y'].append(y)
        data['R'].append(R)
        data['err'].append(err)
        if debug:
            print(f'{f}: R: {R} ; s: {s} ; y: {y}', f'-> err: {err:.2f}')
        
    def _err(s, y, R):
        s_ = np.repeat(s, R)
        y1 = y.reshape(y.shape[0] // R, R).mean(axis=1)
        y2 = y.reshape(y.shape[0] // R, R).sum(axis=1)
        y2 =  np.clip(y2, params.vmin, params.vmax)
        y3 = y.reshape(y.shape[0] // R, R).max(axis=1)
        
        err = xcompute_error(s, y, R=R, method='dp')
        err1 = xcompute_error(s, y, R=R, method='mean')
        err2 = xcompute_error(s, y, R=R, method='sum+clip')
        err3 = xcompute_error(s, y, R=R, method='max')
        print('-'*4)
        _add('dp', s_, y, R, err)
        _add('mean', s, y1, R, err1)
        _add('sum+clip', s, y2, R, err2)
        _add('max', s, y3, R, err3)

               
    _err(np.array([1,0]), np.array([1,0]), R=1)
    _err(np.array([1,0]), np.array([1,1,0,0]), R=2)
    _err(np.array([1,0]), np.array([1,0,0,0]), R=2)
    _err(np.array([1,0]), np.array([0,1,0,0]), R=2)
    _err(np.array([1,0]), np.array([0,0,0,0]), R=2)
    _err(np.array([1,0]), np.array([0,0,1,1]), R=2)
    _err(np.array([1,0]), np.array([0,0,1,0]), R=2)
    _err(np.array([1,0]), np.array([0,0,0,1]), R=2)
    _err(np.array([1,0]), np.array([1,1,1,1]), R=2)
    
    df = pd.DataFrame(data)
    return df


#df = test_compute_error()
#display(df)

df=test_compute_error2()
display(df)


# Bias

In [None]:
def test_bias_update():
    params = SSNNParams(t_b=10)
    print(params)
    def _bias_update(b, y, n=10):
        for i in range(0,n):
            b = bias_update(b, y, params, debug=True)
    
    _bias_update(np.array([1]), np.array([0]))
    print('-'*10)
    _bias_update(np.array([0]), np.array([1]))

test_bias_update()