# CalMS21 State-Annotation Comparison Analysis

In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use GPU 0
os.environ["JAX_CAPTURED_CONSTANTS_REPORT_FRAMES"]="-1"
from pathlib import Path
import jax 
jax.config.update("jax_compilation_cache_dir", (Path.cwd() / "tmp/jax_cache").as_posix())
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
try: 
    jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")
except AttributeError:
    pass  # Skip if not available in this JAX version

try:
    import blackjax
except ModuleNotFoundError:
    print('installing blackjax')
    %pip install -qq blackjax
    import blackjax
import matplotlib.pyplot as plt
from natsort import natsorted
# from fastprogress.fastprogress import progress_bar
from functools import partial

# jax.config.update('jax_platform_name', 'cpu')
from jax import random as  jr
from jax import numpy as jnp
from jax import jit, vmap
from itertools import count
from flax import nnx

from tqdm.auto import tqdm
# device = 'gpu' if jax.lib.xla_bridge.get_backend().platform == 'gpu' else 'cpu'
device = 'gpu' if jax.extend.backend.get_backend().platform == 'gpu' else 'cpu'
n_gpus = jax.device_count(backend=device)
print(f"Using {n_gpus} device(s) on {device}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import sys

# Add TiDHy to path if needed
# sys.path.insert(0, str(Path.cwd()))

# Import TiDHy utilities
from TiDHy.utils.state_annotation_comparison import (
    match_states_to_annotations,
    compute_clustering_metrics,
    compute_state_purity,
    compute_per_behavior_metrics,
    plot_confusion_matrix,
    analyze_state_annotation_correspondence
)
from TiDHy.utils.slds_analysis import load_slds_results
from TiDHy.models.TiDHy_nnx_vmap import TiDHy
from TiDHy.models.TiDHy_nnx_vmap_training import train_model, evaluate_record, load_model, get_latest_checkpoint_epoch, list_checkpoints
from TiDHy.datasets.datasets_dynamax import *
from TiDHy.datasets.load_data import load_data, stack_data
from TiDHy.utils import io_dict_to_hdf5 as ioh5
from TiDHy.utils.path_utils import *

# Plotting settings
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 11

%matplotlib inline


##### Plotting settings ######
import matplotlib as mpl
mpl.rcParams.update({'font.size':          10,
                     'axes.linewidth':     2,
                     'xtick.major.size':   5,
                     'ytick.major.size':   5,
                     'xtick.major.width':  2,
                     'ytick.major.width':  2,
                     'axes.spines.right':  False,
                     'axes.spines.top':    False,
                     'pdf.fonttype':       42,
                     'ps.fonttype':        42,
                     'xtick.labelsize':    10,
                     'ytick.labelsize':    10,
                     'figure.facecolor':   'white',
                     'pdf.use14corefonts': False,  # Changed to False - we're embedding TrueType fonts
                     'font.family':        'sans-serif',
                     'font.sans-serif':    'Arial',  # Uncommented - Arial is now available
                     'axes.unicode_minus': True,  # Ensures proper minus sign rendering in PDFs
                    })

from matplotlib.colors import ListedColormap
clrs = ['#1A237E','#7E57C2','#757575','#BDBDBD','#4CAF50','#FF9800','#795548','#FF4081','#00BCD4','#FF1744','#FFFFFF','#000000']
cmap = ListedColormap(clrs)
from sklearn.linear_model import RidgeClassifier,LogisticRegression, RidgeClassifierCV
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA

def map_discrete_cbar(cmap,N):
    cmap = plt.get_cmap(cmap,N+1)
    bounds = np.arange(-.5,N+1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return cmap, norm

def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)

def angle_between(v1, v2):
    """ Returns the angle in radians between vectors 'v1' and 'v2'::
    """
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))


import matplotlib.gridspec as gridspec
fontsize=13

clrs = np.array(['#1A237E','#7E57C2','#757575','#BDBDBD','#4CAF50','#FF9800','#795548','#FF4081','#00BCD4','#FF1744','#FFFFFF','#000000'])
sys_clrs = ['#E3A19F','#E3BE53',"#32373B",'#90CCA9','#B7522E','#B0E0E6','#A89AC2','#556B2F','#FF6F61','#87CEEB','#FFDAB9','#40E0D0']
cmap_sys = ListedColormap(sys_clrs)
clr_ind =[2,2,8,8,9,9]
# clr2 = [sys_clrs[clr_ind[n]] for n in range(len(clr_ind))]
clr_ind3 = [2,8,9]
clr2b = [sys_clrs[clr_ind3[n]] for n in range(len(clr_ind3))]

