# Notebook for doing parameter estimation of LDS with Poisson observations using `ssm` from Linderman lab

In [None]:
from scipy.linalg import block_diag
import autograd.numpy as np
import matplotlib.pyplot as plt
import ssm

from pathlib import Path
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from real_spike.utils import get_spike_events, kalman_filter, bin_spikes, butter_filter
from scipy.ndimage import gaussian_filter1d
from matplotlib import colormaps
from ssm.plots import plot_dynamics_2d

%matplotlib inline

# Get data

In [None]:
mat = loadmat("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/MAT_FILES/rb50_20250127_datastruct_pt2.mat")
mat = mat['data']

print(mat.dtype.names)

In [None]:
data_struct = mat[0, 0]  # MATLAB structs are 2D arrays even if 1x1
data = {field: data_struct[field] for field in mat.dtype.names}

In [None]:
data.keys()

# Get single-reach trials

In [None]:
control_idxs = np.where(data["pattern_id"] == 0)[0]
control_idxs

In [None]:
stim_idxs = np.where(data["pattern_id"][:150] > 2)[0]

# control_idxs = np.intersect1d(control_idxs, stim_idxs)
# con
control_idxs = np.sort(np.concatenate((control_idxs, stim_idxs)))

In [None]:
# get no laser trials
control_idxs = np.intersect1d(control_idxs, np.where(data["single"] == 1)[1])

In [None]:
#control_idxs = np.sort(np.concatenate((a, b)))
control_idxs

In [None]:
np.unique(data["pattern_id"][control_idxs])

## Get relevant time information

In [None]:
cue_times = data["cue_rec_time"][:, :150]
cue_times[:, control_idxs]

In [None]:
lift_times = data["lift_ms"][:, :150]
lift_times[:, control_idxs]

In [None]:
mouth_times = data["mouth_ms"][:, :150]
mouth_times[:, control_idxs]

## Get the AP.bin file

In [None]:
from real_spike.utils import get_sample_data, get_meta
import tifffile

In [None]:
file_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/rb50_20250127_g0/rb50_20250127_g0_t0.imec0.ap.bin")
meta_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/rb50_20250127_g0/rb50_20250127_g0_t0.imec0.ap.meta")

In [None]:
meta_data = get_meta(meta_path)

In [None]:
ap_data = get_sample_data(file_path, meta_data)
ap_data.shape

## Get conversion params

In [None]:
vmax = float(meta_data["imAiRangeMax"])
# get Imax
imax = float(meta_data["imMaxInt"])
# get gain
gain = float(meta_data['imroTbl'].split(sep=')')[1].split(sep=' ')[3])

In [None]:
vmax

In [None]:
imax

In [None]:
gain

# Get trials

In [None]:
all_data = list()

# HYPERPARAMETERS TO TOGGLE
bin_size = 5

p_colors = list()
c = {0: "indigo", 14: "teal", 17: "magenta", 20: "orange"}

for i in tqdm(control_idxs):
    trial_no = i
    p_colors.append(c[data["pattern_id"][i][0]])
    # get time points in ap space
    lift_time = int((cue_times[:, trial_no] + lift_times[:, trial_no] - 50) / 1_000 * 30_000)
    
    # get end of behavior = mouth (260ms) + 300ms 
    end_behavior = int((cue_times[:, trial_no] + mouth_times[:, trial_no] + 260) / 1_000 * 30_000)
    #print(lift_time, end_behavior)

    trial = ap_data[:150, lift_time:end_behavior]

    conv_data = 1e6 * trial / vmax / imax / gain

    filt_data = butter_filter(conv_data, 1_000, 30_000)

    c_start = int(cue_times[:, trial_no] / 1_000 * 30_000)
    m_start = c_start - (30 * 2000)
    trial_median = ap_data[:150, m_start:c_start]

    trial_median = 1e6 * trial_median / vmax / imax / gain
    trial_median = butter_filter(trial_median, 1_000, 30_000)
    
    median = np.median(trial_median, axis=1)

    spike_ixs, counts = get_spike_events(filt_data, median)
    
    a = np.zeros((filt_data.shape[0], filt_data.shape[1]))

    for i, sc in enumerate(spike_ixs):
        a[i, sc] = 1

    b = bin_size * 30 # 30ms per bin
    binned_spikes = bin_spikes(a, b)

    sigma = 5

    #smoothed = gaussian_filter1d(binned_spikes, sigma=sigma, axis=1)

    all_data.append(binned_spikes)

# Shorten trial to just around dynamics we care about

In [None]:
model_data = list()

lift_start = int(50 / bin_size)
after_lift = int(300 / bin_size) + lift_start


for i, d in enumerate(all_data):
    # smooth spikes
    #d = gaussian_filter1d(d, sigma=5, axis=1)
    lift = int(data["lift_ms"][0, control_idxs[i]] / bin_size) 
    # grab duration
    grab = int(data["grab_ms"][0,control_idxs[i]] / bin_size)
    # mouth duration 
    mouth = int(data["mouth_ms"][0,control_idxs[i]] / bin_size)

    lift_start = int(50 / bin_size)
    grab_start = lift_start + (grab - lift)
    mouth_start = lift_start + (mouth-lift)
    
    model_data.append(np.asarray(d[:, lift_start:mouth_start+1].T, dtype=int))

In [None]:
for d in model_data:
    print(d.shape)

# Visual trials

In [None]:
import random

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(20, 8))

# Flatten axes array for easy iteration
axes = axes.flatten()

ixs = random.sample(range(0, len(all_data)), 10)
ixs.sort()

