# Notebook for running Poisson SLDS 

## want to see the system come back to rest
## bin size = 10ms

In [None]:
from scipy.linalg import block_diag
import autograd.numpy as np
import matplotlib.pyplot as plt
import ssm

from pathlib import Path
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from real_spike.utils import get_spike_events, kalman_filter, bin_spikes, butter_filter, plot_dynamics_2d, plot_dynamics_3d
from matplotlib import colormaps
import pandas as pd

from scipy.ndimage import gaussian_filter1d
from mpl_toolkits.mplot3d import Axes3D 

import random

%matplotlib inline

# Get data

In [None]:
mat = loadmat("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/MAT_FILES/rb50_20250127_datastruct_pt2.mat")
mat = mat['data']

print(mat.dtype.names)

In [None]:
data_struct = mat[0, 0]  # MATLAB structs are 2D arrays even if 1x1
data = {field: data_struct[field] for field in mat.dtype.names}

In [None]:
data.keys()

# Get single-reach trials

In [None]:
control_idxs = np.where(data["pattern_id"] == 0)[0]
control_idxs

In [None]:
stim_idxs = np.where(data["pattern_id"][:150] > 2)[0]

# control_idxs = np.intersect1d(control_idxs, stim_idxs)
# con
control_idxs = np.sort(np.concatenate((control_idxs, stim_idxs)))

In [None]:
# get no laser trials
control_idxs = np.intersect1d(control_idxs, np.where(data["single"] == 1)[1])

In [None]:
#control_idxs = np.sort(np.concatenate((a, b)))
control_idxs

In [None]:
np.unique(data["pattern_id"][control_idxs])

## Get relevant time information

In [None]:
cue_times = data["cue_rec_time"][:, :150]
cue_times[:, control_idxs]

In [None]:
lift_times = data["lift_ms"][:, :150]
lift_times[:, control_idxs]

In [None]:
mouth_times = data["mouth_ms"][:, :150]
mouth_times[:, control_idxs]

## Get the AP.bin file

In [None]:
from real_spike.utils import get_sample_data, get_meta
import tifffile

In [None]:
file_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/rb50_20250127_g0/rb50_20250127_g0_t0.imec0.ap.bin")
meta_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250127/rb50_20250127_g0/rb50_20250127_g0_t0.imec0.ap.meta")

In [None]:
meta_data = get_meta(meta_path)

In [None]:
ap_data = get_sample_data(file_path, meta_data)
ap_data.shape

## Get conversion params

In [None]:
vmax = float(meta_data["imAiRangeMax"])
# get Imax
imax = float(meta_data["imMaxInt"])
# get gain
gain = float(meta_data['imroTbl'].split(sep=')')[1].split(sep=' ')[3])

In [None]:
vmax

In [None]:
imax

In [None]:
gain

# Get trials

In [None]:
def get_trials(idxs, bin_size):
    
    model_data = list()
    
    for i in tqdm(idxs):
        trial_no = i
        # get time points in ap space
        lift_time = int((cue_times[:, trial_no] + lift_times[:, trial_no] - 50) / 1_000 * 30_000)
        
        # get end of behavior = mouth (260ms) + 300ms 
        end_behavior = int((cue_times[:, trial_no] + mouth_times[:, trial_no]) / 1_000 * 30_000)

    
        trial = ap_data[:150, lift_time:end_behavior]
    
        conv_data = 1e6 * trial / vmax / imax / gain
    
        filt_data = butter_filter(conv_data, 1_000, 30_000)
    
        c_start = int(cue_times[:, trial_no] / 1_000 * 30_000)
        m_start = c_start - (30 * 2000)
        trial_median = ap_data[:150, m_start:c_start]
    
        trial_median = 1e6 * trial_median / vmax / imax / gain
        trial_median = butter_filter(trial_median, 1_000, 30_000)
        
        median = np.median(trial_median, axis=1)
    
        spike_ixs, counts = get_spike_events(filt_data, median)
        
        a = np.zeros((filt_data.shape[0], filt_data.shape[1]))
    
        for i, sc in enumerate(spike_ixs):
            a[i, sc] = 1
    
        b = bin_size * 30 # 30ms per bin
        binned_spikes = bin_spikes(a, b)
    
        lift_start = int(50 / bin_size)

        lift = int(data["lift_ms"][0, trial_no] / bin_size) 
        # grab duration
        grab = int(data["grab_ms"][0, trial_no] / bin_size)
        # mouth duration 
        mouth = int(data["mouth_ms"][0, trial_no] / bin_size)
    
        lift_start = int(50 / bin_size)
        grab_start = lift_start + (grab - lift)
        mouth_start = lift_start + (mouth-lift)

        model_data.append(np.asarray(binned_spikes[:, lift_start:].T, dtype=int))

    return model_data

In [None]:
bin_size = 1

In [None]:
model_data = get_trials(control_idxs, bin_size)

# Fit the model

In [None]:
state_dim = 3
obs_dim = model_data[0].shape[1] 

plds = ssm.LDS(obs_dim, state_dim, K=2, M=0, emissions="poisson", emission_kwargs=dict(link="softplus"))

elbos, q = plds.fit(model_data, method="laplace_em", num_iters=8)

# Visualize the results

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(elbos)

plt.xlabel("Iteration")
plt.ylabel("ELBO")

plt.title("ELBO Curve")

plt.show()

## Dynamics

In [None]:
A_est = plds.dynamics.A
b_est = plds.dynamics.b

