1
+ """Transfer entropy using the Gaussian-Copula."""
2
+ import numpy as np
3
+ import xarray as xr
4
+
5
+ from frites .core import cmi_nd_ggg , copnorm_nd
6
+ from frites .config import CONFIG
7
+
8
+
9
+ def conn_transfer_entropy (x , max_delay = 30 , pairs = None , gcrn = True ):
10
+ """Compute the transfer entropy.
11
+
12
+ The transfer entropy represents the amount of information that is send
13
+ from a source to a target. It is defined as :
14
+
15
+ .. math::
16
+
17
+ TE = I(source_{past}; target_{present} | target_{past})
18
+
19
+ Where :math:`past` is defined using the `max_delay` input parameter. Note
20
+ that the transfer entropy only provides about the amount of information
21
+ that is sent, not on the content.
22
+
23
+ Parameters
24
+ ----------
25
+ x : array_like
26
+ Array of data of shape (n_roi, n_times, n_epochs). Must be a gaussian
27
+ variable
28
+ max_delay : int | 30
29
+ Number of time points defining where to stop looking at in the past.
30
+ Increasing this maximum delay input can lead to slower computations
31
+ pairs : array_like
32
+ Array of pairs to consider for computing the transfer entropy. It
33
+ should be an array of shape (n_pairs, 2) where the first column refers
34
+ to sources and the second to targets. If None, all pairs will be
35
+ computed
36
+ gcrn : bool | True
37
+ Apply a Gaussian Copula rank normalization
38
+
39
+ Returns
40
+ -------
41
+ te : array_like
42
+ The transfer entropy array of shape (n_pairs, n_times - max_delay)
43
+ pairs : array_like
44
+ Pairs vector use for computations of shape (n_pairs, 2)
45
+ """
46
+ # -------------------------------------------------------------------------
47
+ # check pairs
48
+ n_roi , n_times , n_epochs = x .shape
49
+ if not isinstance (pairs , np .ndarray ):
50
+ pairs = np .c_ [np .where (~ np .eye (n_roi , dtype = bool ))]
51
+ assert isinstance (pairs , np .ndarray ) and (pairs .ndim == 2 ) and (
52
+ pairs .shape [1 ] == 2 ), ("`pairs` should be a 2d array of shape "
53
+ "(n_pairs, 2) where the first column refers to "
54
+ "sources and the second to targets" )
55
+ x_all_s , x_all_t = pairs [:, 0 ], pairs [:, 1 ]
56
+ n_pairs = len (x_all_s )
57
+ # check max_delay
58
+ assert isinstance (max_delay , (int , np .int )), ("`max_delay` should be an "
59
+ "integer" )
60
+ # check input data
61
+ assert (x .ndim == 3 ), ("input data `x` should be a 3d array of shape "
62
+ "(n_roi, n_times, n_epochs)" )
63
+ x = x [..., np .newaxis , :]
64
+
65
+ # -------------------------------------------------------------------------
66
+ # apply copnorm
67
+ if gcrn :
68
+ x = copnorm_nd (x , axis = - 1 )
69
+
70
+ # -------------------------------------------------------------------------
71
+ # compute the transfer entropy
72
+ te = np .zeros ((n_pairs , n_times - max_delay ), dtype = float )
73
+ for n_s , x_s in enumerate (x_all_s ):
74
+ # select targets
75
+ is_source = x_all_s == x_s
76
+ x_t = x_all_t [is_source ]
77
+ targets = x [x_t , ...]
78
+ # tile source
79
+ source = np .tile (x [[x_s ], ...], (targets .shape [0 ], 1 , 1 , 1 ))
80
+ # loop over remaining time points
81
+ for n_d , d in enumerate (range (max_delay + 1 , n_times )):
82
+ t_pres = np .tile (targets [:, [d ], :], (1 , max_delay , 1 , 1 ))
83
+ past = slice (d - max_delay - 1 , d - 1 )
84
+ s_past = source [:, past , ...]
85
+ t_past = targets [:, past , ...]
86
+ # compute the transfer entropy
87
+ _te = cmi_nd_ggg (s_past , t_pres , t_past , ** CONFIG ["KW_GCMI" ])
88
+ # take the sum over delays
89
+ te [is_source , n_d ] = _te .mean (1 )
90
+
91
+ return te , pairs
0 commit comments