clrs_b = clrs[[0,1,2,9,4,6,7,8,11]]
cmap = ListedColormap(clrs)
cmap_b = ListedColormap(clrs_b)
cmap_b


## 1. Load CalMS21 Data with Annotations

In [None]:
dataset = 'CalMS21'
# version = 'HierarchicalMultiTimescale'
version = 'TiDHy'
# base_dir = Path(f'/gscratch/portia/eabe/biomech_model/Flybody/{dataset}/{version}')
base_dir = Path(f'/data2/users/eabe/TiDHy/{dataset}/{version}')
run_cfg_list = natsorted(list(Path(base_dir).rglob('run_config.yaml')))
for n, run_cfg in enumerate(run_cfg_list):
    temp = OmegaConf.load(run_cfg)
    print(n, temp.dataset.name, temp.version, run_cfg)

# ###### Load and update config with specified paths template ###### 
cfg_num = 0

# NEW APPROACH: Load config and replace paths using workstation.yaml template
cfg = load_config_and_override_paths(
    config_path=run_cfg_list[cfg_num],
    new_paths_template="workstation",    # Use workstation.yaml for local paths
    config_dir=Path.cwd() / "configs",
)

print(f'✅ Loaded experiment: {cfg_num}, {cfg.dataset.name}: {cfg.version} from {run_cfg_list[cfg_num]}')

# Convert string paths to Path objects and create directories
cfg.paths = convert_dict_to_path(cfg.paths)
print("✅ Successfully converted all paths to Path objects and created directories")


data_dict = load_data(cfg)
inputs_test = data_dict['inputs_test'][None]
max_seq_len = inputs_test.shape[1]


In [None]:
# Create RNG
rngs = nnx.Rngs(42)

# Get model params as dict and unpack directly
model_params = OmegaConf.to_container(cfg.model, resolve=True)
# model_params.pop('batch_converge')
model_params['input_dim'] = inputs_test.shape[-1]

model = TiDHy(**model_params, rngs=rngs)
# model.l0 = nnx.data(jnp.zeros(3))
# model.loss_weights = nnx.data(jnp.ones(3))
print(f"\nModel initialized successfully!")
print(f"input_dim: {model.input_dim}, r_dim: {model.r_dim}, r2_dim: {model.r2_dim}, mix_dim: {model.mix_dim}")
jit_model = jax.jit(model)
# out = jit_model(inputs_train)
epoch=get_latest_checkpoint_epoch(cfg.paths.ckpt_dir)
loaded_model = load_model(model,cfg.paths.ckpt_dir/f'epoch_{epoch:04d}')

