In [None]:
import pickle
import numpy as np
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import random
import scipy
from ite.cost.x_factory import co_factory
from tqdm import tqdm
import dill
import ast
import seaborn as sns
from scipy.spatial.distance import cosine
import scipy.spatial as sp
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.preprocessing import StandardScaler
from statannotations.Annotator import Annotator

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
def shuffle_data(data, shuffle_indices):
    # Create a copy of the original data to avoid modifying the input in place.
    shuffled_data = data.copy()

    for indices in shuffle_indices:
        # Generate a random permutation for the indices.
        permutation = np.random.permutation(shuffled_data.shape[0])

        # Apply the permutation to the specified indices.
        for idx in indices:
            shuffled_data[:, idx] = shuffled_data[permutation, idx]

    return shuffled_data

def Streitberg_4(X, div_func):
    n = X.shape[0]

    X_fully_shuffled = shuffle_data(X, [[0], [1], [2]])
    p1234 = div_func(X, X_fully_shuffled)
    p1p234 = div_func(X[:,[1, 2, 3]], X_fully_shuffled[:,[1, 2, 3]])
    p2p134 = div_func(X[:,[0, 2, 3]], X_fully_shuffled[:,[0, 2, 3]])
    p3p124 = div_func(X[:,[0, 1, 3]], X_fully_shuffled[:,[0, 1, 3]])
    p4p123 = div_func(X[:,[0, 1, 2]], X_fully_shuffled[:,[0, 1, 2]])
    p12p34 = div_func(shuffle_data(X, [[0, 1]]), X_fully_shuffled)
    p13p24 = div_func(shuffle_data(X, [[0, 2]]), X_fully_shuffled)
    p14p23 = div_func(shuffle_data(X, [[0, 3]]), X_fully_shuffled)
    p1p2p34 = div_func(X[:,[2, 3]], X_fully_shuffled[:,[2, 3]])
    p1p3p24 = div_func(X[:,[1, 3]], X_fully_shuffled[:,[1, 3]])
    p1p4p23 = div_func(X[:,[1, 2]], X_fully_shuffled[:,[1, 2]])
    p2p3p14 = div_func(X[:,[0, 3]], X_fully_shuffled[:,[0, 3]])
    p2p4p13 = div_func(X[:,[0, 2]], X_fully_shuffled[:,[0, 2]])
    p3p4p12 = div_func(X[:,[0, 1]], X_fully_shuffled[:,[0, 1]])

    streitberg_4 = (p1234 - (p1p234 + p2p134 + p3p124 + p4p123) - (p12p34 + p13p24 + p14p23)
                    + 2 * (p1p2p34 + p1p3p24 + p1p4p23 + p2p3p14 + p2p4p13 + p3p4p12))

    return streitberg_4

def Streitberg_3(X, div_func):
    n = X.shape[0]

    X_fully_shuffled = shuffle_data(X, [[0], [1]])
    p123 = div_func(X, X_fully_shuffled)
    p1p23 = div_func(X[:,[1, 2]], X_fully_shuffled[:,[1, 2]])
    p2p13 = div_func(X[:,[0, 2]], X_fully_shuffled[:,[0, 2]])
    p3p12 = div_func(X[:,[0, 1]], X_fully_shuffled[:,[0, 1]])

    streitberg_3 = p123 - (p1p23 + p2p13 + p3p12)

    return streitberg_3

def Streitberg_2(X, div_func):
    n = X.shape[0]

    X_fully_shuffled = shuffle_data(X, [[0]])
    p12 = div_func(X, X_fully_shuffled)
    
    return p12

In [None]:
with open('data/movement/rate_data_20ms.pkl', 'rb') as f:
    movement = pickle.load(f)

In [None]:
n_trials = defaultdict(lambda: defaultdict(int))
for day in range(44):
    for direction in ['DownLeft', 'Left', 'UpLeft', 'Up', 'UpRight', 'Right', 'DownRight']:
        n_trials[day][direction] = np.shape(movement[day][direction])[0]