In [None]:
fig = plt.figure(figsize=(16, 8))

# 2D subplot
ax1 = fig.add_subplot(1, 2, 1)
plot_dynamics_2d(A_est[:2, :2], b_est[:2], ax1, npts=12)
ax1.set_title('Dynamics 2D')
ax1.set_xlabel("$x_1$")
ax1.set_ylabel("$x_2$")

ax1.set_xticks([])
ax1.set_yticks([])

# 3D subplot
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
plot_dynamics_3d(A_est, b_est, ax2, npts=13, colors="blue")

ax2.set_title('Dynamics 3D')

ax2.set_xticks([])
ax2.set_yticks([])
ax2.set_zticks([])

ax2.set_xlabel("$x_1$")
ax2.set_ylabel("$x_2$")
ax2.set_zlabel("$x_3$")

plt.tight_layout()

#plt.savefig("/home/clewis/repos/realSpike/data/rb50_20250127/plds/dynamics.png")

plt.show()

## Posterior Means 2D

In [None]:
state_means = q.mean_continuous_states

In [None]:
state_means[0].shape

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, p in enumerate(state_means):

    lift = int(data["lift_ms"][0, control_idxs[i]] / bin_size) 
    # grab duration
    grab = int(data["grab_ms"][0,control_idxs[i]] / bin_size)
    # mouth duration 
    mouth = int(data["mouth_ms"][0,control_idxs[i]] / bin_size)

    lift_start = int(50 / bin_size)
    grab_start = lift_start + (grab - lift)
    mouth_start = lift_start + (mouth-lift)

    p = gaussian_filter1d(p, 4, axis=0)

   # axes[0].plot(p[::10, 0], p[::10, 1], c="black", zorder=0, alpha=0.8)

    axes[0].scatter(p[0, 0], p[0, 1], s=35, c="teal", zorder=1, alpha=1)
    #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
   #axes[0].scatter(p[mouth_start+1, 0], p[mouth_start+1, 1], s=100, marker='*', c="magenta", zorder=1, alpha=1)
    axes[0].scatter(p[-1, 0], p[-1, 1], s=50, marker='s', c="orange", zorder=1, alpha=1)

 #    axes[1].plot(p[lift_start:, 0], p[lift_start:, 2], c="black", zorder=0, alpha=0.8)

    axes[1].scatter(p[0, 0], p[0, 2], s=35, c="teal", zorder=1, alpha=1)
 #    #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
 #  #  axes[1].scatter(p[mouth_start+1, 0], p[mouth_start+1, 2], s=100, marker='*', c="magenta", zorder=1, alpha=1)
    axes[1].scatter(p[-1, 0], p[-1, 2], s=50, marker='s', c="orange", zorder=1, alpha=1)

 #    axes[2].plot(p[lift_start:, 1], p[lift_start:, 2], c="black", zorder=0, alpha=0.8)

    axes[2].scatter(p[0, 1], p[0, 2], s=35, c="teal", zorder=1, alpha=1)
 #    #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
 # #   axes[2].scatter(p[mouth_start+1, 1], p[mouth_start+1, 2], s=100, marker='*', c="magenta", zorder=1, alpha=1)
    axes[2].scatter(p[-1, 1], p[-1, 2], s=50, marker='s', c="orange", zorder=1, alpha=1)


for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Posterior State Means ({bin_size}ms)")


# 0, 0 = x1 vs x2
axes[0].set_xlabel("$x_1$")
axes[0].set_ylabel("$x_2$")


# 0, 1 = x1 vs x3
axes[1].set_xlabel("$x_1$")
axes[1].set_ylabel("$x_3$")

# 0, 2 = x2 vs x3
axes[2].set_xlabel("$x_2$")
axes[2].set_ylabel("$x_3$")

plt.tight_layout()

#plt.savefig("/home/clewis/repos/realSpike/data/rb50_20250127/plds/state_means_2D.png")

## Posterior Means 3D

In [None]:
# a single 3d plot of all 

In [None]:
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')

for i, p in enumerate(state_means):

    lift = int(data["lift_ms"][0, control_idxs[i]] / bin_size) 
    # grab duration
    grab = int(data["grab_ms"][0,control_idxs[i]] / bin_size)
    # mouth duration 
    mouth = int(data["mouth_ms"][0,control_idxs[i]] / bin_size)

    lift_start = int(50 / bin_size)
    grab_start = lift_start + (grab - lift)
    mouth_start = lift_start + (mouth-lift)

    p = gaussian_filter1d(p, 4, axis=0)

    #ax.plot(p[lift_start:, 0], p[lift_start:, 1], p[lift_start:, 2], c="black", zorder=0, alpha=0.8)

    ax.scatter(p[0, 0], p[0, 1], p[0, 2], s=35, c="teal", zorder=1, alpha=1)
    #ax.scatter(p[grab - lift + lift_start, 0], p[grab-lift +lift_start, 1], p[grab-lift +lift_start, 2], s=10, c="red")
  #  ax.scatter(p[mouth_start, 0], p[mouth_start, 1], p[mouth_start, 2], s=100, marker='*', c="magenta", zorder=1, alpha=1)
    ax.scatter(p[-1, 0], p[-1, 1], p[-1, 2], s=50, marker='s', c="orange", zorder=1, alpha=1)

ax.set_title(f"Posterior State Means ({bin_size}ms)")


ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$x_3$")


plt.tight_layout()

#plt.savefig("/home/clewis/repos/realSpike/data/rb50_20250127/plds/state_means_3D.png")