In [None]:
result_dict = ioh5.load(cfg.paths.log_dir/'evaluation_results.h5')
seq_len = natsorted(list(result_dict.keys()))
# W = jnp.stack([result_dict[str(seq)]['W'].reshape(-1, result_dict[str(seq)]['W'].shape[-1]) for seq in seq_len])
# I = jnp.stack([result_dict[str(seq)]['I'].reshape(-1, result_dict[str(seq)]['I'].shape[-1]) for seq in seq_len])
# Ihat = jnp.stack([result_dict[str(seq)]['I_hat'].reshape(-1, result_dict[str(seq)]['I_hat'].shape[-1]) for seq in seq_len])
# Ibar = jnp.stack([result_dict[str(seq)]['I_bar'].reshape(-1, result_dict[str(seq)]['I_bar'].shape[-1]) for seq in seq_len])
# R_hat = jnp.stack([result_dict[str(seq)]['R_hat'].reshape(-1, result_dict[str(seq)]['R_hat'].shape[-1]) for seq in seq_len])
# R_bar = jnp.stack([result_dict[str(seq)]['R_bar'].reshape(-1, result_dict[str(seq)]['R_bar'].shape[-1]) for seq in seq_len])
# R2_hat = jnp.stack([result_dict[str(seq)]['R2_hat'].reshape(-1, result_dict[str(seq)]['R2_hat'].shape[-1]) for seq in seq_len])
# Ut = jnp.stack([result_dict[str(seq)]['Ut'].reshape((-1,)+ result_dict[str(seq)]['Ut'].shape[2:]) for seq in seq_len])
ts=-1
W = result_dict['{}'.format(seq_len[ts])]['W'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['W'].shape[-1])
I = result_dict['{}'.format(seq_len[ts])]['I'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['I'].shape[-1])
Ihat = result_dict['{}'.format(seq_len[ts])]['I_hat'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['I_hat'].shape[-1])
R_hat = result_dict['{}'.format(seq_len[ts])]['R_hat'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['R_hat'].shape[-1])
R_bar = result_dict['{}'.format(seq_len[ts])]['R_bar'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['R_bar'].shape[-1])
R2_hat = result_dict['{}'.format(seq_len[ts])]['R2_hat'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['R2_hat'].shape[-1])
Ut = result_dict['{}'.format(seq_len[ts])]['Ut'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['Ut'].shape[-2],result_dict['{}'.format(seq_len[ts])]['Ut'].shape[-1])
annotations = data_dict['annotations_test'][:R2_hat.shape[0]]
W.shape, R2_hat.shape, R_hat.shape, Ut.shape,

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import RidgeClassifier,LogisticRegression, RidgeClassifierCV
ts = -1
exp_path = cfg.paths.save_dir
run_id = next(part.split('=')[1] for part in exp_path.parts if part.startswith('run_id='))
results_path_list = natsorted(list((cfg.paths.base_dir / f'run_id={run_id}').rglob(f'evaluation_results.h5')))
best_pred_all = []
for n, results_path in enumerate(results_path_list):
	print(n, results_path)
	result_dict = ioh5.load(results_path)
	seq_len = natsorted(list(result_dict.keys()))
	W = result_dict['{}'.format(seq_len[ts])]['W'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['W'].shape[-1])
	I = result_dict['{}'.format(seq_len[ts])]['I'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['I'].shape[-1])
	Ihat = result_dict['{}'.format(seq_len[ts])]['I_hat'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['I_hat'].shape[-1])
	R_hat = result_dict['{}'.format(seq_len[ts])]['R_hat'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['R_hat'].shape[-1])
	R_bar = result_dict['{}'.format(seq_len[ts])]['R_bar'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['R_bar'].shape[-1])
	R2_hat = result_dict['{}'.format(seq_len[ts])]['R2_hat'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['R2_hat'].shape[-1])
	Ut = result_dict['{}'.format(seq_len[ts])]['Ut'].reshape(-1,result_dict['{}'.format(seq_len[ts])]['Ut'].shape[-2],result_dict['{}'.format(seq_len[ts])]['Ut'].shape[-1])
	annotations = data_dict['annotations_test'][:R2_hat.shape[0]]
	reg_variables = [np.concatenate([R2_hat,W,R_hat],axis=-1),R2_hat,W,R_hat,R_bar]
	labels = ['all','R2_hat','W','R_hat','R_bar']

	max_acc = 0
	tr_batch_size = 40
	batch_size = 10
	for k,reg_vars in enumerate(reg_variables):

		X_train, X_test, y_train, y_test = train_test_split(reg_vars.reshape(-1,reg_vars.shape[-1]), annotations.reshape(-1), test_size=0.25, random_state=42)
		neigh = KNeighborsClassifier(n_neighbors=3,n_jobs=-1)
		# neigh = RidgeClassifierCV()
		neigh.fit(X_train, y_train)
		scores = neigh.score(X_test, y_test)
		print(labels[k],scores)
		if scores > max_acc:
			y_pred = neigh.predict(reg_vars.reshape(-1,reg_vars.shape[-1]))
			max_acc = scores
			best_reg_vars = reg_vars
			best_label = labels[k]
			best_pred = y_pred
	best_pred_all.append(best_pred)
states_tidhy = best_pred_all
behavior_names = list(data_dict['vocabulary'].keys())

## 2. Load SLDS Results

In [None]:
dataset = 'CalMS21'
# version = 'HierarchicalMultiTimescale'
version = 'SLDS'
# base_dir = Path(f'/gscratch/portia/eabe/biomech_model/Flybody/{dataset}/{version}')
base_dir = Path(f'/data2/users/eabe/TiDHy/{dataset}/{version}')
run_cfg_list = natsorted(list(Path(base_dir).rglob('run_config.yaml')))
for n, run_cfg in enumerate(run_cfg_list):
    temp = OmegaConf.load(run_cfg)
    print(n, temp.dataset.name, temp.version, run_cfg)

