In [None]:
!pip3 uninstall numpy -y
!pip3 install numpy
import autograd
import numpy as np
import pandas as pd
import autograd.numpy.random as npr
npr.seed(0)
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

%matplotlib inline

import seaborn as sns

sns.set_style("white")
sns.set_context("talk")

color_names = ["windows blue",
               "red",
               "amber",
               "faded green",
               "dusty purple",
               "orange",
               "clay",
               "pink",
               "greyish",
               "mint",
               "cyan",
               "steel blue",
               "forest green",
               "pastel purple",
               "salmon",
               "dark brown"]

colors = sns.xkcd_palette(color_names)
cmap = ListedColormap(colors)

import ssm
from ssm.util import random_rotation, find_permutation
from ssm.plots import plot_dynamics_2d

save_figures = False

First we will load the data to see what we are dealing with and we will plot a short animation

In [None]:
#load the data 
import h5py
name = 'mouse_second_video.mp4.predictions.analysis.h5'
with h5py.File(name, 'r') as f:
    occupancy_matrix = f['track_occupancy'][:]
    tracks_matrix = f['tracks'][:]

tracks_matrix = tracks_matrix[0]
print(occupancy_matrix.shape)
print(tracks_matrix.shape)

In [None]:
#This helps us identify which point is which 

for i in range(7):
    plt.scatter(tracks_matrix[0,i,650],tracks_matrix[1,i,650],label=i)


plt.title("Plot to know what label is which")
plt.legend()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimage

x = np.linspace(-2, 2, 200)

duration = 10

fig, ax = plt.subplots()
def make_frame(t):
    ax.clear()
    ax.plot(x, np.sinc(x**2) + np.sin(x + 2*np.pi/duration * t), lw=3)
    ax.set_ylim(-1.5, 2.5)
    return mplfig_to_npimage(fig)

animation = VideoClip(make_frame, duration=duration)
animation.ipython_display(fps=20, loop=True, autoplay=True)



In [None]:
for i in range(7):
    plt.plot(tracks_matrix[0,i,:],label= "x_{}".format(i))
    plt.plot(tracks_matrix[1,i,:],label="y_{}".format(i))

plt.legend()




In [None]:
# Now we need to clean the data.
# I will do this by centering the data an dealing with nan values by interpolation 
# I will ignore the first 30 frames because they are particularly bad 
clean = np.zeros(tracks_matrix.shape)
distance_clean = np.zeros([tracks_matrix.shape[1]-1, tracks_matrix.shape[2]])
start = 30

# We will make it so that the center of the coordinate system is the head 
for position in range(7):
    for coordinate in range(2):
        clean[coordinate,position] = np.array(pd.Series(tracks_matrix[coordinate,position]).interpolate())

# We will make it so that the center of the coordinate system is the head 
head_position = clean[:,0].copy()
#print(head_position)
for position in range(7):
    for coordinate in range(2): 
        clean[coordinate,position] = (clean[coordinate,position] - head_position[coordinate]).copy()
        
for position in range(1, 7):
    distance_clean[position-1] = np.sqrt(clean[0, position]**2 + clean[1, position]**2)
print(head_position)

clean = clean[:,:,start:]
distance_clean = distance_clean[:,start:]

In [None]:
fig, axs = plt.subplots(14,figsize=(20,30))

for i in range(7):
    axs[2*i].plot(clean[0,i,:],label= "x_{}".format(i))
    axs[2*i].legend()
    axs[2*i+1].plot(clean[1,i,:],label="y_{}".format(i))
    axs[2*i+1].legend()

plt.legend()

In [None]:
fig, axs = plt.subplots(6,figsize=(20,30))

for i in range(6):
    axs[i].plot(distance_clean[i,:],label= "x_{}".format(i))
    axs[i].legend()
    
plt.legend()

In [None]:
emissions = clean.reshape(-1,clean.shape[-1]).T
emissions.shape

In [None]:
emissions = distance_clean.reshape(-1,clean.shape[-1]).T
emissions.shape

In [None]:
emissions_dim = emissions.shape[-1]
n_disc_states = 6
latent_dim = 5
emissions_func ='gaussian_orthog'
slds = ssm.SLDS(emissions_dim, n_disc_states,latent_dim, emissions=emissions_func)
# Fit the model using Laplace-EM with a structured variational posterior
q_lem_elbos, q_lem = slds.fit(emissions, method="laplace_em",
                               variational_posterior="structured_meanfield",
                               num_iters=100, alpha=0.0)

# Get the posterior mean of the continuous states
q_lem_x = q_lem.mean_continuous_states[0]

# Find the permutation that matches the true and inferred states
q_lem_z = slds.most_likely_states(q_lem_x, emissions)

# Smooth the data under the variational posterior
q_lem_y = slds.smooth(q_lem_x, emissions)

In [None]:
for i, j in enumerate(q_lem_z):
    print(i+30, j)

In [None]:
plt.plot(q_lem_elbos[2:])


In [None]:
q_lem_z

In [None]:
emissions_dim = emissions.shape[-1]
n_disc_states = 6
latent_dim = 5
emissions_func ='gaussian_orthog'
slds = ssm.SLDS(emissions_dim, n_disc_states,latent_dim, emissions=emissions_func)
# Fit the model using Laplace-EM with a structured variational posterior
q_lem_elbos2, q_lem2 = slds.fit(emissions, method="laplace_em",
                               variational_posterior="structured_meanfield",
                               num_iters=100, alpha=0.0)

# Get the posterior mean of the continuous states
q_lem_x2 = q_lem2.mean_continuous_states[0]

# Find the permutation that matches the true and inferred states
q_lem_z2 = slds.most_likely_states(q_lem_x2, emissions)

# Smooth the data under the variational posterior
q_lem_y2 = slds.smooth(q_lem_x2, emissions)

In [None]:
duration = len(q_lem_z2)/15

fig, ax = plt.subplots()
def make_frame(t):
    ax.clear()
    ax.set_facecolor("xkcd:"+color_names[q_lem_z[int(t*15)]])
    return mplfig_to_npimage(fig)

animation = VideoClip(make_frame, duration=duration)
animation.ipython_display(fps=15, maxduration=1000, loop=True, autoplay=True)

In [None]:
import moviepy.editor as mpe

video = mpe.VideoFileClip('/home/asifmallik/Downloads/concat_mouse_2.mp4')
for state in range(0, 6):
    duration = (q_lem_z2 == state).sum()/15
    def make_frame_mouse(t):
        return video.get_frame(np.arange(0, len(q_lem_z2))[q_lem_z2 == state][int(t*15)]/15)
    animation = VideoClip(make_frame_mouse, duration=duration)
    animation.write_videofile("distance_" + str(state) + ".mp4", fps=15)

In [None]:
import moviepy.editor as mpe

video = mpe.VideoFileClip('/home/asifmallik/Downloads/concat_mouse_2.mp4')
for state in range(0, 6):
    duration = (q_lem_z == state).sum()/15
    def make_frame_mouse(t):
        return video.get_frame(np.arange(0, len(q_lem_z2))[q_lem_z == state][int(t*15)]/15)
    animation = VideoClip(make_frame_mouse, duration=duration)
    animation.write_videofile("position_" + str(state) + ".mp4", fps=15)