forked from awslabs/gluonts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add SampleForecast and Predictor objects for TPPs (awslabs#819)
- Loading branch information
1 parent
0d963f7
commit a757256
Showing
9 changed files
with
597 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from .forecast import PointProcessSampleForecast | ||
from .predictor import PointProcessGluonPredictor | ||
|
||
|
||
__all__ = ["PointProcessGluonPredictor", "PointProcessSampleForecast"] | ||
|
||
|
||
# fix Sphinx issues, see https://bit.ly/2K2eptM | ||
for item in __all__: | ||
if hasattr(item, "__module__"): | ||
setattr(item, "__module__", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from typing import Dict, Optional, Union, cast | ||
|
||
# Third-party imports | ||
import mxnet as mx | ||
import numpy as np | ||
import pandas as pd | ||
|
||
# First-party imports | ||
from gluonts.model.forecast import OutputType, Forecast, Config | ||
from pandas import to_timedelta | ||
|
||
|
||
class PointProcessSampleForecast(Forecast): | ||
""" | ||
Sample forecast object used for temporal point process inference. | ||
Differs from standard forecast objects as it does not implement | ||
fixed length samples. Each sample has a variable length, that is | ||
kept in a separate :code:`valid_length` attribute. | ||
Importantly, PointProcessSampleForecast does not implement some | ||
methods (such as :code:`quantile` or :code:`plot`) that are available | ||
in discrete time forecasts. | ||
Parameters | ||
---------- | ||
samples | ||
A multidimensional array of samples, of shape | ||
(number_of_samples, max_pred_length, target_dim) or | ||
(number_of_samples, max_pred_length). For marked TPP, the | ||
target_dim is 2 with the first element keeping interarrival times | ||
and the second keeping marks. If samples are two-dimensional, each | ||
entry stands for the interarrival times in a (unmarked) TPP sample. | ||
valid_length | ||
An array of integers denoting the valid lengths of each sample | ||
in :code:`samples`. That is, :code:`valid_length[0] == 2` implies | ||
that only the first two entries of :code:`samples[0, ...]` are | ||
valid "points". | ||
start_date | ||
Starting timestamp of the sample | ||
freq | ||
The time unit of interarrival times | ||
prediction_interval_length | ||
The length of the prediction interval for which samples were drawn. | ||
item_id | ||
Item ID, if available. | ||
info | ||
Optional dictionary of additional information. | ||
""" | ||
|
||
prediction_interval_length: float | ||
|
||
# not used | ||
prediction_length = cast(int, None) | ||
mean = None | ||
_index = None | ||
|
||
def __init__( | ||
self, | ||
samples: Union[mx.nd.NDArray, np.ndarray], | ||
valid_length: Union[mx.nd.NDArray, np.ndarray], | ||
start_date: pd.Timestamp, | ||
freq: str, | ||
prediction_interval_length: float, | ||
item_id: Optional[str] = None, | ||
info: Optional[Dict] = None, | ||
) -> None: | ||
assert isinstance( | ||
samples, (np.ndarray, mx.nd.NDArray) | ||
), "samples should be either a numpy or an mxnet array" | ||
assert ( | ||
samples.ndim == 2 or samples.ndim == 3 | ||
), f"samples should be a 2-dimensional or 3-dimensional array. Dimensions found: {samples.ndim}" | ||
|
||
assert isinstance( | ||
valid_length, (np.ndarray, mx.nd.NDArray) | ||
), "samples should be either a numpy or an mxnet array" | ||
assert ( | ||
valid_length.ndim == 1 | ||
), "valid_length should be a 1-dimensional array" | ||
assert ( | ||
valid_length.shape[0] == samples.shape[0] | ||
), "valid_length and samples should have compatible dimensions" | ||
|
||
self.samples, self.valid_length = ( | ||
x if isinstance(x, np.ndarray) else x.asnumpy() | ||
for x in (samples, valid_length) | ||
) | ||
|
||
self._dim = samples.ndim | ||
self.item_id = item_id | ||
self.info = info | ||
|
||
assert isinstance( | ||
start_date, pd.Timestamp | ||
), "start_date should be a pandas Timestamp object" | ||
self.start_date = start_date | ||
|
||
assert isinstance(freq, str), "freq should be a string" | ||
self.freq = freq | ||
|
||
assert ( | ||
prediction_interval_length > 0 | ||
), "prediction_interval_length must be greater than 0" | ||
self.prediction_interval_length = prediction_interval_length | ||
|
||
self.end_date = ( | ||
start_date | ||
+ to_timedelta(1, self.freq) * prediction_interval_length | ||
) | ||
|
||
def dim(self) -> int: | ||
return self._dim | ||
|
||
@property | ||
def index(self) -> pd.DatetimeIndex: | ||
raise AttributeError( | ||
"Datetime index not defined for point process samples" | ||
) | ||
|
||
def as_json_dict(self, config: "Config") -> dict: | ||
result = super().as_json_dict(config) | ||
|
||
if OutputType.samples in config.output_types: | ||
result["samples"] = self.samples.tolist() | ||
result["valid_length"] = self.valid_length.tolist() | ||
|
||
return result | ||
|
||
def __repr__(self): | ||
return ", ".join( | ||
[ | ||
f"PointProcessSampleForecast({self.samples!r})", | ||
f"{self.valid_length!r}", | ||
f"{self.start_date!r}", | ||
f"{self.end_date!r}", | ||
f"{self.freq!r}", | ||
f"item_id={self.item_id!r}", | ||
f"info={self.info!r})", | ||
] | ||
) | ||
|
||
def quantile(self, q: Union[float, str]) -> np.ndarray: | ||
raise NotImplementedError( | ||
"Quantile function is not defined for point process samples" | ||
) | ||
|
||
def plot(self, **kwargs): | ||
raise NotImplementedError( | ||
"Plotting not implemented for point process samples" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
# Standard library imports | ||
from functools import partial | ||
from pathlib import Path | ||
from typing import Iterator, List, Optional, cast | ||
|
||
# Third-party imports | ||
import mxnet as mx | ||
import numpy as np | ||
|
||
# First-party imports | ||
from gluonts.core.component import DType | ||
from gluonts.dataset.common import Dataset | ||
from gluonts.dataset.loader import DataBatch, InferenceDataLoader | ||
from gluonts.dataset.parallelized_loader import batchify | ||
from gluonts.model.forecast import Forecast | ||
from gluonts.model.forecast_generator import ForecastGenerator | ||
from gluonts.model.predictor import ( | ||
GluonPredictor, | ||
SymbolBlockPredictor, | ||
OutputTransform, | ||
) | ||
from gluonts.transform import Transformation | ||
|
||
# Relative imports | ||
from .forecast import PointProcessSampleForecast | ||
|
||
|
||
class PointProcessForecastGenerator(ForecastGenerator): | ||
def __call__( | ||
self, | ||
inference_data_loader: InferenceDataLoader, | ||
prediction_net: mx.gluon.Block, | ||
input_names: List[str], | ||
freq: str, | ||
output_transform: Optional[OutputTransform], | ||
num_samples: Optional[int], | ||
**kwargs, | ||
) -> Iterator[Forecast]: | ||
|
||
for batch in inference_data_loader: | ||
inputs = [batch[k] for k in input_names] | ||
|
||
outputs, valid_length = ( | ||
x.asnumpy() for x in prediction_net(*inputs) | ||
) | ||
|
||
# sample until enough point process trajectories are collected | ||
if num_samples: | ||
num_collected_samples = outputs[0].shape[0] | ||
collected_samples, collected_vls = [outputs], [valid_length] | ||
while num_collected_samples < num_samples: | ||
outputs, valid_length = ( | ||
x.asnumpy() for x in prediction_net(*inputs) | ||
) | ||
|
||
collected_samples.append(outputs) | ||
collected_vls.append(valid_length) | ||
|
||
num_collected_samples += outputs[0].shape[0] | ||
|
||
outputs = [ | ||
np.concatenate(s)[:num_samples] | ||
for s in zip(*collected_samples) | ||
] | ||
valid_length = [ | ||
np.concatenate(s)[:num_samples] | ||
for s in zip(*collected_vls) | ||
] | ||
|
||
assert len(outputs[0]) == num_samples | ||
assert len(valid_length[0]) == num_samples | ||
|
||
assert len(batch["forecast_start"]) == len(outputs) | ||
|
||
for i, output in enumerate(outputs): | ||
yield PointProcessSampleForecast( | ||
output, | ||
valid_length=valid_length[i], | ||
start_date=batch["forecast_start"][i], | ||
freq=freq, | ||
prediction_interval_length=prediction_net.prediction_interval_length, | ||
item_id=batch["item_id"][i] | ||
if "item_id" in batch | ||
else None, | ||
info=batch["info"][i] if "info" in batch else None, | ||
) | ||
|
||
|
||
class PointProcessGluonPredictor(GluonPredictor): | ||
""" | ||
Predictor object for marked temporal point process models. | ||
TPP predictions differ from standard discrete-time models in several | ||
regards. First, at least for now, only sample forecasts implementing | ||
PointProcessSampleForecast are available. Similar to TPP Estimator | ||
objects, the Predictor works with :code:`prediction_interval_length` | ||
as opposed to :code:`prediction_length`. | ||
The predictor also accounts for the fact that the prediction network | ||
outputs a 2-tuple of Tensors, for the samples themselves and their | ||
`valid_length`. | ||
Finally, this class uses a VariableLengthInferenceDataLoader as opposed | ||
to the default InferenceDataLoader. | ||
Parameters | ||
---------- | ||
prediction_interval_length | ||
The length of the prediction interval | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_names: List[str], | ||
prediction_net: mx.gluon.Block, | ||
batch_size: int, | ||
prediction_interval_length: float, | ||
freq: str, | ||
ctx: mx.Context, | ||
input_transform: Transformation, | ||
dtype: DType = np.float32, | ||
forecast_generator: ForecastGenerator = PointProcessForecastGenerator(), | ||
**kwargs, | ||
) -> None: | ||
super().__init__( | ||
input_names=input_names, | ||
prediction_net=prediction_net, | ||
batch_size=batch_size, | ||
prediction_length=np.ceil( | ||
prediction_interval_length | ||
), # for validation only | ||
freq=freq, | ||
ctx=ctx, | ||
input_transform=input_transform, | ||
output_transform=None, | ||
dtype=dtype, | ||
lead_time=0, | ||
**kwargs, | ||
) | ||
|
||
# not used by TPP predictor | ||
self.prediction_length = cast(int, None) | ||
|
||
self.forecast_generator = forecast_generator | ||
self.prediction_interval_length = prediction_interval_length | ||
|
||
def hybridize(self, batch: DataBatch) -> None: | ||
raise NotImplementedError( | ||
"Point process models are currently not hybridizable" | ||
) | ||
|
||
def as_symbol_block_predictor( | ||
self, batch: DataBatch | ||
) -> SymbolBlockPredictor: | ||
raise NotImplementedError( | ||
"Point process models are currently not hybridizable" | ||
) | ||
|
||
def predict( | ||
self, | ||
dataset: Dataset, | ||
num_samples: Optional[int] = None, | ||
num_workers: Optional[int] = None, | ||
num_prefetch: Optional[int] = None, | ||
**kwargs, | ||
) -> Iterator[Forecast]: | ||
yield from super().predict( | ||
dataset=dataset, | ||
num_samples=num_samples, | ||
num_workers=num_workers, | ||
num_prefetch=num_prefetch, | ||
batchify_fn=partial(batchify, variable_length=True), | ||
) | ||
|
||
def serialize_prediction_net(self, path: Path) -> None: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.