# ###### Load and update config with specified paths template ###### 
cfg_num = 3

# NEW APPROACH: Load config and replace paths using workstation.yaml template
cfg_ssm = load_config_and_override_paths(
    config_path=run_cfg_list[cfg_num],
    new_paths_template="workstation",    # Use workstation.yaml for local paths
    config_dir=Path.cwd().parent / "configs",
)

print(f'✅ Loaded experiment: {cfg_num}, {cfg_ssm.dataset.name}: {cfg_ssm.version} from {run_cfg_list[cfg_num]}')

# Convert string paths to Path objects and create directories
cfg_ssm.paths = convert_dict_to_path(cfg_ssm.paths)
print("✅ Successfully converted all paths to Path objects and created directories")


In [None]:
natsorted(list((cfg_ssm.paths.base_dir / f'run_id={run_id}').rglob(f'ssm_slds_*.h5')))

In [None]:
cfg_ssm.paths.save_dir

In [None]:
run_id = next(part.split('=')[1] for part in cfg_ssm.paths.save_dir.parts if part.startswith('run_id='))
run_id

In [None]:

rslds_dict = ioh5.load(list(cfg_ssm.paths.log_dir.glob('ssm_rslds_*.h5'))[0])
run_id = next(part.split('=')[1] for part in cfg_ssm.paths.save_dir.parts if part.startswith('run_id='))

results_path_list = natsorted(list((cfg_ssm.paths.base_dir / f'run_id={run_id}').rglob(f'ssm_slds_*.h5')))
slds_states = []
for n, results_path in enumerate(results_path_list):
	print(n, results_path)
	slds_dict = ioh5.load(list(cfg_ssm.paths.log_dir.glob('ssm_slds_*.h5'))[0])
	emission = slds_dict['SLDS_emission']
	latents = slds_dict['SLDS_latents']
	slds_states.append(slds_dict['SLDS_states'])
slds_states = jnp.stack(slds_states)
annotations = data_dict['annotations_test']
behavior_names = list(data_dict['vocabulary'].keys())
slds_states.shape

## 3. Quick Data Check

In [None]:
# Uncomment when you have loaded data

print(f"States shape: {states.shape}")
print(f"Annotations shape: {annotations.shape}")
print(f"Number of unique states: {len(np.unique(states))}")
print(f"Number of unique behaviors: {len(np.unique(annotations))}")
print(f"\nBehavior vocabulary: {behavior_names}")
print(f"\nState distribution:")
for s in np.unique(states):
    count = np.sum(states == s)
    print(f"  State {s}: {count} timesteps ({count/len(states)*100:.1f}%)")
print(f"\nAnnotation distribution:")
for a in np.unique(annotations):
    count = np.sum(annotations == a)
    behavior = behavior_names[a] if a < len(behavior_names) else f'Unknown_{a}'
    print(f"  {behavior}: {count} timesteps ({count/len(annotations)*100:.1f}%)")

## 4. Comprehensive State-Annotation Analysis

Run the full analysis pipeline using the main analysis function.

In [None]:
# Run comprehensive analysis
# Uncomment when data is loaded
matched_states, f1_tidhy = [], []
for n in range(len(states)):
	print(f"\nAnalyzing TiDHy run {n+1}/{len(states)}")
	results_tidhy = analyze_state_annotation_correspondence(
		states=states[n],
		annotations=annotations,
		behavior_names=behavior_names,
		verbose=True
	)
	matched_states.append(results_tidhy['matching']['matched_states'])
	f1_tidhy.append(results_tidhy['per_behavior_metrics']['per_behavior_f1'])
matched_states = jnp.stack(matched_states)
f1_tidhy = jnp.stack(f1_tidhy)

slds_matched_states, f1_slds = [], []
for n in range(len(slds_states)):
	print(f"\nAnalyzing SLDS run {n+1}/{len(slds_states)}")
	results_slds = analyze_state_annotation_correspondence(
		states=slds_states[n],
		annotations=annotations,
		behavior_names=behavior_names,
		verbose=True
	)
	slds_matched_states.append(results_slds['matching']['matched_states'])
	f1_slds.append(results_slds['per_behavior_metrics']['per_behavior_f1'])
