# System dynamics

$\pi_{i,t} = \tau_{i,t} + \delta_{i,1,t} \cdot seas_1(t) + \ldots + \delta_{i,4,t} \cdot seas_4(t) + \eta_{i,t}$

$\tau_{i,t} = \tau_{i,t-1}  + \epsilon_{i,t}$

$\delta_{i,j,t} = \delta_{i,j,t-1} + seas_j(t) \cdot \xi_{i,t},\quad i = 1,2,3,4$

(after each step, subtract mean from deltas to get 0 delta mean.)

$\eta_t \sim \mathcal{N}(\mathbb{0}, \sigma_{\eta,t}^2)$

$\epsilon_t \sim \mathcal{N}(\mathbb{0}, \sigma_{\epsilon,t}^2)$

$\ln \sigma_{\eta,t}^2 = \ln \sigma_{\eta,t-1}^2 + \nu_{\eta, t}$

$\ln \sigma_{\epsilon,t}^2 = \ln \sigma_{\epsilon,t-1}^2 + \nu_{\epsilon, t}$

where $seas_i(t) = 1$ iff the season corresponding to timestamp $t$ is $i$.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

# env XLA_FLAGS=--xla_force_host_platform_device_count=28 python foo.py
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=1'

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.devices()

I0000 00:00:1712229555.658932 3933033 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7),
 CpuDevice(id=8),
 CpuDevice(id=9),
 CpuDevice(id=10),
 CpuDevice(id=11)]

In [3]:
import pandas as pd
from seminartools.data import read_inflation
from seminartools.models.mucsvss_model import MUCSVSSModel

# 1. Read data

In [4]:
df_inflation = read_inflation(mergeable_format=True).reset_index()
df_inflation

Unnamed: 0,country,date,inflation
0,Portugal,1970-04-01,0.000000
1,New Zealand,1970-04-01,0.015421
2,Dominican Republic,1970-04-01,-0.001604
3,Finland,1970-04-01,0.008333
4,Ireland,1970-04-01,0.029871
...,...,...,...
7415,Canada,2023-01-01,0.006080
7416,Sweden,2023-01-01,0.014955
7417,Korea,2023-01-01,0.010538
7418,United Kingdom,2023-01-01,0.008814


# 2. Model

In [9]:
from seminartools.utils import geo_distance
import jax

#jax.config.update("jax_enable_x64", True)

model = MUCSVSSModel(
    num_particles=10008, stochastic_seasonality=True
)
model

<seminartools.models.mucsvss_model.MUCSVSSModel at 0x7f64e0972650>

In [10]:
model.full_fit(df_inflation)

  1%|          | 2/212 [01:14<2:10:21, 37.25s/it]

In [None]:
%load_ext line_profiler

In [None]:
import seaborn as sns
# 2 decimals. thin column width
model.corr.style.format(precision = 2).background_gradient(cmap='coolwarm', axis=None)

In [None]:
model.stored_state_means

# 3. Evaluation

In [None]:
df_inflation["country"].unique()

In [None]:
model.stored_state_means["etau"]["United States"].plot()
df_inflation.query("country == 'United States'").set_index("date")["inflation"].plot()

# 4. H-period ahead forecasting

In [None]:
df_inflation

In [None]:
model.predict(df_inflation.query("date <= '2022-10-1'"))

In [None]:
from seminartools.models.utils import h_period_ahead_forecast

forecast = h_period_ahead_forecast(model, df_inflation, "2010-01-01", h=1)
forecast

In [None]:
model.stored_state_means.loc["United States"][["edelta1", "edelta2", "edelta3", "edelta4"]].plot()

In [None]:
import matplotlib.pyplot as plt
forecast.query("country == 'United States'").set_index("date")['inflation'].plot(label = "forecast 1 ahead")
df_inflation.query("country == 'United States' and date >= '2010-01-01'").set_index("date")['inflation'].plot(label = "actual")
model.stored_state_means.loc["United States"].loc["2010-01-01":]["etau"].plot(label = "tau")
plt.legend()
plt.title("Forecasts, tau and actual inflation for the US")

In [None]:
import seaborn as sns
sns.lineplot(
    data = model.stored_state_means["etau"].to_frame().reset_index(),
    x = "date",
    y = "etau",
    hue = "country"
)
plt.legend().remove()
plt.title("Tau over time per country")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.lineplot(
    data = forecast,
    x = "date",
    y = "inflation",
    hue = "country"
)
# turn off legend
plt.legend().remove()

In [None]:
model.stored_state_means.loc["United States"][["edelta1", "edelta2", "edelta3", "edelta4"]].plot()
plt.title("Delta over time for the US")
plt.tight_layout()
plt.savefig("../../Figures/delta_over_time.png", dpi = 300)

# 5. Compare the forecast with the actual data

In [None]:
forecast_merged = forecast.copy().rename(columns={"inflation": "pred"})
forecast_merged = forecast_merged.merge(
    df_inflation[["country", "date", "inflation"]],
    on=["country", "date"],
    how="left",
).rename(columns={"inflation": "actual"})
forecast_merged.tail(100)

In [None]:
sns.scatterplot(
    data = forecast_merged,
    x = "pred",
    y = "actual",
    hue = "country"
)
plt.legend().remove()

In [None]:
import statsmodels.api as sm

sm.OLS(
    forecast_merged["actual"], sm.add_constant(forecast_merged["pred"])
).fit().summary()

In [None]:
#getDensities
from seminartools.models.uc_sv_ss_model import UCSVSSModel
modelDistribution = UCSVSSModel(num_particles=10000, stochastic_seasonality=True)
modelDistribution.run_pf(df_inflation, aggregation_method = "distribution")


In [None]:
distributionForecast = h_period_ahead_forecast(modelDistribution, df_inflation, "2010-01-01", h=1)


In [None]:
US_inflation = distributionForecast.query("country == 'United States'").set_index("date")['inflation']

plt.plot(
    US_inflation.iloc[-1]["inflation_grid"],
    US_inflation.iloc[-1]["pdf"]
    )