# Notebook for running Poisson SLDS with inputs
## One-hot encoding for each pattern
## Stim ONLY trials, no behavior

In [None]:
from scipy.linalg import block_diag
import autograd.numpy as np
import matplotlib.pyplot as plt
import ssm

from pathlib import Path
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from real_spike.utils import get_spike_events, kalman_filter, bin_spikes, butter_filter, plot_dynamics_2d, plot_dynamics_3d
from matplotlib import colormaps
import pandas as pd

from scipy.ndimage import gaussian_filter1d
from mpl_toolkits.mplot3d import Axes3D 

import random

%matplotlib inline

# Get data

In [None]:
mat = loadmat("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/MAT_FILES/rb50_20250127_datastruct_pt2.mat")
mat = mat['data']

print(mat.dtype.names)

In [None]:
data_struct = mat[0, 0]  # MATLAB structs are 2D arrays even if 1x1
data = {field: data_struct[field] for field in mat.dtype.names}

In [None]:
data.keys()

# Get stim only trials

In [None]:
stim_idxs = np.where(data["pattern_id"] > 2)[0]

# control_idxs = np.intersect1d(control_idxs, stim_idxs)
# con
stim_idxs = stim_idxs[stim_idxs > 150]
stim_idxs

In [None]:
np.unique(data["pattern_id"][stim_idxs])

In [None]:
stim_idxs.shape

## Get relevant time information

In [None]:
laser_times = data["laser_rec_time"]
laser_times.shape

## Get the AP.bin file

In [None]:
from real_spike.utils import get_sample_data, get_meta
import tifffile

In [None]:
file_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/rb50_20250127_g0/rb50_20250127_g0_t0.imec0.ap.bin")
meta_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/rb50_20250127_g0/rb50_20250127_g0_t0.imec0.ap.meta")

In [None]:
meta_data = get_meta(meta_path)

In [None]:
ap_data = get_sample_data(file_path, meta_data)
ap_data.shape

## Get conversion params

In [None]:
vmax = float(meta_data["imAiRangeMax"])
# get Imax
imax = float(meta_data["imMaxInt"])
# get gain
gain = float(meta_data['imroTbl'].split(sep=')')[1].split(sep=' ')[3])

In [None]:
vmax

In [None]:
imax

In [None]:
gain

# Get trials

In [None]:
def get_trials(idxs, bin_size):
    
    model_data = list()
    
    for i in tqdm(idxs):
        trial_no = i
        # get time points in ap space
        start_time = int((laser_times[:, trial_no] - 50) / 1_000 * 30_000)
        
        # get end of behavior = mouth (260ms) + 300ms 
        end_time = int((laser_times[:, trial_no] + 50) / 1_000 * 30_000)

    
        trial = ap_data[:150, start_time:end_time]
    
        conv_data = 1e6 * trial / vmax / imax / gain
    
        filt_data = butter_filter(conv_data, 1_000, 30_000)
    
        c_start = int(laser_times[:, trial_no] / 1_000 * 30_000)
        m_start = c_start - (30 * 2000)
        trial_median = ap_data[:150, m_start:c_start]
    
        trial_median = 1e6 * trial_median / vmax / imax / gain
        trial_median = butter_filter(trial_median, 1_000, 30_000)
        
        median = np.median(trial_median, axis=1)
    
        spike_ixs, counts = get_spike_events(filt_data, median)
        
        a = np.zeros((filt_data.shape[0], filt_data.shape[1]))
    
        for i, sc in enumerate(spike_ixs):
            a[i, sc] = 1
    
        b = bin_size * 30 # 30ms per bin
        binned_spikes = bin_spikes(a, b)

        model_data.append(np.asarray(binned_spikes.T, dtype=int))

    return model_data

In [None]:
bin_size = 1

In [None]:
#model_data = get_trials(stim_idxs, bin_size)

In [None]:
# save as a pickle file 
import pickle

In [None]:
filename = f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/{bin_size}ms_data.pkl"
# with open(filename, "wb") as file:
#     pickle.dump(model_data, file)

In [None]:
with open(filename, "rb") as file:
    model_data = pickle.load(file)

# Design the input matrix

In [None]:
# get the pattern types

In [None]:
p_ids = np.unique(data["pattern_id"][stim_idxs])

p_ids

In [None]:
p_ids.shape

## Get colors for plotting

In [None]:
c = [
    "red", "blue", "green", "orange", "purple", "brown", "pink", "gray", "olive",
    "cyan", "magenta", "gold", "teal", "navy", "maroon", "lime", "indigo", "coral",
    "turquoise", "salmon", "orchid", "chocolate", "crimson", "darkgreen",
    "mediumblue", "slategray", "deeppink"
]

## Create encodings 

