In [12]:
from functools import partial

import graspy
from graspy.utils import symmetrize
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import truncnorm, ks_2samp
from scipy.optimize import fmin_slsqp
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder

from twins import load_dataset
from src import generate_truncnorm_sbms, compute_pr_at_k
%matplotlib inline

In [2]:
graphs = load_dataset(modality='fmri', parcellation='desikan_res-2x2x2', preprocess=None, ptr=None)[0]

df = pd.read_csv('../../../../twins/data/raw/unrestricted_jaewonc78_1_20_2019_23_7_58.csv')

In [3]:
gender = []
for sub in graphs.keys():
    gender.append(df[df.Subject == int(sub)]['Gender'].values[0])
    
le = LabelEncoder()
labels = le.fit_transform(gender)
# 0 is female, 1 is male

In [4]:
np.unique(labels, return_counts=True)

(array([0, 1]), array([407, 330]))

In [5]:
male_graphs = []
female_graphs = []

for idx, (sub, graph) in enumerate(graphs.items()):
    if labels[idx] == 1:
        male_graphs.append(graph)
    else:
        female_graphs.append(graph)
        
male_graphs = np.array(male_graphs)
female_graphs = np.array(female_graphs)

male_graphs_mean = male_graphs.mean(axis=0)
female_graphs_mean = female_graphs.mean(axis=0)

## Estiamate mean and variance of truncnorm for each edge

In [6]:
def estimate_params(data):
    def func(p, r, xa, xb):
        return truncnorm.nnlf(p, r)

    def constraint(p, r, xa, xb):
        a, b, loc, scale = p
        return np.array([a*scale + loc - xa, b*scale + loc - xb])

    xa, xb = 0, 1

    loc_guess = data.mean()
    scale_guess = data.std()
    
    a_guess = (xa - loc_guess) / scale_guess
    b_guess = (xb - loc_guess) / scale_guess
    p0 = [a_guess, b_guess, loc_guess, scale_guess]

    a, b, mean, std = fmin_slsqp(func, p0, f_eqcons=constraint, args=(data, xa, xb),
                     iprint=False, iter=1000)
    
    return mean, std

In [7]:
verts = male_graphs.shape[-1]

res = Parallel(-1, verbose=1)(
    delayed(estimate_params)(male_graphs[:, i, j]) for i in range(verts) for j in range(i+1, verts)
)

res2 = Parallel(-1, verbose=1)(
    delayed(estimate_params)(female_graphs[:, i, j]) for i in range(verts) for j in range(i+1, verts)
)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 64 concurrent workers.
[Parallel(n_jobs=-1)]: Done  72 tasks      | elapsed:    1.8s
[Parallel(n_jobs=-1)]: Done 322 tasks      | elapsed:    2.7s
[Parallel(n_jobs=-1)]: Done 672 tasks      | elapsed:    4.4s
[Parallel(n_jobs=-1)]: Done 1122 tasks      | elapsed:    6.4s
[Parallel(n_jobs=-1)]: Done 1672 tasks      | elapsed:    8.4s
[Parallel(n_jobs=-1)]: Done 2415 out of 2415 | elapsed:   16.4s finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 64 concurrent workers.
[Parallel(n_jobs=-1)]: Done  72 tasks      | elapsed:    0.8s
[Parallel(n_jobs=-1)]: Done 522 tasks      | elapsed:    4.1s
[Parallel(n_jobs=-1)]: Done 1045 tasks      | elapsed:    7.2s
[Parallel(n_jobs=-1)]: Done 1503 tasks      | elapsed:    9.6s
[Parallel(n_jobs=-1)]: Done 2056 tasks      | elapsed:   12.8s
[Parallel(n_jobs=-1)]: Done 2415 out of 2415 | elapsed:   18.4s finished


## Compute empirical trustworthiness using the estimated parameters as inputs

In [10]:
def compute_statistic(test, pop1, pop2):
    if test.__name__ == "ttest_ind":
        test_statistics, pvals = ttest_ind(pop1, pop2, axis=0)
        np.nan_to_num(test_statistics, copy=False)
        np.nan_to_num(pvals, copy=False)
    else:  # for other tests, do by edge
        n = pop1.shape[-1]
        test_statistics = np.zeros((n, n))
        pvals = np.zeros((n, n))

        for i in range(n):
            for j in range(i + 1, n):
                x_ij = pop1[:, i, j]
                y_ij = pop2[:, i, j]

                if test.__name__ == "multiscale_graphcorr":
                    tmp, pval, _ = test(x_ij, y_ij, is_twosamp=True, reps=1)
                else:
                    tmp, pval = test(x_ij, y_ij)

                test_statistics[i, j] = tmp
                pvals[i, j] = pval

        test_statistics = symmetrize(test_statistics, method="triu")
        pvals = symmetrize(pvals, method="triu")

    return test_statistics, pvals

def run_experiment(mean_1, var_1, mean_2, var_2, 
                   samp_1=330, samp_2=407,
                   test=ks_2samp, 
                   block_1=5, block_2=15,
                   a=0, b=1, reps=100):

    if mean_1 is None or mean_2 is None:
        return [m, mean_1, mean_2, var_1, var_2, 0, 0]
    
       
    pop1, _, _ = generate_truncnorm_sbms(samp_1, block_1, block_2, mean_1, mean_2, var_1, var_2, a=a, b=b)
    _, pop2, _ = generate_truncnorm_sbms(samp_2, block_1, block_2, mean_1, mean_2, var_1, var_2, a=a, b=b)

    precisions, recalls = np.zeros((2, reps))
    for i in range(reps):
        test_statistics, pvalues = compute_statistic(test, pop1, pop2)
        precision, recall = compute_pr_at_k(
            k=[10], true_labels=true_labels, pvalues=pvalues
        )
        precisions[i] = precision
        recalls[i] = recall
        
    return [m, mean_1, mean_2, var_1, var_2, *precisions.mean(), *recalls.mean()]

In [None]:
res_arr = np.hstack([res, res2])
args = [
    dict(mean_1=a, var_1=b**2, mean_2=c, var_2=d**2) 
    for (a, b, c, d) in res_arr
]
res = Parallel(n_jobs=-1, verbose=5)(delayed(run_experiment)(**arg) for arg in args)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 64 concurrent workers.
