In [1]:
from typing import Tuple

import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.formula.api import ols
import pandas as pd
import pickle

from param import *
from dataloader import UniformSegmentDataset, BaseDataset
from util import segment, get_place_cell


output_dir = ParamDir().OUTPUT_ROOT/ "data_exploration/"
data_list = ParamDir().data_list

## place cell ratio

In [2]:
def get_place_cell_ratio(data_name):
    # print(data_name)
    with open(output_dir/data_name/"MI_all.pickle","rb") as f:
        results_all = pickle.load(f)

    pc_beh_id, pc_event_id = get_place_cell(results_all, 0.001)

    return len(pc_beh_id)/len(results_all['original MI'])

In [3]:
anova_data = []

KO_names = [] # mice that passed the test
KO_data_list = [data_dir for data_dir in ParamDir().data_list if "KO" in str(data_dir).split('/')[-1]]
for data_dir in KO_data_list:
    data_name = str(data_dir).split('/')[-1]
    ratio = get_place_cell_ratio(data_name)
    if data_name not in KO_names:
        anova_data.append([ratio, "KO", "didnot pass"])
    else:
        anova_data.append([ratio, "KO", "pass"])

WT_names = ["M45_042718_OF", "M46_042718_OF", "081117 OF B6J M27-n1"] # mice that didn't pass the test
WT_data_list = [data_dir for data_dir in ParamDir().data_list if "KO" not in str(data_dir).split('/')[-1]]
for data_dir in WT_data_list:
    data_name = str(data_dir).split('/')[-1]
    ratio = get_place_cell_ratio(data_name)
    if data_name in WT_names:
        anova_data.append([ratio, "WT", "didnot pass"])
    else:
        anova_data.append([ratio, "WT", "pass"])

df = pd.DataFrame(anova_data, columns=["place_cell_ratio", "mouse_type", "permutation_test"])

model = ols('place_cell_ratio ~ C(mouse_type) + C(permutation_test) + C(mouse_type):C(permutation_test)', data=df).fit()
sm.stats.anova_lm(model, typ=2)

Unnamed: 0,sum_sq,df,F,PR(>F)
C(mouse_type),0.0624,1.0,2.374496,0.15772
C(permutation_test),0.001642,1.0,0.0625,0.8082
C(mouse_type):C(permutation_test),0.001218,1.0,0.046359,0.834324
Residual,0.236512,9.0,,


## Median time of staying in one place

In [4]:
anova_data = []

KO_data_list = [data_dir for data_dir in ParamDir().data_list if "KO" in  str(data_dir).split('/')[-1]]

segment_len_all = []
for i, data_dir in enumerate(KO_data_list):
    segment_len = []
    data_name = str(data_dir).split('/')[-1]
    dataset = UniformSegmentDataset(data_dir, ParamData().mobility, ParamData().shuffle, ParamData().random_state)
    (X_train, y_train), (X_test, y_test) = dataset.load_all_data(ParamData().window_size, ParamData().K, ParamData().train_ratio)

    segment_ind = segment(dataset.y_train_base)
    segment_len.append(round((segment_ind[0]+1)/3, 2))
    for i in range(1, len(segment_ind)):
        segment_len.append(round((segment_ind[i]-segment_ind[i-1])/3, 2))
    segment_len_all.append([segment_len, data_name])
    

for item in segment_len_all:
    if item[1] in []:
        anova_data.append([np.median(item[0]), "KO", "pass"])
    else:
        anova_data.append([np.median(item[0]), "KO", "didnot pass"])

WT_names = ["M45_042718_OF", "M46_042718_OF", "081117 OF B6J M27-n1"] # mice that didn't pass the test
WT_data_list = [data_dir for data_dir in ParamDir().data_list if "KO" not in  str(data_dir).split('/')[-1]]

segment_len_all = []
for i, data_dir in enumerate(WT_data_list):
    segment_len = []
    data_name = str(data_dir).split('/')[-1]
    dataset = UniformSegmentDataset(data_dir, ParamData().mobility, ParamData().shuffle, ParamData().random_state)
    (X_train, y_train), (X_test, y_test) = dataset.load_all_data(ParamData().window_size, ParamData().K, ParamData().train_ratio)

    segment_ind = segment(dataset.y_train_base)
    segment_len.append(round((segment_ind[0]+1)/3, 2))
    for i in range(1, len(segment_ind)):
        segment_len.append(round((segment_ind[i]-segment_ind[i-1])/3, 2))
    segment_len_all.append([segment_len, data_name])
    

for item in segment_len_all:
    if item[1] not in WT_names:
        anova_data.append([np.median(item[0]), "WT", "didnot pass"])
    else:
        anova_data.append([np.median(item[0]), "WT", "pass"])

df = pd.DataFrame(anova_data, columns=["median_time_in_one_position", "mouse_type", "permutation_test"])

model = ols('median_time_in_one_position ~ C(mouse_type) + C(permutation_test) + C(mouse_type):C(permutation_test)', data=df).fit()
sm.stats.anova_lm(model, typ=2)

Unnamed: 0,sum_sq,df,F,PR(>F)
C(mouse_type),29.68295,1.0,32.87142,0.000282
C(permutation_test),-6.380083e-14,1.0,-7.065416e-14,1.0
C(mouse_type):C(permutation_test),1.449657,1.0,1.605376,0.236946
Residual,8.127015,9.0,,
