In [233]:
from pathlib import Path

import polars as pl
import numpy as np
from scipy.stats import (dirichlet, poisson, beta, multinomial)
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.util as util
from numpyro.infer import MCMC, NUTS, init_to_feasible, init_to_median
import jax
from jax import random
import jax.numpy as jnp


In [234]:
jax.config.update('jax_enable_x64', True)

# Research Question: What is the distribution of vehicle model years in the target population?

Note that we include vehicles that are driven in Utah County without being registered in Utah County.  Thus, we provide added information to what is publicly available from government registration records.

## Strategy
 Use the registration counts as the concentration parameters for a Dirichlet distribution.  Use the technique [here](https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial) to use these concentration parameters as pseudocounts to be added to our observed counts.  The summed counts can then be used as the concentration parameter for the posterior Dirichlet distribution of the relative frequencies of different vehicle model years in the population.

## ETL
New vehicles are still being sold for 2024, 2025, and 2026 model years, but not for model year 2023.  The registration data that we have is for vehicles registered in 2024 all of the way up to February 17, 2025.  Thus, there may be additional registrations for the newer model year vehicles between February 2025 and March 2025 which are not in our dataset.  We can modify the registration counts for these new model years using the count for model year 2023. 

Assume that if the registration data were to go all of the way up to February 17, 2026 that there would be the same number of registrations expiring for model year 2024 as there currently is for model year 2023.  Note that the proportion of this period traversed at the time of data collection is approximately 1/12. 

Assume that if the registration data were to go all of the way up to February 17, 2027 that there would be the same number of registrations expiring for model year 2025 as there currently is for model year 2023.  Note that the proportion of this period traversed at the time of data collection is approximately 1/24. 

Assume that if the registration data were to go all of the way up to February 17, 2028 that there would be the same number of registrations expiring for model year 2026 as there currently is for model year 2023.  Note that the proportion of this period traversed at the time of data collection is approximately 1/36. 

In [99]:
source = Path("..", "raw_data", "registrations", "registrations.csv")
reg = pl.scan_csv(
    source=source
)

reg = (reg
    .with_columns(
        pl.col("num_registrations").str.replace_all(",", "").cast(pl.Int64).alias("num_registrations")
    )
    .collect()
    .lazy()
)

reg.collect().tail()

model_year,num_registrations
i64,i64
2022,30312
2023,31266
2024,27037
2025,5830
2026,8


In [100]:
reg_2023 = (reg
    .filter(pl.col("model_year") == 2023)
    .select("num_registrations")
    .collect()
    .item()
)

reg_2 = (reg
    .with_columns(
        pl.when(pl.col("model_year") > 2023)
        .then(pl.col("num_registrations") + (reg_2023 - pl.col("num_registrations")) / (12 * (pl.col("model_year") - 2023)))
        .otherwise(pl.col("num_registrations"))
        .cast(pl.Int64)
        .alias("num_registrations")
    )
)

reg_2.tail().collect()

model_year,num_registrations
i64,i64
2022,30312
2023,31266
2024,27389
2025,6889
2026,876


