In [None]:
import sys
sys.path.append('/Users/leah/Columbia/courses/19summer/SocialBehavior/SocialBehaviorptc')

In [None]:
from ssm_ptc.models.hmm import HMM
from ssm_ptc.distributions.truncatednormal import TruncatedNormal
from ssm_ptc.utils import find_permutation, random_rotation, k_step_prediction

from project_ssms.ar_truncated_normal_observation import ARTruncatedNormalObservation
from project_ssms.coupled_transformations.grid_transformation import GridTransformation
from project_ssms.feature_funcs import feature_vec_func
from project_ssms.momentum_utils import filter_traj_by_speed, get_momentum_in_batch
from project_ssms.utils import k_step_prediction_for_grid_model
from project_ssms.plot_utils import plot_z, plot_2_mice, plot_4_traces

import torch
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style("white")
sns.set_context("talk")

#from tqdm import trange
from tqdm import tqdm_notebook as tqdm

import time

import joblib

from hips.plotting.colormaps import gradient_cmap, white_to_color_cmap
color_names = [
    "windows blue",
    "red",
    "amber",
    "faded green",
    "dusty purple",
    "orange"
    ]

colors = sns.xkcd_palette(color_names)
cmap = gradient_cmap(colors)

# data

In [None]:
datasets_processed = joblib.load('/Users/leah/Columbia/courses/19summer/SocialBehavior/tracedata/all_data_3_1')  # a list of length 30, each is a social_dataset

trajs = []
for dataset in datasets_processed:
    session_data = dataset.render_trajectories([3, 8])
    traj = np.concatenate((session_data), axis=1)
    trajs.append(traj)

trajs_all = np.concatenate(trajs, axis=0)

In [None]:
trajs_all.shape

In [None]:
1080000/30

In [None]:
trajs = trajs_all[36000*0:36000*3]

In [None]:
np.min(trajs, axis=0)

In [None]:
np.max(trajs, axis=0)

In [None]:
arena_xmin = 0
arena_xmax = 330
arena_ymin = -10
arena_ymax = 390

In [None]:
# make 3 by 3 grid world
x_grid_gap = (arena_xmax - arena_xmin) / 3
y_grid_gap = (arena_ymax - arena_ymin) / 3

In [None]:
x_grids = [arena_xmin + i * x_grid_gap for i in range(4)]
y_grids = [arena_ymin + i * y_grid_gap for i in range(4)]

In [None]:
x_grids

In [None]:
y_grids

In [None]:
data_grids_a = []

for i in range(3):
    for j in range(3):
        cond_x = (x_grids[i] < trajs[:,0]) & (trajs[:,0] <= x_grids[i+1])
        cond_y = (y_grids[j] < trajs[:,1]) & (trajs[:,1] <= y_grids[j+1])
        out = trajs[cond_x & cond_y]
        data_grids_a.append(out)
        
data_grids_b = []

for i in range(3):
    for j in range(3):
        cond_x = (x_grids[i] < trajs[:,2]) & (trajs[:,2] <= x_grids[i+1])
        cond_y = (y_grids[j] < trajs[:,3]) & (trajs[:,3] <= y_grids[j+1])
        out = trajs[cond_x & cond_y]
        data_grids_b.append(out)

In [None]:
[data_grid.shape[0]/trajs.shape[0] for data_grid in data_grids_a]

In [None]:
[data_grid.shape[0]/trajs.shape[0] for data_grid in data_grids_b]

In [None]:
[data_grid.shape[0] for data_grid in data_grids_a]

In [None]:
[data_grid.shape[0] for data_grid in data_grids_b]

# model

In [None]:
data = torch.tensor(trajs, dtype=torch.float64)

In [None]:
torch.manual_seed(0)
np.random.seed(0)

K = 2
D = 4
M = 0

Df = 5

momentum_lags = 30
momentum_weights = np.arange(0.55, 2.05, 0.05)
momentum_weights = torch.tensor(momentum_weights, dtype=torch.float64)

bounds = np.array([[arena_xmin, arena_xmax], [arena_ymin, arena_ymax], 
                   [arena_xmin, arena_xmax], [arena_ymin, arena_ymax]])

tran = GridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids,
                          Df=Df, feature_vec_func=feature_vec_func,
                          lags=momentum_lags, momentum_weights=momentum_weights)

# observation
obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=momentum_lags, bounds=bounds, transformation=tran)

# model
model = HMM(K=K, D=D, M=M, observation=obs)

In [None]:
model.observation.mus_init = data[0] * torch.ones(K, D, dtype=torch.float64)

In [None]:
# compute memories
masks_a, masks_b = tran.get_masks(data[:-1])

momentum_vecs_a = get_momentum_in_batch(data[:-1, 0:2], lags=momentum_lags, weights=momentum_weights)
momentum_vecs_b = get_momentum_in_batch(data[:-1, 2:4], lags=momentum_lags, weights=momentum_weights)

feature_vecs_a = feature_vec_func(data[:-1, 0:2], data[:-1, 2:4])
feature_vecs_b = feature_vec_func(data[:-1, 2:4], data[:-1, 0:2])

m_kwargs_a = dict(momentum_vecs=momentum_vecs_a, feature_vecs=feature_vecs_a)
m_kwargs_b = dict(momentum_vecs=momentum_vecs_b, feature_vecs=feature_vecs_b)


In [None]:
model.log_likelihood(data, masks=(masks_a, masks_b),
                                  memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)

# training

In [None]:
##################### training ############################

