## Inspection plot for assimiliation and data simulation

In [None]:
from __future__ import print_function, division
import sys
import imp
import os
import numpy as np
from munch import Munch
from dstrf import spikes, strf, filters, mle
from corner import corner

# plotting packages
%matplotlib notebook
import matplotlib as mpl
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("ticks")
sns.set_context("paper", font_scale=0.7)
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 0.5
mpl.rcParams['axes.linewidth'] = 0.5
mpl.rcParams['xtick.major.width'] = 0.5
mpl.rcParams['ytick.major.width'] = 0.5
mpl.rcParams['xtick.major.size'] = 1.5
mpl.rcParams['ytick.major.size'] = 1.5
#print(mpl.rcParams.keys())
outdir = os.path.join("..", "figures")
est_clr = ["darkmagenta", "goldenrod", "darkcyan"]
names = ["posp", "tonic", "phasic"]

In [None]:
results = "../results/song_phasic_samples_cold.npz"
data = np.load(results)
with open("../config/song_dynamical.yml", "rt") as fp:
    cf = Munch.fromYAML(fp)

In [None]:
stim = data["stim"]
spike_v = data["spike_v"]
duration = data["duration"]
krank = cf.model.filter.rank
kcosbas = strf.cosbasis(cf.model.filter.len, cf.model.filter.ncos)
stim_dt = cf.data.dt
model_dt = cf.model.dt

In [None]:
bins, ntrials = spikes.shape
t_stim = np.linspace(0, duration, stim.shape[1])
t_spike = np.linspace(0, duration, spikes.shape[0])

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(16, 4))

axes[0].imshow(stim, 
               extent=(0, duration, cf.data.stimulus.spectrogram.f_min, cf.data.stimulus.spectrogram.f_max),
               cmap='jet', origin='lower', aspect='auto')
for i in range(ntrials):
    spk = np.nonzero(spikes[:, i])[0]
    axes[1].vlines(spk * cf.model.dt, i, i + 0.5)

axes[0].set_xlim(0, duration);

In [None]:
w0 = data["mle"]
print("MLE rate and adaptation parameters:", w0[:3])

fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(6, 6))

k1, k1t, k1f = simulate.get_filter(cf)
k1c = strf.to_basis(k1, kcosbas)
rf_mle = strf.from_basis(strf.defactorize(w0[3:], cf.data.filter.nfreq, krank), kcosbas)

axes[0, 0].imshow(k1, cmap='jet', aspect='auto')
axes[0, 0].set_title("True RF")
axes[1, 0].imshow(rf_mle, cmap='jet', aspect='auto')
axes[1, 0].set_title("MLE (rank-{})".format(krank));

try:
    w1 = np.median(data["samples"], 0)
    print("MLE rate and adaptation parameters:", w1[:3])
    rf_map = strf.from_basis(strf.defactorize(w1[3:], cf.data.filter.nfreq, krank), kcosbas)
    axes[1, 1].imshow(rf_mle, cmap='jet', aspect='auto')
    axes[1, 1].set_title("MLE (rank-{})".format(krank))
except KeyError:
    pass