slds_matched_states = jnp.stack(slds_matched_states)
f1_slds = jnp.stack(f1_slds)

print("This will print a comprehensive analysis summary including:")
print("  - Data summary (timesteps, num states, num behaviors)")
print("  - Clustering quality metrics (ARI, NMI, V-measure, accuracy)")
print("  - Purity metrics (state purity, annotation purity)")
print("  - Per-behavior performance (precision, recall, F1)")

## Visualize Confusion Matrix



In [None]:
import matplotlib.gridspec as gridspec
from TiDHy.utils.state_annotation_comparison import analyze_state_annotation_correspondence
fontsize=13

clrs = np.array(['#1A237E','#7E57C2','#757575','#BDBDBD','#4CAF50','#FF9800','#795548','#FF4081','#00BCD4','#FF1744','#FFFFFF','#000000'])
sys_clrs = ['#E3A19F','#E3BE53',"#32373B",'#90CCA9','#B7522E','#B0E0E6','#A89AC2','#556B2F','#FF6F61','#87CEEB','#FFDAB9','#40E0D0']
cmap_sys = ListedColormap(sys_clrs)
clr_ind =[2,2,8,8,9,9]
# clr2 = [sys_clrs[clr_ind[n]] for n in range(len(clr_ind))]
clr_ind3 = [2,8,9]
clr2b = [sys_clrs[clr_ind3[n]] for n in range(len(clr_ind3))]

clrs_b = clrs[[0,1,2,9,4,6,7,8,11]]
cmap = ListedColormap(clrs)
cmap_b = ListedColormap(clrs_b)

In [None]:
fontsize=13
fig = plt.figure(constrained_layout=True, figsize=(8.75,5))
gs  = gridspec.GridSpec(nrows=4, ncols=6,hspace=5,wspace=.1) 
gs0 = gridspec.GridSpecFromSubplotSpec(1, 6, subplot_spec=gs[:2,:], wspace=.1,hspace=.1)
gs1 = gridspec.GridSpecFromSubplotSpec(1, 8, subplot_spec=gs[2:4,:], wspace=10,hspace=1.5)
# gs2 = gridspec.GridSpecFromSubplotSpec(3, 2, subplot_spec=gs[4:,:], wspace=.1,hspace=.1)

ax = fig.add_subplot(gs0[:, :3])
fig1 = plot_confusion_matrix(
    matched_states=matched_states[-1],
    annotations=annotations,
    behavior_names=behavior_names,
    normalize='true',  # Row-normalize (shows recall)
    model='TiDHy',
    ax=ax
)
fig1.delaxes(fig1.axes[1])
ax = fig.add_subplot(gs0[:, 3:])
fig2 = plot_confusion_matrix(
    matched_states=slds_matched_states[-1],
    annotations=annotations,
    behavior_names=behavior_names,
    normalize='true',  # Row-normalize (shows recall)
    model='SLDS',
    ax=ax
)
ax.set_ylabel('')
ax.set_yticklabels([])
ax = fig.add_subplot(gs1[:, :4])
ax.bar(x=np.arange(len(results_tidhy['per_behavior_metrics']['per_behavior_f1'])), height=np.mean(f1_tidhy, axis=0),color='k')
ax.errorbar(x=np.arange(len(results_tidhy['per_behavior_metrics']['per_behavior_f1'])), y=np.mean(f1_tidhy, axis=0), yerr=np.std(f1_tidhy, axis=0), capsize=5, color='gray', fmt='none')
ax.set_xticks(np.arange(len(behavior_names)))
ax.set_xticklabels(behavior_names,rotation=45,ha='right',fontsize=fontsize-2)
ax.set_ylabel('F1 Score',fontsize=fontsize-2)


ax = fig.add_subplot(gs1[:, 4:8])
ax.bar(x=np.arange(len(results_slds['per_behavior_metrics']['per_behavior_f1'])), height=np.mean(f1_slds, axis=0),color='k')
ax.errorbar(x=np.arange(len(results_slds['per_behavior_metrics']['per_behavior_f1'])), y=np.mean(f1_slds, axis=0), yerr=np.std(f1_slds, axis=0), capsize=5, color='gray', fmt='none')
ax.set_xticks(np.arange(len(behavior_names)))
ax.set_xticklabels(behavior_names,rotation=45,ha='right',fontsize=fontsize-2)
ax.set_ylabel('F1 Score',fontsize=fontsize-2)
ax.set_ylim(0,1)
fig.savefig(cfg.paths.fig_dir / f'state_annotation_analysis_{cfg.dataset.name}_{cfg.version}_vs_{cfg_ssm.version}.pdf', bbox_inches='tight', dpi=300, transparent=True)

