# Apply the model to get state sequences
* This section requires GPU + CUDA
* use the top_bottom_moseq_37 env

In [2]:
import matplotlib.pyplot as plt
import jax, jax.numpy as jnp
import tqdm as tqdm
import numpy as np
import glob
import joblib
from os.path import join, exists

from keypoint_moseq.util import *
from keypoint_moseq.gibbs import *
from keypoint_moseq.initialize import *

In [7]:
# make a list of sessions to include
exclude_patterns = ['wavelet']
latents_paths = glob.glob('/n/groups/datta/Jonah/Thermistor_only_recordings/*/202210*/*latents.npy')
latents_paths = [lp for lp in latents_paths if not any([ep in lp for ep in exclude_patterns])]
overwrite = False
latents_paths

['/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221012_gmou81/20221012_gmou81.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221010_gmou81/20221010_gmou81.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221013_gmou81/20221013_gmou81.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221010_gmou78/20221010_gmou78.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221012_gmou78/20221012_gmou78.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221013_gmou78/20221013_gmou78.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221015_gmou83/20221015_gmou83.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221013_gmou83/20221013_gmou83.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221014_gmou83/20221014_gmou83.latents.npy',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/202210

In [8]:
# load model
# model_path = '/n/groups/datta/caleb/21_2_19_segmentation_redux/tb_jonah/moseq_model_only20_iters.p'
# model_path = '/n/groups/datta/Jonah/Thermistor_only_recordings/top_bottom_dataset3/moseq_model.p'  # kappa 1e6
model_path = '/n/groups/datta/Jonah/Thermistor_only_recordings/top_bottom_dataset3/moseq_model_kappa5e6.p'
saved_model = joblib.load(model_path)

In [9]:
data_dict = {p.split('.latents.npy')[0]:np.load(p).squeeze() for p in latents_paths}

In [13]:
# make sure to shut down any other notebooks using GPU resources before running this cell

key = jr.PRNGKey(0)
model_params = {k:jnp.array(v) for k,v in saved_model['params'].items()}
whitening_params = [jnp.array(v) for v in saved_model['whitening_params']]

for path in tqdm.tqdm(latents_paths):
    
    out_path = path.replace('.latents.npy','.stateseq.npy')
    if exists(out_path) and not overwrite:
        print(f'{out_path} exists, continuing...') 
        continue
    
    x = jnp.load(path).squeeze()[None]
    mask = jnp.ones(x.shape[:2])
    
    x = whiten_all(x,mask, params=whitening_params)[0]
    stateseq = resample_stateseqs(key, x=x, mask=mask, **model_params)[0]
    stateseq = np.array(stateseq).squeeze()
    
    np.save(out_path, stateseq)

  0%|                                                                                        | 0/15 [00:00<?, ?it/s]

/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221012_gmou81/20221012_gmou81.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221010_gmou81/20221010_gmou81.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221013_gmou81/20221013_gmou81.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221010_gmou78/20221010_gmou78.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221012_gmou78/20221012_gmou78.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221013_gmou78/20221013_gmou78.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221015_gmou83/20221015_gmou83.stateseq.npy exists, continuing...


 53%|██████████████████████████████████████████▋                                     | 8/15 [00:08<00:07,  1.02s/it]

/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221014_gmou83/20221014_gmou83.stateseq.npy exists, continuing...


100%|███████████████████████████████████████████████████████████████████████████████| 15/15 [00:14<00:00,  1.02it/s]

/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221010_gmou83/20221010_gmou83.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221013_gmou77/20221013_gmou77.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221015_gmou77/20221015_gmou77.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221010_gmou77/20221010_gmou77.stateseq.npy exists, continuing...
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221012_gmou77/20221012_gmou77.stateseq.npy exists, continuing...





# Get sorted usages and simple scalars
* This part can be run on its own without a GPU
* use the dataPy_NWB env

### Sorted usages

In [1]:
from copy import copy
from glob import glob
import numpy as np
from os.path import exists
from moseq_fo.util import timeseries_utils as tsu

In [2]:
def get_dict_map_np(my_dict):
    return np.vectorize(my_dict.get)

In [7]:
# Save states labeled by usage-sorted syllable

# use 77,78,81 data for sorting / ranking the syllables
prefixes_to_sort = ['/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221012_gmou81/20221012_gmou81',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221010_gmou81/20221010_gmou81',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221013_gmou81/20221013_gmou81',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221010_gmou78/20221010_gmou78',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221012_gmou78/20221012_gmou78',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221013_gmou78/20221013_gmou78',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221013_gmou77/20221013_gmou77',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221015_gmou77/20221015_gmou77',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221010_gmou77/20221010_gmou77',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou77/20221012_gmou77/20221012_gmou77']

# apply to 83 as well (and anyone else that this model is applied to)
prefixes_to_apply = glob('/n/groups/datta/Jonah/Thermistor_only_recordings/*/202210*/*.latents.npy')
prefixes_to_apply = [path.split('.')[0] for path in prefixes_to_apply]


overwrite = True
num_states = 100

prefixes_to_apply

['/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221012_gmou81/20221012_gmou81',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221010_gmou81/20221010_gmou81',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221013_gmou81/20221013_gmou81',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221010_gmou78/20221010_gmou78',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221012_gmou78/20221012_gmou78',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221013_gmou78/20221013_gmou78',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221015_gmou83/20221015_gmou83',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221013_gmou83/20221013_gmou83',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221014_gmou83/20221014_gmou83',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221011_gmou83/20221011_gmou83',
 '/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221010_gmou83/202210

In [8]:
# Load all the stateseqs to be used for usage ranking
stateseqs = np.hstack([np.load(prefix+'.stateseq.npy') for prefix in prefixes_to_sort])
uq_states = np.unique(stateseqs)
uq_states = uq_states[~np.isnan(uq_states)]

# bad old way -- uses total frame nums
# usage_rank = np.argsort(np.argsort(np.bincount(np.hstack(stateseqs), minlength=100))[::-1])

# good way -- uses num syl instances
durations, start_idx, stateseq_no_rep = tsu.rle(stateseqs)
syl_counts = np.bincount(stateseq_no_rep, minlength=num_states)
relabeled = np.argsort(np.argsort(syl_counts)[::-1])  # most-used is 0, next is 1, etc
orig2sorted = {orig:lab for orig,lab in zip(np.arange(num_states), relabeled)}
mapping = get_dict_map_np(orig2sorted)
print(f'{(syl_counts>0).sum()} total syllables used')

98 total syllables used


In [11]:
for prefix in prefixes_to_apply:
    stateseq = np.load(prefix+'.stateseq.npy')
    stateseq_usage = mapping(stateseq)
    usage_npy = prefix+'.state_usage_ranks.npy'
    stateseq_sorted_npy = prefix+'.stateseq_usage_sorted.npy'
    if exists(usage_npy) and not overwrite:
        continue
    else:
        np.save(usage_npy, relabeled)
        
    if exists(stateseq_sorted_npy) and not overwrite:
        continue
    else:
        np.save(stateseq_sorted_npy, stateseq_usage)

### Simple scalars

In [27]:
from top_bottom_moseq.util import *
from top_bottom_moseq.io import videoReader
import re 
from os.path import join, exists
import pickle

In [21]:
def vec_to_angle(v, degrees=False):
    a = np.arctan(v[:,1]/v[:,0]) + np.pi*(v[:,0]>0)
    if degrees: a = a / np.pi * 180
    return a

def angle_to_vec(a, degrees=False):
    if degrees: a = a / 180 * np.pi 
    return -np.hstack([np.cos(a)[:,None],np.sin(a)[:,None]])

def camera_project(points, camera_transform, intrinsics):
    R,t = camera_transform
    if points.shape[1]==2: points = np.hstack((points,np.zeros((points.shape[0],1))))
    return cv2.projectPoints((points-t).dot(R),(0,0,0),(0,0,0),*intrinsics)[0]

def scalars_to_cameraspace(scalars, camera_transform, intrinsics):
    centroid = camera_project(scalars[:,:2], camera_transform, intrinsics)[:,0,:]
    head = camera_project(scalars[:,:2] + angle_to_vec(scalars[:,2]), camera_transform, intrinsics)[:,0,:]
    angle = vec_to_angle(head-centroid)
    return np.hstack((centroid, angle[:,None], scalars[:,3][:,None]))


def get_calibration_file(calibration_dir, date=None, date_regexp=None, current_prefix=None):
    """ Find calibration file corresponding to given date, or via regexp + prefix.
    """
    if date is not None:
        cf = glob(join(calibration_dir, date, 'camera_3D_transforms.p'))
    else:    
        regexp = re.match(date_from_folder_regexp, prefix)
        cf = glob(join(calibration_dir, regexp.group('date'), 'camera_3D_transforms.p'))
    
    if len(cf) == 0:
        print(f'No calibn file found for {prefix}, continuing...')
    elif len(cf) > 1:
        print('Found two calibration files for {prefix}, skipping for now...')
    else:
        cf = cf[0]
    return cf

In [22]:
calibration_dir = '/n/groups/datta/Jonah/Thermistor_only_recordings/calibrations'  # in which must be nested folders called, eg, 20221008 (YYYYMMDD)
calibration_file_names = 'camera_3D_transforms.p'  # all must have same name!
date_from_folder_regexp = re.compile('.*/(?P<mouse>gmou\d*)/(?P<date>\d{8})_gmou.*')
intrinsics_prefix = '/n/groups/datta/Jonah/gh_topbottom/intrinsics/JP_rig'
intrinsics = {name:load_intrinsics(intrinsics_prefix+'.'+name+'.json') for name in ['top','bottom']}
camera_names = ['top','bottom']

overwrite = False

In [28]:
scalars_dict = {}
for prefix in prefixes_to_apply:
    print(prefix)
    scalars_dict[prefix] = {}
    
    # Load moseq data
    thetas = np.load(prefix+'.thetas.npy').squeeze()
    crop_origins = np.load(prefix+'.crop_centers.npy')
    
    # Load camera data
    if any([d in prefix for d in ['20221010', '20221012']]):
        cf = get_calibration_file(calibration_dir, date='20221010')
    else:
        cf = get_calibration_file(calibration_dir, date_regexp=date_from_folder_regexp, current_prefix=prefix)

    camera_transforms = pickle.load(open(cf,'rb'))
    
    # Scalars are (x, y, heading, ?)
    scalars = np.hstack([crop_origins - thetas[:,2:4], 
                        (vec_to_angle(thetas[:,:2])[:,None]+np.pi/2)%(2*np.pi),
                        thetas[:,4][:,None]])

    scalars2d = scalars_to_cameraspace(scalars, camera_transforms['bottom'], intrinsics['bottom'])
    vel = np.vstack([np.array([np.nan, np.nan]), np.diff(scalars2d[:, :2], axis=0)])
    speed = np.sqrt(np.sum(vel**2, axis=1))
    
    scalars_npy = prefix+'.scalars.npy'
    scalars2d_npy = prefix+'.scalars2d.npy'
    
    if exists(scalars_npy) and not overwrite:
        continue
    else:
        np.save(scalars_npy, scalars)
        
    if exists(scalars2d_npy) and not overwrite:
        continue
    else:
        np.save(scalars2d_npy, scalars2d)

/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221012_gmou81/20221012_gmou81
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221010_gmou81/20221010_gmou81
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou81/20221013_gmou81/20221013_gmou81
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221010_gmou78/20221010_gmou78
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221012_gmou78/20221012_gmou78
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou78/20221013_gmou78/20221013_gmou78
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221015_gmou83/20221015_gmou83
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221013_gmou83/20221013_gmou83
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221014_gmou83/20221014_gmou83
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221011_gmou83/20221011_gmou83
/n/groups/datta/Jonah/Thermistor_only_recordings/gmou83/20221010_gmou83/20221010_gmou83
/n/groups/datta/Jonah/Thermistor