# Poisson rSLDS 
## Behaving dataset
## Stim is as cue, system should come back to rest before lift 
## Fitting model from cue to grab, can I see 2 phases of dynamics? 

In [None]:
from scipy.linalg import block_diag
import autograd.numpy as np
import matplotlib.pyplot as plt
import ssm

from pathlib import Path
from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from real_spike.utils import get_spike_events, kalman_filter, bin_spikes, butter_filter, plot_dynamics_2d, plot_dynamics_3d
from matplotlib import colormaps
import pandas as pd

from scipy.ndimage import gaussian_filter1d
from mpl_toolkits.mplot3d import Axes3D 

import random
import h5py

%matplotlib inline

In [None]:
path = f"/home/clewis/repos/realSpike/data/behavior/rb50_20250125/"

# Get data

In [None]:
f = h5py.File("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250125/MAT_FILES/rb50_20250125_datastruct_pt3.mat", 'r')
data = f['data']
print(data.keys())

# Plot the patterns

In [None]:
# visualize the patterns again
pattern_ids = list(np.unique(data['pattern_id'][:]))
len(pattern_ids)

In [None]:
from scipy.ndimage import zoom

In [None]:
reshape_size = 12

In [None]:
all_patterns = list()

for p_id in pattern_ids[3:]:
    ix = np.where(data['pattern_id'][:] == p_id)[1][0].astype(np.int32)
    pattern = np.zeros((data['pattern_xy'][0][ix].astype(np.int32), data['pattern_xy'][1][ix].astype(np.int32)))
    # use the pattern fill to set elements to one
    object_ref = f[data['pattern_fill'][ix, 0]]
    for x, y in zip(object_ref[0, :], object_ref[1, :]):
        pattern[x.astype(np.int32)-1, y.astype(np.int32)-1] = 1

    pattern = zoom(pattern, (reshape_size / pattern.shape[0], reshape_size / pattern.shape[1]), order=0)

    all_patterns.append(pattern)


In [None]:
len(all_patterns)

In [None]:
fig, axes = plt.subplots(5, 6, figsize=(12, 10))
axes = axes.flatten()

axes[-1].remove()
axes[-3].remove()
axes[-2].remove()

for i, p in enumerate(all_patterns):
    axes[i].matshow(p, cmap="binary") 
    axes[i].set_title(f"Pattern {i+3}") 

    axes[i].set_xticks([])
    axes[i].set_yticks([])

plt.tight_layout()

plt.savefig(f"{path}patterns.png")

plt.show()

# Get single-reach trials

In [None]:
stim_idxs = np.where(data["pattern_id"][:] > 2)[1]
single_reach_idxs = np.where(data["single"][:, 0] == 1)

In [None]:
# get no laser trials
behavior_idxs = np.intersect1d(stim_idxs, single_reach_idxs)
behavior_idxs

In [None]:
behavior_idxs.shape

## Get relevant time information

In [None]:
cue_times = data["aligned_cue_rec_time"][behavior_idxs, :]
cue_times

In [None]:
lift_times = data["lift_ms"][behavior_idxs, :]
lift_times

In [None]:
grab_times = data["mouth_ms"][behavior_idxs, :]
grab_times

## Get the AP.bin file

In [None]:
from real_spike.utils import get_sample_data, get_meta
import tifffile

In [None]:
file_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250125/rb50_20250125_g0/rb50_20250125_g0_t0.imec0.ap.bin")
meta_path = Path("/home/clewis/wasabi/reaganbullins2/ProjectionProject/rb50/20250125/rb50_20250125_g0/rb50_20250125_g0_t0.imec0.ap.meta")

In [None]:
meta_data = get_meta(meta_path)

In [None]:
ap_data = get_sample_data(file_path, meta_data)
ap_data.shape

## Get conversion params

In [None]:
vmax = float(meta_data["imAiRangeMax"])
# get Imax
imax = float(meta_data["imMaxInt"])
# get gain
gain = float(meta_data['imroTbl'].split(sep=')')[1].split(sep=' ')[3])

In [None]:
vmax

In [None]:
imax

In [None]:
gain

# Get the model data

In [None]:
(cue_times[i, 0] - 5) 