In [None]:
import cupy
import cuml
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [None]:
result_dict = ioh5.load(cfg.paths.log_dir/'evaluation_results.h5')
seq_len = natsorted(list(result_dict.keys()))
W = jnp.stack([result_dict[str(seq)]['W'].reshape(-1, result_dict[str(seq)]['W'].shape[-1]) for seq in seq_len])
I = jnp.stack([result_dict[str(seq)]['I'].reshape(-1, result_dict[str(seq)]['I'].shape[-1]) for seq in seq_len])
Ihat = jnp.stack([result_dict[str(seq)]['I_hat'].reshape(-1, result_dict[str(seq)]['I_hat'].shape[-1]) for seq in seq_len])
Ibar = jnp.stack([result_dict[str(seq)]['I_bar'].reshape(-1, result_dict[str(seq)]['I_bar'].shape[-1]) for seq in seq_len])
R_hat = jnp.stack([result_dict[str(seq)]['R_hat'].reshape(-1, result_dict[str(seq)]['R_hat'].shape[-1]) for seq in seq_len])
R_bar = jnp.stack([result_dict[str(seq)]['R_bar'].reshape(-1, result_dict[str(seq)]['R_bar'].shape[-1]) for seq in seq_len])
R2_hat = jnp.stack([result_dict[str(seq)]['R2_hat'].reshape(-1, result_dict[str(seq)]['R2_hat'].shape[-1]) for seq in seq_len])
Ut = jnp.stack([result_dict[str(seq)]['Ut'].reshape((-1,)+ result_dict[str(seq)]['Ut'].shape[2:]) for seq in seq_len])

In [None]:
ts = -1
t = 0
dt = 50000
feature = 10
fig, axs = plt.subplots(3,1,figsize=(6,6))
ax = axs[0]
ax.plot(I[ts,t:t+dt,feature],label='I1')
ax.plot(Ihat[ts,t:t+dt,feature],label='Ihat1')
ax = axs[1]
ax.plot(R_hat[ts,t:t+dt,:],label='R_hat')
ax = axs[2]
ax.plot(R2_hat[ts,t:t+dt,:],label='R_2hat')
# ax.plot(Ihat[ts,:10000,2],label='Ihat3')

In [None]:

ts = -1
scaler = StandardScaler()
X_scaled = scaler.fit_transform(R_hat[ts])
umap = cuml.manifold.UMAP(n_components=10, n_neighbors=500, min_dist=0.0, metric='euclidean', random_state=42)
reduced_data = umap.fit_transform(X_scaled)
# reduced_data = umap.fit_transform(R2_hat[ts])
# reduced_data = PCA(n_components=8).fit_transform(R_hat[ts])

# clusterer = cuml.cluster.hdbscan.HDBSCAN(min_cluster_size=500, min_samples=100, cluster_selection_epsilon=0.075, prediction_data=True)
clusterer = cuml.cluster.hdbscan.HDBSCAN(min_cluster_size=5000, min_samples=5, cluster_selection_epsilon=0.05, prediction_data=True)
clusterer.fit(reduced_data)
soft_clusters = cuml.cluster.hdbscan.all_points_membership_vectors(clusterer)
labels = clusterer.labels_
l_lab, l_counts = np.unique(labels,return_counts=True)
print(l_lab)
print(l_counts)
print(l_counts[0], sum(l_counts[1:]))

soft_label = jnp.argmax(soft_clusters[:,1:].get() if isinstance(soft_clusters, cupy.ndarray) else soft_clusters[:,1:], axis=1)
reduced_data = reduced_data.get() if isinstance(reduced_data, cupy.ndarray) else reduced_data