num_iters = 2000
losses, opt = model.fit(data, num_iters=num_iters, lr=0.005, masks=(masks_a, masks_b),
                                  memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)

In [None]:
params1 = model.params

In [None]:
##################### training ############################

num_iters = 1000
losses_1, _ = model.fit(data, optimizer=opt, num_iters=num_iters, masks=(masks_a, masks_b),
                                  memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)

In [None]:
params2 = model.params

In [None]:
plt.plot(losses[1200:])

In [None]:
plt.plot(losses_1)

In [None]:
# inference
print("inferiring most likely states...")
z = model.most_likely_states(data, masks=(masks_a, masks_b),
                                  memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)


In [None]:
plot_z(z, ylim=[0, 380])

plt.plot(data[:,0].numpy(), color='white')

In [None]:
print("0 step prediction")
x_predict = k_step_prediction_for_grid_model(model, z, data, memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)

In [None]:
plt.figure(figsize=(20,2))
plt.plot(x_predict[:3600,0], label='prediction')
plt.plot(data[:3600,0].numpy(), label='truth')
plt.legend();

In [None]:
np.average(abs(x_predict - data.numpy()), axis=0)

In [None]:
# sampling
print("sampling")
sample_z, sample_x = model.sample(50000)

In [None]:
plot_2_mice(sample_x, 0.5)

In [None]:
plot_2_mice(data[:50000].numpy(), 0.5)

In [None]:
plt.figure(figsize=(20, 4))

plt.subplot(1, 4, 1)
plt.hist(sample_x[:,0], bins=100);
plt.title("x1")
plt.subplot(1, 4, 2)
plt.hist(sample_x[:,1], bins=100);
plt.title("y1")
plt.subplot(1, 4, 3)
plt.hist(sample_x[:,2], bins=100);
plt.title("x2")
plt.subplot(1, 4, 4)
plt.hist(sample_x[:,3], bins=100);
plt.title("y2")

plt.tight_layout()

In [None]:
plt.figure(figsize=(20, 4))

plt.subplot(1, 4, 1)
plt.hist(trajs[:,0], bins=100);
plt.title("x1")
plt.subplot(1, 4, 2)
plt.hist(trajs[:,1], bins=100);
plt.title("y1")
plt.subplot(1, 4, 3)
plt.hist(trajs[:,2], bins=100);
plt.title("x2")
plt.subplot(1, 4, 4)
plt.hist(trajs[:,3], bins=100);
plt.title("yw")

plt.tight_layout()

# dynamics

In [None]:
torch.sigmoid(tran.transformations_a[0].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[1].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[2].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[3].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[4].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[5].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[6].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[7].Ws)

In [None]:
torch.sigmoid(tran.transformations_a[8].Ws[0])

In [None]:
weights_a = 2 * np.array([torch.sigmoid(t.Ws).detach().numpy() for t in tran.transformations_a])

In [None]:
weights_b = 2 * np.array([torch.sigmoid(t.Ws).detach().numpy() for t in tran.transformations_b])

In [None]:
weights_a.shape

In [None]:
def plot_weights(weights):
    plt.figure(figsize=(16, 12))

    plt.subplot(3,3,1)
    plt.title("Grid 2")
    for k in range(K):
        plt.bar(np.arange(6) - 0.2, weights[2][k], width=.4, color='b', label='k={}'.format(k))
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1), ["m", "other", "water", "nest", "food", "corner"])
    plt.grid()
    plt.legend()

    plt.subplot(3,3,2)
    plt.title("Grid 5")
    plt.bar(np.arange(6) - 0.2, weights[5][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[5][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.subplot(3,3,3)
    plt.title("Grid 8")
    plt.bar(np.arange(6) - 0.2, weights[8][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[8][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.subplot(3,3,4)
    plt.title("Grid 1")
    plt.bar(np.arange(6) - 0.2, weights[1][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[1][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.subplot(3,3,5)
    plt.title("Grid 4")
    plt.bar(np.arange(6) - 0.2, weights[4][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[4][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.subplot(3,3,6)
    plt.title("Grid 7")
    plt.bar(np.arange(6) - 0.2, weights[7][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[7][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.subplot(3,3,7)
    plt.title("Grid 0")
    plt.bar(np.arange(6) - 0.2, weights[0][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[0][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.subplot(3,3,8)
    plt.title("Grid 3")
    plt.bar(np.arange(6) - 0.2, weights[3][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[3][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.subplot(3,3,9)
    plt.title("Grid 6")
    plt.bar(np.arange(6) - 0.2, weights[6][0], width=.4, color='b', label='k=0')
    plt.bar(np.arange(6) + 0.2, weights[6][1], width=.4, color='r', label='k=1')
    plt.plot([0, 6], [0, 0], '-k')
    plt.ylim(0, 2)
    plt.xticks(np.arange(0,6,1))
    plt.grid()

    plt.tight_layout()

In [None]:
plot_weights(weights_a)

In [None]:
plot_weights(weights_b)

In [None]:
data.shape

In [None]:
plt.figure(figsize=(20,2))
plt.plot(trajs[:,2])

In [None]:
plot_2_mice(trajs, alpha=0.5)
plt.scatter([50, 270, 50, 270], [50, 50, 330, 330])


In [None]:
trajs.shape

In [None]:
# show occupations of K
# check consistency between data and samples

In [None]:
torch.sigmoid(tran.transformations_a[3].Ws[1]).detach().numpy()

In [None]:
for i in range(4):
    print(i)