In [None]:
encodings = np.eye(p_ids.shape[0])
encodings.shape

In [None]:
inputs = list()
colors = list()
for i, d in zip(stim_idxs, model_data):
    # get the pattern id 
    p_id = int(data["pattern_id"][i][0]) - 3

    colors.append(c[p_id])

    encoding = encodings[p_id] 

    # stack the encoding for every timepoint (each bin)
    nput = np.vstack([encoding] * d.shape[0])

    inputs.append(nput)

# Fit the model

In [None]:
state_dim = 3
obs_dim = model_data[0].shape[1] 

plds = ssm.LDS(obs_dim, state_dim, M=27, K=2, emissions="poisson", emission_kwargs=dict(link="softplus"))

elbos, q = plds.fit(model_data, inputs=inputs, method="laplace_em", num_iters=8)

# Visualize the results

In [None]:
# plot my elbos

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(elbos)

plt.xlabel("Iteration")
plt.ylabel("ELBO")

plt.title("ELBO Curve")

plt.show()

## Dynamics

In [None]:
A_est = plds.dynamics.A
b_est = plds.dynamics.b

In [None]:
fig = plt.figure(figsize=(16, 8))

# 2D subplot
ax1 = fig.add_subplot(1, 2, 1)
plot_dynamics_2d(A_est[:2, :2], b_est[:2], ax1, npts=12)
ax1.set_title('Dynamics 2D')
ax1.set_xlabel("$x_1$")
ax1.set_ylabel("$x_2$")

ax1.set_xticks([])
ax1.set_yticks([])

# 3D subplot
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
plot_dynamics_3d(A_est, b_est, ax2, npts=13, colors="blue")

ax2.set_title('Dynamics 3D')

ax2.set_xticks([])
ax2.set_yticks([])
ax2.set_zticks([])

ax2.set_xlabel("$x_1$")
ax2.set_ylabel("$x_2$")
ax2.set_zlabel("$x_3$")

plt.tight_layout()

#plt.savefig("/home/clewis/repos/realSpike/data/rb50_20250127/plds/inputs_dynamics.png")

plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/dynamics_{bin_size}ms.png")

plt.show()

## Posterior Means 2D

In [None]:
from matplotlib.patches import Patch

In [None]:
state_means = q.mean_continuous_states

