Skip to content

Commit a148310

Browse files
committed
conditional covgc + example
1 parent 9b72b07 commit a148310

File tree

5 files changed

+264
-15
lines changed

5 files changed

+264
-15
lines changed

docs/source/api.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,17 @@ Random data for directed connectivity measures
181181
sim_distant_cc_ms
182182
sim_distant_cc_ss
183183
sim_gauss_fit
184+
185+
.. raw:: html
186+
187+
<hr>
188+
189+
Autoregressive model
190+
++++++++++++++++++++
191+
192+
.. autosummary::
193+
:toctree: generated/
194+
184195
StimSpecAR
185196

186197
.. raw:: html

examples/armodel/plot_ar_condcovgc.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
AR : conditional covariance based Granger Causality
3+
===================================================
4+
5+
This example reproduces the results of Ding et al. 2006 :cite:`ding2006granger`
6+
where in Fig3 there's an indirect transfer of information from Y->X that is
7+
mediated by Z. The problem is that if the Granger Causality is used, there's
8+
indeed a transfer of information from Y->X while with the conditional Granger
9+
causality, conditioning by the past of other sources suppresses this indirect
10+
transfer.
11+
"""
12+
import numpy as np
13+
14+
from frites.simulations import StimSpecAR
15+
from frites.conn import conn_covgc
16+
17+
import matplotlib.pyplot as plt
18+
19+
20+
###############################################################################
21+
# Simulate 3 nodes 40hz oscillations
22+
# ----------------------------------
23+
#
24+
# Here, we use the class :class:`frites.simulations.StimSpecAR` to simulate an
25+
# stimulus-specific autoregressive model made of three nodes (X, Y and Z). This
26+
# network simulates a transfer Y->Z and Z->X such as an indirect transfer from
27+
# Y->X mediated by Z
28+
29+
ar_type = 'ding_3' # 40hz oscillations
30+
n_stim = 2 # number of stimulus
31+
n_epochs = 50 # number of epochs per stimulus
32+
33+
ss = StimSpecAR()
34+
ar = ss.fit(ar_type=ar_type, n_epochs=n_epochs, n_stim=n_stim)
35+
36+
# ###############################################################################
37+
# # plot the network
38+
39+
plt.figure(figsize=(5, 4))
40+
ss.plot_model()
41+
plt.show()
42+
43+
###############################################################################
44+
# Compute the Granger-Causality
45+
# -----------------------------
46+
#
47+
# We first compute the Granger Causality and then the conditional Granger
48+
# causality (i.e conditioning by the past coming from other sources)
49+
50+
dt, lag, step = 50, 5, 2
51+
t0 = np.arange(lag, ar.shape[-1] - dt, step)
52+
kw_gc = dict(dt=dt, lag=lag, step=1, t0=t0, roi='roi', times='times',
53+
n_jobs=-1)
54+
# granger causality
55+
gc = conn_covgc(ar, conditional=False, **kw_gc)[0]
56+
57+
# conditional granger causality
58+
gc_cond = conn_covgc(ar, conditional=True, **kw_gc)[0]
59+
60+
61+
###############################################################################
62+
# Plot the Granger causality
63+
64+
plt.figure(figsize=(12, 10))
65+
ss.plot_covgc(gc)
66+
plt.tight_layout()
67+
plt.show()
68+
69+
70+
###############################################################################
71+
# Plot the conditional Granger causality
72+
73+
plt.figure(figsize=(12, 10))
74+
ss.plot_covgc(gc_cond)
75+
plt.tight_layout()
76+
plt.show()
77+
78+
79+
###############################################################################
80+
# Direct comparison
81+
# -----------------
82+
#
83+
# In this plot, we only select the transfer of information from Y->X for both
84+
# granger and conditional granger causality
85+
86+
# select Y->X and mean per stimulus for the granger causality
87+
gc_yx = gc.sel(roi='x-y', direction='y->x').groupby('trials').mean('trials')
88+
gc_yx = gc_yx.rename({'trials': 'stimulus'})
89+
90+
# select Y->X and mean per stimulus for the conditional granger causality
91+
gc_cond_yx = gc_cond.sel(roi='x-y', direction='y->x').groupby('trials').mean(
92+
'trials')
93+
gc_cond_yx = gc_cond_yx.rename({'trials': 'stimulus'})
94+
95+
# get (min, max) of granger causality from Y->X
96+
gc_min = min(gc_yx.data.min(), gc_cond_yx.data.min())
97+
gc_max = max(gc_yx.data.max(), gc_cond_yx.data.max())
98+
99+
# sphinx_gallery_thumbnail_number = 4
100+
plt.figure(figsize=(10, 5))
101+
# plot granger causality from Y->X
102+
plt.subplot(121)
103+
gc_yx.plot.line(x='times', hue='stimulus')
104+
plt.title(r'Granger causality Y$\rightarrow$X', fontweight='bold')
105+
plt.axvline(0, color='k', lw=2)
106+
plt.ylim(gc_min, gc_max)
107+
# plot the conditional granger causality from Y->X
108+
plt.subplot(122)
109+
gc_cond_yx.plot.line(x='times', hue='stimulus')
110+
plt.title(r'Conditional Granger causality Y$\rightarrow$X|others',
111+
fontweight='bold')
112+
plt.axvline(0, color='k', lw=2)
113+
plt.ylim(gc_min, gc_max)
114+
plt.tight_layout()
115+
116+
plt.show()

frites/conn/conn_covgc.py

Lines changed: 126 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
from frites.core.copnorm import copnorm_nd
1010

1111

12+
13+
###############################################################################
14+
###############################################################################
15+
# COVGC ENTROPY
16+
###############################################################################
17+
###############################################################################
18+
19+
1220
LOG2 = np.log(2)
1321

1422

@@ -89,6 +97,13 @@ def _covgc(d_s, d_t, ind_tx, t0):
8997
return gc / (2. * LOG2)
9098

9199

100+
###############################################################################
101+
###############################################################################
102+
# GAUSSIAN COPULA COVGC
103+
###############################################################################
104+
###############################################################################
105+
106+
92107
def _gccovgc(d_s, d_t, ind_tx, t0):
93108
"""Compute the Gaussian-Copula based covGC for a single pair.
94109
@@ -128,12 +143,84 @@ def _gccovgc(d_s, d_t, ind_tx, t0):
128143