## Extrapolate the registration counts for pre-1913 model year vehicles
The first steam-powered vehicle dates back to 1672 ([Wikipedia](https://en.wikipedia.org/wiki/History_of_the_automobile#Steam-powered_wheeled_vehicles)).  The Utah registration data starts for model year 1913.  We assume that the number of vehicle registrations is 0 for model years not listed in the government data.  Also, we assume that the model year of a vehicle must be between 1672 and 2026.

In [None]:
yr_range = range(1672, 2027)

# Make sure we account for all of the model years.
reg_3 = pl.LazyFrame(
    data={
        "model_year": [x for x in yr_range],
    },
    schema={
        "model_year": pl.Int64,
    }
)

# Fill in nulls with 0s.
reg_4 = (reg_2
    .join(other=reg_3, on="model_year", how="right")
    .with_columns(
        pl.col("num_registrations").fill_null(0) 
    )
)

reg_4.collect()

num_registrations,model_year
i64,i64
0,1672
0,1673
0,1674
0,1675
0,1676
…,…
30312,2022
31266,2023
27389,2024
6889,2025


In [136]:
reg_4.select(pl.sum("num_registrations")).collect()

num_registrations
i64
584221


## Non-business registrations

In [180]:
def richards_curve(t, A, K, B, nu, Q, C, M):
    """https://en.wikipedia.org/wiki/Generalised_logistic_function
    """
    return A + (K - A)/(C + Q*jnp.exp(-B*(t - M)))**(1.0/nu)

In [None]:
# For model years earlier than 2005-ish, the registration
# data probably better captures what is going on in the population
# because vehicles registered by commercial entities are 
# more likely to be newer.  For the pre-2005 model years,
# most were probably registered by actual individuals
# instead of businesses.

params = {
    "t":reg_4.select("model_year").collect().to_series().to_numpy(),
    "A":1, # 1
    "K":0.75, # 0.75 to 0.90
    "B":0.4, # 0.3 to 1
    "nu":1, # 0.05 to 1
    "Q":1, # 1
    "C":1, # 1
    "M":2008 # 2004-2008
}

y = richards_curve(
    **params
)
px.scatter(
    x=params["t"],
    y=y
)

In [None]:
def non_business_regs(PRNG_key):
    params = {
        "t":reg_4.select("model_year").collect().to_series().to_numpy(),
        "A":1, 
        "K":numpyro.sample("K", dist.Uniform(0.75, 0.85), rng_key=PRNG_key), 
        "B":numpyro.sample("B", dist.Uniform(0.3, 1), rng_key=PRNG_key), 
        "nu":numpyro.sample("nu", dist.Uniform(0.05, 1), rng_key=PRNG_key), 
        "Q":1, 
        "C":1, 
        "M":numpyro.sample("M", dist.Uniform(2004, 2008), rng_key=PRNG_key)
    }

    regs = numpyro.deterministic(
        name="regs",
        value=richards_curve(
            **params
        )
    ) 
    
    return regs

In [155]:
key = random.key(7)
# Run NUTS.
kernel = NUTS(non_business_regs)
num_samples = 100
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key=key,
    PRNG_key=key
)

sample: 100%|██████████| 1100/1100 [00:02<00:00, 470.57it/s, 3 steps of size 7.37e-01. acc. prob=0.88] 


In [156]:
non_business_regs_samples = mcmc.get_samples()

In [174]:
non_business_regs_samples["regs"].dtype

dtype('float32')

In [157]:
fig = px.line()

for i in range(num_samples):
    y = non_business_regs_samples["regs"][i, :]
    fig.add_trace(
        go.Scatter(
            x=params["t"],
            y=y,
            opacity=0.05,
            mode="lines",
            line = dict(color='black')
        )
    )

fig.update_layout(
    showlegend=False,
    title=f"{num_samples} Samples from Prior Distribution for Proportion of Non-Business Registrations",
    xaxis={"title": "Model Year"},
    yaxis={"title": "Proportion of Utah-County Registrations"}
)

fig.show()

## Dirichlet
Assume that the proportion of vehicles in the population for each model year follows a Dirichlet distribution whose parameters are a function of the registration counts.



In [235]:
def get_alpha(PRNG_key):
    # reg_weight is the weight given to the registrations.
    # Smaller values mean that we trust the usefulness
    # of the registration data less. 
    reg_weight = 1
    # pre_alpha is not the real alpha
    pre_alpha = (reg_4
        .cast({"num_registrations": pl.Float64})
        # Set 0 registration counts to small real numbers less
        # than 1.
        .with_columns(
            pl.when((pl.col("num_registrations") == 0) & (pl.col("model_year") < 1913))
            .then(1.0/reg_weight * \
                np.exp(-(1913 - pl.col("model_year")) ** 0.5)
            )
            .when((pl.col("num_registrations") == 0) & (pl.col("model_year") > 1913))
            .then(pl.lit(1.0, dtype=pl.Float64))
            .otherwise(pl.col("num_registrations"))
            .alias("num_registrations")
        )
        .select(reg_weight * pl.col("num_registrations"))
        .collect()
        .to_series()
        .to_numpy()
    )
    key, subkey = random.split(PRNG_key)
    alpha = numpyro.deterministic(
        name="alpha",
        value=non_business_regs(PRNG_key=subkey) * pre_alpha
    )
    return alpha

In [236]:
key, subkey = random.split(key)
# Run NUTS.
kernel = NUTS(model=get_alpha)
num_samples = 100
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key=key,
    PRNG_key=subkey
)

