In [1]:
import math

import torch
from scipy.special import lambertw
from tqdm import tqdm

In [2]:
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle, Circle
import gif

In [3]:
gif.options.matplotlib["dpi"] = 150
plt.style.use('scientific')

In [4]:
from rsnn.firing_sequences.sampling import backward_filtering_forward_sampling
from rsnn.firing_sequences.utils import is_predictable
from rsnn.ikie.ikie import compute_box_prior, compute_posterior_means, compute_m_ary_prior
from rsnn.ikie.utils import compute_observation_matrix, get_mask_refractory_period, get_mask_around_firing, get_mask_at_firing

In [5]:
@gif.frame
def add_frame(itr, w, y, theta, eta, dymin, wmin, wmax, mask_at_firing, mask_around_firing, mask_refractory_period):
    N = y.size(0)
    fig, ax = plt.subplots(3, figsize=(20, 20))
    
    ax[0].set_title(f"action potential (level) at iteration {itr}")
    ax[0].set_ylim(-2, 3)
    
    for n in range(N):
        if mask_around_firing[n]:
            ax[0].add_patch(Rectangle((n,-2), 1, 5, facecolor="orange", alpha=0.3))
            if mask_at_firing[n]:
                ax[0].axvline(n, color="orange")
        elif not mask_refractory_period[n]:
            ax[0].add_patch(Rectangle((n,-2), 1, 5, facecolor="blue", alpha=0.3))
        
    ax[0].axhline(theta, linestyle="dashed", color="black")
    ax[0].axhline(eta, linestyle="dashed", color="black")
    ax[0].plot(y[:,0], color="red")
    
    ax[1].set_title(f"action potential (slope) at iteration {itr}")
    ax[1].set_ylim(-1, 1)
    
    for n in range(N):
        if mask_around_firing[n]:
            ax[1].add_patch(Rectangle((n,-1), 1, 2, facecolor="orange", alpha=0.3))
        elif not mask_refractory_period[n]:
            ax[1].add_patch(Rectangle((n,-1), 1, 2, facecolor="blue", alpha=0.3))
    
    ax[1].axhline(dymin, linestyle="dashed", color="black")
    ax[1].plot(y[:,1], color="red")
    
    sorted_w, _ = w.sort()
    ax[2].set_title(f"weights at iteration {itr}")
    ax[2].set_ylim(wmin-0.5, wmax+0.5)
    ax[2].axhline(wmin, linestyle="dashed", color="black")
    ax[2].axhline(wmax, linestyle="dashed", color="black")
    ax[2].plot(sorted_w, color="red")

In [6]:
L, K, N, Tr = 100, 100, 100, 20

In [7]:
taus = torch.randint(1, 2*Tr, (L,K))
origins = torch.randint(0, L, (L,K))

In [8]:
beta = -Tr / lambertw(-1e-2 / math.exp(1), -1).real
impulse_response = lambda t_: (t_>0) * t_ / beta * torch.exp(1 - t_ / beta)
theta, eta = 1, 0
eps, dymin = 5, 1e-1
wmin, wmax = -1, 1

In [9]:
firing_sequences = backward_filtering_forward_sampling(1, L, N, Tr).view(L, N)

In [10]:
# predictability
is_predictable(firing_sequences[None,...], Tr)