129144

130145

146+
###############################################################################
147+
###############################################################################
148+
# CONDITIONAL GAUSSIAN COPULA COVGC
149+
###############################################################################
150+
###############################################################################
151+
152+
153+
def _cond_gccovgc(data, s, t, ind_tx, t0, conditional=True):
154+
"""Compute the Gaussian-Copula based covGC for a single pair.
155+
156+
This function computes the covGC for a single pair, across multiple trials,
157+
at different time indices.
158+
"""
159+
conditional = conditional if data.shape[1] > 2 else False
160+
kw = CONFIG["KW_GCMI"]
161+
d_s, d_t = data[:, s, :], data[:, t, :]
162+
n_lags, n_dt = ind_tx.shape
163+
n_trials, n_times = d_s.shape[0], len(t0)
164+
gc = np.empty((n_trials, n_times, 3), dtype=d_s.dtype, order='C')
165+
# define z past
166+
roi_range = np.array([k for k in range(data.shape[1]) if k not in [s, t]])
167+
z_roi = data[:, roi_range, :] # other roi selection
168+
rsh = int(len(roi_range) * (n_lags - 1))
169+
for n_ti, ti in enumerate(t0):
170+
# force starting indices at t0 + force row-major slicing
171+
ind_t0 = np.ascontiguousarray(ind_tx + ti)
172+
x = d_s[:, ind_t0]
173+
y = d_t[:, ind_t0]
174+
# temporal selection
175+
x_pres, x_past = x[:, [0], :], x[:, 1:, :]
176+
y_pres, y_past = y[:, [0], :], y[:, 1:, :]
177+
xy_past = np.concatenate((x[:, 1:, :], y[:, 1:, :]), axis=1)
178+
# conditional granger causality case
179+
if conditional:
180+
# condition by the past of every other possible sources
181+
z_past = z_roi[..., ind_t0[1:, :]] # (lag_past, dt) selection
182+
z_past = z_past.reshape(n_trials, rsh, n_dt)
183+
# cat with past
184+
yz_past = np.concatenate((y_past, z_past), axis=1)
185+
xz_past = np.concatenate((x_past, z_past), axis=1)
186+
xyz_past = np.concatenate((xy_past, z_past), axis=1)
187+
else:
188+
yz_past, xz_past, xyz_past = y_past, x_past, xy_past
189+
# copnorm over the last axis (avoid copnorming several times)
190+
x_pres = copnorm_nd(x_pres, axis=-1)
191+
x_past = copnorm_nd(x_past, axis=-1)
192+
y_pres = copnorm_nd(y_pres, axis=-1)
193+
y_past = copnorm_nd(y_past, axis=-1)
194+
yz_past = copnorm_nd(yz_past, axis=-1)
195+
xz_past = copnorm_nd(xz_past, axis=-1)
196+
xyz_past = copnorm_nd(xyz_past, axis=-1)
197+
198+
# -----------------------------------------------------------------
199+
# Granger Causality measures
200+
# -----------------------------------------------------------------
201+
# gc(pairs(:,1) -> pairs(:,2))
202+
gc[:, n_ti, 0] = cmi_nd_ggg(y_pres, x_past, yz_past, **kw)
203+
# gc(pairs(:,2) -> pairs(:,1))
204+
gc[:, n_ti, 1] = cmi_nd_ggg(x_pres, y_past, xz_past, **kw)
205+
# gc(pairs(:,2) . pairs(:,1))
206+
gc[:, n_ti, 2] = cmi_nd_ggg(x_pres, y_pres, xyz_past, **kw)
207+
208+
return gc
209+
210+
211+
###############################################################################
212+
###############################################################################
213+
# HIGH-LEVEL CONN_COVGC
214+
###############################################################################
215+
###############################################################################
216+
217+
131218
def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
132-
n_jobs=-1, verbose=None):
219+
conditional=False, n_jobs=-1, verbose=None):
133220
r"""Single-trial covariance-based Granger Causality for gaussian variables.
134221
135-
This function computes the covariance-based Granger Causality (covgc) for
136-
each trial.
222+
This function computes the (conditional) covariance-based Granger Causality
223+
(covgc) for each trial.
137224
138225
.. note::
139226
**Total Granger interdependence**
@@ -180,6 +267,9 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
180267
Method for the estimation of the covgc. Use either 'gauss' which
181268
assumes that the time-points are normally distributed or 'gc' in order
182269
to use the gaussian-copula.
270+
conditional : bool | False
271+
If True, the conditional Granger Causality is computed i.e the past is
272+
also conditioned by the past of other sources.
183273
n_jobs : int | -1
184274
Number of jobs to use for parallel computing (use -1 to use all
185275
jobs). The parallel loop is set at the pair level.
@@ -211,18 +301,21 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
211301
t0, CONFIG['FLOAT_DTYPE']):
212302
t0 = np.array([t0])
213303
t0 = np.asarray(t0).astype(int)
214-
dt, lag, step = int(dt), int(lag), int(step)
304+
dt, lag, step, trials = int(dt), int(lag), int(step), None
215305
# handle dataarray input
216306
if isinstance(data, xr.DataArray):
217307
if isinstance(roi, str):
218308
roi = data[roi].data
219309
if isinstance(times, str):
220310
times = data[times].data
311+
trials = data['trials'].data
221312
data = data.data
222313
# force C contiguous array because operations on row-major
223314
if not data.flags.c_contiguous:
224315
data = np.ascontiguousarray(data)
225316
n_epochs, n_roi, n_times = data.shape
317+
if trials is None:
318+
trials = np.arange(n_epochs)
226319
# default roi vector
227320
if roi is None:
228321
roi = np.array([f"roi_{k}" for k in range(n_roi)])
@@ -262,16 +355,21 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
262355
logger.debug(f"Index shape : {ind_tx.shape}")
263356

264357
# -------------------------------------------------------------------------
358+
ext = 'conditional' if conditional else ''
265359
# compute covgc and parallel over pairs
266-
logger.info(f"Compute the covgc (method={method}, n_pairs={len(x_s)}; "
267-
f"n_windows={len(t0)}, lag={lag}, dt={dt}, step={step})")
268-
gc = Parallel(n_jobs=n_jobs)(delayed(fcn)(
269-
data[:, s, :], data[:, t, :], ind_tx, t0) for s, t in zip(x_s, x_t))
360+
logger.info(f"Compute the {ext} covgc (method={method}, n_pairs={len(x_s)}"
361+
f"; n_windows={len(t0)}, lag={lag}, dt={dt}, step={step})")
362+
if not conditional:
363+
gc = Parallel(n_jobs=n_jobs)(delayed(fcn)(
364+
data[:, s, :], data[:, t, :], ind_tx, t0) for s, t in zip(
365+
x_s, x_t))
366+
else:
367+
gc = Parallel(n_jobs=n_jobs)(delayed(_cond_gccovgc)(
368+
data, s, t, ind_tx, t0) for s, t in zip(x_s, x_t))
270369
gc = np.stack(gc, axis=1)
271370

272371
# -------------------------------------------------------------------------
273372
# change output type
274-
trials = np.arange(n_epochs)
275373
dire = np.array(['x->y', 'y->x', 'x.y'])
276374
gc = xr.DataArray(gc, dims=('trials', 'roi', 'times', 'direction'),
277375
coords=(trials, roi_p, times_p, dire))
@@ -280,5 +378,24 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
280378
gc.attrs['step'] = step
281379
gc.attrs['dt'] = dt
282380
gc.attrs['t0'] = t0
381+
gc.attrs['conditional'] = conditional
283382

284383
return gc, pairs, roi_p, times_p
384+
385+
386+
if __name__ == '__main__':
387+
from frites.simulations import StimSpecAR
388+
import matplotlib.pyplot as plt
389+
390+
ss = StimSpecAR()
391+
ar = ss.fit(ar_type='ding_3', n_stim=2, n_epochs=20)
392+
# plot the model
393+
# plt.figure(figsize=(7, 8))
394+
# ss.plot()
395+
# compute covgc
396+
dt, lag, step = 50, 5, 2
397+
t0 = np.arange(lag, ar.shape[-1] - dt, step)
398+
gc = conn_covgc(ar, roi='roi', times='times', dt=dt, lag=lag, t0=t0,
399+
n_jobs=-1, conditional=False)[0]
400+
ss.plot_covgc(gc=gc)
401+
plt.show()

frites/conn/tests/test_conn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ def test_conn_covgc(self):
5959
t0 = [50, 80]
6060

6161
_ = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gc')[0]
62-
gc = conn_covgc(x, dt, lag, t0, n_jobs=1)[0]
62+
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gauss')[0]
6363
assert gc.shape == (n_epochs, 3, len(t0), 3)
64-
gc = conn_covgc(x, dt, lag, t0, n_jobs=1)[0]
6564
assert isinstance(gc, xr.DataArray)
65+
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gc',
66+
conditional=True)[0]
6667

frites/simulations/sim_ar.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ def plot_model(self):
420420
edge_labels = {(u, v): rf"{u}$\rightarrow${v}={d['weight']}" for
421421
u, v, d in G.edges(data=True)}
422422
# fix ding_5 for bidirectional connectivity between 4 <-> 5
423-
if self._ar_type is 'ding_5':
423+
if self._ar_type is 'ding_3':
424+
edge_labels[('Y', 'X')] = "indirect\n(mediated by Z)"
425+
elif self._ar_type is 'ding_5':
424426
edge_labels[('X5', 'X4')] = (r"X4$\rightarrow$X5=2" + "\n" +
425427
r'X5$\rightarrow$X4=1')
426428

@@ -466,6 +468,8 @@ def plot_covgc(self, gc=None, plot_mi=False):
466468
gcm = gcm.rename({'trials': 'Stimulus'})
467469
else:
468470
gcm = self._mi
471+
# conditional covgc
472+
ext = '|others' if self._gc.attrs['conditional'] else ''
469473

470474
y_min, y_max = gcm.data.min(), gcm.data.max()
471475
direction, roi = gcm['direction'].data, gcm['roi'].data
@@ -480,11 +484,11 @@ def plot_covgc(self, gc=None, plot_mi=False):
480484
plt.axvline(0., lw=2., color='k')
481485
r_sp = r.split('-')
482486
if d == 'x->y':
483-
tit = fr'{r_sp[0]}$\rightarrow${r_sp[1]}'
487+
tit = fr'{r_sp[0]}$\rightarrow${r_sp[1]}{ext}'
484488
elif d == 'y->x':
485-
tit = fr'{r_sp[1]}$\rightarrow${r_sp[0]}'
489+
tit = fr'{r_sp[1]}$\rightarrow${r_sp[0]}{ext}'
486490
elif d == 'x.y':
487-
tit = fr'{r_sp[0]} . {r_sp[1]}'
491+
tit = fr'{r_sp[0]} . {r_sp[1]}{ext}'
488492
plt.title(tit, fontweight='bold', fontsize=15)
489493
if n_r >= 1:
490494
plt.ylabel('')

0 commit comments

Comments
 (0)