sample: 100%|██████████| 1100/1100 [00:02<00:00, 418.20it/s, 7 steps of size 7.12e-01. acc. prob=0.90]


In [237]:
alpha_samples = mcmc.get_samples()

In [238]:
alpha_samples["alpha"].min()

Array(1.81107566e-07, dtype=float64)

In [239]:
px.histogram(alpha_samples["alpha"].sum(axis=1))


response variable: vector of means (proportion for each model year)
posterior parameter(s): concentration vector for Dirichlet distribution


## View Samples from Prior Predictive Distribution
The prior predictive distribution shows what we think the proportions of each model year are in the population.
The 1 July, 2024 [estimate](https://www.census.gov/quickfacts/fact/table/utahcountyutah/PST045223) of the number of residents of Utah County, Utah is 747,234.  The five-year 2019-2023 ACS estimate of the number of households in Utah County is 195,602.  According to (https://datausa.io/profile/geo/utah), the number of vehicles per household in Utah is about 2.  Note that vehicles can be registered by businesses and not just households.

In [None]:
# total_num_regs = reg_4.select(pl.col("num_registrations").sum()).collect().item()
# total_num_regs

In [None]:
# beta_rv = beta(a=30, b=5)
# go.Figure(go.Histogram(x=beta_rv.rvs(300)))

In [None]:
# estimated_prop_household_vehicles_registered = beta_rv.rvs(1)
# estimated_prop_household_vehicles_registered.item()

In [None]:
# total_num_households = 196000
# pop_mean_vehicles_per_household = poisson(mu=2)
# total_num_regs / total_num_households

In [None]:
# (total_num_regs - 2*total_num_households) / total_num_regs

In [None]:
# What is the population size?
# What are the possible numbers of vehicles per household?
# poisson_rv = poisson(mu=2)
# np.sum(poisson_rv.rvs(196000))
# mu1 =
# mu2 = 
# skellam_rv = skellam()
# binom_rv.rvs(size=3)

In [None]:
# px.histogram(x=poisson_rv.rvs(196000))

In [223]:
key, subkey_1, subkey_2 = random.split(key, 3)
alpha = get_alpha(subkey_1)

In [224]:
alpha.min()

Array(1.8110757e-07, dtype=float32)

In [212]:
util.clamp_probs(alpha)

Array([1.81107566e-07, 1.87041863e-07, 1.93183595e-07, 1.99540494e-07,
       2.06120617e-07, 2.12932321e-07, 2.19984287e-07, 2.27285582e-07,
       2.34845615e-07, 2.42674162e-07, 2.50781397e-07, 2.59177966e-07,
       2.67874839e-07, 2.76883526e-07, 2.86215965e-07, 2.95884576e-07,
       3.05902319e-07, 3.16282666e-07, 3.27039629e-07, 3.38187817e-07,
       3.49742464e-07, 3.61719430e-07, 3.74135197e-07, 3.87006992e-07,
       4.00352718e-07, 4.14191078e-07, 4.28541483e-07, 4.43424256e-07,
       4.58860541e-07, 4.74872365e-07, 4.91482751e-07, 5.08715573e-07,
       5.26595954e-07, 5.45149874e-07, 5.64404559e-07, 5.84388374e-07,
       6.05130879e-07, 6.26663052e-07, 6.49017068e-07, 6.72226577e-07,
       6.96326708e-07, 7.21354127e-07, 7.47347201e-07, 7.74345835e-07,
       8.02391753e-07, 8.31528723e-07, 8.61802221e-07, 8.93259880e-07,
       9.25951497e-07, 9.59929025e-07, 9.95246978e-07, 1.03196192e-06,
       1.07013352e-06, 1.10982387e-06, 1.15109799e-06, 1.19402375e-06,
      

In [None]:
# numpyro.distributions.Dirichlet(alpha).sample(subkey_2)

In [240]:
def model(PRNG_key):
    key, subkey_1, subkey_2 = random.split(PRNG_key, 3)
    alpha = get_alpha(subkey_1)
    prior = numpyro.sample(
        name="prior", 
        fn=numpyro.distributions.Dirichlet(alpha),
        rng_key=subkey_2
    )
    return prior

In [241]:
key, subkey = random.split(key)
# Run NUTS.
kernel = NUTS(model=model, init_strategy=init_to_median())
num_samples = 100
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)

mcmc.run(
    rng_key=key,
    PRNG_key=subkey
)

sample: 100%|██████████| 1100/1100 [00:19<00:00, 56.49it/s, 34 steps of size 1.34e-02. acc. prob=0.76] 


In [248]:
# Get samples
prior_samples = mcmc.get_samples()["prior"]

In [255]:
x = reg_4.select("model_year").collect().to_series().to_numpy()
num_samples = 3
fig = make_subplots(
    rows=num_samples, 
    shared_xaxes=True,
    x_title="Model Year",
    y_title="Relative Frequency"
)

for t in range(num_samples):
    row = t + 1
    y = prior_samples[t, :]
    # https://stackoverflow.com/questions/65910725/plotly-bar-chart-opacity-changes-with-longer-time-range
    # Plot later years
    fig.add_trace(
        go.Bar(
            x=x[x > 2000],
            y=y[x > 2000],
            orientation="v"  
        ),
        row=row,
        col=1
    )

fig.update_traces(marker_line_width = 0)

# https://stackoverflow.com/questions/56712486/how-to-hide-legend-with-plotly-express-and-plotly
fig.update_layout(
    barmode="overlay",
    bargap=0,
    showlegend=False,
    title="Samples from Prior Predictive Distribution"
)

fig.show()

In [254]:
# Plot earlier years
fig = make_subplots(
    rows=num_samples, 
    shared_xaxes=True,
    x_title="Model Year",
    y_title="Relative Frequency"
)

for t in range(num_samples):
    row = t + 1
    y = prior_samples[t, :]
    # https://stackoverflow.com/questions/65910725/plotly-bar-chart-opacity-changes-with-longer-time-range
    # Plot earlier years
    fig.add_trace(
        go.Bar(
            x=x[x <= 2000],
            y=y[x <= 2000],
            orientation="v"  
        ),
        row=row,
        col=1
    )

fig.update_traces(marker_line_width = 0)
# https://stackoverflow.com/questions/56712486/how-to-hide-legend-with-plotly-express-and-plotly
fig.update_layout(
    barmode="overlay",
    bargap=0,
    showlegend=False,
    title="Samples from Prior Predictive Distribution"
)

fig.show()

# Sensitivity Analysis for Non-sampling Error

In [None]:
np.sqrt(0.28*0.72/300)

In [None]:
N = int(1e6)
n = 300
0.007 * np.sqrt(((N - 1)/n) * (1 - n/N)) * 0.03

In [None]:
R = np.concat((np.ones((n,)), np.zeros((N - n,))))
y = np.concat((
    np.zeros((int(np.ceil((1 - 0.1) * n)), )), 
    np.ones((int(np.ceil(0.1 * n)), )), 

    np.zeros((int(np.ceil((1 - 0.25) * (N - n))), )), 
    np.ones((int(np.ceil(0.25 * (N - n))), )), 
))

In [None]:
np.corrcoef(R, y)

In [None]:
np.concat((np.ones((1,)), np.zeros((100,))))