Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bayesnf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed every time `__version__` is increased.
# When changing this, also update the CHANGELOG.md
__version__ = '0.1.2'
__version__ = '0.1.3'

from .spatiotemporal import BayesianNeuralFieldMAP
from .spatiotemporal import BayesianNeuralFieldMLE
Expand Down
80 changes: 58 additions & 22 deletions src/bayesnf/spatiotemporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def seasonality_to_float(seasonality: str, freq: str) -> float:


def seasonalities_to_array(
seasonalities: Sequence[float | str], freq: str
) -> np.ndarray:
seasonalities: Sequence[float | str],
freq: str
) -> np.ndarray:
"""Convert a list of floats or strings to durations relative to a frequency.

Args:
Expand Down Expand Up @@ -99,6 +100,10 @@ def _convert_datetime_col(table, time_column, timetype, freq, time_min=None):
first_date = pd.to_datetime('2020-01-01').to_period(freq)
table[time_column] = table[time_column].dt.to_period(freq)
table[time_column] = (table[time_column] - first_date).apply(lambda x: x.n)
elif timetype == 'float':
table[time_column] = table[time_column].apply(float)
else:
raise ValueError(f'Unknown timetype: {timetype}')
if time_min is None:
time_min = table[time_column].min()
table[time_column] = table[time_column] - time_min
Expand Down Expand Up @@ -217,7 +222,7 @@ def __init__(
num_seasonal_harmonics: Sequence[int] | None = None,
fourier_degrees: Sequence[float] | None = None,
interactions: Sequence[tuple[int, int]] | None = None,
freq: str,
freq: str | None = None,
timetype: str = 'index',
depth: int = 2,
width: int = 512,
Expand All @@ -237,16 +242,18 @@ def __init__(

seasonality_periods:
A list of numbers representing the seasonal frequencies of the data
in the time domain. It is also possible to specify a string such as
'W', 'D', etc. corresponding to a valid Pandas frequency: see the
Pandas [Offset Aliases](
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
for valid values.
in the time domain. If timetype == 'index', then it is possible
to specify numeric frequencies by using string short hands such as
'W', 'D', etc., which correspond to a valid Pandas frequency.
See Pandas [Offset Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
for valid string values.

num_seasonal_harmonics:
A list of seasonal harmonics, one for each entry in
`seasonality_periods`. The number of seasonal harmonics (h) for a
given seasonal period `p` must satisfy `h < p//2`.
given seasonal period `p` must satisfy `h < p//2`. It is an error
fir `len(num_seasonal_harmonics) != len(seasonality_periods)`.
Should be used only if `timetype == 'index'`.

fourier_degrees:
A list of integer degrees for the Fourier features of the inputs.
Expand All @@ -263,13 +270,13 @@ def __init__(
freq:
A frequency string for the sampling rate at which the data is
collected. See the Pandas
[Offset Aliases](
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
for valid values.
[Offset Aliases](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)
for valid values. Should be used if and only if `timetype == 'index'`.

timetype:
Must be specified as `index`. The general versions will be
integrated pending https://github.com/google/bayesnf/issues/16.
Either `index` or `float`. If `index`, then the time column must
be a `datetime` type and `freq` must be given.
Otherwise, if `float`, then the time column must be `float`.

depth:
The number of hidden layers in the BayesNF architecture.
Expand Down Expand Up @@ -337,18 +344,47 @@ def _get_interactions(self) -> np.ndarray:
f' passed shape was {interactions.shape})')
return interactions

def _get_seasonality_periods(self):
"""Return array of seasonal periods."""
if (
(self.timetype == 'index' and self.freq is None) or
(self.timetype == 'float' and self.freq is not None)):
raise ValueError(f'Invalid {self.freq=} with {self.timetype=}.')
if self.seasonality_periods is None:
return np.zeros(0)
if self.timetype == 'index':
return seasonalities_to_array(self.seasonality_periods, self.freq)
if self.timetype == 'float':
return np.asarray(self.seasonality_periods, dtype=float)
assert False, f'Impossible {self.timetype=}.'

def _get_num_seasonal_harmonics(self):
"""Return array of seasonal harmonics per seasonal period."""
# Discrete time.
if self.timetype == 'index':
return (
np.array(self.num_seasonal_harmonics)
if self.num_seasonal_harmonics is not None else
np.zeros(0))
# Continuous time.
if self.timetype == 'float':
if self.num_seasonal_harmonics is not None:
raise ValueError(
f'Cannot use num_seasonal_harmonics with {self.timetype=}.')
# HACK: models.make_seasonal_frequencies assumes the data is discrete
# time where each harmonic h is between 1, ..., p/2 and the harmonic
# factors are np.arange(1, h + 1). Since our goal with continuous
# time data is exactly 1 harmonic per seasonal factor, any h between
# 0 and min(0.5, p/2) will work, as np.arange(1, 1+h) = [1]
return np.fmin(.5, self._get_seasonality_periods() / 2)
assert False, f'Impossible {timetype=}.'

def _model_args(self, batch_shape):
return {
'depth': self.depth,
'input_scales': self.data_handler.get_input_scales(),
'num_seasonal_harmonics':
np.array(self.num_seasonal_harmonics)
if self.num_seasonal_harmonics is not None
else np.zeros(0),
'seasonality_periods':
seasonalities_to_array(self.seasonality_periods, self.freq)
if self.seasonality_periods is not None
else np.zeros(0),
'num_seasonal_harmonics': self._get_num_seasonal_harmonics(),
'seasonality_periods': self._get_seasonality_periods(),
'width': self.width,
'init_x': batch_shape,
'fourier_degrees': self._get_fourier_degrees(batch_shape),
Expand Down
45 changes: 0 additions & 45 deletions tests/spatiotemporal_test.py

This file was deleted.

Loading