for i, x in enumerate(ixs):
    ax = axes[i]
    d = ax.imshow(model_data[x].T, aspect="auto", interpolation="none", cmap="inferno")
    ax.set_xlabel("Times (ms)")
    ax.set_ylabel("Channel")

    ax.axvline(x=10, linestyle='--', color='red')

    fig.colorbar(d, ax=ax)


    ax.set_title(f"Trial {control_idxs[x]}") 


# Adjust layout
plt.tight_layout()

#fig.colorbar()

plt.savefig("/home/clewis/repos/realSpike/data/rb50_20250127/binned_spikes_heatmap.png")

plt.show()

# Create model

In [None]:
control_idxs.shape

In [None]:
len(all_data)

In [None]:
control_idxs[0]

In [None]:
# smooth the data 
model_data = list()

lift_start = int(50 / bin_size)
after_lift = int(300 / bin_size) + lift_start


for i, d in enumerate(all_data):

    lift = int(data["lift_ms"][0, control_idxs[i]] / bin_size) 
    # grab duration
    grab = int(data["grab_ms"][0,control_idxs[i]] / bin_size)
    # mouth duration 
    mouth = int(data["mouth_ms"][0,control_idxs[i]] / bin_size)

    lift_start = int(50 / bin_size)
    grab_start = lift_start + (grab - lift)
    mouth_start = lift_start + (mouth-lift)
    
    # smooth spikes
    d = gaussian_filter1d(d, sigma=8, axis=1)
    model_data.append(np.asarray(d[:, lift_start:mouth_start+1].T, dtype=int))

In [None]:
state_dim = 3
obs_dim = model_data[0].shape[1] 

(state_dim, obs_dim)

In [None]:
plds = ssm.LDS(obs_dim, state_dim, emissions="poisson_orthog", emission_kwargs=dict(link="softplus"))

# set bias vector to 0
#plds.emissions.ds = 0 * np.ones(obs_dim)

# Fit the model

In [None]:
elbos, q = plds.fit(model_data, method="laplace_em", num_iters=15)

In [None]:
# Plot the ELBOs
plt.plot(elbos, label="Laplace-EM")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.legend()

plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_elbo_{bin_size}ms.png")

In [None]:
# Extract dynamics matrix from the true model
# A_true = true_lds.dynamics.A
# b_true = true_lds.dynamics.b

A_est = plds.dynamics.A
b_est = plds.dynamics.b

f, ax = plt.subplots(1, 1, figsize=(6, 6))
# plot_dynamics_2d(A_true, b_true, npts=10, axis=ax[0], color=colors[0])
# ax[0].set_xlabel("$x_1$")
# ax[0].set_ylabel("$x_2$")
# ax[0].set_title("True Dynamics")

plot_dynamics_2d(A_est[:2, :2] , b_est[:2], npts=15, axis=ax, color="red")
#plt.plot(states_plds[:,0], states_plds[:,1], '-k', lw=3)
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_title("Inferred Dynamics")

plt.tight_layout()

plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_dynamics_{bin_size}ms.png")
plt.show()

##  estimated expected values of the latent (hidden) continuous variables given the observed data and the current model parameters

The posterior mean of the continuous states is the expected value:

This mean can be interpreted as the best guess (under mean squared error) of the hidden latent state trajectory explaining the observed data, taking into account both the system dynamics and the Poisson likelihood of the observations.

In [None]:
# Get the posterior mean of the continuous states
state_means = q.mean_continuous_states

# Plot all the posterior means together 2D

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)

for i, p in enumerate(state_means):

    lift = int(data["lift_ms"][0, control_idxs[i]] / bin_size) 
    # grab duration
    grab = int(data["grab_ms"][0,control_idxs[i]] / bin_size)
    # mouth duration 
    mouth = int(data["mouth_ms"][0,control_idxs[i]] / bin_size)

    lift_start = int(50 / bin_size)
    grab_start = lift_start + (grab - lift)
   # mouth_start = lift_start + (mouth-lift)

    ax.plot(p[lift_start:, 0], p[lift_start:, 1], c="black", zorder=0, alpha=0.8)

    ax.scatter(p[lift_start, 0], p[lift_start, 1], s=35, c="teal", zorder=1, alpha=1)
    #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
    ax.scatter(p[-1, 0], p[-1, 1], s=100, marker='*', c="magenta", zorder=1, alpha=1)

ax.set_title(f"Posterior State Means ({bin_size}ms)")



ax.set_xticks([])
ax.set_yticks([])


plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_posterior_means_{bin_size}ms_2d.png")
plt.show()

# Plot all the posterior means together 3D

In [None]:
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')

for i, p in enumerate(state_means):

    lift = int(data["lift_ms"][0, control_idxs[i]] / bin_size) 
    # grab duration
    grab = int(data["grab_ms"][0,control_idxs[i]] / bin_size)
    # mouth duration 
    mouth = int(data["mouth_ms"][0,control_idxs[i]] / bin_size)

    lift_start = int(50 / bin_size)
    grab_start = lift_start + (grab - lift)
   # mouth_start = lift_start + (mouth-lift)

    ax.plot(p[lift_start:, 0], p[lift_start:, 1], p[lift_start:, 2], c="black", zorder=0, alpha=0.8)

    ax.scatter(p[lift_start, 0], p[lift_start, 1], p[lift_start, 2], s=35, c="teal", zorder=1, alpha=1)
    #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
    ax.scatter(p[-1, 0], p[-1, 1], p[-1, 2], s=100, marker='*', c="magenta", zorder=1, alpha=1)

ax.set_title(f"Posterior State Means ({bin_size}ms)")


ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_posterior_means_{bin_size}ms_3d.png")
plt.show()