In [None]:
(cue_times[i, 0]) / 1_000

In [None]:
def get_trials(idxs, bin_size):
    
    model_data = list()
    
    for i in tqdm(range(len(idxs))):
        trial_no = i
        # get time points in ap space
        cue_time = int((cue_times[i, 0] - 50) / 1_000 * 30_000)
        

       # end_behavior = int((cue_times[i, 0] + grab_times[i, 0]) / 1_000 * 30_000)
        end_behavior = int((cue_times[i, 0] + grab_times[i, 0]) / 1_000 * 30_000)

    
        trial = ap_data[:150, cue_time:end_behavior]
    
        conv_data = 1e6 * trial / vmax / imax / gain
    
        filt_data = butter_filter(conv_data, 1_000, 30_000)
    
        m_start = cue_time - (30 * 1000)
        trial_median = ap_data[:150, m_start:cue_time]
    
        trial_median = 1e6 * trial_median / vmax / imax / gain
        trial_median = butter_filter(trial_median, 1_000, 30_000)
        
        median = np.median(trial_median, axis=1)
    
        spike_ixs, counts = get_spike_events(filt_data, median)
        
        a = np.zeros((filt_data.shape[0], filt_data.shape[1]))
    
        for i, sc in enumerate(spike_ixs):
            a[i, sc] = 1
    
        b = bin_size * 30 # 30ms per bin
        binned_spikes = bin_spikes(a, b)

        model_data.append(np.asarray(binned_spikes.T, dtype=int))

    return model_data

In [None]:
bin_size = 5

In [None]:
model_data = get_trials(behavior_idxs, bin_size)

In [None]:
model_data[0].shape

In [None]:
lift = int(lift_times[0, 0] / bin_size) + 10
lift

grab = int(grab_times[0, 0] / bin_size) + 10 - 1

In [None]:
lift

In [None]:
grab

In [None]:
plt.matshow(model_data[0].T)
plt.axvline(10-1, c="red", linestyle="--", lw=1)
plt.axvline(lift, c="red", linestyle="--")
plt.axvline(grab, c="red", linestyle="--")

# Design the input matrix

In [None]:
# get the pattern types

In [None]:
data["pattern_id"]

In [None]:
p_ids = np.unique(data["pattern_id"][0, behavior_idxs])

p_ids

In [None]:
p_ids.shape

## Get colors for plotting

In [None]:
c = [
    "maroon",
    "deeppink",
    "palevioletred", 
    "blue", 
    "orange", 
    "green", 
    "red", 
    "purple", 
    "brown", 
    "pink",
    "turquoise", 
    "olive", 
    "cyan", 
    "gold", 
    "lime", 
    "navy", 
    "magenta",
    "teal", 
    "royalblue", 
    "darkgreen",
    "dimgray",
    "darkgoldenrod",
    "midnightblue",
    "plum",
    "lime",
    "cadetblue",
    "steelblue",
    "peru"
]

In [None]:
len(c)

In [None]:
from matplotlib.colors import ListedColormap

cmap = ListedColormap(c)

plt.imshow(np.arange(27).reshape(1,-1), aspect="auto", cmap=cmap)

# Create encodings 

In [None]:
model_data[1].shape

In [None]:
import math

In [None]:
inputs = list()
colors = list()
for i, d in zip(stim_idxs, model_data):
    # get the pattern id 
    p_id = int(data["pattern_id"][0][i]) - 3

    colors.append(c[p_id])

    encoding = all_patterns[p_id].ravel() 

    # stack the encoding for every timepoint (each bin)
    nput = np.zeros((d.shape[0], reshape_size**2))

    for z in range(math.ceil(5 / bin_size)):
        nput[10 + z] = encoding

    inputs.append(nput)

In [None]:
plt.imshow(inputs[0].T, cmap="binary")
plt.title("Pattern Encoding")
plt.xlabel("Bins")
plt.ylabel("Pixel Location")
plt.show()

# Fit the model

In [None]:
state_dim = 3
obs_dim = model_data[0].shape[1] 