In [None]:
# most trials
n_trials[3]

In [None]:
def sampledata(my_list, samples=500, k=2, seed=331):
    unique_sets = set()

    local_random = random.Random()
    local_random.seed(seed)

    total_combinations = scipy.special.comb(len(my_list), k)
    max_unique_sets = min(samples, total_combinations)

    while len(unique_sets) < max_unique_sets:
        random_set = frozenset(local_random.sample(my_list, k=k))
        unique_sets.add(random_set)

    unique_sets = [list(s) for s in unique_sets]
    return unique_sets

In [None]:
cost_name = 'BDTsallis_KnnK'  # dim >= 1
co = co_factory(cost_name, mult=True, alpha=0.5, k=20)  # cost object 

In [None]:
directions = ['DownLeft', 'Left', 'UpLeft', 'Up', 'UpRight', 'Right', 'DownRight']

# Compare Prep with Action

In [None]:
def find_hoi(data):
    hoi_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
    for direction in tqdm(['DownLeft', 'Left', 'UpLeft', 'Up', 'UpRight', 'Right', 'DownRight']):
        for time in range(60):
            X = data[direction][:,:,time]
            ### 2 way hoi ###
            unique_sets = sampledata(range(12), samples=500, k=2, seed=331)
            for sample_ind in unique_sets:
                info = Streitberg_2(X[:, sample_ind], co.estimation)
                hoi_dict[direction][time][tuple(sample_ind)] = info
            ### 3 way hoi ###
            unique_sets = sampledata(range(12), samples=500, k=3, seed=331)
            for sample_ind in unique_sets:
                info = Streitberg_3(X[:, sample_ind], co.estimation)
                hoi_dict[direction][time][tuple(sample_ind)] = info
            ### 4 way hoi ###
            unique_sets = sampledata(range(12), samples=500, k=4, seed=331)
            for sample_ind in unique_sets:
                info = Streitberg_4(X[:, sample_ind], co.estimation)
                hoi_dict[direction][time][tuple(sample_ind)] = info
    return hoi_dict

In [None]:
hoi_dict = find_hoi(movement[17])

In [None]:
def sum_hoi(hoi_dict):# Initialize dictionaries to store the sums for nodes of different lengths
    sums_2_nodes = {}
    sums_3_nodes = {}
    sums_4_nodes = {}

    # Iterate over the dictionary to compute the sums
    for direction, time_dict in hoi_dict.items():
        for time, node_info in time_dict.items():
            for nodes, info in node_info.items():
                node_length = len(nodes)
                abs_info = abs(info)
                if node_length == 2:
                    sums_2_nodes.setdefault(direction, {}).setdefault(time, 0)
                    sums_2_nodes[direction][time] += abs_info
                elif node_length == 3:
                    sums_3_nodes.setdefault(direction, {}).setdefault(time, 0)
                    sums_3_nodes[direction][time] += abs_info
                elif node_length == 4:
                    sums_4_nodes.setdefault(direction, {}).setdefault(time, 0)
                    sums_4_nodes[direction][time] += abs_info
    return sums_2_nodes, sums_3_nodes, sums_4_nodes

In [None]:
sums_2, sums_3, sums_4 = sum_hoi(hoi_dict)

In [None]:
def normalise_sum(sums_2_nodes, length):
    for direction, time_dict in sums_2_nodes.items():
        for time, value in time_dict.items():
            sums_2_nodes[direction][time] /= length
    return sums_2_nodes

In [None]:
sums_2 = normalise_sum(sums_2, 66)
sums_3 = normalise_sum(sums_3, 220)
sums_4 = normalise_sum(sums_4, 495)

# classification using all combinations instead of sum

