In [1]:
import math

import torch
from scipy.special import lambertw

import matplotlib.pyplot as plt

from tqdm import tqdm

In [2]:
from rsnn.firing_sequences.sampling import backward_filtering_forward_sampling
from rsnn.firing_sequences.utils import is_predictable
from rsnn.alternating_maximization.alternating_maximization import compute_prior_messages, compute_posterior_means
from rsnn.alternating_maximization.utils import compute_observation_matrix, get_mask_refractory_period, get_mask_around_firing, get_mask_at_firing


In [3]:
import torch
import torch.nn.functional as F

In [4]:
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle, Circle

import gif
gif.options.matplotlib["dpi"] = 300

In [5]:
@gif.frame
def add_action_potential(y, theta, firing_times, mask_around_firing, mask_refractory_period):
    N = y.size(0)
    fig, ax = plt.subplots(1, figsize=(N//10, 10))
    
    ax.set_xlim(0, N-1)
    ax.set_ylim(-5, 5)
    
    for n in range(N):
        if mask_around_firing[n]:
            ax.add_patch(Rectangle((n,-5), 1, 10, facecolor="orange", alpha=0.3))
        elif not mask_refractory_period[n]:
            ax.add_patch(Rectangle((n,-5), 1, 10, facecolor="blue", alpha=0.3))
            
    ax.vlines(firing_times.squeeze(), -5, 5, colors="red")
    ax.hlines(theta, 0, N-1, linestyle="dashed", colors="black")
    ax.plot(y)
    
@gif.frame
def add_weight(w):
    K = w.size(0)
    plt.ylim(-1.2,1.2)
    plt.scatter(torch.arange(K), w, s=5)

In [16]:
L, K, N, Tr = 100, 100, 400, 25

In [17]:
taus = torch.randint(1, 2*Tr, (L,K))
origins = torch.randint(0, L, (L,K))

In [18]:
wlim = (-1, 1)
beta = -Tr / lambertw(-1e-2 / math.exp(1), -1).real
impulse_response = lambda t_: (t_>0) * t_ / beta * torch.exp(1 - t_ / beta)
theta, eta = 1, 1e-5
eps, dymin = 5, 1e-1

In [19]:
firing_sequences = backward_filtering_forward_sampling(1, L, N, Tr).view(L, N)
is_predictable(firing_sequences[None,...], Tr)

True

In [20]:
mask_at_firing = get_mask_at_firing(firing_sequences)
mask_refractory_period = get_mask_refractory_period(firing_sequences, Tr, eps)
mask_around_firing = get_mask_around_firing(firing_sequences, eps)
mask_before_firing = ~(mask_refractory_period | mask_around_firing)

In [21]:
C = compute_observation_matrix(firing_sequences.view(L,N), taus, origins, Tr, impulse_response)

In [22]:
# init with random weights between wmin and wmax
mw = torch.FloatTensor(L,1,K,1).uniform_(*wlim)
my = C @ mw

In [23]:
Vw_f, mw_f, Vy_b, my_b = compute_prior_messages(mw, my, wlim, theta, eta, Tr, eps, dymin, firing_sequences)

In [24]:
weight_frames = []
action_potential_frames = []

max_iter = 20

l = 42

for _ in tqdm(range(max_iter)):
    Vw_f, mw_f, Vy_b, my_b = compute_prior_messages(mw, my, wlim, theta, eta, Tr, eps, dymin, firing_sequences)
    mw, my = compute_posterior_means(Vw_f, mw_f, Vy_b, my_b, C)
    
    frame = add_weight(mw[l,0,:,0])
    weight_frames.append(frame)
    
    frame = add_action_potential(my[l,:,0,0], theta, torch.argwhere(mask_at_firing[l]), mask_around_firing[l], mask_refractory_period[l])
    action_potential_frames.append(frame)

100%|███████████████████████████████████████████| 20/20 [01:11<00:00,  3.59s/it]


In [25]:
gif.save(weight_frames, "am_learning_weight.gif", loop=True)
gif.save(action_potential_frames, "am_learning_action_potential.gif", loop=True)

# Check all conditions
## Bounded weights

In [26]:
mw.min() >= wlim[0] and mw.max() <= wlim[1]

tensor(True)

## Equality at firing

In [34]:
mask = get_mask_at_firing(firing_sequences)
(my[...,0,0][mask] - theta).square().sum()

tensor(0.0983)

## Slope around firing

In [35]:
mask = get_mask_around_firing(firing_sequences, eps)
(my[...,1,0][mask] - dymin).min()

tensor(-1.0935)

## Before firing

In [37]:
mask1 = get_mask_refractory_period(firing_sequences, Tr, eps)
mask2 = get_mask_around_firing(firing_sequences, eps)
(eta - my[...,0,0][~(mask1|mask2)]).max()

tensor(6.0077)