plds = ssm.SLDS(N=obs_dim, 
                K=2, 
                D=state_dim, 
                M=144, 
                emissions="poisson", 
                emission_kwargs=dict(link="softplus"),
                dynamics="diagonal_gaussian",
                transitions="recurrent"
               )

elbos, q = plds.fit(model_data, inputs=inputs, method="laplace_em", num_iters=5)

# Visualize the results

In [None]:
# plot my elbos

In [None]:
plt.figure(figsize=(8, 6))

plt.plot(elbos)

plt.xlabel("Iteration")
plt.ylabel("ELBO")

plt.title("ELBO Curve")

plt.show()

## Dynamics

In [None]:
As = plds.dynamics.As
As.shape

In [None]:
b = plds.dynamics.b
b.shape

# Eigenvalue Decomposition on `A`

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

vmin = min(mat.min() for mat in [As[0], As[1]])
vmax = max(mat.max() for mat in [As[0], As[1]])

for i in range(2):
    ax = axes[i]
    im = ax.matshow(As[i], vmin=vmin, vmax=vmax)
    ax.set_title(f"$A_{i}$")

    eigvals, eigvecs = np.linalg.eig(As[i])
    axes[i+2].axhline(0, color='gray', linestyle='--', linewidth=0.7)
    axes[i+2].axvline(1, color='gray', linestyle='--', linewidth=0.7)

    print(eigvals.real)
    #ax.colorbar()

    axes[i+2].scatter(eigvals.real, eigvals.imag, c='red', s=50, marker="x")

    axes[i+2].set_xlim((0, 2))
    
    axes[i+2].set_title(f"Eigenvalues $A_{i}$")
    
    axes[i+2].set_xlabel("Real")
    axes[i+2].set_ylabel("Imaginary")


fig.colorbar(im, ax=axes[:2]) 

plt.savefig(f"{path}A.png")

plt.show()

# Plot the dynamics 

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(14, 10))

axes = axes.flatten()


A = As[0]

ax = axes[0]

x = np.linspace(-5, 5, 20)
y = np.linspace(-5, 5, 20)
X, Y = np.meshgrid(x, y)

# Compute vector field: z_dot = A z
U = A[0,0]*X + A[0,1]*Y  
V = A[1,0]*X + A[1,1]*Y 

# Normalize for nicer arrows
N = np.sqrt(U**2 + V**2)
U, V = U/N, V/N

# Plot
ax.quiver(X, Y, U, V, angles="xy", scale=25)
ax.axhline(0, color='k', linewidth=0.5)
ax.axvline(0, color='k', linewidth=0.5)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
    
ax.set_title(f'Dynamics 2D \n $A_{0}$')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")


A = As[0][[0, 2], :][:, [0, 2]]
ax = axes[1]

x = np.linspace(-5, 5, 20)
y = np.linspace(-5, 5, 20)
X, Y = np.meshgrid(x, y)

# Compute vector field: z_dot = A z
U = A[0,0]*X + A[0,1]*Y  
V = A[1,0]*X + A[1,1]*Y 
# Normalize for nicer arrows
N = np.sqrt(U**2 + V**2)
U, V = U/N, V/N

# Plot
ax.quiver(X, Y, U, V, angles="xy", scale=25)
ax.axhline(0, color='k', linewidth=0.5)
ax.axvline(0, color='k', linewidth=0.5)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
    
ax.set_title(f'Dynamics 2D \n $A_{0}$')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_3$")


A = As[0][[1, 2], :][:, [1, 2]]
ax = axes[2]

x = np.linspace(-5, 5, 20)
y = np.linspace(-5, 5, 20)
X, Y = np.meshgrid(x, y)

# Compute vector field: z_dot = A z
U = A[0,0]*X + A[0,1]*Y 
V = A[1,0]*X + A[1,1]*Y

# Normalize for nicer arrows
N = np.sqrt(U**2 + V**2)
U, V = U/N, V/N

# Plot
ax.quiver(X, Y, U, V, angles="xy", scale=25)
ax.axhline(0, color='k', linewidth=0.5)
ax.axvline(0, color='k', linewidth=0.5)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
    
ax.set_title(f'Dynamics 2D \n $A_{0}$')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_2$")
ax.set_ylabel("$x_3$")

A = As[1]