In [None]:
def create_df(hoi_dict):
    # Create an empty DataFrame
    df = pd.DataFrame()

    # Iterate over each direction and time
    for direction, time_data in hoi_dict.items():
        for time, comb_data in time_data.items():
            # Create a temporary DataFrame from the combination data
            temp_df = pd.DataFrame(list(comb_data.values()), index=list(comb_data.keys()), columns=[f'{direction}_{time}'])
            temp_df.index.names = ['Comb']
            # Transpose the temporary DataFrame and reset the index
            temp_df = temp_df.T.reset_index(drop=True)
            # Concatenate this temporary DataFrame with the main DataFrame
            df = pd.concat([df, temp_df], axis=0)

    label = [0]*25 + [1]*35
    label = label*7
    df['label'] = label

    label2 = [0]*60 + [1]*60 +[2]*60 +[3]*60 +[4]*60 +[5]*60 +[6]*60
    df['direction'] = label2
    
    return df

# Find all data

In [None]:
hoi_dict_full = {}
for day in [1, 2, 3, 5, 11, 17, 19, 24, 26, 27, 28, 31]:
     hoi_dict_full[day] = find_hoi(movement[day])

In [None]:
with open('data/movement/hoi_12day.pkl', 'rb') as f:
    hoi_dict_full = dill.load(f)

In [None]:
def clf_allnodes_label(df, label, n_components):
    X = df.iloc[:, :-2]
    X = StandardScaler().fit_transform(X)
    y = df[label]
    acc = {}
    
    X_2 = PCA(n_components=n_components).fit_transform(X[:, :66])
    X_train, X_test, y_train, y_test = train_test_split(X_2, y, test_size=0.15, random_state=986)
    clf_2way = LogisticRegression(random_state=10).fit(X_train, y_train)
    acc['2'] = accuracy_score(clf_2way.predict(X_test), y_test)
    # print('2: train:', clf_2way.score(X_train, y_train), 'test:', accuracy_score(clf_2way.predict(X_test), y_test))
    
    X_23 = PCA(n_components=n_components).fit_transform(X[:, :286])
    X_train, X_test, y_train, y_test = train_test_split(X_23, y, test_size=0.15, random_state=986)
    clf_23way = LogisticRegression(random_state=10).fit(X_train, y_train)
    acc['23'] = accuracy_score(clf_23way.predict(X_test), y_test)
    # print('23: train:', clf_23way.score(X_train, y_train),'test:', accuracy_score(clf_23way.predict(X_test), y_test))
    
    X_234 = PCA(n_components=n_components).fit_transform(X)
    X_train, X_test, y_train, y_test = train_test_split(X_234, y, test_size=0.15, random_state=986)
    clf_234way = LogisticRegression(random_state=10).fit(X_train, y_train)
    acc['234'] = accuracy_score(clf_234way.predict(X_test), y_test)
    # print('234: train:', clf_234way.score(X_train, y_train),'test:', accuracy_score(clf_234way.predict(X_test), y_test))
    
    X_3 = PCA(n_components=n_components).fit_transform(X[:, 66:286])
    X_train, X_test, y_train, y_test = train_test_split(X_3, y, test_size=0.20, random_state=986)
    clf_3way = LogisticRegression(random_state=10, penalty='l2', solver='liblinear').fit(X_train, y_train)
    acc['3'] = accuracy_score(clf_3way.predict(X_test), y_test)
    # print('3: train:', clf_3way.score(X_train, y_train),'test:', accuracy_score(clf_3way.predict(X_test), y_test))

    X_4 = PCA(n_components=n_components).fit_transform(X[:, 286:])
    X_train, X_test, y_train, y_test = train_test_split(X_4, y, test_size=0.20, random_state=986)
    clf_4way = LogisticRegression(random_state=10, penalty='l2', solver='liblinear').fit(X_train, y_train)
    acc['4'] = accuracy_score(clf_4way.predict(X_test), y_test)
    # print('4: train:', clf_4way.score(X_train, y_train),'test:', accuracy_score(clf_4way.predict(X_test), y_test)) 
    return acc

In [None]:
acc_stage_var = {}
for day in [1, 2, 3, 5, 11, 17, 19, 24, 26, 27, 28, 31]:
    df1 = create_df(hoi_dict_full[day])
    # print('day', day)
    acc_stage_var[day] = clf_allnodes_label(df1, 'label', 66)