In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc as pm
import pymc.sampling_jax
import bambi as bmb
import scipy.stats as stats
from scipy.stats import gaussian_kde
from sklearn.preprocessing import scale
import arviz as az
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%config InlineBackend.figure_format = 'retina'
RANDOM_SEED = 42
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

In [3]:
# Initialize a random number generator for reproducibility
rng = np.random.default_rng(seed=42)

# Generate date range
dates = pd.date_range(start="2020-05-01", end="2020-07-01")
n_dates = len(dates)

# Cities and their average temperatures
cities = ["Berlin", "Paris", "Rome"]
avg_temps = [15, 17, 20]
n_cities = len(cities)

# Pre-allocate a NumPy array to hold the temperature data
temperature_array = np.zeros(n_dates * n_cities)

# Populate the temperature array
for i, avg_temp in enumerate(avg_temps):
    temperature_array[i * n_dates : (i + 1) * n_dates] = rng.normal(
        loc=avg_temp, scale=2, size=n_dates
    )

# Create date and city columns, repeating as necessary
date_column = np.tile(dates, n_cities)
city_column = np.repeat(cities, n_dates)

# Create a DataFrame in long format
df_data_long = pd.DataFrame(
    {"date": date_column, "city": city_column, "temperature": temperature_array}
)

# Show first few rows
print(df_data_long.head())

        date    city  temperature
0 2020-05-01  Berlin    15.609434
1 2020-05-02  Berlin    12.920032
2 2020-05-03  Berlin    16.500902
3 2020-05-04  Berlin    16.881129
4 2020-05-05  Berlin    11.097930


In [4]:
# The "coordinates" are the unique values that these dimensions can take
coords = {
    "date": pd.Categorical(df_data_long["date"]).categories,
    "city": pd.Categorical(df_data_long["city"]).categories,
}
print(coords)

{'date': DatetimeIndex(['2020-05-01', '2020-05-02', '2020-05-03', '2020-05-04',
               '2020-05-05', '2020-05-06', '2020-05-07', '2020-05-08',
               '2020-05-09', '2020-05-10', '2020-05-11', '2020-05-12',
               '2020-05-13', '2020-05-14', '2020-05-15', '2020-05-16',
               '2020-05-17', '2020-05-18', '2020-05-19', '2020-05-20',
               '2020-05-21', '2020-05-22', '2020-05-23', '2020-05-24',
               '2020-05-25', '2020-05-26', '2020-05-27', '2020-05-28',
               '2020-05-29', '2020-05-30', '2020-05-31', '2020-06-01',
               '2020-06-02', '2020-06-03', '2020-06-04', '2020-06-05',
               '2020-06-06', '2020-06-07', '2020-06-08', '2020-06-09',
               '2020-06-10', '2020-06-11', '2020-06-12', '2020-06-13',
               '2020-06-14', '2020-06-15', '2020-06-16', '2020-06-17',
               '2020-06-18', '2020-06-19', '2020-06-20', '2020-06-21',
               '2020-06-22', '2020-06-23', '2020-06-24', '2020-06-25

In [5]:
RANDOM_SEED = 123  

with pm.Model(coords=coords) as model:
    # Constant Data
    data = pm.Data("observed_temp", df_data_long["temperature"], dims=("obs_id"))
    
    # Coordinates for the observed data
    date_idx = pd.Categorical(df_data_long["date"]).codes
    city_idx = pd.Categorical(df_data_long["city"]).codes
    
    # Priors
    europe_mean = pm.Normal("europe_mean_temp", mu=15.0, sigma=3.0)
    city_offset = pm.Normal("city_offset", mu=0.0, sigma=3.0, dims="city")
    
    # Expected city temperature
    city_temperature = pm.Deterministic(
        "expected_city_temp", europe_mean + city_offset[city_idx], dims="obs_id"
    )
    
    # Model Error
    sigma = pm.Exponential("sigma", 1)
    
    # Likelihood
    pm.Normal("temperature", mu=city_temperature, sigma=sigma, observed=data, dims=("obs_id"))
    
    # Sampling
    idata = pm.sampling_jax.sample_numpyro_nuts()




Compiling...


I0000 00:00:1699560306.839348       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


Compilation time = 0:00:05.465431


Sampling...


  0%|                                                                                                | 0/2000 [00:00<?, ?it/s]

Compiling.. :   0%|                                                                                  | 0/2000 [00:00<?, ?it/s]




  0%|                                                                                                | 0/2000 [00:00<?, ?it/s]

[A




Compiling.. :   0%|                                                                                  | 0/2000 [00:00<?, ?it/s]

[A





  0%|                                                                                                | 0/2000 [00:00<?, ?it/s]

[A[A





Compiling.. :   0%|                                                                                  | 0/2000 [00:00<?, ?it/s]

[A[A






  0%|                                                                                                | 0/2000 [00:00<?, ?it/s]

[A[A[A






Compiling.. :   0%|                                                                                  | 0/2000 [00:00<?, ?it/s]

[A[A[A





Running chain 2:   0%|                                                                               | 0/2000 [00:01<?, ?it/s]

[A[A






Running chain 3:   0%|                                                                               | 0/2000 [00:01<?, ?it/s]

[A[A[A




Running chain 1:   0%|                                                                               | 0/2000 [00:01<?, ?it/s]

[A

Running chain 0:   0%|                                                                               | 0/2000 [00:01<?, ?it/s]






Running chain 3:  95%|█████████████████████████████████████████████████████████████▊   | 1900/2000 [00:01<00:00, 18716.26it/s]

[A[A[A

Running chain 0:  90%|██████████████████████████████████████████████████████████▌      | 1800/2000 [00:01<00:00, 17708.57it/s]





Running chain 2:  90%|██████████████████████████████████████████████████████████▌      | 1800/2000 [00:01<00:00, 17516.94it/s]

[A[A




Running chain 1:  90%|██████████████████████████████████████████████████████████▌      | 1800/2000 [00:01<00:00, 17575.95it/s]

[A

Running chain 0: 100%|██████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1263.32it/s]


Running chain 1: 100%|██████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1264.48it/s]


Running chain 2: 100%|██████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1265.51it/s]


Running chain 3: 100%|██████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1266.43it/s]


Sampling time = 0:00:01.917180


Transforming variables...


Transformation time = 0:00:00.117639


In [6]:
az.summary(idata)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
europe_mean_temp,16.687,1.536,13.853,19.556,0.065,0.046,555.0,681.0,1.00
city_offset[Berlin],-1.625,1.539,-4.549,1.206,0.065,0.046,558.0,707.0,1.00
city_offset[Paris],0.044,1.543,-2.755,2.936,0.065,0.046,564.0,679.0,1.00
city_offset[Rome],3.309,1.544,0.497,6.220,0.066,0.046,560.0,686.0,1.00
sigma,1.744,0.091,1.578,1.918,0.003,0.002,1155.0,1132.0,1.01
...,...,...,...,...,...,...,...,...,...
expected_city_temp[181],19.997,0.223,19.580,20.423,0.003,0.002,4687.0,2695.0,1.00
expected_city_temp[182],19.997,0.223,19.580,20.423,0.003,0.002,4687.0,2695.0,1.00
expected_city_temp[183],19.997,0.223,19.580,20.423,0.003,0.002,4687.0,2695.0,1.00
expected_city_temp[184],19.997,0.223,19.580,20.423,0.003,0.002,4687.0,2695.0,1.00
