In [65]:
%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np

from nibabel.loadsave import ImageFileError

from nilearn.image import load_img
from nilearn import datasets
from nilearn.maskers import NiftiMasker, NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure

from sklearn.preprocessing import StandardScaler

from nilearn import plotting

from itertools import combinations, product
from tqdm.notebook import tqdm

from scipy.stats import spearmanr, entropy

from scipy.ndimage import gaussian_filter1d, uniform_filter1d, maximum_filter1d, minimum_filter1d

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from catboost import CatBoostClassifier

from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from scipy.linalg import logm

import pandas as pd

import ripser

In [66]:
T_raw = np.load("T_subjects_1,5_nsd_vs_rest.npy")[:, 1]
X_raw = np.load("X_subjects_1,5_nsd_vs_rest.npy")[:, 1]
y = np.load("y_subjects_1,5_nsd_vs_rest.npy")[1]
T_raw.shape, X_raw.shape, y.shape

((956, 226, 116), (956, 116, 116), (956,))

In [119]:
# visual cortex
visual = list(range(42, 53+1))
# sensorimotor network, somatomotor network (SMN)
smn = [0, 1, 6, 7, 18, 19, 56, 57, 62, 63, 68, 69]
# dorsal attention network (DAN), dorsal frontoparietal network (D-FPN)
dan = [28, 29, 30, 31, 84, 85]
# ventral attention network (VAN), ventral frontoparietal network (VFN), ventral attention system (VAS)
van = [32, 33, 34, 35, 36, 37, 52, 53, 62, 63, 64, 65]
# frontoparietal network (FPN), central executive network (CEN), lateral frontoparietal network (L-FPN)
fpn = [6, 7, 10, 11, 12, 13, 60, 61, 64, 65]
# limbic system, paleomammalian cortex
limbic = list(range(30, 39+1)) + list(range(80, 87+1))
# default mode network (DMN), default network, default state network, medial frontoparietal network (M-FPN)
dmn = [22, 23, 34, 35, 36, 37, 38, 39, 64, 65, 66, 67]
dmn = van

brain_networks = [visual, smn, dan, van, fpn, limbic, dmn]
brain_network_names = ["visual cortex (12 regions)",
                       "sensorimotor network (SMN) (12 regions)",
                       "dorsal attention network (DAN) (6 regions)",
                       "ventral attention network (VAN) (12 regions)",
                       "frontoparietal network (FPN) (10 regions)",
                       "limbic system (18 regions)",
                       "default mode network (DMN) (12 regions)"
                      ]
brain_short_network_names = ["visual (12)", "SMN (12)", "DAN (6)", "VAN (12)", "FPN (10)", "limbic (18)", "DMN (12)"]
brain_networks_with_names = list(zip(brain_networks, brain_network_names))
brain_networks_with_names

