In [1]:
import sys
sys.path.append('/Users/leah/Columbia/courses/19summer/SocialBehavior/SocialBehaviorptc')

In [2]:
from ssm_ptc.models.hmm import HMM
from ssm_ptc.distributions.truncatednormal import TruncatedNormal
from ssm_ptc.utils import find_permutation, random_rotation, k_step_prediction

from project_ssms.direction_observation import DirectionObservation, DirectionTransformation
from project_ssms.feature_funcs import feature_vec_func
from project_ssms.momentum_utils import filter_traj_by_speed
from project_ssms.utils import k_step_prediction_for_direction_model
from project_ssms.plot_utils import plot_z, plot_2_mice, plot_4_traces

import torch
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style("white")
sns.set_context("talk")

#from tqdm import trange
from tqdm import tqdm_notebook as tqdm

import time

from hips.plotting.colormaps import gradient_cmap, white_to_color_cmap
color_names = [
    "windows blue",
    "red",
    "amber",
    "faded green",
    "dusty purple",
    "orange"
    ]

colors = sns.xkcd_palette(color_names)
cmap = gradient_cmap(colors)

ModuleNotFoundError: No module named 'project_ssms.direction_observation'

# data

In [None]:
import joblib

datasets_processed = joblib.load('/Users/leah/Columbia/courses/19summer/SocialBehavior/tracedata/all_data_3_1')  # a list of length 30, each is a social_dataset

session_data = datasets_processed[0].render_trajectories([3,8])
traj0 = np.concatenate((session_data), axis=1)

In [None]:
del datasets_processed

In [None]:
del session_data

In [None]:
f_traj = filter_traj_by_speed(traj0, q1=0.99, q2=0.99)

In [None]:
arena_xmin = 10
arena_xmax = 320

arena_ymin = -10
arena_ymax = 390

WATER = torch.tensor([50, 50], dtype=torch.float64)
FOOD = torch.tensor([270, 50], dtype=torch.float64)
NEST = torch.tensor([270, 330], dtype=torch.float64)
CORNER = torch.tensor([50, 330], dtype=torch.float64)

In [None]:
data = torch.tensor(f_traj, dtype=torch.float64)

# model

$$x^a_t \sim x^a_{t-1} + s * [ \sigma(W^a_0) m_t  + \sum_{i=1}^{Df} \sigma(W^a_i) f_i ]$$

$$x^b_t \sim x^b_{t-1} + s * [ \sigma(W^b_0) m_t  + \sum_{i=1}^{Df} \sigma(W^b_i) f_i ]$$

In [None]:
momentum_lags = 30
momentum_weights = np.arange(0.55, 2.05, 0.05)

In [None]:
torch.manual_seed(0)
np.random.seed(0)

bounds = np.array([[arena_xmin - 5, arena_xmax + 5], [arena_ymin - 5, arena_ymax + 5], 
                   [arena_xmin - 5, arena_xmax + 5], [arena_ymin - 5, arena_ymax + 5]])

max_v = np.array([5.0, 5.0, 5.0, 5.0])

acc_factor=2

K = 6
D = 4
Df = 5
T = 36000

observation = DirectionObservation(K=K, D=D, M=0, bounds=bounds,
                                         momentum_lags=momentum_lags,momentum_weights=momentum_weights,
                                        Df=Df, feature_vec_func=feature_vec_func, acc_factor=acc_factor)

model = HMM(K=K, D=D, M=0, observation=observation)
m_tran = model.observation.transformation

In [None]:
model.observation.mus_init = data[0] * torch.ones(K, D, dtype=torch.float64)

In [None]:
momentum_vecs = DirectionTransformation._compute_momentum_vecs(data[:-1],
                                                                     lags=momentum_lags, 
                                                                     weights=momentum_weights)
features = DirectionTransformation._compute_features(m_tran.feature_vec_func, data[:-1])

In [None]:
model.log_likelihood(data)

# training

