This code is provided in tandem with our paper on graphical directed information (GDI), which is a model-free technique that infers causal influences between pairs of time series and in particular captures unique influences between pairs by conditioning on other time series.
The directed information (DI) from a time series X(t) to another time series Y(t) is the mutual information (MI) between the past of X(t) and the present of Y(t) conditioned on the past of Y(t), which means DI quantifies the causal influence exclusively from X(t) to Y(t).
GDI is an extension of DI that conditions not only on the past of Y(t), but also conditions on the pasts of other time series W(t), Z(t), etc. Conditioning on other time series enables the identification and quantification of a direct connection from one time series to another as well as the possibility for eliminating indirect connections.
We note that one must choose a history parameter which controls the number of past samples that are considered when using the pasts of X(t) and the time series being conditioned on, which is referred to as M in our code.
We offer Python and MATLAB implementations of GDI, which both include five different example GDI analyses as well as instructions on how to install both. We also include the SNNAP files which were used to simulate networks of neurons.
For all relevant references, please see our paper.
GDI was tested using Python 3.7, and requires the following packages:
- NumPy
- seaborn
- SciPy
- TensorFlow (<v2.0)
GDI also requires CCMI. Extract all files from the CIT folder of CCMI, and place them in the gdi_python directory of our code.
(Optional) If you want to run example 4, the arbitrary network, then please download the simulation file here and add it to the gdi_python
folder.
Note: Our GDI results on binned spike times are normalized using code from CTM-DI, which is only available for MATLAB. For the purposes of this repository, we included normalization factors that were precalculated using CTM-DI and are loaded for our two binned spike time examples. Please note that these normalizations are only appropriate for the already specified bin widths and M values in those examples.
The GDI.py
file contains all functions used for GDI. The core functions are:
DI(X,M,B)
: Compute the pairwise (non-graphical) DI between columns of X with history parameter M (i.e. past number of samples relevant in estimation) and B bootstrap iterations. Output is DI matrix where each element is the DI from the row to the column.GDI(X,M,B)
: Compute the GDI between columns of X with history parameter M (i.e. past number of samples relevant in estimation) and B bootstrap iterations. When computing the GDI between two columns, all other columns are conditioned on. Output is GDI matrix where each element is the GDI from the row to the column.sign_inference(X,M)
: Determine the sign of the relationship between columns of X with history parameter M. First output is the partial sign matrix where each element is the sign of the relationship from the row to the column based on partial correlations, and second output is the same but based on regular correlations.GDI_mask(X,M,B,mask)
: Same asGDI()
, however a mask in the form of a square matrix containing zeros and ones specifies which time series GDI is to be computed for as well as which time series are to be conditioned on in such GDI computations. For example, a mask with all zeros except for ones at elements (2,3), (4,3), and (5,4) would mean that GDI would only be computed from columns 2 to 3, 4 to 3, and 5 to 4 of X. Furthermore, the GDI computation from column 2 to 3 would only be conditioned on column 4, while the GDI from 4 to 3 would only be conditioned on 2, and finally the GDI from 5 to 4 would not be conditioned on any other column, making it equivalent to the DI from 5 to 4. Output is GDI matrix where each element is the GDI from the row to the column.
For more detail, view the header for each function in the GDI.py
file.
Minimal Working Example (Jupyter Notebook)
Here we construct an example where node 0 causally influences node 1 and node 1 causally influences node 2. This results in direct connections from 0 to 1 and 1 to 2 as well as an indirect connection from 0 to 2 which GDI correctly eliminates.
import GDI
import numpy as np
num_samples = 5000
mean = [0,0,0]
cov = np.eye(3)
X = np.random.multivariate_normal(mean, cov, num_samples)
X[2:,1] = X[2:,1] + X[:-2,0]
X[2:,2] = X[2:,2] + X[:-2,1]
M = 4
B = 10
X_DI = GDI.DI(X,M,B)
X_GDI = GDI.GDI(X,M,B)
print(X_DI)
print(X_GDI)
An example run produces a DI matrix of:
[[ 0. 0.36758208 0.15418205] [ 0.04643047 0. 0.46776479] [-0.00124419 0.01545924 0. ]]
where the incorrectly identified indirect connection from 0 to 2 is bolded.
Then GDI matrix then looks like:
[[ 0. 0.29086685 -0.00878638] [-0.03832781 0. 0.21460593] [ 0.01501656 -0.00237674 0. ]]
which shows that GDI eliminated the indirect connection from node 0 to 2 that was incorrectly identified by DI.
These steps are required (although all steps regarding CTM-DI can be ignored if not running arb.m or cpg.m):
-
Download CCMI and CTM-DI. Extract all files from the CIT folder of CCMI, and place them in the ccdi_mat directory of our code. The CTM_DI_package folder of CTM-DI should be placed at the same directory level as the ccdi_mat directory of our code. Ensure that you have the necessary Python packages installed as listed in the prior Python section.
-
Modify a few lines of the CCMI code to include the CMI estimate for each bootstrap iteration before averaing across bootstrap iterations occurs. The specific lines that we want to modify are right around the return statement at the end of the definition for
get_cmi_est()
in the fileCCMI.py
. When we downloaded CCMI, the lines to be considered were 102 to 106, which appear as:
cmi_est = I_xyz - I_xz
else:
raise NotImplementedError
return cmi_est
Change those lines to include the CMI estimates before averaging over bootstrap iterations:
cmi_est_list = I_xyz_list - I_xz_list
cmi_est = I_xyz - I_xz
else:
raise NotImplementedError
return cmi_est, cmi_est_list
- Replace the contents of the file
CTM_DI_package/Supporting_functions/CTM_spiketime_wrapper.m
with the following:
function [CMatrix, w_i, HMatrix] = CTM_spiketime_wrapper(sig,M, IRI, causal)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% note this function takes vectors of spike times per channel as input.
% When a channel has less spike than others, zero pad the rest of the
% entries
%
% sig is an array where rows are spike times in ms and columns are neurons
% M is the maximum depth that the tree can be (maximum number of past bins used)
% IRI is the bin_width to use in ms
% causal should be 1 (causal definition of DI)
% SPIKE TIMES SHOULD BE IN MS!
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
channel = size(sig,2);
CMatrix = zeros(channel,channel,length(IRI));
parfor cnt = 1:length(IRI)
cnt
[CMatrix(:,:,cnt) w_i(:,:,:,cnt) HMatrix(:,:,cnt)]= ...
Connect_CTM(spiketime2train(sig,IRI(cnt)),M,causal);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
return
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- Modify
CTM_DI_package/CTM_DI/Connect_CTM.m
as follows:
- Change the function definition to:
function [CMatrix,w_i,HMatrix] = Connect_CTM(Sig,D,causal,sumspike,sig_label)
- Add this line in between lines 68 and 69 (which should be between
weight_matrix = zeros(channel);
andfor cnt = 1:length(row)
):
w_i = nan(channel,channel,D+1);
- Just after what should now be line 75 (
weight_matrix(col(cnt),row(cnt))=sum(weights2(1+causal:end));
), add these two lines just beforeend
:
w_i(row(cnt),col(cnt),1:length(weights1)) = weights1;
w_i(col(cnt),row(cnt),1:length(weights2)) = weights2;
- Just after the definition of CMatrix (
CMatrix = zeros(channel,channel);
) which should be at line 83 now, add this right after:
HMatrix = zeros(channel,channel);
- Just after
CMatrix(col(cnt), row(cnt)) = DI21/H1;
which should now be at line 117, add these two lines:
HMatrix(row(cnt), col(cnt)) = H2;
HMatrix(col(cnt), row(cnt)) = H1;
-
Since the core of this toolbox relies on the CCMI implementation which is written in Python, you must insert your system path and python path in the
gdi_matlab/python_path_script.m
file. This script is called by deeper functions to access python. This means copying the terminal output for the commandecho $PATH
and putting it in between the '' for system_path in thegdi_matlab/python_path_script.m
file, and then also copying the output for the commandwhich python
and putting it in between the '' for the python_path in thegdi_matlab/python_path_script.m
file. -
(Optional) If you want to run example 4, the arbitrary network, then please download the simulation file here and place it in the
gdi_matlab
folder.
The ccdi_mat
folder contains all of the files/functions for GDI. The core functions are:
di_compute(X,M,C,B)
: Computes the DI or GDI between each column of X using the history parameter M (i.e. past number of samples relevant in estimation) and B bootstrap iterations. C is 0 (compute DI) or 1 (compute GDI).di_compute_pair(X,M,C,B,pairs)
: Same asdi_compute()
, but computes DI/GDI only for the specified pairs.di_compute_post(DI_uncond,thresh,M,X,B)
: Computes GDI based on thresholding the (non-grapical) DI values. GDI will only be computed between channels with DI values >= thresh, and will only be conditioned on channels with DI values >= thresh. If thresholding means there are no other channels to condition on for a particular GDI analysis, then that GDI analysis will not be performed and the DI value will be taken to be the GDI value.sign_inference(X,M)
: Determine the sign of the relationship between columns of X with history parameter M. First output is the partial sign matrix where each element is the sign of the relationship from the row to the column based on partial correlations, and second output is the same but based on regular correlations.
For more detail, view the header for each function/file in the ccdi_mat
folder.
Minimal Working Example (MATLAB)
Here we construct an example where node 1 causally influences node 2 and node 2 causally influences node 3. This results in direct connections from 1 to 2 and 2 to 3 as well as an indirect connection from 1 to 3 which GDI correctly eliminates.
addpath ccdi_mat
num_samples = 5000;
mean_vec = [0,0,0];
cov_mat = eye(3);
X = mvnrnd(mean_vec, cov_mat, num_samples);
X(3:end,2) = X(3:end,2) + X(1:(end-2),1);
X(3:end,3) = X(3:end,3) + X(1:(end-2),2);
M = 4;
B = 10;
X_DI = di_compute(X,M,0,B)
X_GDI = di_compute(X,M,1,B)
An example run produces a DI matrix of:
NaN 0.3569 0.1730 -0.0298 NaN 0.4008 -0.0060 -0.1157 NaN
where the incorrectly identified indirect connection from 1 to 3 is bolded.
Then GDI matrix then looks like:
NaN 0.2392 -0.0166 0.0490 NaN 0.3186 -0.0438 -0.1290 NaN
which shows that GDI eliminated the indirect connection from node 1 to 3 that was incorrectly identified by DI.
We included five different example analyses in our code which correspond to the five results figures in our paper, and links in parentheses go to the files within our repository corresponding to those examples.
1. Scaling (MATLAB, Jupyter Notebook)
GDI's performance with regard to sample size, the number of dimensions being conditioned on, and the number of bootstrap iterations (see Method section below) is analyzed.
2. Gaussian Network (MATLAB, Jupyter Notebook)
GDI is applied to a Gaussian network, which consists of causal influences between nodes that have their own Gaussian noise. The analytic solution for GDI is known for this network, and the accuracy of our GDI estimates is compared with the derived values.
3. Nonlinear Network (MATLAB, Jupyter Notebook)
GDI is applied to a nonlinear network, which has the same structure as the Gaussian network however source nodes now follow uniform distributions and causal influence involves a squared relationship.
4. Arbitrary Network (MATLAB, Jupyter Notebook)
GDI is applied to binned spike times produced by an abritrary model of a network of neurons.
5. CPG Network (MATLAB, Jupyter Notebook)
GDI is applied to binned spike times produced by a model of the central pattern generator (CPG) in Aplyisa.
We have adopted the GPLv2 license for this toolbox (see LICENSE and GPLv2_note.txt files).