In [None]:
custom_patches = [Patch(facecolor=c[i-3], edgecolor='black', label=f'Pattern {i}') for i in p_ids]

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, p in enumerate(state_means):

    p = gaussian_filter1d(p, 4, axis=0)

    axes[0].plot(p[:, 0], p[:, 1], c=colors[i], zorder=0, alpha=0.8)

  #  axes[0].scatter(p[0, 0], p[0, 1], s=35, c="black", zorder=1, alpha=1)
    axes[0].scatter(p[0, 0], p[0, 1], s=35, marker='o', c="black", zorder=1, alpha=1)
    axes[0].scatter(p[-1, 0], p[-1, 1], s=100, marker='*', c="black", zorder=1, alpha=1)

    axes[1].plot(p[:, 0], p[:, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[1].scatter(p[0, 0], p[0, 2], s=35, c="black", zorder=1, alpha=1)

    axes[1].scatter(p[-1, 0], p[-1, 2], s=100, marker='*', c="black", zorder=1, alpha=1)

    axes[2].plot(p[:, 1], p[:, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[2].scatter(p[0, 1], p[0, 2], s=35, c="black", zorder=1, alpha=1)
    axes[2].scatter(p[-1, 1], p[-1, 2], s=100, marker='*', c="black", zorder=1, alpha=1)




for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Posterior State Means ({bin_size}ms)")
    ax.legend(handles=custom_patches)


# 0, 0 = x1 vs x2
axes[0].set_xlabel("$x_1$")
axes[0].set_ylabel("$x_2$")


# 0, 1 = x1 vs x3
axes[1].set_xlabel("$x_1$")
axes[1].set_ylabel("$x_3$")

# 0, 2 = x2 vs x3
axes[2].set_xlabel("$x_2$")
axes[2].set_ylabel("$x_3$")

plt.tight_layout()


#plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/state_means_{bin_size}ms_2D.png")

## Posterior Means 3D

In [None]:
# a single 3d plot of all 

In [None]:
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')

for i, p in enumerate(state_means):

    p = gaussian_filter1d(p, 4, axis=0)

    ax.plot(p[:, 0], p[:, 1], p[:, 2], c=colors[i], zorder=0, alpha=0.8)

    ax.scatter(p[0, 0], p[0, 1], p[0, 2], s=35, c="black", zorder=1, alpha=1)
    #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
    ax.scatter(p[-1, 0], p[-1, 1], p[-1, 2], s=100, marker='*', c="black", zorder=1, alpha=1)

ax.set_title(f"Posterior State Means ({bin_size}ms)")

ax.legend(handles=custom_patches, loc='center left')


ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$x_3$")


plt.tight_layout()

#plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/state_means_{bin_size}ms_3D.png")

# Plot individual patterns together in subplots

In [None]:
p_ids

In [None]:
stim_idxs

In [None]:
# create a subplot for each pattern
fig, axes = plt.subplots(3, 9, figsize=(21, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, j in enumerate(p_ids):
    # get trials where this pattern happens 
    idxs = np.where(data["pattern_id"][stim_idxs] == j)[0]

    for z in idxs:
        p = gaussian_filter1d(state_means[z], 4, axis=0)
        t_end = 50 + 5 + 10 + 1
        axes[i].plot(p[50:t_end, 0], p[50:t_end, 1], c=colors[i], zorder=0, alpha=0.8)
        axes[i].scatter(p[50, 0], p[50, 1], s=15, c="black", zorder=1, alpha=1)
        axes[i].scatter(p[55, 0], p[55, 1], s=15, marker='s', c="black", zorder=1, alpha=1)
        axes[i].scatter(p[t_end-1, 0], p[t_end-1, 1], s=25, marker='*', c="black", zorder=1, alpha=1)

        axes[i].set_title(f"Pattern {j}")

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")


#plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/state_means_patterns_{bin_size}ms_2D.png")

In [None]:
# create a subplot for each pattern
fig, axes = plt.subplots(3, 9, figsize=(21, 10), subplot_kw={"projection": "3d"})

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, j in enumerate(p_ids):
    # get trials where this pattern happens 
    idxs = np.where(data["pattern_id"][stim_idxs] == j)[0]

    for z in idxs:
        t_end = 50 + 5 + 10 + 1
        p = gaussian_filter1d(state_means[z], 4, axis=0)
        axes[i].plot(p[50:t_end, 0], p[50:t_end, 1], p[50:t_end, 2], c=colors[i], zorder=0, alpha=0.8)
        axes[i].scatter(p[50, 0], p[50, 1], p[50, 2], s=15, c="black", zorder=1, alpha=1)
        axes[i].scatter(p[55, 0], p[55, 1], p[55, 2], s=5, c="black", marker="s", zorder=1, alpha=1)
        axes[i].scatter(p[t_end-1, 0], p[t_end-1, 1], p[t_end-1, 2], s=25, marker='*', c="black", zorder=1, alpha=1)

        axes[i].set_title(f"Pattern {j}")


for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.set_xlabel("x1", labelpad=-15)
    ax.set_ylabel("x2", labelpad=-15)
    ax.set_zlabel("x3", labelpad=-15)


plt.tight_layout()


#plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/state_means_patterns_{bin_size}ms_3D.png")

# Plot from stim to 10ms after stim

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, j in enumerate(p_ids):
    # get trials where this pattern happens 
    idxs = np.where(data["pattern_id"][stim_idxs] == j)[0]

    d = [state_means[z] for z in idxs]

    d = np.array(d).mean(axis=0)

    p = gaussian_filter1d(d, 4, axis=0)

    t_end = 55 + 10 + 1

    axes[0].plot(p[50:t_end, 0], p[50:t_end, 1], c=colors[i], zorder=0, alpha=0.8)
    axes[0].scatter(p[50, 0], p[50, 1], s=35, marker='o', c="black", zorder=1, alpha=1)
    axes[0].scatter(p[55, 0], p[55, 1], s=35, marker='s', c="black", zorder=1, alpha=1)
    axes[0].scatter(p[t_end-1, 0], p[t_end-1, 1], s=55, marker='*', c="black", zorder=1, alpha=1)

    axes[1].plot(p[50:t_end, 0], p[50:t_end, 2], c=colors[i], zorder=0, alpha=0.8)
    axes[1].scatter(p[50, 0], p[50, 2], s=35, marker='o', c="black", zorder=1, alpha=1)
    axes[1].scatter(p[55, 0], p[55, 2], s=35, marker='s', c="black", zorder=1, alpha=1)
    axes[1].scatter(p[t_end-1, 0], p[t_end-1, 2], s=55, marker='*', c="black", zorder=1, alpha=1)

    axes[2].plot(p[50:t_end, 1], p[50:t_end, 2], c=colors[i], zorder=0, alpha=0.8)
    axes[2].scatter(p[50, 1], p[50, 2], s=35, marker='o', c="black", zorder=1, alpha=1)
    axes[2].scatter(p[55, 1], p[55, 2], s=35, marker='s', c="black", zorder=1, alpha=1)
    axes[2].scatter(p[t_end-1, 1], p[t_end-1, 2], s=55, marker='*', c="black", zorder=1, alpha=1)



for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Posterior State Means ({bin_size}ms)")
    ax.legend(handles=custom_patches)


# 0, 0 = x1 vs x2
axes[0].set_xlabel("$x_1$")
axes[0].set_ylabel("$x_2$")


# 0, 1 = x1 vs x3
axes[1].set_xlabel("$x_1$")
axes[1].set_ylabel("$x_3$")

# 0, 2 = x2 vs x3
axes[2].set_xlabel("$x_2$")
axes[2].set_ylabel("$x_3$")

plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')

for i, j in enumerate(p_ids):
    # get trials where this pattern happens 
    idxs = np.where(data["pattern_id"][stim_idxs] == j)[0]

    d = [state_means[z] for z in idxs]

    d = np.array(d).mean(axis=0)

    p = gaussian_filter1d(d, 4, axis=0)

    t_end = 55 + 10 + 1

    ax.plot(p[50:t_end, 0], p[50:t_end, 1], p[50:t_end, 2], c=colors[i], zorder=0, alpha=0.8)

    ax.scatter(p[50, 0], p[50, 1], p[50, 2], s=35, c="black", zorder=1, alpha=1)
    ax.scatter(p[55, 0], p[55, 1], p[55, 2], marker="s", s=35, c="black", zorder=1, alpha=1)
    ax.scatter(p[t_end-1, 0], p[t_end-1, 1], p[t_end-1, 2], marker="*", s=55, c="black", zorder=1, alpha=1)

ax.set_title(f"Posterior State Means Pattern Averaged ({bin_size}ms)")

ax.legend(handles=custom_patches, loc='center left')


ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$x_3$")


plt.tight_layout()

# Pattern average

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, j in enumerate(p_ids):
    # get trials where this pattern happens 
    idxs = np.where(data["pattern_id"][stim_idxs] == j)[0]

    d = [state_means[z] for z in idxs]

    d = np.array(d).mean(axis=0)

    p = gaussian_filter1d(d, 3, axis=0)

    axes[0].plot(p[:, 0], p[:, 1], c=colors[i], zorder=0, alpha=0.8)

  #  axes[0].scatter(p[0, 0], p[0, 1], s=35, c="black", zorder=1, alpha=1)
    axes[0].scatter(p[0, 0], p[0, 1], s=35, marker='o', c="black", zorder=1, alpha=1)
    axes[0].scatter(p[-1, 0], p[-1, 1], s=100, marker='*', c="black", zorder=1, alpha=1)

    axes[1].plot(p[:, 0], p[:, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[1].scatter(p[0, 0], p[0, 2], s=35, c="black", zorder=1, alpha=1)

    axes[1].scatter(p[-1, 0], p[-1, 2], s=100, marker='*', c="black", zorder=1, alpha=1)

    axes[2].plot(p[:, 1], p[:, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[2].scatter(p[0, 1], p[0, 2], s=35, c="black", zorder=1, alpha=1)
    axes[2].scatter(p[-1, 1], p[-1, 2], s=100, marker='*', c="black", zorder=1, alpha=1)




for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Posterior State Means ({bin_size}ms)")
    ax.legend(handles=custom_patches)


# 0, 0 = x1 vs x2
axes[0].set_xlabel("$x_1$")
axes[0].set_ylabel("$x_2$")


# 0, 1 = x1 vs x3
axes[1].set_xlabel("$x_1$")
axes[1].set_ylabel("$x_3$")

# 0, 2 = x2 vs x3
axes[2].set_xlabel("$x_2$")
axes[2].set_ylabel("$x_3$")

plt.tight_layout()

#plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/state_means_pattern_avg_{bin_size}ms_2D.png")

In [None]:
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')

for i, j in enumerate(p_ids):
    # get trials where this pattern happens 
    idxs = np.where(data["pattern_id"][stim_idxs] == j)[0]

    d = [state_means[z] for z in idxs]

    d = np.array(d).mean(axis=0)

    p = gaussian_filter1d(d, 3, axis=0)

    ax.plot(p[:, 0], p[:, 1], p[:, 2], c=colors[i], zorder=0, alpha=0.8)

    ax.scatter(p[0, 0], p[0, 1], p[0, 2], s=35, c="black", zorder=1, alpha=1)
    # #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
    ax.scatter(p[-1, 0], p[-1, 1], p[-1, 2], s=100, marker='*', c="black", zorder=1, alpha=1)

ax.set_title(f"Posterior State Means Pattern Averaged ({bin_size}ms)")

ax.legend(handles=custom_patches, loc='center left')


ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$x_3$")


plt.tight_layout()

#plt.savefig(f"/home/clewis/repos/realSpike/data/rb50_20250127/plds_stim/state_means_pattern_avg_{bin_size}ms_3D.png")

# 