9
9
from frites .core .copnorm import copnorm_nd
10
10
11
11
12
+
13
+ ###############################################################################
14
+ ###############################################################################
15
+ # COVGC ENTROPY
16
+ ###############################################################################
17
+ ###############################################################################
18
+
19
+
12
20
LOG2 = np .log (2 )
13
21
14
22
@@ -89,6 +97,13 @@ def _covgc(d_s, d_t, ind_tx, t0):
89
97
return gc / (2. * LOG2 )
90
98
91
99
100
+ ###############################################################################
101
+ ###############################################################################
102
+ # GAUSSIAN COPULA COVGC
103
+ ###############################################################################
104
+ ###############################################################################
105
+
106
+
92
107
def _gccovgc (d_s , d_t , ind_tx , t0 ):
93
108
"""Compute the Gaussian-Copula based covGC for a single pair.
94
109
@@ -128,12 +143,84 @@ def _gccovgc(d_s, d_t, ind_tx, t0):
128
143
129
144
130
145
146
+ ###############################################################################
147
+ ###############################################################################
148
+ # CONDITIONAL GAUSSIAN COPULA COVGC
149
+ ###############################################################################
150
+ ###############################################################################
151
+
152
+
153
+ def _cond_gccovgc (data , s , t , ind_tx , t0 , conditional = True ):
154
+ """Compute the Gaussian-Copula based covGC for a single pair.
155
+
156
+ This function computes the covGC for a single pair, across multiple trials,
157
+ at different time indices.
158
+ """
159
+ conditional = conditional if data .shape [1 ] > 2 else False
160
+ kw = CONFIG ["KW_GCMI" ]
161
+ d_s , d_t = data [:, s , :], data [:, t , :]
162
+ n_lags , n_dt = ind_tx .shape
163
+ n_trials , n_times = d_s .shape [0 ], len (t0 )
164
+ gc = np .empty ((n_trials , n_times , 3 ), dtype = d_s .dtype , order = 'C' )
165
+ # define z past
166
+ roi_range = np .array ([k for k in range (data .shape [1 ]) if k not in [s , t ]])
167
+ z_roi = data [:, roi_range , :] # other roi selection
168
+ rsh = int (len (roi_range ) * (n_lags - 1 ))
169
+ for n_ti , ti in enumerate (t0 ):
170
+ # force starting indices at t0 + force row-major slicing
171
+ ind_t0 = np .ascontiguousarray (ind_tx + ti )
172
+ x = d_s [:, ind_t0 ]
173
+ y = d_t [:, ind_t0 ]
174
+ # temporal selection
175
+ x_pres , x_past = x [:, [0 ], :], x [:, 1 :, :]
176
+ y_pres , y_past = y [:, [0 ], :], y [:, 1 :, :]
177
+ xy_past = np .concatenate ((x [:, 1 :, :], y [:, 1 :, :]), axis = 1 )
178
+ # conditional granger causality case
179
+ if conditional :
180
+ # condition by the past of every other possible sources
181
+ z_past = z_roi [..., ind_t0 [1 :, :]] # (lag_past, dt) selection
182
+ z_past = z_past .reshape (n_trials , rsh , n_dt )
183
+ # cat with past
184
+ yz_past = np .concatenate ((y_past , z_past ), axis = 1 )
185
+ xz_past = np .concatenate ((x_past , z_past ), axis = 1 )
186
+ xyz_past = np .concatenate ((xy_past , z_past ), axis = 1 )
187
+ else :
188
+ yz_past , xz_past , xyz_past = y_past , x_past , xy_past
189
+ # copnorm over the last axis (avoid copnorming several times)
190
+ x_pres = copnorm_nd (x_pres , axis = - 1 )
191
+ x_past = copnorm_nd (x_past , axis = - 1 )
192
+ y_pres = copnorm_nd (y_pres , axis = - 1 )
193
+ y_past = copnorm_nd (y_past , axis = - 1 )
194
+ yz_past = copnorm_nd (yz_past , axis = - 1 )
195
+ xz_past = copnorm_nd (xz_past , axis = - 1 )
196
+ xyz_past = copnorm_nd (xyz_past , axis = - 1 )
197
+
198
+ # -----------------------------------------------------------------
199
+ # Granger Causality measures
200
+ # -----------------------------------------------------------------
201
+ # gc(pairs(:,1) -> pairs(:,2))
202
+ gc [:, n_ti , 0 ] = cmi_nd_ggg (y_pres , x_past , yz_past , ** kw )
203
+ # gc(pairs(:,2) -> pairs(:,1))
204
+ gc [:, n_ti , 1 ] = cmi_nd_ggg (x_pres , y_past , xz_past , ** kw )
205
+ # gc(pairs(:,2) . pairs(:,1))
206
+ gc [:, n_ti , 2 ] = cmi_nd_ggg (x_pres , y_pres , xyz_past , ** kw )
207
+
208
+ return gc
209
+
210
+
211
+ ###############################################################################
212
+ ###############################################################################
213
+ # HIGH-LEVEL CONN_COVGC
214
+ ###############################################################################
215
+ ###############################################################################
216
+
217
+
131
218
def conn_covgc (data , dt , lag , t0 , step = 1 , roi = None , times = None , method = 'gc' ,
132
- n_jobs = - 1 , verbose = None ):
219
+ conditional = False , n_jobs = - 1 , verbose = None ):
133
220
r"""Single-trial covariance-based Granger Causality for gaussian variables.
134
221
135
- This function computes the covariance-based Granger Causality (covgc) for
136
- each trial.
222
+ This function computes the (conditional) covariance-based Granger Causality
223
+ (covgc) for each trial.
137
224
138
225
.. note::
139
226
**Total Granger interdependence**
@@ -180,6 +267,9 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
180
267
Method for the estimation of the covgc. Use either 'gauss' which
181
268
assumes that the time-points are normally distributed or 'gc' in order
182
269
to use the gaussian-copula.
270
+ conditional : bool | False
271
+ If True, the conditional Granger Causality is computed i.e the past is
272
+ also conditioned by the past of other sources.
183
273
n_jobs : int | -1
184
274
Number of jobs to use for parallel computing (use -1 to use all
185
275
jobs). The parallel loop is set at the pair level.
@@ -211,18 +301,21 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
211
301
t0 , CONFIG ['FLOAT_DTYPE' ]):
212
302
t0 = np .array ([t0 ])
213
303
t0 = np .asarray (t0 ).astype (int )
214
- dt , lag , step = int (dt ), int (lag ), int (step )
304
+ dt , lag , step , trials = int (dt ), int (lag ), int (step ), None
215
305
# handle dataarray input
216
306
if isinstance (data , xr .DataArray ):
217
307
if isinstance (roi , str ):
218
308
roi = data [roi ].data
219
309
if isinstance (times , str ):
220
310
times = data [times ].data
311
+ trials = data ['trials' ].data
221
312
data = data .data
222
313
# force C contiguous array because operations on row-major
223
314
if not data .flags .c_contiguous :
224
315
data = np .ascontiguousarray (data )
225
316
n_epochs , n_roi , n_times = data .shape
317
+ if trials is None :
318
+ trials = np .arange (n_epochs )
226
319
# default roi vector
227
320
if roi is None :
228
321
roi = np .array ([f"roi_{ k } " for k in range (n_roi )])
@@ -262,16 +355,21 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
262
355
logger .debug (f"Index shape : { ind_tx .shape } " )
263
356
264
357
# -------------------------------------------------------------------------
358
+ ext = 'conditional' if conditional else ''
265
359
# compute covgc and parallel over pairs
266
- logger .info (f"Compute the covgc (method={ method } , n_pairs={ len (x_s )} ; "
267
- f"n_windows={ len (t0 )} , lag={ lag } , dt={ dt } , step={ step } )" )
268
- gc = Parallel (n_jobs = n_jobs )(delayed (fcn )(
269
- data [:, s , :], data [:, t , :], ind_tx , t0 ) for s , t in zip (x_s , x_t ))
360
+ logger .info (f"Compute the { ext } covgc (method={ method } , n_pairs={ len (x_s )} "
361
+ f"; n_windows={ len (t0 )} , lag={ lag } , dt={ dt } , step={ step } )" )
362
+ if not conditional :
363
+ gc = Parallel (n_jobs = n_jobs )(delayed (fcn )(
364
+ data [:, s , :], data [:, t , :], ind_tx , t0 ) for s , t in zip (
365
+ x_s , x_t ))
366
+ else :
367
+ gc = Parallel (n_jobs = n_jobs )(delayed (_cond_gccovgc )(
368
+ data , s , t , ind_tx , t0 ) for s , t in zip (x_s , x_t ))
270
369
gc = np .stack (gc , axis = 1 )
271
370
272
371
# -------------------------------------------------------------------------
273
372
# change output type
274
- trials = np .arange (n_epochs )
275
373
dire = np .array (['x->y' , 'y->x' , 'x.y' ])
276
374
gc = xr .DataArray (gc , dims = ('trials' , 'roi' , 'times' , 'direction' ),
277
375
coords = (trials , roi_p , times_p , dire ))
@@ -280,5 +378,24 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
280
378
gc .attrs ['step' ] = step
281
379
gc .attrs ['dt' ] = dt
282
380
gc .attrs ['t0' ] = t0
381
+ gc .attrs ['conditional' ] = conditional
283
382
284
383
return gc , pairs , roi_p , times_p
384
+
385
+
386
+ if __name__ == '__main__' :
387
+ from frites .simulations import StimSpecAR
388
+ import matplotlib .pyplot as plt
389
+
390
+ ss = StimSpecAR ()
391
+ ar = ss .fit (ar_type = 'ding_3' , n_stim = 2 , n_epochs = 20 )
392
+ # plot the model
393
+ # plt.figure(figsize=(7, 8))
394
+ # ss.plot()
395
+ # compute covgc
396
+ dt , lag , step = 50 , 5 , 2
397
+ t0 = np .arange (lag , ar .shape [- 1 ] - dt , step )
398
+ gc = conn_covgc (ar , roi = 'roi' , times = 'times' , dt = dt , lag = lag , t0 = t0 ,
399
+ n_jobs = - 1 , conditional = False )[0 ]
400
+ ss .plot_covgc (gc = gc )
401
+ plt .show ()
0 commit comments