/
event.py
403 lines (318 loc) · 13.9 KB
/
event.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# Copyright 2016 Princeton University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Event segmentation using a Hidden Markov Model
Given an ROI timeseries, this class uses an annealed fitting procedure to
segment the timeseries into events with stable activity patterns. After
learning the signature activity pattern of each event, the model can then be
applied to other datasets to identify a corresponding sequence of events.
Full details are available in the bioRxiv preprint:
Christopher Baldassano, Janice Chen, Asieh Zadbood,
Jonathan W Pillow, Uri Hasson, Kenneth A Norman
Discovering event structure in continuous narrative perception and memory
http://biorxiv.org/content/early/2016/10/14/081018
"""
# Authors: Chris Baldassano and Cătălin Iordan (Princeton University)
import numpy as np
from scipy import stats
import logging
import copy
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted, check_array
from sklearn.exceptions import NotFittedError
from . import _utils as utils # type: ignore
logger = logging.getLogger(__name__)
__all__ = [
"EventSegment",
]
class EventSegment(BaseEstimator):
"""Class for event segmentation of continuous fMRI data
Parameters
----------
n_events: int
Number of segments to learn
step_var: Callable[[int], float] : default 4 * (0.98 ** (step - 1))
The Gaussian variance to use during fitting, as a function of the
number of steps. Should decrease slowly over time.
n_iter: int : default 500
Maximum number of steps to run during fitting
Attributes
----------
p_start, p_end: length n_events+1 ndarray
initial and final prior distributions over events
P: n_events+1 by n_events+1 ndarray
HMM transition matrix
ll_ : ndarray with length = number of training datasets
Log-likelihood for training datasets over the course of training
segments_: list of (time by event) ndarrays
Learned (soft) segmentation for training datasets
event_var_ : float
Gaussian variance at the end of learning
event_pat_ : voxel by event ndarray
Learned mean patterns for each event
"""
def _default_var_schedule(step):
return 4 * (0.98 ** (step - 1))
def __init__(self, n_events=2,
step_var=_default_var_schedule,
n_iter=500):
self.n_events = n_events
self.classes_ = np.arange(self.n_events)
self.step_var = step_var
self.n_iter = n_iter
self.event_var_ = None
# Set up transition matrix, with final sink state
# For transition matrix of this form, the transition probability has
# no impact on the final solution, since all valid paths must take
# the same number of transitions
self.p_start = np.zeros((1, self.n_events + 1))
self.p_start[0, 0] = 1
self.P = np.vstack((np.hstack((
0.5 * np.diag(np.ones(self.n_events))
+ 0.5 * np.diag(np.ones(self.n_events - 1), 1),
np.append(np.zeros((self.n_events - 1, 1)), [[0.5]], axis=0))),
np.append(np.zeros((1, self.n_events)), [[1]],
axis=1)))
self.p_end = np.zeros((1, self.n_events + 1))
self.p_end[0, -2] = 1
def fit(self, X, y=None):
"""Learn a segmentation on training data
Fits event patterns and a segmentation to training data. After
running this function, the learned event patterns can be used to
segment other datasets using find_events
Parameters
----------
X: time by voxel ndarray, or a list of such ndarrays
fMRI data to be segmented. If a list is given, then all datasets
are segmented simultaneously with the same event patterns
y: not used (added to comply with BaseEstimator definition)
Returns
-------
self: the EventSegment object
"""
X = copy.deepcopy(X)
if type(X) is not list:
X = check_array(X)
X = [X]
n_train = len(X)
for i in range(n_train):
X[i] = X[i].T
n_dim = X[0].shape[0]
for i in range(n_train):
assert (X[i].shape[0] == n_dim)
# Double-check that data is z-scored in time
for i in range(n_train):
X[i] = stats.zscore(X[i], axis=1, ddof=1)
# Initialize variables for fitting
log_gamma = []
for i in range(n_train):
log_gamma.append(np.zeros((X[i].shape[1],
self.n_events)))
step = 1
best_ll = float("-inf")
self.ll_ = np.empty((0, n_train))
while step <= self.n_iter:
iteration_var = self.step_var(step)
# Based on the current segmentation, compute the mean pattern
# for each event
seg_prob = [np.exp(lg) / np.sum(np.exp(lg), axis=0)
for lg in log_gamma]
mean_pat = np.empty((n_train, n_dim, self.n_events))
for i in range(n_train):
mean_pat[i, :, :] = X[i].dot(seg_prob[i])
mean_pat = np.mean(mean_pat, axis=0)
# Based on the current mean patterns, compute the event
# segmentation
self.ll_ = np.append(self.ll_, np.empty((1, n_train)), axis=0)
for i in range(n_train):
logprob = self._logprob_obs(X[i],
mean_pat, iteration_var)
log_gamma[i], self.ll_[-1, i] = self._forward_backward(logprob)
# If log-likelihood has started decreasing, undo last step and stop
if np.mean(self.ll_[-1, :]) < best_ll:
self.ll_ = self.ll_[:-1, :]
break
self.segments_ = [np.exp(lg) for lg in log_gamma]
self.event_var_ = iteration_var
self.event_pat_ = mean_pat
best_ll = np.mean(self.ll_[-1, :])
logger.debug("Fitting step %d, LL=%f", step, best_ll)
step += 1
return self
def _logprob_obs(self, data, mean_pat, var):
"""Log probability of observing each timepoint under each event model
Computes the log probability of each observed timepoint being
generated by the Gaussian distribution for each event pattern
Parameters
----------
data: voxel by time ndarray
fMRI data on which to compute log probabilities
mean_pat: voxel by event ndarray
Centers of the Gaussians for each event
var: float or 1D array of length equal to the number of events
Variance of the event Gaussians. If scalar, all events are
assumed to have the same variance
Returns
-------
logprob : time by event ndarray
Log probability of each timepoint under each event Gaussian
"""
n_vox = data.shape[0]
t = data.shape[1]
# z-score both data and mean patterns in space, so that Gaussians
# are measuring Pearson correlations and are insensitive to overall
# activity changes
data_z = stats.zscore(data, axis=0, ddof=1)
mean_pat_z = stats.zscore(mean_pat, axis=0, ddof=1)
logprob = np.empty((t, self.n_events))
if type(var) is not np.ndarray:
var = var * np.ones(self.n_events)
for k in range(self.n_events):
logprob[:, k] = -0.5 * n_vox * np.log(
2 * np.pi * var[k]) - 0.5 * np.sum(
(data_z.T - mean_pat_z[:, k]).T ** 2, axis=0) / var[k]
logprob /= n_vox
return logprob
def _forward_backward(self, logprob):
"""Runs forward-backward algorithm on observation log probs
Given the log probability of each timepoint being generated by
each event, run the HMM forward-backward algorithm to find the
probability that each timepoint belongs to each event (based on the
transition priors in p_start, p_end, and P)
See https://en.wikipedia.org/wiki/Forward-backward_algorithm for
mathematical details
Parameters
----------
logprob : time by event ndarray
Log probability of each timepoint under each event Gaussian
Returns
-------
log_gamma : time by event ndarray
Log probability of each timepoint belonging to each event
ll : float
Log-likelihood of fit
"""
logprob = copy.copy(logprob)
t = logprob.shape[0]
logprob = np.hstack((logprob, float("-inf") * np.ones((t, 1))))
# Initialize variables
log_scale = np.zeros(t)
log_alpha = np.zeros((t, self.n_events + 1))
log_beta = np.zeros((t, self.n_events + 1))
# Forward pass
for t in range(t):
if t == 0:
log_alpha[0, :] = self._log(self.p_start) + logprob[0, :]
else:
log_alpha[t, :] = self._log(np.exp(log_alpha[t - 1, :])
.dot(self.P)) + logprob[t, :]
log_scale[t] = np.logaddexp.reduce(log_alpha[t, :])
log_alpha[t] -= log_scale[t]
# Backward pass
log_beta[-1, :] = self._log(self.p_end) - log_scale[-1]
for t in reversed(range(t - 1)):
obs_weighted = log_beta[t + 1, :] + logprob[t + 1, :]
offset = np.max(obs_weighted)
log_beta[t, :] = offset + self._log(
np.exp(obs_weighted - offset).dot(self.P.T)) - log_scale[t]
# Combine and normalize
log_gamma = log_alpha + log_beta
log_gamma -= np.logaddexp.reduce(log_gamma, axis=1, keepdims=True)
ll = np.sum(log_scale[:(t - 1)]) + np.logaddexp.reduce(
log_alpha[-1, :] + log_scale[-1] + self._log(self.p_end), axis=1)
log_gamma = log_gamma[:, :-1]
return log_gamma, ll
def _log(self, x):
"""Modified version of np.log that manually sets values <=0 to -inf
Parameters
----------
x: ndarray of floats
Input to the log function
Returns
-------
log_ma: ndarray of floats
log of x, with x<=0 values replaced with -inf
"""
xshape = x.shape
_x = x.flatten()
y = utils.masked_log(_x)
return y.reshape(xshape)
def set_event_patterns(self, event_pat):
"""Set HMM event patterns manually
Rather than fitting the event patterns automatically using fit(), this
function allows them to be set explicitly. They can then be used to
find corresponding events in a new dataset, using find_events().
Parameters
----------
event_pat: voxel by event ndarray
"""
if event_pat.shape[1] != self.n_events:
raise ValueError(("Number of columns of event_pat must match "
"number of events"))
self.event_pat_ = event_pat.copy()
def find_events(self, testing_data, var=None, scramble=False):
"""Applies learned event segmentation to new testing dataset
After fitting an event segmentation using fit() or setting event
patterns directly using set_event_patterns(), this function finds the
same sequence of event patterns in a new testing dataset.
Parameters
----------
testing_data: timepoint by voxel ndarray
fMRI data to segment based on previously-learned event patterns
var: float or 1D ndarray of length equal to the number of events
default: uses variance that maximized training log-likelihood
Variance of the event Gaussians. If scalar, all events are
assumed to have the same variance. If fit() has not previously
been run, this must be specifed (cannot be None).
scramble: bool : default False
If true, the order of the learned events are shuffled before
fitting, to give a null distribution
Returns
-------
segments : time by event ndarray
The resulting soft segmentation. segments[t,e] = probability
that timepoint t is in event e
test_ll : float
Log-likelihood of model fit
"""
if var is None:
if self.event_var_ is None:
raise NotFittedError(("The event patterns must first be set "
"by fit() or set_event_patterns()"))
else:
var = self.event_var_
if scramble:
mean_pat = self.event_pat_[:, np.random.permutation(self.n_events)]
else:
mean_pat = self.event_pat_
logprob = self._logprob_obs(testing_data.T, mean_pat, var)
lg, test_ll = self._forward_backward(logprob)
segments = np.exp(lg)
return segments, test_ll
def predict(self, X):
"""Applies learned event segmentation to new testing dataset
Alternative function for segmenting a new dataset after using
fit() to learn a sequence of events, to comply with the sklearn
Classifier interface
Parameters
----------
X: timepoint by voxel ndarray
fMRI data to segment based on previously-learned event patterns
Returns
-------
Event label for each timepoint
"""
check_is_fitted(self, ["event_pat_", "event_var_"])
X = check_array(X)
segments, test_ll = self.find_events(X)
return np.argmax(segments, axis=1)