[([42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
  'visual cortex (12 regions)'),
 ([0, 1, 6, 7, 18, 19, 56, 57, 62, 63, 68, 69],
  'sensorimotor network (SMN) (12 regions)'),
 ([28, 29, 30, 31, 84, 85], 'dorsal attention network (DAN) (6 regions)'),
 ([32, 33, 34, 35, 36, 37, 52, 53, 62, 63, 64, 65],
  'ventral attention network (VAN) (12 regions)'),
 ([6, 7, 10, 11, 12, 13, 60, 61, 64, 65],
  'frontoparietal network (FPN) (10 regions)'),
 ([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 80, 81, 82, 83, 84, 85, 86, 87],
  'limbic system (18 regions)'),
 ([32, 33, 34, 35, 36, 37, 52, 53, 62, 63, 64, 65],
  'default mode network (DMN) (12 regions)')]

In [120]:
def gaussI(S):
    return -0.5 * np.log(np.linalg.det(S))

In [121]:
def apply_brain_network(T_raw, X_raw, brain_network):
    bn_idx = np.ix_(np.arange(0, X_raw.shape[0]), brain_network, brain_network)
    T_bn = T_raw[:, :, brain_network]
    X_bn = X_raw[bn_idx]
    return T_bn, X_bn

In [122]:
# dionysus

# def get_simplices_from_corrmat(T_bn, X_bn, k=2, modif=abs, check=lambda w: True):
#     # to prune lightweight edges use check=lambda w: w > threshold
#     # to prune negative correlations use modif=id, check=lambda w: w > 0
#     n = X_bn.shape[0]
#     simplices = {}
#     for i in range(n):
#         simplices[(i,)] = 0
#         for j in range(i):
#             w = modif(X_bn[j, i])
#             if check(w):
#                 simplices[(j, i)] = w
#     for dim in range(2, k + 1):
#         ...
    
    
    
# def get_simplices_functor_with_params(func, *args, **kwargs):
#     return (lambda T_bn, X_bn: func(*args, **kwargs))

# def get_persistence_diagram(simplices, reverse=True):
#     filtration = dionysus.Filtration()
#     for vertices, time in simplices:
#         filtration.append(dionysus.Simplex(vertices, time))
#     filtration.sort(reverse=reverse)

#     diagram = dionysus.init_diagrams(dionysus.homology_persistence(filtration), filtration)
#     return diagram
    
# def get_persistence_diagram_functor_with_params(func, *args, **kwargs):
#     return (lambda simplices: func(*args, **kwargs))

In [123]:
def get_persistence_diagram_from_corrmat(T_bn, X_bn, dim=2, threshold=-2):
    return ripser.ripser(1 - np.abs(X_bn), thresh=1 - threshold, maxdim=dim, distance_matrix=True)['dgms']

In [124]:
T_dmn, X_dmn = apply_brain_network(T_raw, X_raw, dmn)
persistence_diagram = get_persistence_diagram_from_corrmat(T_dmn[0], X_dmn[0])
persistence_diagram

[array([[0.        , 0.12529169],
        [0.        , 0.22025172],
        [0.        , 0.26729095],
        [0.        , 0.28290591],
        [0.        , 0.35668072],
        [0.        , 0.3819766 ],
        [0.        , 0.40466523],
        [0.        , 0.44850785],
        [0.        , 0.45469671],
        [0.        , 0.4919467 ],
        [0.        , 0.51505613],
        [0.        ,        inf]]),
 array([[0.40878159, 0.43009564]]),
 array([], shape=(0, 2), dtype=float64)]

In [125]:
l0, l1, l2 = [], [], []
for i in range(len(T_dmn)):
    persistence_diagram = get_persistence_diagram_from_corrmat(T_dmn[i], X_dmn[i])
    l0.append(len(persistence_diagram[0]))
    l1.append(len(persistence_diagram[1]))
    l2.append(len(persistence_diagram[2]))

In [126]:
from collections import Counter

print(sorted(Counter(l0).items()))
print(sorted(Counter(l1).items()))
print(sorted(Counter(l2).items()))

[(12, 956)]
[(0, 139), (1, 340), (2, 298), (3, 127), (4, 47), (5, 5)]
[(0, 829), (1, 115), (2, 10), (3, 2)]


In [127]:
def conv_pd(diagrams):
    pd = np.zeros((0, 3))

    for k, diagram_k in enumerate(diagrams):
        diagram_k = diagram_k[~np.isinf(diagram_k).any(axis=1)] # filter infs  
        diagram_k = np.concatenate((diagram_k, k * np.ones((diagram_k.shape[0], 1))), axis=1)
        pd = np.concatenate((pd, diagram_k))

    return pd

X = []

for x_pc in tqdm(X_dmn):
    diagram = conv_pd(get_persistence_diagram_from_corrmat(_, x_pc))
    X.append(diagram)

100%|███████████████████████████████████████████████████████████████████████████████| 956/956 [00:00<00:00, 2063.69it/s]


In [128]:
from __future__ import print_function

%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import accuracy_score

import warnings
warnings.filterwarnings('ignore')

from ripser import lower_star_img
from ripser import Rips

import persim

from scipy.ndimage import gaussian_filter

from sklearn.datasets import make_circles
from sklearn.manifold import MDS

import pickle
from tqdm import tqdm

import torch
from torch.nn import Linear
from torch.nn.functional import relu

from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

In [129]:
class Orbit2kDataset(Dataset):
    
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
def collate_fn(data):
    
    tmp_pd, _ = data[0]
    
    n_batch = len(data)
    n_features_pd = tmp_pd.shape[1]
    n_points_pd = max(len(pd) for pd, _ in data)
    inputs_pd = np.zeros((n_batch, n_points_pd, n_features_pd), dtype=float)
    labels = np.zeros(len(data))
    
    for i, (pd, label) in enumerate(data):
        inputs_pd[i][:len(pd)] = pd
        labels[i] = label
    
    return torch.Tensor(inputs_pd), torch.Tensor(labels).long()

In [130]:
class DeepSets(torch.nn.Module):
    def __init__(self, n_in, n_hidden_enc, n_out_enc, n_hidden_dec=16, n_out_dec=2):
        super(DeepSets, self).__init__()
        self.encoder = Encoder(n_in, n_hidden_enc, n_out_enc)
        self.decoder = MLP(n_out_enc, n_hidden_dec, n_out_dec)
        
    def forward(self, X):
        z_enc = self.encoder(X)
        z = self.decoder(z_enc)
        return z
    
class MLP(torch.nn.Module):
    def __init__(self, n_in, n_hidden, n_out):
        super(MLP, self).__init__()
        self.linear1 = Linear(n_in, n_hidden)
        self.linear2 = Linear(n_hidden, n_out)
        self.bn = torch.nn.BatchNorm1d(n_out)
        
    def forward(self, X):
        X = relu(self.linear1(X))
        X = self.linear2(X)
        return X
    
class Encoder(torch.nn.Module):
    def __init__(self, n_in, n_hidden, n_out):
        super(Encoder, self).__init__()
        self.mlp = MLP(n_in, n_hidden, n_out)
        
    def forward(self, X):
        X = self.mlp(X)
        x = X.mean(dim=1) # aggregation
        return x

In [131]:
%%time
n_repeats = 3
n_epochs = 100
batch_size = 32
lr = 0.0005

n_train, n_test = 1600, 400

history = np.zeros((n_repeats, n_epochs, 3))
criterion = CrossEntropyLoss()

dataset = Orbit2kDataset(X, y)

ret = [0] * n_repeats
for repeat_idx in range(n_repeats):
    
    # data init
    dataset_train, dataset_test = random_split(dataset, [700, 256])
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    dataloader_test =  DataLoader(dataset_test, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    # model init
    model = DeepSets(n_in=3, n_hidden_enc=16, n_out_enc=8)
    optimizer = Adam(model.parameters(), lr=lr)
    
    #print("{:3} {:6} {:6} {:6}".format(repeat_idx, "Loss", "Train", "Test"))
    
    mx = 0
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for batch in dataloader_train:
            loss_batch = criterion(model(batch[0]), batch[1])
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for batch in dataloader_train:
            y_hat = model(batch[0]).argmax(dim=1)
            correct += int((y_hat == batch[1]).sum())
        accuracy_train = correct / len(dataloader_train.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train

        correct = 0
        for batch in dataloader_test:
            y_hat = model(batch[0]).argmax(dim=1)
            correct += int((y_hat == batch[1]).sum())
        accuracy_test = correct / len(dataloader_test.dataset)
        history[repeat_idx,epoch_idx,2] = accuracy_test
        
        mx = max(mx, accuracy_test)
        
        print("{:3} {:.4f} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train, accuracy_test))
    print(mx)
    ret[repeat_idx] = mx
    print("\r")
ret = np.array(ret)
print(f"{ret.mean():.5f} ± {ret.std():.5f}")

  0 0.6936 0.5157 0.4648
  1 0.6932 0.5157 0.4648
  2 0.6929 0.5157 0.4648
  3 0.6931 0.5157 0.4648
  4 0.6927 0.5157 0.4648
  5 0.6927 0.5157 0.4648
  6 0.6927 0.5157 0.4648
  7 0.6926 0.5157 0.4648
  8 0.6925 0.5157 0.4648
  9 0.6924 0.5157 0.4648
 10 0.6920 0.5157 0.4648
 11 0.6919 0.5157 0.4648
 12 0.6917 0.5157 0.4648
 13 0.6915 0.5157 0.4648
 14 0.6913 0.5157 0.4648
 15 0.6910 0.5157 0.4648
 16 0.6909 0.5157 0.4648
 17 0.6907 0.5157 0.4648
 18 0.6901 0.5157 0.4648
 19 0.6901 0.5157 0.4648
 20 0.6894 0.5157 0.4648
 21 0.6891 0.5157 0.4648
 22 0.6885 0.5229 0.4648
 23 0.6878 0.5357 0.4922
 24 0.6872 0.5386 0.4883
 25 0.6871 0.5514 0.5000
 26 0.6861 0.5543 0.5391
 27 0.6851 0.5543 0.5156
 28 0.6851 0.5457 0.5391
 29 0.6838 0.5514 0.5547
 30 0.6841 0.5643 0.5742
 31 0.6831 0.5686 0.5938
 32 0.6829 0.5614 0.5703
 33 0.6826 0.5843 0.6055
 34 0.6819 0.5829 0.5977
 35 0.6823 0.5743 0.6055
 36 0.6813 0.5843 0.5938
 37 0.6810 0.5843 0.6055
 38 0.6804 0.5757 0.5898
 39 0.6805 0.5771 0.6211


In [132]:
def full_test(T_raw, X_raw, y, brain_networks_with_names, get_simplices):
    for brain_network, brain_network_name in brain_networks_with_names:
        T_bn, X_bn = apply_brain_network(T_raw, X_raw, brain_network)
        persistence_diagram = get_persistence_diagram(T_bn, X_bn)
        