In [None]:
##################### training ############################

num_iters = 2000
losses, opt = model.fit(data, num_iters=num_iters, lr=0.005, momentum_vecs=momentum_vecs, features=features)

In [None]:
num_iters = 1000
losses_1, _ = model.fit(data, optimizer=opt, num_iters=num_iters, momentum_vecs=momentum_vecs, features=features)

In [None]:
plt.plot(losses_1)

In [None]:
z = model.most_likely_states(data)

In [None]:
x_predict = k_step_prediction_for_direction_model(model, z, data, momentum_vecs=momentum_vecs, features=features)

In [None]:
plot_z(z)

In [None]:
plt.figure(figsize=(20,2))
plt.plot(x_predict[:,0], label='prediction')
plt.plot(data[:,0].numpy(), label='truth')
plt.legend();

In [None]:
x_predict_5_step = k_step_prediction(model, z, data, k=5)

In [None]:
plt.figure(figsize=(20,2))
plt.plot(x_predict_5_step[:,0], label='prediction')
plt.plot(data[5:,0].numpy(), label='truth')
plt.legend();

In [None]:
x_predict_10_step = k_step_prediction(model, z, data, k=10)

In [None]:
plt.figure(figsize=(20,2))
plt.plot(x_predict_10_step[:,0], label='prediction')
plt.plot(data[10:,0].numpy(), label='truth')
plt.legend();

In [None]:
np.average(abs(x_predict - data.numpy()), axis=0)

In [None]:
np.average(abs(x_predict_5_step - data[5:].numpy()), axis=0)

In [None]:
np.average(abs(x_predict_10_step - data[10:].numpy()), axis=0)

# samples

In [None]:
sample_z, sample_x = model.sample(T)

In [None]:
[sum(sample_z == k) for k in range(K)]

In [None]:
plot_2_mice(sample_x)

In [None]:
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(sample_x[:,0], sample_x[:,1])
plt.xlim(arena_xmin, arena_xmax)
plt.ylim(arena_ymin, arena_ymax)
plt.title("virgin")
plt.subplot(1,2,2)
plt.plot(sample_x[:,2], sample_x[:,3])
plt.xlim(arena_xmin, arena_xmax)
plt.ylim(arena_ymin, arena_ymax)
plt.title("mother")

plt.tight_layout()

In [None]:
plt.figure(figsize=(20, 4))

plt.subplot(1, 4, 1)
plt.hist(sample_x[:,0], bins=100);
plt.title("x1")
plt.subplot(1, 4, 2)
plt.hist(sample_x[:,1], bins=100);
plt.title("y1")
plt.subplot(1, 4, 3)
plt.hist(sample_x[:,2], bins=100);
plt.title("x2")
plt.subplot(1, 4, 4)
plt.hist(sample_x[:,3], bins=100);
plt.title("y2")

plt.tight_layout()

In [None]:
plt.figure(figsize=(20, 4))

plt.subplot(1, 4, 1)
plt.hist(f_traj[:,0], bins=100);
plt.title("x1")
plt.subplot(1, 4, 2)
plt.hist(f_traj[:,1], bins=100);
plt.title("y1")
plt.subplot(1, 4, 3)
plt.hist(f_traj[:,2], bins=100);
plt.title("x2")
plt.subplot(1, 4, 4)
plt.hist(f_traj[:,3], bins=100);
plt.title("y2")

plt.tight_layout()

In [None]:
torch.sigmoid(m_tran.Ws[0,0])

In [None]:
torch.sigmoid(m_tran.Ws[1,0])

In [None]:
torch.sigmoid(m_tran.Ws[2,0])

In [None]:
torch.sigmoid(m_tran.Ws[3,0])

In [None]:
torch.sigmoid(m_tran.Ws[4,0])

In [None]:
torch.sigmoid(m_tran.Ws[5,0])

does not vary very much..

In [None]:
torch.exp(model.observation.log_sigmas)