Autoregressive models form a simple but useful class of timeseries models. Consider a timeseries $\{x_t\}_{t=1}^N$ of $N$ observations of the Billboard ranking of a particular song. Perhaps the simplest possible model for predicting the next observed ranking would be as an affine function of the current ranking, i.e.
$$
\hat{x}_{t+1} = wx_t + b
$$.
This is an first-order autoregressive model, written AR(1). In general we may have $m$ variables so that
$$
\hat{x}_{t+1} = \sum^m_{i=1}{w_i x_{t-i}} + b
$$, 
giving us an $m$-th order autoregressive model, or AR(m).

These models use no other information other than the series itself. They serve as a useful baseline with which to compare other more complex models. For instance, Google search volume has been shown to have utility for predicting a number of things, such as unemployment levels, stock prices, auto and home sales, and even disease prevalence. One would assume that search volume would be useful for predicting Billboard rankings, and indeed it is. However, as Goel, et al. show in their paper [Predicting consumer behavior with Web search](https://www.pnas.org/doi/10.1073/pnas.1005962107), this model is outperformed by a simple autoregressive model.

Suppose we have $N$ observations $\{x_t\}$ of some time series and we want to fit an AR(1) process to this data. As it turns out, we can easily determine the parameters $w$ and $b$ simply by using trusty old least squares. Taking $x_1$ as given, then we simply solve

$$
\begin{bmatrix}
x_2 \\
\vdots \\
x_N
\end{bmatrix} = \theta\begin{bmatrix}
x_1 \\ \vdots \\ x_{N-1}
\end{bmatrix}
+
\begin{bmatrix}
b \\ \vdots \\ b
\end{bmatrix}
$$

With the formalism out of the way, let's implement this with JAX. We'll do it "ML style" and fit the model by optimizing a loss function rather than using any closed form formulas.

In [1]:
import jax
import jax.numpy as jnp
from jax import random
from collections import namedtuple
randkey = random.PRNGKey(0)
AR1Params = namedtuple('AR1Params', 'w b')

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In order to test the method, let's generate some AR(1) data and try to guess what the true parameters are.

In [2]:
w = random.normal(randkey)
b = jnp.array(1.0)
randkey, _ = random.split(randkey)
x_1 = random.randint(randkey, (), 1, 100)
true_params = AR1Params(w, b)

In [3]:
@jax.jit
def predict(params: AR1Params, x: jnp.array) -> jnp.array:
    return params.w * x + params.b
vpredict = jax.vmap(predict,(None, 0))

In [4]:
data = [x_1]
for _ in range(1_000):
    data.append(predict(true_params, data[-1]))
data = jnp.array(data)
test_data = data[-100:]
train_data = data[:-100]

We'll fit the model by optimizing the mean-squared-error loss function.

In [5]:
@jax.jit
def loss(params: AR1Params, xs: jnp.array):
    return jnp.square(xs[1:] - vpredict(params, xs)[:-1]).mean()

In [6]:
params = AR1Params(jnp.array(0.1), jnp.array(0.1))
g = jax.grad(loss)
G = g(params, train_data)
eta = 1e-2

In [7]:
loss_grad_fn = jax.grad(loss)
for i in range(10_000):
    if i % 50 == 0:
        ell = loss(params, train_data)
        print(f"loss {ell:.2}")
        print(f"w is now {params.w:.5}\tb is now {params.b:.5}")
    G = loss_grad_fn(params, train_data)
    old_params = params
    params = AR1Params(params.w - eta * G.w, params.b - eta * G.b)
    if abs(old_params.w - params.w) < 1e-6:
        break

loss 1.4
w is now 0.1	b is now 0.1
loss 0.12
w is now -0.17454	b is now 0.63699
loss 0.019
w is now -0.19348	b is now 0.85669
loss 0.003
w is now -0.20096	b is now 0.94342
loss 0.00046
w is now -0.20392	b is now 0.97767
loss 7.2e-05
w is now -0.20508	b is now 0.99118
loss 1.1e-05
w is now -0.20554	b is now 0.99652
loss 1.8e-06
w is now -0.20572	b is now 0.99863


In [8]:
print(f"true w: {true_params.w:.5}\ncomputed w: {params.w:.5}\ntrue b: {true_params.b:.5}\ncomputed b: {params.b:.5}")

true w: -0.20584
computed w: -0.20579
true b: 1.0
computed b: 0.99938


We're in the ballpark! As a sanity check, let's use the `statsmodels` library to fit a first-order autoregressive model on the same data.

In [9]:
from statsmodels.tsa.ar_model import AutoReg
mod = AutoReg(list(train_data), 1)
res = mod.fit()
print(res.summary())

                            AutoReg Model Results                             
Dep. Variable:                      y   No. Observations:                  901
Model:                     AutoReg(1)   Log Likelihood               15292.762
Method:               Conditional MLE   S.D. of innovations              0.000
Date:                Sat, 12 Aug 2023   AIC                         -30579.523
Time:                        13:32:54   BIC                         -30565.116
Sample:                             1   HQIC                        -30574.020
                                  901                                         
                 coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
const          1.0000    3.5e-10   2.86e+09      0.000       1.000       1.000
y.L1          -0.2058   1.03e-10     -2e+09      0.000      -0.206      -0.206
                                    Roots           

The `y.L1` field shows the same result, as hoped.