-
Notifications
You must be signed in to change notification settings - Fork 3
/
seq_modules.py
528 lines (452 loc) · 24.4 KB
/
seq_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional, Union, Sequence
from typing_extensions import Protocol
import torch
from ..lstnet.lstnet_modules import LSTNetwork
from ..mlp.mlp_modules import MultiLayerPerceptron
from ... import ActivationFunction
from ....util.string import object_repr, ToStringMixin
log = logging.getLogger(__name__)
class EncoderProtocol(Protocol):
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
:param x: a tensor of shape (batch_size, seq_length=max(lengths), history_features) containing the sequence of
history features to encode
:param lengths: an optional tensor of shape (batch_size) containing the lengths of the sequences in `x`
:return: a tensor of shape (batch_size, latent_dim) containing the encodings
"""
pass
class DecoderProtocol(Protocol):
def forward(self,
latent: torch.Tensor,
target_features: Optional[torch.Tensor] = None,
target_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
:param latent: a tensor of shape (batch_size, latent_dim) containing the latent representations
:param target_features: a tensor of shape (batch_size, target_seq_length=max(target_lengths), target_feature_dim)
:param target_lengths: a tensor of shape (batch_size) containing the lengths of sequences in `target_features`
:return: a tensor of shape (batch_size, output_dim) or (batch_size, target_seq_length, output_dim) containing the predictions,
where the shape depends on the use case and can vary depending on the needs
"""
pass
class PredictorProtocol(Protocol):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: a tensor of shape (batch_size, input_dim) an intermediate representation
:return: a tensor of shape (batch_size, output_dim)
"""
pass
# TODO: Should use intersection type A & B once we switch to Python 3.9+
TDecoder = Union[DecoderProtocol, torch.nn.Module]
TEncoder = Union[EncoderProtocol, torch.nn.Module]
TPredictor = Union[PredictorProtocol, torch.nn.Module]
class EncoderFactory(ToStringMixin, ABC):
"""
Represents a factory for an encoder modules that map a sequence of items to a latent vector
"""
@abstractmethod
def create_encoder(self, input_dim: int, latent_dim: int) -> TEncoder:
"""
:param input_dim: the input dimension per sequence item
:param latent_dim: the latent vector dimension that is to be generated by the encoder
:return: a torch module satisfying :class:`EncoderProtocol`
"""
pass
class DecoderFactory(ToStringMixin, ABC):
@abstractmethod
def create_decoder(self, latent_dim: int, target_feature_dim: int) -> TDecoder:
"""
:param latent_dim: the latent vector size which is used for the representation of the history
:param target_feature_dim: the number of dimensions/features that are given for each prediction to be made
(each future sequence item)
:return: a torch module satisfying :class:`DecoderProtocol`
"""
pass
class PredictorFactory(ToStringMixin, ABC):
"""
Represents a factory for predictor components which sample map from an intermediate representation to the
desired output dimension.
"""
def create_predictor(self, input_dim: int, output_dim: int) -> TPredictor:
"""
:param input_dim: the input dimension
:param output_dim: the output dimension
:return: a module which maps an input with dimension `input_dim` to the desired prediction dimension (`output_dim`)
"""
pass
class LinearPredictorFactory(PredictorFactory):
"""A factory for predictors consisting only of a linear layer (without subsequent activation)"""
def create_predictor(self, input_dim: int, output_dim: int) -> torch.nn.Module:
return torch.nn.Linear(input_dim, output_dim)
class MLPPredictorFactory(PredictorFactory):
"""A factor for predictors that are multi-layer perceptrons"""
def __init__(self,
hidden_dims: Sequence[int] = (),
hid_activation_fn: ActivationFunction = ActivationFunction.RELU,
output_activation_fn: ActivationFunction = ActivationFunction.NONE,
p_dropout: Optional[float] = None):
self.hidden_dims = hidden_dims
self.hid_activation_fn = hid_activation_fn
self.output_activation_fn = output_activation_fn
self.p_dropout = p_dropout
def create_predictor(self, input_dim: int, output_dim: int) -> TPredictor:
return MultiLayerPerceptron(input_dim, output_dim, self.hidden_dims, hid_activation_fn=self.hid_activation_fn.get_torch_function(),
output_activation_fn=self.output_activation_fn.get_torch_function(), p_dropout=self.p_dropout)
class RnnEncoderModule(torch.nn.Module):
"""
Encodes a sequence of feature vectors, outputting a latent vector.
The input sequence may either be fixed-length or variable-length.
"""
class RnnType:
GRU = "gru"
"""gated recurrent unit"""
LSTM = "lstm"
"""long short-term memory"""
def __init__(self, input_dim, latent_dim: int, rnn_type: RnnType = RnnType.LSTM):
"""
:param input_dim: the input dimension per time slice
:param latent_dim: the dimension of the latent output vector
:param rnn_type: the type of recurrent network to use
"""
super().__init__()
self.window_dim_per_item = input_dim
self.latent_dim = latent_dim
self.rnn_type = rnn_type
if rnn_type == self.RnnType.GRU:
self.rnn = torch.nn.GRU(input_size=self.window_dim_per_item, hidden_size=latent_dim, batch_first=True)
elif rnn_type == self.RnnType.LSTM:
self.rnn = torch.nn.LSTM(input_size=self.window_dim_per_item, hidden_size=latent_dim, batch_first=True)
else:
raise ValueError(f"Unknown rnn type '{rnn_type}', use either 'gru' or 'lstm'")
def __str__(self):
return object_repr(self, dict(rnn=self.rnn))
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None):
"""
:param x: a tensor of size (batch_size, seq_length, dim_per_item)
:param lengths: an optional tensor containing the lengths of the sequences; if None,
all sequences are assumed to have the same full length
:return: a tensor of size (batch_size, latent_dim)
"""
if lengths is not None:
x = torch.nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
if self.rnn_type == self.RnnType.GRU:
_, l = self.rnn(x)
elif self.rnn_type == self.RnnType.LSTM:
_o, (l, _c) = self.rnn(x)
else:
raise ValueError(self.rnn_type)
# l has shape (1, batch_size, latent_dim)
l = l.squeeze(0)
return l # (batch_size, latent_dim)
class RnnEncoderFactory(EncoderFactory):
def __init__(self,
input_dim: int,
latent_dim: int,
rnn_type: RnnEncoderModule.RnnType = RnnEncoderModule.RnnType.GRU):
self.input_dim = input_dim
self.latent_dim = latent_dim
self.rnn_type = rnn_type
def create_encoder(self, input_dim: int, latent_dim: int):
return RnnEncoderModule(input_dim, latent_dim, self.rnn_type)
class LSTNetworkEncoder(torch.nn.Module):
"""
Adapts an LSTNetwork instance to the encoder interface
"""
def __init__(self, lstnet: LSTNetwork):
super().__init__()
self.lstnet = lstnet
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None):
"""
:param x: a tensor of size (batch_size, seq_length, dim_per_item)
:param lengths: an optional tensor containing the lengths of the sequences; if None,
all sequences are assumed to have the same full length
:return: a tensor of size (batch_size, latent_dim)
"""
if lengths is not None:
unique_lengths = torch.unique(lengths)
if len(unique_lengths) != 1:
raise ValueError("LSTNetwork does not support variable-length inputs")
l = self.lstnet(x)
return l
class LSTNetworkEncoderFactory(EncoderFactory):
def __init__(self,
num_input_time_slices: int,
num_convolutions: int,
num_cnn_time_slices: int,
hid_rnn: int,
skip: int,
hid_skip: int,
dropout: float = 0.2):
self.dropout = dropout
self.num_input_time_slices = num_input_time_slices
self.num_convolutions = num_convolutions
self.num_cnn_time_slices = num_cnn_time_slices
self.hid_rnn = hid_rnn
self.skip = skip
self.hid_skip = hid_skip
def create_encoder(self, input_dim: int, latent_dim: int) -> torch.nn.Module:
lstnet = LSTNetwork(num_input_time_slices=self.num_input_time_slices,
input_dim_per_time_slice=input_dim,
num_convolutions=self.num_convolutions,
num_cnn_time_slices=self.num_cnn_time_slices,
hid_rnn=self.hid_rnn,
skip=self.skip,
hid_skip=self.hid_skip,
dropout=self.dropout,
mode=LSTNetwork.Mode.ENCODER)
if lstnet.get_encoder_dim() != latent_dim:
raise ValueError(f"LSTNetwork produces latent_dim={lstnet.get_encoder_dim()}; please adjust the parameter")
return LSTNetworkEncoder(lstnet)
def get_latent_dim(self) -> int:
return LSTNetwork.compute_encoder_dim(self.hid_rnn, self.skip, self.hid_skip)
class SingleTargetDecoderModule(torch.nn.Module, DecoderProtocol):
"""
Represents a decoder that output a single value for a single target item, taking as input the concatenation of the
latent tensor (generated by the encoder) and the target item's feature vector.
"""
def __init__(self, target_feature_dim, latent_dim, predictor_factory: PredictorFactory, output_dim=1):
"""
:param target_feature_dim: the number of target item features
:param latent_dim: the dimension of the latent vector generated by the encoder, which we receive as input
:param predictor_factory: a factory for the creation of the predictor that will map the combined
latent vector and target feature vector to the prediction of size `output_dim`
:param output_dim: the output (prediction) dimension
"""
super().__init__()
self.target_feature_dim = target_feature_dim
self.latent_dim = latent_dim
self.predictor_input_dim = self.latent_dim + self.target_feature_dim
self.predictor = predictor_factory.create_predictor(self.predictor_input_dim, output_dim)
def __str__(self):
return object_repr(self, dict(predictor=self.predictor))
def forward(self, latent, target_features=None, target_lengths=None):
if target_features is not None:
# target_features must have shape (batch_size, 1, target_feature_dim)
assert target_features.shape[1] == 1, "target_features must contain but one sequence item"
target_features = torch.squeeze(target_features, 1)
lf = torch.cat((latent, target_features), dim=1)
else:
lf = latent
return self.predictor(lf)
class TargetSequenceDecoderModule(torch.nn.Module, DecoderProtocol, ToStringMixin):
"""
Wrapper for decoders that take as input a latent representation (generated by an encoder)
and a sequence of target features.
It can generate either a single prediction for the entire sequence of target features or
a sequence of predictions (one for each target sequence item), depending on the prediction/output mode.
"""
class PredictionMode(Enum):
"""Defines how the prediction works"""
SINGLE_LATENT = "single_latent"
"""
Use an LSTM to process the target feature sequence and use only the final hidden state for prediction,
outputting a single average prediction only (for OutputMode.SINGLE_OUTPUT only)
"""
MULTI_LATENT = "multi_latent"
"""Use an LSTM to process the target feature sequence and use all hidden states (full output) for prediction"""
DIRECT = "direct"
"""Directly use the latent vector and target features to make predictions for each target sequence
item (use with LatentPassOnMode.CONCAT_INPUT & NO_LATENT only)
"""
class LatentPassOnMode(Enum):
"""Defines how the latent state from the encoder stage is passed on to the decoder"""
INIT_HIDDEN = "init_hidden"
"""
Pass on the encoder output as the initial hidden state of the LSTM
(only possible for OutputMode in {SINGLE_LATENT, MULTI_LATENT})
"""
CONCAT_INPUT = "concat_input"
"""Pass on the encoder output by concatenating it with each target feature input vector"""
NO_LATENT = "no_latent"
"""
Do not pass on the latent vector at all (ignored by subsequent decoder component).
This is mostly useful for ablation testing.
"""
class OutputMode(Enum):
"""Defines how to treat multiple predictions (for PredictionMode != SINGLE_LATENT)"""
SINGLE_OUTPUT = "single"
"""Output a single result from a single input (for PredictionMode.SINGLE_LATENT only)"""
SINGLE_OUTPUT_MEAN = "mean"
"""Output the mean of multiple (intermediate) predictions"""
MULTI_OUTPUT = "multi"
"""Output multiple predictions directly"""
def __init__(self,
target_feature_dim: int,
latent_dim: int,
predictor_factory: PredictorFactory,
output_dim: int = 1,
prediction_mode: PredictionMode = PredictionMode.MULTI_LATENT,
latent_pass_on_mode: LatentPassOnMode = LatentPassOnMode.CONCAT_INPUT,
output_mode: OutputMode = OutputMode.MULTI_OUTPUT,
p_recurrent_dropout: float = 0.0):
super().__init__()
if not ((prediction_mode == self.PredictionMode.SINGLE_LATENT) ==
(output_mode == self.OutputMode.SINGLE_OUTPUT)): # SINGLE_LATENT <=> SINGLE_OUTPUT
raise ValueError(f"{self.PredictionMode.SINGLE_LATENT} must coincide with {self.OutputMode.SINGLE_OUTPUT}; "
f"got {prediction_mode} and {output_mode}")
if prediction_mode == self.PredictionMode.DIRECT and \
latent_pass_on_mode not in (self.LatentPassOnMode.CONCAT_INPUT, self.LatentPassOnMode.NO_LATENT):
raise ValueError(f"{prediction_mode} requires {self.LatentPassOnMode.CONCAT_INPUT}")
if latent_pass_on_mode == self.LatentPassOnMode.INIT_HIDDEN and \
prediction_mode not in (self.PredictionMode.SINGLE_LATENT, self.PredictionMode.MULTI_LATENT):
raise ValueError(f"{output_mode} requires {self.PredictionMode.SINGLE_LATENT} or {self.PredictionMode.MULTI_LATENT} ")
if latent_pass_on_mode == self.LatentPassOnMode.NO_LATENT:
latent_dim = 0
self.latent_pass_on_mode = latent_pass_on_mode
self.prediction_mode = prediction_mode
self.output_mode = output_mode
self.target_feature_dim = target_feature_dim
self.latent_dim = latent_dim
if prediction_mode == self.PredictionMode.DIRECT:
self.lstm = None
predictor_input_dim = self.latent_dim + self.target_feature_dim
else:
if latent_pass_on_mode == self.LatentPassOnMode.INIT_HIDDEN:
rnn_input_dim = target_feature_dim
else:
rnn_input_dim = target_feature_dim + latent_dim
self.lstm = torch.nn.LSTM(rnn_input_dim, self.latent_dim, batch_first=True, dropout=p_recurrent_dropout)
predictor_input_dim = self.latent_dim
self.predictor = predictor_factory.create_predictor(predictor_input_dim, output_dim)
def _tostring_exclude_private(self) -> bool:
return True
def forward(self, latent, target_features=None, target_lengths=None):
"""
:param latent: a tensor of shape (batch_size, latent_dim)
:param target_features: a tensor of shape (batch_size, max_seq_length, target_feature_dim)
:param target_lengths: a tensor indicating the lengths of the sequences in target_features
:return:
"""
if target_features is None:
raise ValueError(f"target_features cannot be None when using {self.__class__}")
# latent has shape (batch_size, latentDim)
# targetFeatures has shape (batch_size, maxSeqLength, targetFeatureDim)
batch_size = target_features.shape[0]
use_lstm = self.prediction_mode != self.PredictionMode.DIRECT
lstm_input, s0, latent_plus_target_features = None, None, None
if self.latent_pass_on_mode == self.LatentPassOnMode.INIT_HIDDEN:
if target_lengths is not None:
lstm_input = torch.nn.utils.rnn.pack_padded_sequence(target_features, target_lengths, batch_first=True,
enforce_sorted=False)
else:
lstm_input = target_features
c0 = torch.zeros((1, batch_size, self.latent_dim)).to(latent.device)
h0 = latent.unsqueeze(0)
s0 = (h0, c0)
elif self.latent_pass_on_mode in (self.LatentPassOnMode.CONCAT_INPUT, self.LatentPassOnMode.NO_LATENT):
if self.latent_pass_on_mode == self.LatentPassOnMode.NO_LATENT:
latent_plus_target_features = target_features
else:
latent = latent.unsqueeze(1) # (batch_size, 1, latentDim)
latent = latent.expand(-1, target_features.shape[1], -1) # (batch_size, maxSeqLength, latentDim)
latent_plus_target_features = torch.cat((latent, target_features), dim=2)
# latent_plus_target_features has shape (batch_size, maxSeqLength, latentDim + targetFeatureDim)
if use_lstm:
lstm_input = torch.nn.utils.rnn.pack_padded_sequence(latent_plus_target_features, target_lengths, batch_first=True,
enforce_sorted=False)
s0 = None
else:
raise ValueError(f"Unknown latent pass-on mode '{self.latent_pass_on_mode}'")
if self.prediction_mode == self.PredictionMode.SINGLE_LATENT: # use only final latent state and produce a single output
_, (hn, _) = self.lstm(lstm_input, s0)
encoding = hn.squeeze(0)
result = self.predictor(encoding)
else: # compute multiple predictions (and optionally compute their mean)
if self.prediction_mode == self.PredictionMode.MULTI_LATENT: # use all latent states
hseq, _ = self.lstm(lstm_input, s0)
encodings, lengths = torch.nn.utils.rnn.pad_packed_sequence(hseq, batch_first=True) # (batch_size, maxSeqLength, latentDim)
predictions = self.predictor(encodings) # (batch_size, maxSeqLength, outputDim)
elif self.prediction_mode == self.PredictionMode.DIRECT: # directly map concatenated values to outputs via predictor
predictions = self.predictor(latent_plus_target_features)
else:
raise ValueError(f"Unknown prediction mode '{self.prediction_mode}'")
if self.output_mode == self.OutputMode.SINGLE_OUTPUT_MEAN:
mean_predictions = predictions.data.new(batch_size, 1)
for i, l in enumerate(target_lengths):
mean_predictions[i] = predictions[i][:l].sum() / l
result = mean_predictions
elif self.output_mode == self.OutputMode.MULTI_OUTPUT:
return predictions
else:
raise ValueError(self.output_mode)
return result
class TargetSequenceDecoderFactory(DecoderFactory):
"""
A factory for :class:`TargetSequenceDecoderModule` which takes the latent encoding and a sequence of target
items as input
"""
def __init__(self,
prediction_mode: TargetSequenceDecoderModule.PredictionMode = TargetSequenceDecoderModule.PredictionMode.MULTI_LATENT,
output_mode: TargetSequenceDecoderModule.OutputMode = TargetSequenceDecoderModule.OutputMode.MULTI_OUTPUT,
latent_pass_on_mode: TargetSequenceDecoderModule.LatentPassOnMode = TargetSequenceDecoderModule.LatentPassOnMode.CONCAT_INPUT,
predictor_factory: Optional[PredictorFactory] = None,
p_recurrent_dropout: float = 0.0,
output_dim: int = 1):
if predictor_factory is None:
predictor_factory = LinearPredictorFactory()
self.output_dim = output_dim
self.p_recurrent_dropout = p_recurrent_dropout
self.prediction_mode = prediction_mode
self.output_mode = output_mode
self.latent_pass_on_mode = latent_pass_on_mode
self.predictor_factory = predictor_factory
def create_decoder(self, latent_dim: int, target_feature_dim: int) -> torch.nn.Module:
return TargetSequenceDecoderModule(target_feature_dim, latent_dim, self.predictor_factory,
prediction_mode=self.prediction_mode,
output_mode=self.output_mode,
latent_pass_on_mode=self.latent_pass_on_mode,
output_dim=self.output_dim,
p_recurrent_dropout=self.p_recurrent_dropout)
class SingleTargetDecoderFactory(DecoderFactory):
"""
A factory for :class:`SingleTargetDecoderModule` which takes the latent encoding and a single-element sequence of target
items as input, producing a single prediction
"""
def __init__(self, predictor_factory: PredictorFactory):
self.predictor_factory = predictor_factory
def create_decoder(self, latent_dim: int, target_feature_dim: int) -> torch.nn.Module:
return SingleTargetDecoderModule(target_feature_dim, latent_dim, self.predictor_factory)
class EncoderDecoderModule(torch.nn.Module):
"""
Represents and encoder-decoder (where both components can be injected). It takes a history sequence and a sequence of
target feature vectors as input. Both sequences are potentially of variable length, and for the target sequence,
the common special case where there is but one target and thus one prediction to be made is specifically catered for
using dedicated decoders (see :class:`SingleTargetDecoderModule`).
The module first encodes the history sequence to a latent vector and then uses the decoder to map this latent vector
along with the target features to a prediction.
"""
def __init__(self, encoder: TEncoder, decoder: TDecoder, variable_history_length: bool):
"""
:param encoder: a torch module satisfying :class:`EncoderProtocol`
:param decoder: a torch module satisfying :class:`DecoderProtocol`
:param variable_history_length: whether the history sequence is variable-length.
If it is not, then the model will not pass on the lengths tensor to the encoder, allowing it to simplify
its handling of this case (even if the original input provides the lengths).
"""
super().__init__()
self.variable_history_length = variable_history_length
self.encoder = encoder
self.decoder = decoder
def __str__(self):
return object_repr(self, dict(encoder=self.encoder, predictor=self.decoder))
def forward(self, window_features: torch.Tensor,
window_lengths: Optional[torch.Tensor] = None,
target_features: Optional[torch.Tensor] = None,
target_lengths: Optional[torch.Tensor] = None):
"""
:param window_features: a tensor of size (batch_size, max(window_lengths), dim_per_window_item) containing the window features
:param window_lengths: a tensor containing the lengths of windows in `w`
:param target_features: an optional tensor containing target features with shape
(batch_size, max_target_seq_length, target_feature_dim).
For the case where there is only one target item (no actual sequence), `max_target_seq_length`
should be 1.
:param target_lengths: an optional tensor containing the lengths target the target sequences, allowing
the actual sequence lengths to differ
"""
if self.variable_history_length:
latent = self.encoder(window_features, window_lengths)
else:
latent = self.encoder(window_features)
return self.decoder(latent, target_features, target_lengths)