In [None]:
annotations = data_dict['annotations_test'][:reduced_data.shape[0]]
# full_state_z = data_dict['annotations_test'][:reduced_data.shape[0]]
clrs = np.array(['#1A237E','#7E57C2','#757575','#BDBDBD','#4CAF50','#FF9800','#795548','#FF4081','#00BCD4','#FF1744','#FFFFFF','#000000',
                 '#E3A19F','#E3BE53',"#32373B",'#90CCA9','#B7522E','#B0E0E6','#A89AC2','#556B2F','#FF6F61','#87CEEB','#FFDAB9','#40E0D0'])
clrs_1 = clrs[[0,2,5,8]]
clrs_2 = clrs[:len(np.unique(soft_label))]
def map_discrete_cbar(colors):
	N = len(colors)  # Number of discrete colors
	cmap = mpl.colors.ListedColormap(colors)
	bounds = np.arange(-0.5, N + .5, 1)  # [-0.5, 0.5, 1.5, ..., N-0.5]
	norm = mpl.colors.BoundaryNorm(bounds, N)
	return cmap, norm

cmap1, norm1 = map_discrete_cbar(clrs_1)
cmap2, norm2 = map_discrete_cbar(clrs_2)

fig = plt.figure(figsize=(8.5,3), dpi=300)
ax = fig.add_subplot(121, projection='3d')
im = ax.scatter(reduced_data[:,0], reduced_data[:,1], reduced_data[:,2], c=annotations, cmap=cmap1, alpha=.05, rasterized=True)
# ax.scatter(reduced_data[:1000,0], reduced_data[:1000,1], reduced_data[:1000,2], c=np.arange(1000), cmap='turbo', alpha=0.1)
cbar = fig.colorbar(im, boundaries=norm1.boundaries, spacing='uniform', extend='neither', aspect=15, pad=.2, shrink=0.75)
cbar.set_ticks(np.arange(0,len(np.unique(annotations)),1))
cbar.set_ticklabels([key for key in behavior_names],fontsize=fontsize)
cbar.outline.set_linewidth(1)
cbar.minorticks_off()
cbar.ax.tick_params(width=1,which="major")
cbar.solids.set_alpha(1)
ax.set_xlabel('UMAP 1',fontsize=fontsize)
ax.set_ylabel('UMAP 2',fontsize=fontsize)
ax.set_zlabel('UMAP 3',fontsize=fontsize)
ax.set_xlim([-5.5,5.5])
ax.set_ylim([-5.5,5.5])
ax.set_zlim([-5.5,5.5])

ax = fig.add_subplot(122, projection='3d')
im = ax.scatter(reduced_data[:,0], reduced_data[:,1], reduced_data[:,2], c=soft_label, cmap=cmap2, alpha=.01, rasterized=True)
cbar = fig.colorbar(im, boundaries=norm2.boundaries, spacing='uniform', extend='neither', aspect=15, pad=.2, shrink=0.75)
cbar.set_ticks(np.arange(0,len(np.unique(soft_label)),1))
cbar.outline.set_linewidth(1)
cbar.minorticks_off()
cbar.ax.tick_params(width=1,which="major")
cbar.solids.set_alpha(1)
ax.set_xlabel('UMAP 1',fontsize=fontsize)
ax.set_ylabel('UMAP 2',fontsize=fontsize)
ax.set_zlabel('UMAP 3',fontsize=fontsize)
ax.set_xlim([-5.5,5.5])
ax.set_ylim([-5.5,5.5])
ax.set_zlim([-5.5,5.5])
plt.tight_layout()
plt.show()
fig.savefig(cfg.paths.fig_dir/'UMAP_Clustering.pdf',dpi=300, transparent=True)

In [None]:

ts = -1
# scaler = StandardScaler()
# X_scaled = scaler.fit_transform(R_hat[ts])
# umap = cuml.manifold.UMAP(n_components=10, n_neighbors=500, min_dist=0.0, metric='euclidean', random_state=42)
# reduced_data = umap.fit_transform(X_scaled)
# reduced_data = umap.fit_transform(R2_hat[ts])
reduced_data = PCA(n_components=8).fit_transform(R_hat[ts])