ax = axes[0+3]

x = np.linspace(-5, 5, 20)
y = np.linspace(-5, 5, 20)
X, Y = np.meshgrid(x, y)

# Compute vector field: z_dot = A z
U = A[0,0]*X + A[0,1]*Y  
V = A[1,0]*X + A[1,1]*Y 

# Normalize for nicer arrows
N = np.sqrt(U**2 + V**2)
U, V = U/N, V/N

# Plot
ax.quiver(X, Y, U, V, angles="xy", scale=25)
ax.axhline(0, color='k', linewidth=0.5)
ax.axvline(0, color='k', linewidth=0.5)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
    
ax.set_title(f'Dynamics 2D \n $A_{1}$')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")


A = As[1][[0, 2], :][:, [0, 2]]
ax = axes[1+3]

x = np.linspace(-5, 5, 20)
y = np.linspace(-5, 5, 20)
X, Y = np.meshgrid(x, y)

# Compute vector field: z_dot = A z
U = A[0,0]*X + A[0,1]*Y  
V = A[1,0]*X + A[1,1]*Y

# Normalize for nicer arrows
N = np.sqrt(U**2 + V**2)
U, V = U/N, V/N

# Plot
ax.quiver(X, Y, U, V, angles="xy", scale=25)
ax.axhline(0, color='k', linewidth=0.5)
ax.axvline(0, color='k', linewidth=0.5)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
    
ax.set_title(f'Dynamics 2D \n $A_{1}$')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_3$")


A = As[1][[1, 2], :][:, [1, 2]]
ax = axes[2+3]

x = np.linspace(-5, 5, 20)
y = np.linspace(-5, 5, 20)
X, Y = np.meshgrid(x, y)

# Compute vector field: z_dot = A z
U = A[0,0]*X + A[0,1]*Y  
V = A[1,0]*X + A[1,1]*Y 

# Normalize for nicer arrows
N = np.sqrt(U**2 + V**2)
U, V = U/N, V/N

# Plot
ax.quiver(X, Y, U, V, angles="xy", scale=25)
ax.axhline(0, color='k', linewidth=0.5)
ax.axvline(0, color='k', linewidth=0.5)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
    
ax.set_title(f'Dynamics 2D \n $A_{1}$')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_2$")
ax.set_ylabel("$x_3$")

plt.gca().set_aspect('equal')

plt.tight_layout()

plt.savefig(f"{path}dynamics2.png")
plt.show()

In [None]:
from matplotlib.patches import Patch

In [None]:
state_means = q.mean_continuous_states

In [None]:
custom_patches = [Patch(facecolor=c[int(i)-3], edgecolor='black', label=f'{int(i)}') for i in p_ids]

