In [1]:
import warnings

warnings.filterwarnings("ignore")

import os.path
import ssm
import pickle
import numpy as np
import pandas as pd
import scipy
import copy
from tqdm import tqdm

from runwise_ts_log_data import get_ts_log_data_blocked

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import to_rgba
import seaborn as sns

from sympy.utilities.iterables import multiset_permutations
from ssm.util import *
from scipy.stats import ttest_ind, wilcoxon, ranksums

In [2]:
with open('pkl/emoprox2_dataset_timeseries+inputs_MAX85.pkl','rb') as f:
    orig_df = pickle.load(f)
subj_list = sorted(orig_df['pid'].unique())
subj_list = subj_list[:] # remove first 30 subjects
orig_df = orig_df[orig_df['pid'].isin(subj_list)]

In [None]:
K = 6
D = 10
N = 85
num_subjs = 92
num_resamples = 100
M = 20

In [4]:
from sklearn.cluster import KMeans
from munkres import Munkres
from sklearn.metrics.pairwise import cosine_similarity


def reference_comms(num_states, X):
    X = np.concatenate(X, axis=-1)
    kmeans = KMeans(
        n_clusters=num_states, init="k-means++", n_init=50, random_state=74
    ).fit(X=X.T)
    print("found reference")
    return kmeans.cluster_centers_.T  # num_rois x num_comms


def align_two_partitions(source_comms, target_comms):
    # align source onto target

    # distatnce matrix
    Dd = 1 - cosine_similarity(target_comms.T, source_comms.T)
    # Hungarian permutation method
    best_pi = list(zip(*Munkres().compute(Dd)))[1]
    source_comms = source_comms[:, best_pi]
    return source_comms, best_pi


def align_partitions(num_states, X):
    comm_cntrs = reference_comms(num_states, X)
    best_pis = []
    for idx, comms in tqdm(enumerate(X)):
        _, best_pi = align_two_partitions(source_comms=comms, target_comms=comm_cntrs)
        best_pis.append(best_pi)
    return best_pis


def get_attractors(idx_resample):
    with open(
        f"pkl/rslds_emoprox2_K{K}_D{D}_N{N}_{num_subjs}subjs_resample{idx_resample}.pkl",
        "rb",
    ) as f:
        model, _, _, _ = pickle.load(f)
    As = model.dynamics.As
    bs = model.dynamics.bs
    C = model.emissions.Cs[0]
    d = model.emissions.ds[0]
    attractors = np.zeros((K, N))
    for k in range(K):
        attractors[k, :] = C @ (np.linalg.inv(np.eye(D) - As[k]) @ bs[k]) + d
    attractors /= np.expand_dims(np.linalg.norm(attractors, axis=1), axis=1)
    return attractors.T


all_attractors = []
for idx_resample in tqdm(range(1, 1 + num_resamples)):
    all_attractors.append(get_attractors(idx_resample))

all_perms = align_partitions(K, all_attractors)

 20%|██        | 100/500 [03:14<12:58,  1.95s/it]


FileNotFoundError: [Errno 2] No such file or directory: 'pkl/rslds_emoprox2_K6_D10_N85_92subjs_resample101.pkl'

In [None]:
def get_df(idx_resample, perm):
    with open(
        f"pkl/rslds_emoprox2_K{K}_D{D}_N{N}_{num_subjs}subjs_resample{idx_resample}.pkl",
        "rb",
    ) as f:
        model, q, elbos, resampled_subj_list = pickle.load(f)

    model.permute(perm)
    model.permute(np.array([1, 2, 4, 3, 5, 0]))

    resampled_df = []
    for pid in resampled_subj_list:
        resampled_df.append(orig_df[orig_df["pid"] == pid])
    resampled_df = pd.concat(resampled_df).reset_index().drop("index", axis=1)
    df = resampled_df

    df["continuous_states"] = [None] * df.shape[0]
    df["discrete_states"] = [None] * df.shape[0]
    hrflag = 0
    prox_bins = list(np.arange(M // 2) / (M // 2))[1:]
    dir_bins = [0]
    nprox = len(prox_bins) + 1
    ndir = len(dir_bins) + 1
    for idx_row in range(df.shape[0]):
        prox = df.loc[idx_row]["proximity"]
        prox = prox - prox.min()
        prox = prox / prox.max()
        proxd = np.digitize(prox, bins=prox_bins)
        dird = 1 - np.digitize(df.loc[idx_row]["direction"], bins=dir_bins)
        stim_category = (nprox * ndir - 1) * dird + ((-1) ** dird) * proxd
        input = np.roll(np.eye(nprox * ndir)[stim_category], shift=hrflag, axis=0)

        y = df.loc[idx_row]["timeseries"]
        x = q.mean_continuous_states[idx_row]
        z = model.most_likely_states(x, y, input=input)
        df.at[idx_row, "continuous_states"] = x
        df.at[idx_row, "discrete_states"] = z
        prox = df.loc[idx_row]["proximity"]
        prox = prox - prox.min()
        prox = prox / prox.max()
        df.at[idx_row, "proximity"] = prox

    return df

In [None]:
all_dfs = {}
for idx_resample in tqdm(range(1, num_resamples + 1)):
    all_dfs[idx_resample] = get_df(idx_resample, np.array(all_perms[idx_resample - 1]))

100%|██████████| 100/100 [04:23<00:00,  2.64s/it]


In [None]:
def create_regs(df,idx_resample):

    directory_path = f'./luiz_grant_analysis/regressors/rslds_K{K}_D{D}_N{N}_{num_subjs}subjs_emoprox_resample{idx_resample}/'
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)

    subj_list = df.pid.unique()
    for pid in subj_list:
        # print("pid",pid)
        df_subj = df[df.pid==pid]
        states_allruns = []
        for rid in range(6):
            # print("rid",rid)
            if df_subj[df_subj.rid.isin([rid])].empty: continue
            states_run = -1*np.ones_like(df_subj.block_mask.values[0])
            block_masks = [df_subj[(df_subj.rid==rid)&(df_subj.block==i)].block_mask.values[0] for i in [1,2]]
            states = [df_subj[(df_subj.rid==rid)&(df_subj.block==i)].discrete_states.values[0] for i in [1,2]]
            for idx_block in range(2): states_run[block_masks[idx_block]] = states[idx_block]
            
            states_allruns.append(states_run)
        states_allruns = np.hstack(states_allruns)
        
        # fig=plt.figure(figsize=(30,3))    
        # res = np.zeros_like(states_allruns)
        for idx_state in range(K):
            subj_state_regressor = (states_allruns==idx_state).astype('int')
            # res += subj_state_regressor
            np.savetxt(directory_path + f'CON{pid}_state{idx_state}.txt',subj_state_regressor)
        # plt.plot(res)

In [None]:
for idx_resample in tqdm(range(1,1+num_resamples)):
    create_regs(all_dfs[idx_resample],idx_resample)

100%|██████████| 100/100 [02:06<00:00,  1.27s/it]
