In [3]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import xgi
from IPython.display import display
from matplotlib.pyplot import cm
xgi.__version__
import networkx as nx
import numpy as np
import pickle
from scipy.io import loadmat
import h5py
from scipy.ndimage import gaussian_filter
from scipy.signal import correlate
from tqdm import tqdm
import itertools

In [4]:

#this analysis produces plots with various windows (Dec 29: 5,10,15,20 done but need up to 50 ms)

ALL_FILES = [
    "2950_spike_mat_or_rand",
    "2953_spike_mat_or_rand",
    "2957_spike_mat_or_rand",
    "5116_spike_mat_or_rand",
    # Add more files if necessary
]

def run_analysis(file_name, use_rand_data, lag_window=None, X=0.5):
    try:
        f = loadmat(f"/Users/pispisenka/Downloads/organoid data/{file_name}.mat")
        if use_rand_data:
            ar = np.array(f['t_spk_mat_rand']).T.astype(np.float32)
        else:
            ar = np.array(f['t_spk_mat']).T.astype(np.float32)
            ar = ar[:, :3 * 60 * 1000]  # 3 minutes only
    except:
        f = h5py.File(f"/Users/pispisenka/Downloads/organoid data/{file_name}.mat")
        if use_rand_data:
            ar = np.array(f['t_spk_mat_rand']).astype(np.float32)
        else:
            ar = np.array(f['t_spk_mat']).astype(np.float32)
            ar = ar[:, :3 * 60 * 1000]  # 3 minutes only

    br = gaussian_filter(ar, sigma=100.0, axes=-1)
    if lag_window is None:
        br = br[:, ::10]
    N = br.shape[1]
    var = np.einsum('it,it->i', br, br)

    if lag_window is None:
        corr = np.zeros((br.shape[0], br.shape[0]))
        corr_idx = []
        corr_idx = np.zeros((br.shape[0], br.shape[0]), dtype=np.int32)
        for i in range(br.shape[0]):
            for j in range(br.shape[0]):
                c = correlate(br[i], br[j])
                corr[i, j] = np.max(c) / np.sqrt(var[i] * var[j])
                corr_idx[i, j] = np.argmax(c) - N + 1
    else:
        shifts = np.stack([np.roll(br, shift, 1) for shift in range(-lag_window, lag_window + 1)])
        corr = np.einsum('it,sjt->ijs', br, shifts)
        corr_idx = np.argmax(corr, -1) - lag_window
        corr = np.max(corr, -1) / np.maximum(np.sqrt(var * var[None]), 1e-8)
    pairs = [[] for _ in range(br.shape[0])]
    for i in range(br.shape[0]):
        for j in range(br.shape[0]):
            pairs[i].append(br[i] * np.roll(br[j], corr_idx[i, j]))

    varp = np.einsum('ijt,ijt->ij', pairs, pairs)

    # # Three-way correlation calculation
    # corr3 = np.zeros((br.shape[0], br.shape[0], br.shape[0]))
    # if lag_window is None:
    #     for i, j in list(itertools.product(range(br.shape[0]), range(br.shape[0]))):
    #         for k in range(br.shape[0]):
    #             corr3[i, j, k] = np.max(correlate(pairs[i][j], br[k])) / np.maximum(np.sqrt(varp[i, j] * var[k]), 1e-8)
    # else:
    #     shifts = np.stack([np.roll(br, shift, 1) for shift in range(-lag_window, lag_window + 1)])
    #     corr3 = np.einsum('ijt,skt->ijks', pairs, shifts)
    #     corr3 = np.max(corr3, -1) / np.maximum(np.sqrt(varp[:, :, None] * var[None, None]), 1e-8)


  # Processing for hist2d graph
    mask = np.sum(ar, -1) > 10
    N = mask.size  # Correct size for the mask array
    hist2d_data = []

    for i in range(N):
        if not mask[i]:
            continue
        for j in range(N):
            if not mask[j]:
                continue
            for k in range(N):
                if not mask[k]:
                    continue
                if i < j and i != k and j != k:
                    hist2d_data.append((max(corr[i, j], corr[i, k], corr[j, k]), corr3[i, j, k]))

    hist2d_data = np.array(hist2d_data)
    corr_idx = np.array(corr_idx)

    # print("corr_idx v 2", corr_idx)
    # Network construction and clustering coefficients calculation
    THRESHOLD = np.quantile(corr[np.triu_indices_from(corr, 1)], X)
    a = np.where(corr > THRESHOLD, corr - np.eye(corr.shape[0]), 0.0)
    C = np.einsum('ij,ik,jk->i', a, a, a) / np.maximum(np.einsum('ij,ik->i', a, a), 1.0)
    
    # Constructing the graph
    G = nx.Graph()
    G.add_nodes_from(range(br.shape[0]))
    for i in range(br.shape[0]):
        for j in range(br.shape[0]):
            if i != j and corr[i, j] > THRESHOLD:
                G.add_edge(i, j)

    # Constructing the random graph
    a_rnd = np.random.permutation(a.flatten()).reshape(a.shape)
    G_rnd = nx.Graph()
    G_rnd.add_nodes_from(range(br.shape[0]))
    for i in range(br.shape[0]):
        for j in range(i + 1, br.shape[0]):
            if a_rnd[i, j] > 0:
                G_rnd.add_edge(i, j)

    # Additional network analysis can be added here

    return {
        "C": C,
        "corr": corr,
        # "corr3": corr3,
        "G": G,
        "G_rnd": G_rnd,
        "hist2d_data": hist2d_data, 
        "corr_idx_data": corr_idx
        # Add more return values if necessary
    }

from multiprocessing import Pool

def process(args):
    fn, use_rand, lag_window = args
    res = run_analysis(fn, use_rand, lag_window)
    with open(f"{fn + '_rnd' if use_rand else fn}_lag_window_{lag_window}.pkl", "wb") as f:
        pickle.dump(res, f)

if __name__ == "__main__":
    tasks = []
    for fn in ALL_FILES:
        for use_rand in [False]:
            for lag_window in [None, 10, 20]:
                tasks.append((fn, use_rand, lag_window))
    # process(tasks[0])
    with Pool(8) as p:
        for _ in tqdm(p.imap_unordered(process, tasks), total=len(tasks)):
            pass


Process SpawnPoolWorker-10:
Process SpawnPoolWorker-9:
Process SpawnPoolWorker-11:
Process SpawnPoolWorker-12:
Traceback (most recent call last):
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/queues.py", line 368, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'process' on <module '__main__' (built-in)>
Traceback (most recent call last):
T

KeyboardInterrupt: 