# Plot the state means 

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, p in enumerate(state_means):

    lift = int(lift_times[i, 0] / bin_size) + 10
    grab = int(grab_times[i, 0] / bin_size) + 10 - 1

    p = gaussian_filter1d(p, 4, axis=0)

    axes[0].plot(p[:grab+1, 0], p[:grab+1, 1], c=colors[i], zorder=0, alpha=0.8)

    axes[0].scatter(p[0, 0], p[0, 1], s=25, marker='o', c="black", zorder=1, alpha=1)
    axes[0].scatter(p[lift, 0], p[lift, 1], s=25, marker='s', c="black", zorder=1, alpha=1, label="lift")
    axes[0].scatter(p[grab, 0], p[grab, 1], s=45, marker='*', c="black", zorder=1, alpha=1, label="grab")

    axes[1].plot(p[:grab+1, 0], p[:grab+1, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[1].scatter(p[0, 0], p[0, 2], s=25, marker='o', c="black", zorder=1, alpha=1)
    axes[1].scatter(p[lift, 0], p[lift, 2], s=25, marker='s', c="black", zorder=1, alpha=1)
    axes[1].scatter(p[grab, 0], p[grab, 2], s=45, marker='*', c="black", zorder=1, alpha=1)

    axes[2].plot(p[:grab+1, 1], p[:grab+1, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[2].scatter(p[0, 1], p[0, 2], s=25, marker='o', c="black", zorder=1, alpha=1)
    axes[2].scatter(p[lift, 1], p[lift, 2], s=25, marker='s', c="black", zorder=1, alpha=1)
    axes[2].scatter(p[grab, 1], p[grab, 2], s=45, marker='*', c="black", zorder=1, alpha=1)




for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Posterior State Means ({bin_size}ms)")
    ax.legend(handles=custom_patches)


# 0, 0 = x1 vs x2
axes[0].set_xlabel(f"$x_0$")
axes[0].set_ylabel(f"$x_1$")


# 0, 1 = x1 vs x3
axes[1].set_xlabel(f"$x_0$")
axes[1].set_ylabel(f"$x_2$")

# 0, 2 = x2 vs x3
axes[2].set_xlabel(f"$x_1$")
axes[2].set_ylabel(f"$x_2$")

plt.savefig(f"{path}state_means_all.png")

plt.tight_layout()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(21, 10))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, p in enumerate(state_means):

    lift = int(lift_times[i, 0] / bin_size) + 10
    grab = int(grab_times[i, 0] / bin_size) + 10 - 1

    p = gaussian_filter1d(p, 4, axis=0)

    axes[0].plot(p[lift:grab+1, 0], p[lift:grab+1, 1], c=colors[i], zorder=0, alpha=0.8)

    axes[0].scatter(p[lift, 0], p[lift, 1], s=25, marker='o', c="black", zorder=1, alpha=1)
    axes[0].scatter(p[grab, 0], p[grab, 1], s=45, marker='*', c="black", zorder=1, alpha=1)

    axes[1].plot(p[lift:grab+1, 0], p[lift:grab+1, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[1].scatter(p[lift, 0], p[lift, 2], s=25, marker='o', c="black", zorder=1, alpha=1)
    axes[1].scatter(p[grab, 0], p[grab, 2], s=45, marker='*', c="black", zorder=1, alpha=1)

    axes[2].plot(p[lift:grab+1, 1], p[lift:grab+1, 2], c=colors[i], zorder=0, alpha=0.8)

    axes[2].scatter(p[lift, 1], p[lift, 2], s=25, marker='o', c="black", zorder=1, alpha=1)
    axes[2].scatter(p[grab, 1], p[grab, 2], s=45, marker='*', c="black", zorder=1, alpha=1)




for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"Posterior State Means ({bin_size}ms)")
    ax.legend(handles=custom_patches)


# 0, 0 = x1 vs x2
axes[0].set_xlabel(f"$x_0$")
axes[0].set_ylabel(f"$x_1$")


# 0, 1 = x1 vs x3
axes[1].set_xlabel(f"$x_0$")
axes[1].set_ylabel(f"$x_2$")

# 0, 2 = x2 vs x3
axes[2].set_xlabel(f"$x_1$")
axes[2].set_ylabel(f"$x_2$")

plt.savefig(f"{path}state_means_behavior.png")

plt.tight_layout()

# Plot discrete states

In [None]:
discrete_states = list()

for i in range(len(state_means)):
    z_s = plds.most_likely_states(state_means[i], model_data[i], inputs[i])
    discrete_states.append(z_s)

In [None]:
max_len = max(len(arr) for arr in discrete_states)

# 2. Initialize the 2D array with NaNs
# The number of rows is the number of 1D arrays
# The number of columns is the max_len
s = np.full((len(discrete_states), max_len), np.nan)

# 3. Populate the 2D array
for i, arr in enumerate(discrete_states):
    s[i, :len(arr)] = arr

In [None]:
# plot each row 
s.shape

In [None]:
s[0]

In [None]:
import matplotlib.colors as mcolors

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 10))

masked_data = np.ma.array(s, mask=np.isnan(s))

cmap = plt.cm.binary  # Or any other colormap
cmap.set_bad('red')  # Set NaN values to be red

cax = ax.matshow(masked_data, cmap=cmap, aspect="auto")

plt.tight_layout()

ax.set_xlabel("Bin")
ax.set_ylabel("Trial")

ax.xaxis.set_ticks_position('bottom')

ax.set_title("Most Likely States")

cs = {0: "white", 1: "black"}
labels = {0: '$A_0$', 1: '$A_1$'}

patches = [Patch(facecolor=cs[val], label=labels[val], edgecolor="black")
               for val in cs.keys()]

plt.legend(handles=patches, title="States", bbox_to_anchor=(1.05, 1), loc='upper left')

#fig.colorbar(cax)

plt.tight_layout()

plt.savefig(f"{path}states.png")

plt.show()

# Plotting B

In [None]:
Bs = plds.dynamics.params[2]
Bs.shape

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import CenteredNorm

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i, ax in enumerate(axes):
    im = ax.matshow(Bs[i], cmap="seismic", aspect='auto', norm=CenteredNorm(vcenter=0))

    ax.set_title(f"$B_{i}$")
    ax.set_ylabel("Dim")
    ax.set_xlabel("Pixel Position")

    ax.xaxis.tick_bottom()

    fig.colorbar(im, ax=ax)

plt.tight_layout()

plt.savefig(f"{path}B.png")

plt.show()

In [None]:
Bs.shape

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(11, 6))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i in range(3):
    ax = axes[i]
    im = ax.matshow(Bs[0][i].reshape(12, 12), cmap='seismic', aspect='auto', norm=CenteredNorm(vcenter=0))
    ax.set_title(f'Dim {i}')
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])

    fig.colorbar(im, ax=ax)