[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.


True

In [11]:
# set regions mask
mask_at_firing = get_mask_at_firing(firing_sequences)
mask_refractory_period = get_mask_refractory_period(firing_sequences, Tr, eps)
mask_around_firing = get_mask_around_firing(firing_sequences, eps)
mask_before_firing = ~(mask_refractory_period | mask_around_firing)

In [12]:
# observation matrix
C = compute_observation_matrix(firing_sequences.view(L,N), taus, origins, Tr, impulse_response)

# 1. Box constraint

In [None]:
# init with random weights between wmin and wmax
mw = torch.FloatTensor(L,K).uniform_(wmin,wmax)
my = (C @ mw.view(L, 1, K, 1)).view(L, N, 2)

# add frames to gif
l = 42
frames = []
frame = add_frame(0, mw[l], my[l], theta, eta, dymin, wmin, wmax, mask_at_firing[l], mask_around_firing[l], mask_refractory_period[l])
frames.append(frame)

for itr in tqdm(range(1, 51)):
    # priors
    ## weights
    mw_f, Vw_f = compute_box_prior(mw, wmin, wmax)
    
    ## action potentials
    my_b = torch.zeros(L,N, 2)
    Vy_b = 1e9 * torch.ones(L,N, 2)  
    my_b[...,0][mask_before_firing], Vy_b[...,0][mask_before_firing] = compute_box_prior(my[...,0][mask_before_firing], None, eta, 50)
    my_b[...,1][mask_around_firing], Vy_b[...,1][mask_around_firing] = compute_box_prior(my[...,1][mask_around_firing], dymin, None, 100)
    my_b[...,0][mask_at_firing], Vy_b[...,0][mask_at_firing] = compute_box_prior(my[...,0][mask_at_firing], theta, theta, 50)
        
    # posteriors
    mw, _, my = compute_posterior_means(mw_f, Vw_f, my_b, Vy_b, C)
    
    # add frames to gif
    frame = add_frame(itr, mw[l], my[l], theta, eta, dymin, wmin, wmax, mask_at_firing[l], mask_around_firing[l], mask_refractory_period[l])
    frames.append(frame)

print(f"error at firing time is {(my[...,0][mask_at_firing] - theta).abs().sum().item()/L}")
print(f"error around firing time is {torch.maximum(dymin - my[...,1][mask_around_firing], torch.tensor(0)).sum().item()/L}")
print(f"error before firing time is {torch.maximum(my[...,0][mask_before_firing] - eta, torch.tensor(0)).sum().item()/L}")

In [None]:
if frames:
    gif.save(frames, "box.gif", duration=150)

# 2. M-level constraint

In [None]:
# random init of the prior messages
M = 7
mws = torch.FloatTensor(L,M-1,K).normal_((wmin + wmax)/(M-1), (wmax - wmin)/(M-1) * 1e-2)
mw = mws.sum(dim=1)
my = (C @ mw.view(L, 1, K, 1)).view(L, N, 2)

# add frames to gif
l = 42
frames = []
frame = add_frame(0, mw[l], my[l], theta, eta, dymin, wmin, wmax, mask_at_firing[l], mask_around_firing[l], mask_refractory_period[l])
frames.append(frame)

for itr in tqdm(range(1, 101)):
    # priors
    ## weights
    mws_f, Vws_f = compute_m_ary_prior(mws, None, wmin, wmax, M, "am")
    mw_f, Vw_f = mws_f.sum(dim=1), Vws_f.sum(dim=1)

    ## action potentials
    my_b = torch.zeros(L,N, 2)
    Vy_b = 1e9 * torch.ones(L,N, 2)  
    my_b[...,0][mask_before_firing], Vy_b[...,0][mask_before_firing] = compute_box_prior(my[...,0][mask_before_firing], None, eta, 100)
    my_b[...,1][mask_around_firing], Vy_b[...,1][mask_around_firing] = compute_box_prior(my[...,1][mask_around_firing], dymin, None, 200)
    my_b[...,0][mask_at_firing], Vy_b[...,0][mask_at_firing] = compute_box_prior(my[...,0][mask_at_firing], theta, theta, 200)

    # posteriors
    mw, Vw, my = compute_posterior_means(mw_f, Vw_f, my_b, Vy_b, C)
    xiw = Vw_f.pow(-1) * (mw_f - mw) # (IV.9) in Loeliger2016
    mws = mws_f - Vws_f * xiw.unsqueeze(1) # (IV.9) in Loeliger2016
    
    frame = add_frame(itr, mw[l], my[l], theta, eta, dymin, wmin, wmax, mask_at_firing[l], mask_around_firing[l], mask_refractory_period[l])
    frames.append(frame)
    
print(f"error at firing time is {(my[...,0][mask_at_firing] - theta).pow(2).sum().item()/L}")
print(f"error around firing time is {torch.maximum(dymin - my[...,1][mask_around_firing], torch.tensor(0)).sum().item()/L}")
print(f"error before firing time is {torch.maximum(my[...,0][mask_before_firing] - eta, torch.tensor(0)).sum().item()/L}")

In [None]:
if frames:
    gif.save(frames, "7_ary_am.gif", duration=150)

In [None]:
# random init of the prior messages
M = 7
mws = torch.FloatTensor(L,M-1,K).normal_((wmin + wmax)/(M-1), (wmax - wmin)/(M-1) * 1e-2)
Vws = torch.FloatTensor(L,M-1,K).normal_((wmax - wmin)/(M-1) * 1e-1, (wmax - wmin)/(M-1) * 1e-3)
mw = mws.sum(dim=1)
my = (C @ mw.view(L, 1, K, 1)).view(L, N, 2)

# add frames to gif
l = 42
frames = []
frame = add_frame(0, mw[l], my[l], theta, eta, dymin, wmin, wmax, mask_at_firing[l], mask_around_firing[l], mask_refractory_period[l])
frames.append(frame)

for itr in tqdm(range(1, 301)):
    # priors
    ## weights
    mws_f, Vws_f = compute_m_ary_prior(mws, Vws, wmin, wmax, M, "em")
    mw_f, Vw_f = mws_f.sum(dim=1), Vws_f.sum(dim=1)
        
    ## action potentials
    my_b = torch.zeros(L,N, 2)
    Vy_b = 1e9 * torch.ones(L,N, 2)  
    my_b[...,0][mask_before_firing], Vy_b[...,0][mask_before_firing] = compute_box_prior(my[...,0][mask_before_firing], None, eta, 100)
    my_b[...,1][mask_around_firing], Vy_b[...,1][mask_around_firing] = compute_box_prior(my[...,1][mask_around_firing], dymin, None, 200)
    my_b[...,0][mask_at_firing], Vy_b[...,0][mask_at_firing] = compute_box_prior(my[...,0][mask_at_firing], theta, theta, 200)
        
    # posteriors
    mw, Vw, my = compute_posterior_means(mw_f, Vw_f, my_b, Vy_b, C)

    xiw = Vw_f.pow(-1) * (mw_f - mw) # (IV.9) in Loeliger2016
    Ww = Vw_f.pow(-2) * (Vw_f - Vw) # (IV.13) in Loeliger2016
    
    mws = mws_f - Vws_f * xiw.unsqueeze(1) # (IV.9) in Loeliger2016
    Vws = Vws_f - Vws_f.pow(2) * Ww.unsqueeze(1) # (IV.13) in Loeliger2016
    
    # add frames to gif
    frame = add_frame(itr, mw[l], my[l], theta, eta, dymin, wmin, wmax, mask_at_firing[l], mask_around_firing[l], mask_refractory_period[l])
    frames.append(frame)
    
print(f"error at firing time is {(my[...,0][mask_at_firing] - theta).abs().sum().item()/L}")
print(f"error around firing time is {torch.maximum(dymin - my[...,1][mask_around_firing], torch.tensor(0)).sum().item()/L}")
print(f"error before firing time is {torch.maximum(my[...,0][mask_before_firing] - eta, torch.tensor(0)).sum().item()/L}")

In [None]:
if frames:
    gif.save(frames, "7_ary_em.gif", duration=150)