clusterer = cuml.cluster.hdbscan.HDBSCAN(min_cluster_size=500, min_samples=100, cluster_selection_epsilon=0.075, prediction_data=True)
# clusterer = cuml.cluster.hdbscan.HDBSCAN(min_cluster_size=5000, min_samples=5, cluster_selection_epsilon=0.05, prediction_data=True)
clusterer.fit(reduced_data)
soft_clusters = cuml.cluster.hdbscan.all_points_membership_vectors(clusterer)
labels = clusterer.labels_
l_lab, l_counts = np.unique(labels,return_counts=True)
print(l_lab)
print(l_counts)
print(l_counts[0], sum(l_counts[1:]))

soft_label = jnp.argmax(soft_clusters[:,1:].get() if isinstance(soft_clusters, cupy.ndarray) else soft_clusters[:,1:], axis=1)
reduced_data = reduced_data.get() if isinstance(reduced_data, cupy.ndarray) else reduced_data

In [None]:
annotations = data_dict['annotations_test'][:reduced_data.shape[0]]
# full_state_z = data_dict['annotations_test'][:reduced_data.shape[0]]
clrs = np.array(['#1A237E','#7E57C2','#757575','#BDBDBD','#4CAF50','#FF9800','#795548','#FF4081','#00BCD4','#FF1744','#FFFFFF','#000000',
                 '#E3A19F','#E3BE53',"#32373B",'#90CCA9','#B7522E','#B0E0E6','#A89AC2','#556B2F','#FF6F61','#87CEEB','#FFDAB9','#40E0D0'])
clrs_1 = clrs[[0,2,5,8]]
clrs_2 = clrs[:len(np.unique(soft_label))]
def map_discrete_cbar(colors):
	N = len(colors)  # Number of discrete colors
	cmap = mpl.colors.ListedColormap(colors)
	bounds = np.arange(-0.5, N + .5, 1)  # [-0.5, 0.5, 1.5, ..., N-0.5]
	norm = mpl.colors.BoundaryNorm(bounds, N)
	return cmap, norm

cmap1, norm1 = map_discrete_cbar(clrs_1)
cmap2, norm2 = map_discrete_cbar(clrs_2)

fig = plt.figure(figsize=(8.5,3), dpi=300)
ax = fig.add_subplot(121, projection='3d')
im = ax.scatter(reduced_data[:,0], reduced_data[:,1], reduced_data[:,2], c=annotations, cmap=cmap1, alpha=.05, rasterized=True)
# ax.scatter(reduced_data[:1000,0], reduced_data[:1000,1], reduced_data[:1000,2], c=np.arange(1000), cmap='turbo', alpha=0.1)
cbar = fig.colorbar(im, boundaries=norm1.boundaries, spacing='uniform', extend='neither', aspect=15, pad=.2, shrink=0.75)
cbar.set_ticks(np.arange(0,len(np.unique(annotations)),1))
cbar.set_ticklabels([key for key in behavior_names],fontsize=fontsize)
cbar.outline.set_linewidth(1)
cbar.minorticks_off()
cbar.ax.tick_params(width=1,which="major")
cbar.solids.set_alpha(1)
ax.set_xlabel('PCA 1',fontsize=fontsize)
ax.set_ylabel('PCA 2',fontsize=fontsize)
ax.set_zlabel('PCA 3',fontsize=fontsize)
# ax.set_xlim([-5.5,5.5])
# ax.set_ylim([-5.5,5.5])
# ax.set_zlim([-5.5,5.5])

ax = fig.add_subplot(122, projection='3d')
im = ax.scatter(reduced_data[:,0], reduced_data[:,1], reduced_data[:,2], c=soft_label, cmap=cmap2, alpha=.01, rasterized=True)
cbar = fig.colorbar(im, boundaries=norm2.boundaries, spacing='uniform', extend='neither', aspect=15, pad=.2, shrink=0.75)
cbar.set_ticks(np.arange(0,len(np.unique(soft_label)),1))
cbar.outline.set_linewidth(1)
cbar.minorticks_off()
cbar.ax.tick_params(width=1,which="major")
cbar.solids.set_alpha(1)
ax.set_xlabel('PCA 1',fontsize=fontsize)
ax.set_ylabel('PCA 2',fontsize=fontsize)
ax.set_zlabel('PCA 3',fontsize=fontsize)
# ax.set_xlim([-5.5,5.5])
# ax.set_ylim([-5.5,5.5])
# ax.set_zlim([-5.5,5.5])
plt.tight_layout()
plt.show()
fig.savefig(cfg.paths.fig_dir/'PCA_Clustering.pdf',dpi=300, transparent=True)