for i in range(3):
    ax = axes[i+3]
    im = ax.matshow(Bs[1][i].reshape(12, 12), cmap='seismic', aspect='auto',norm=CenteredNorm(vcenter=0))
    ax.set_title(f'Dim {i}')
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])

    fig.colorbar(im, ax=ax)

axes[0].set_ylabel("$B_0$")
axes[3].set_ylabel("$B_1$")

plt.tight_layout()

plt.savefig(f"{path}B2.png")
plt.show()

In [None]:
p_ids

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(11, 3))

# Flatten axes array for easy iteration
axes = axes.flatten()

for i in range(3):
    ax = axes[i] 
    ax.matshow(all_patterns[int(p_ids[i] - 3)], cmap="binary")
    ax.set_title(f"Pattern {int(p_ids[i])}")
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()

plt.savefig(f"{path}/p_id.png")

plt.show()

# Plot the Rs

Where are different phases active in space? 

In [None]:
a = state_means[0] @ plds.transitions.Rs.T 
a.shape

In [None]:
b = plds.transitions.Rs @ state_means[0].T
b.shape

In [None]:
np.argmax(b)

In [None]:
np.exp(b[:, 25])

In [None]:
plds.transitions.log_transition_matrices(state_means[0], inputs[0], None, None).shape

In [None]:
plds.transitions.log_transition_matrices(state_means[0], inputs[0], None, None)[0]

p(x_1 = i | x_0 = j)

In [None]:
np.exp(plds.transitions.log_transition_matrices(state_means[0], inputs[0], None, None)[0])

In [None]:
np.argmanz

In [None]:
def plot_most_likely_dynamics(model,
    xlim=(-4, 4), ylim=(-3, 3), nxpts=30, nypts=30,
    alpha=0.8, ax=None, figsize=(3, 3)):

    K = model.K
    assert model.D == 2
    x = np.linspace(*xlim, nxpts)
    y = np.linspace(*ylim, nypts)
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack((X.ravel(), Y.ravel()))

    # Get the probability of each state at each xy location
    log_Ps = model.transitions.log_transition_matrices(
        xy, np.zeros((nxpts * nypts, 0)), np.ones_like(xy, dtype=bool), None)
    z = np.argmax(log_Ps[:, 0, :], axis=-1)
    z = np.concatenate([[z[0]], z])

    if ax is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111)

    for k, (A, b) in enumerate(zip(model.dynamics.As, model.dynamics.bs)):
        dxydt_m = xy.dot(A.T) + b - xy

        zk = z == k
        if zk.sum(0) > 0:
            ax.quiver(xy[zk, 0], xy[zk, 1],
                      dxydt_m[zk, 0], dxydt_m[zk, 1],
                      color=colors[k % len(colors)], alpha=alpha)

    ax.set_xlabel('$x_1$')
    ax.set_ylabel('$x_2$')

    plt.tight_layout()

    return ax