Skip to content

Commit

Permalink
add SampleForecast and Predictor objects for TPPs (awslabs#819)
Browse files Browse the repository at this point in the history
  • Loading branch information
canerturkmen committed May 20, 2020
1 parent 0d963f7 commit a757256
Show file tree
Hide file tree
Showing 9 changed files with 597 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/gluonts/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class GluonPredictor(Predictor):
ctx
MXNet context to use for computation
forecast_generator
Class to generate forecasts from network ouputs
Class to generate forecasts from network outputs
"""

BlockType = mx.gluon.Block
Expand Down
Empty file added src/gluonts/model/tpp/.typesafe
Empty file.
24 changes: 24 additions & 0 deletions src/gluonts/model/tpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from .forecast import PointProcessSampleForecast
from .predictor import PointProcessGluonPredictor


__all__ = ["PointProcessGluonPredictor", "PointProcessSampleForecast"]


# fix Sphinx issues, see https://bit.ly/2K2eptM
for item in __all__:
if hasattr(item, "__module__"):
setattr(item, "__module__", __name__)
163 changes: 163 additions & 0 deletions src/gluonts/model/tpp/forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, Optional, Union, cast

# Third-party imports
import mxnet as mx
import numpy as np
import pandas as pd

# First-party imports
from gluonts.model.forecast import OutputType, Forecast, Config
from pandas import to_timedelta


class PointProcessSampleForecast(Forecast):
"""
Sample forecast object used for temporal point process inference.
Differs from standard forecast objects as it does not implement
fixed length samples. Each sample has a variable length, that is
kept in a separate :code:`valid_length` attribute.
Importantly, PointProcessSampleForecast does not implement some
methods (such as :code:`quantile` or :code:`plot`) that are available
in discrete time forecasts.
Parameters
----------
samples
A multidimensional array of samples, of shape
(number_of_samples, max_pred_length, target_dim) or
(number_of_samples, max_pred_length). For marked TPP, the
target_dim is 2 with the first element keeping interarrival times
and the second keeping marks. If samples are two-dimensional, each
entry stands for the interarrival times in a (unmarked) TPP sample.
valid_length
An array of integers denoting the valid lengths of each sample
in :code:`samples`. That is, :code:`valid_length[0] == 2` implies
that only the first two entries of :code:`samples[0, ...]` are
valid "points".
start_date
Starting timestamp of the sample
freq
The time unit of interarrival times
prediction_interval_length
The length of the prediction interval for which samples were drawn.
item_id
Item ID, if available.
info
Optional dictionary of additional information.
"""

prediction_interval_length: float

# not used
prediction_length = cast(int, None)
mean = None
_index = None

def __init__(
self,
samples: Union[mx.nd.NDArray, np.ndarray],
valid_length: Union[mx.nd.NDArray, np.ndarray],
start_date: pd.Timestamp,
freq: str,
prediction_interval_length: float,
item_id: Optional[str] = None,
info: Optional[Dict] = None,
) -> None:
assert isinstance(
samples, (np.ndarray, mx.nd.NDArray)
), "samples should be either a numpy or an mxnet array"
assert (
samples.ndim == 2 or samples.ndim == 3
), f"samples should be a 2-dimensional or 3-dimensional array. Dimensions found: {samples.ndim}"

assert isinstance(
valid_length, (np.ndarray, mx.nd.NDArray)
), "samples should be either a numpy or an mxnet array"
assert (
valid_length.ndim == 1
), "valid_length should be a 1-dimensional array"
assert (
valid_length.shape[0] == samples.shape[0]
), "valid_length and samples should have compatible dimensions"

self.samples, self.valid_length = (
x if isinstance(x, np.ndarray) else x.asnumpy()
for x in (samples, valid_length)
)

self._dim = samples.ndim
self.item_id = item_id
self.info = info

assert isinstance(
start_date, pd.Timestamp
), "start_date should be a pandas Timestamp object"
self.start_date = start_date

assert isinstance(freq, str), "freq should be a string"
self.freq = freq

assert (
prediction_interval_length > 0
), "prediction_interval_length must be greater than 0"
self.prediction_interval_length = prediction_interval_length

self.end_date = (
start_date
+ to_timedelta(1, self.freq) * prediction_interval_length
)

def dim(self) -> int:
return self._dim

@property
def index(self) -> pd.DatetimeIndex:
raise AttributeError(
"Datetime index not defined for point process samples"
)

def as_json_dict(self, config: "Config") -> dict:
result = super().as_json_dict(config)

if OutputType.samples in config.output_types:
result["samples"] = self.samples.tolist()
result["valid_length"] = self.valid_length.tolist()

return result

def __repr__(self):
return ", ".join(
[
f"PointProcessSampleForecast({self.samples!r})",
f"{self.valid_length!r}",
f"{self.start_date!r}",
f"{self.end_date!r}",
f"{self.freq!r}",
f"item_id={self.item_id!r}",
f"info={self.info!r})",
]
)

def quantile(self, q: Union[float, str]) -> np.ndarray:
raise NotImplementedError(
"Quantile function is not defined for point process samples"
)

def plot(self, **kwargs):
raise NotImplementedError(
"Plotting not implemented for point process samples"
)
189 changes: 189 additions & 0 deletions src/gluonts/model/tpp/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from functools import partial
from pathlib import Path
from typing import Iterator, List, Optional, cast

# Third-party imports
import mxnet as mx
import numpy as np

# First-party imports
from gluonts.core.component import DType
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import DataBatch, InferenceDataLoader
from gluonts.dataset.parallelized_loader import batchify
from gluonts.model.forecast import Forecast
from gluonts.model.forecast_generator import ForecastGenerator
from gluonts.model.predictor import (
GluonPredictor,
SymbolBlockPredictor,
OutputTransform,
)
from gluonts.transform import Transformation

# Relative imports
from .forecast import PointProcessSampleForecast


class PointProcessForecastGenerator(ForecastGenerator):
def __call__(
self,
inference_data_loader: InferenceDataLoader,
prediction_net: mx.gluon.Block,
input_names: List[str],
freq: str,
output_transform: Optional[OutputTransform],
num_samples: Optional[int],
**kwargs,
) -> Iterator[Forecast]:

for batch in inference_data_loader:
inputs = [batch[k] for k in input_names]

outputs, valid_length = (
x.asnumpy() for x in prediction_net(*inputs)
)

# sample until enough point process trajectories are collected
if num_samples:
num_collected_samples = outputs[0].shape[0]
collected_samples, collected_vls = [outputs], [valid_length]
while num_collected_samples < num_samples:
outputs, valid_length = (
x.asnumpy() for x in prediction_net(*inputs)
)

collected_samples.append(outputs)
collected_vls.append(valid_length)

num_collected_samples += outputs[0].shape[0]

outputs = [
np.concatenate(s)[:num_samples]
for s in zip(*collected_samples)
]
valid_length = [
np.concatenate(s)[:num_samples]
for s in zip(*collected_vls)
]

assert len(outputs[0]) == num_samples
assert len(valid_length[0]) == num_samples

assert len(batch["forecast_start"]) == len(outputs)

for i, output in enumerate(outputs):
yield PointProcessSampleForecast(
output,
valid_length=valid_length[i],
start_date=batch["forecast_start"][i],
freq=freq,
prediction_interval_length=prediction_net.prediction_interval_length,
item_id=batch["item_id"][i]
if "item_id" in batch
else None,
info=batch["info"][i] if "info" in batch else None,
)


class PointProcessGluonPredictor(GluonPredictor):
"""
Predictor object for marked temporal point process models.
TPP predictions differ from standard discrete-time models in several
regards. First, at least for now, only sample forecasts implementing
PointProcessSampleForecast are available. Similar to TPP Estimator
objects, the Predictor works with :code:`prediction_interval_length`
as opposed to :code:`prediction_length`.
The predictor also accounts for the fact that the prediction network
outputs a 2-tuple of Tensors, for the samples themselves and their
`valid_length`.
Finally, this class uses a VariableLengthInferenceDataLoader as opposed
to the default InferenceDataLoader.
Parameters
----------
prediction_interval_length
The length of the prediction interval
"""

def __init__(
self,
input_names: List[str],
prediction_net: mx.gluon.Block,
batch_size: int,
prediction_interval_length: float,
freq: str,
ctx: mx.Context,
input_transform: Transformation,
dtype: DType = np.float32,
forecast_generator: ForecastGenerator = PointProcessForecastGenerator(),
**kwargs,
) -> None:
super().__init__(
input_names=input_names,
prediction_net=prediction_net,
batch_size=batch_size,
prediction_length=np.ceil(
prediction_interval_length
), # for validation only
freq=freq,
ctx=ctx,
input_transform=input_transform,
output_transform=None,
dtype=dtype,
lead_time=0,
**kwargs,
)

# not used by TPP predictor
self.prediction_length = cast(int, None)

self.forecast_generator = forecast_generator
self.prediction_interval_length = prediction_interval_length

def hybridize(self, batch: DataBatch) -> None:
raise NotImplementedError(
"Point process models are currently not hybridizable"
)

def as_symbol_block_predictor(
self, batch: DataBatch
) -> SymbolBlockPredictor:
raise NotImplementedError(
"Point process models are currently not hybridizable"
)

def predict(
self,
dataset: Dataset,
num_samples: Optional[int] = None,
num_workers: Optional[int] = None,
num_prefetch: Optional[int] = None,
**kwargs,
) -> Iterator[Forecast]:
yield from super().predict(
dataset=dataset,
num_samples=num_samples,
num_workers=num_workers,
num_prefetch=num_prefetch,
batchify_fn=partial(batchify, variable_length=True),
)

def serialize_prediction_net(self, path: Path) -> None:
raise NotImplementedError()
2 changes: 2 additions & 0 deletions src/gluonts/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"cdf_to_gaussian_forward_transform",
"CDFtoGaussianTransform",
"ConcatFeatures",
"ContinuousTimeInstanceSplitter",
"ContinuousTimeUniformSampler",
"ExpandDimArray",
"ExpectedNumInstanceSampler",
"FilterTransformation",
Expand Down
Loading

0 comments on commit a757256

Please sign in to comment.