Simple test-run of chronos2

In [11]:
import pandas as pd
from chronos import Chronos2Pipeline

In [12]:
def regularize_series(group):
    group = group.sort_values("timestamp")
    group = group.set_index("timestamp").asfreq("D")  # hourly frequency
    group["item_id"] = group["item_id"].iloc[0]      # fill item_id for missing rows
    return group.reset_index()

In [13]:
# ----------------------------
# Read data
# ----------------------------
df = pd.read_parquet("/storage/bln-aq/data/2024-citsci-pollutants-hourly.parquet")

# Ensure datetime
df['timestamp_hour'] = pd.to_datetime(df['timestamp_hour'])

# Aggregate PM2.5 per sensor/location + hour
df = df.groupby(['lat', 'lon', 'timestamp_hour'], as_index=False)['PM2_5'].mean()

# Create a unique key for each sensor location
unique_coords = df.drop_duplicates(subset=["lat", "lon"]).reset_index(drop=True)
unique_coords["loc_id"] = range(1, len(unique_coords) + 1)
df = df.merge(unique_coords, on = ["lat", "lon"], how = "left")
df = (
    df.drop(columns=[c for c in df.columns if c.endswith("_y")])
      .rename(columns={c: c[:-2] for c in df.columns if c.endswith("_x")})
)

# Save locations...
loc_dict = (
    df.groupby("loc_id")[["lat", "lon"]]
      .first()                     # take the first row per loc_id
      .apply(tuple, axis=1)        # convert to (lat, lon)
      .to_dict()                   # make into a dict
)

# Remove non-essentials
df = df.drop(columns = ["lat", "lon"])

# Rename to fit chronos2
df = df.rename(columns={
    "timestamp_hour": "timestamp",
    "loc_id": "item_id",
    "PM2_5": "target"
})

# Remove problematic sensor:
df = df[df['item_id'] != 163]


In [14]:
# Ensure timestamp is datetime
df['timestamp'] = pd.to_datetime(df['timestamp'])

# Aggregate to daily mean per sensor
daily_df = df.groupby(['item_id', pd.Grouper(key='timestamp', freq='D')])['target'].mean()
daily_df = daily_df.reset_index()

# Get all sensors
sensor_ids = daily_df['item_id'].unique()

# Regularize per sensor
regularized = []

for sensor in sensor_ids:
    sensor_df = daily_df[daily_df['item_id'] == sensor].set_index('timestamp')
    
    # Create full daily index
    full_idx = pd.date_range(sensor_df.index.min(), sensor_df.index.max(), freq='D')
    
    # Reindex to regular daily frequency
    sensor_df = sensor_df.reindex(full_idx)
    
    # Reset index and add item_id column
    sensor_df = sensor_df.rename_axis('timestamp').reset_index()
    sensor_df['item_id'] = sensor
    
    regularized.append(sensor_df)

regularized_df = pd.concat(regularized, ignore_index=True)

print(regularized_df.head(10))

df = regularized_df

   timestamp  item_id    target
0 2024-06-06        1  5.678618
1 2024-06-07        1  7.761885
2 2024-06-08        1  8.217329
3 2024-06-09        1  2.668462
4 2024-06-10        1  3.713228
5 2024-06-11        1  2.240476
6 2024-06-12        1  2.140483
7 2024-06-13        1  2.916410
8 2024-06-14        1  4.856910
9 2024-06-15        1  3.868851


In [15]:
# Define cutoff
cutoff = pd.Timestamp("2024-12-17 00:00:00")



# Split per loc_id
train_parts = []
test_parts = []

for loc, group in df.groupby("item_id"):
    group = group.sort_values("timestamp")
    train_parts.append(group[group["timestamp"] < cutoff])
    test_parts.append(group[group["timestamp"] >= cutoff])

train_df = pd.concat(train_parts).reset_index(drop=True)
test_df = pd.concat(test_parts).reset_index(drop=True)

print(f"Train: {train_df.shape}, Test: {test_df.shape}")

Train: (63369, 3), Test: (2763, 3)


In [16]:
train_df.head()


Unnamed: 0,timestamp,item_id,target
0,2024-06-06,1,5.678618
1,2024-06-07,1,7.761885
2,2024-06-08,1,8.217329
3,2024-06-09,1,2.668462
4,2024-06-10,1,3.713228


In [17]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="cuda")

In [18]:
# # Generate predictions with covariates
# pred_df = pipeline.predict_df(
#     train_df,
#     prediction_length=1,  # Number of steps to forecast
#     quantile_levels=[0.1, 0.5, 0.9],  # Quantiles for probabilistic forecast
#     id_column="item_id",  # Column identifying different time series
#     timestamp_column="timestamp",  # Column with datetime information
#     target="target",  # Column(s) with time series values to predict
# )

In [19]:
preds = pd.read_csv("/storage/bln-aq/data/aq-predictions-2024.csv")
preds['timestamp'] = pd.to_datetime(preds['timestamp'])

preds.head()

Unnamed: 0,item_id,timestamp,target_name,predictions,0.1,0.5,0.9
0,1,2024-12-17,target,2.866671,2.308423,2.866671,8.159126
1,1,2024-12-18,target,3.627684,2.030792,3.627684,11.995113
2,1,2024-12-19,target,4.33415,2.078264,4.33415,13.996067
3,1,2024-12-20,target,4.685068,2.183544,4.685068,14.757867
4,1,2024-12-21,target,5.006974,2.108205,5.006974,15.142235


In [20]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from sklearn.metrics import r2_score, root_mean_squared_error

# ----------------------------
# Load predictions and true data
# ----------------------------
preds = pd.read_csv("/storage/bln-aq/data/aq-predictions-2024.csv")
preds['timestamp'] = pd.to_datetime(preds['timestamp'])

# ----------------------------
# Compute per-sensor metrics
# ----------------------------
r2_per_sensor = {}
rmse_per_sensor = {}


for sensor, group in test_df.groupby("item_id"):
    mask = preds['item_id'] == sensor
    if mask.sum() == 0:
        continue
    
    # Align predictions with truth
    y_true = group.set_index("timestamp").reindex(preds.loc[mask,'timestamp'])['target'].values
    y_pred = preds.loc[mask, '0.5'].values
    
    # Remove NaNs
    valid_mask = ~np.isnan(y_true)
    y_true = y_true[valid_mask]
    y_pred = y_pred[valid_mask]
    
    if len(y_true) == 0:
        continue
    
    r2_per_sensor[sensor] = r2_score(y_true, y_pred)
    rmse_per_sensor[sensor] = root_mean_squared_error(y_true, y_pred)

# Agg metrics
all_true = []
all_pred = []

for sensor, group in test_df.groupby("item_id"):
    mask = preds['item_id'] == sensor
    if mask.sum() == 0:
        continue

    y_true = group.set_index("timestamp").reindex(preds.loc[mask, 'timestamp'])['target'].values
    y_pred = preds.loc[mask, '0.5'].values

    # Remove NaNs
    valid_mask = ~np.isnan(y_true)
    y_true = y_true[valid_mask]
    y_pred = y_pred[valid_mask]

    if len(y_true) == 0:
        continue

    all_true.append(y_true)
    all_pred.append(y_pred)

# Concatenate all sensors
all_true = np.concatenate(all_true)
all_pred = np.concatenate(all_pred)

agg_r2 = r2_score(all_true, all_pred)
agg_rmse = root_mean_squared_error(all_true, all_pred)

print(f"Aggregated R2: {agg_r2:.3f}, RMSE: {agg_rmse:.3f}")

# ----------------------------
# Interactive plotting per sensor
# ----------------------------
sensors = sorted(preds['item_id'].unique())

for sensor in sensors:
    mask_pred = preds['item_id'] == sensor
    mask_true = test_df['item_id'] == sensor

    pred_sensor = preds[mask_pred].sort_values("timestamp")
    true_sensor = test_df[mask_true].sort_values("timestamp")

    # Include previous 2 days for context
    start_time = pred_sensor['timestamp'].min() - pd.Timedelta(days=2)
    true_window = true_sensor[true_sensor['timestamp'] >= start_time]

    fig = go.Figure()

    # True PM2.5
    fig.add_trace(go.Scatter(
        x=true_window['timestamp'],
        y=true_window['target'],
        mode='lines+markers',
        name='True'
    ))

    # Predicted median
    fig.add_trace(go.Scatter(
        x=pred_sensor['timestamp'],
        y=pred_sensor['0.5'],
        mode='lines+markers',
        name='Predicted'
    ))

    # Quantile shading
    fig.add_trace(go.Scatter(
        x=pd.concat([pred_sensor['timestamp'], pred_sensor['timestamp'][::-1]]),
        y=pd.concat([pred_sensor['0.9'], pred_sensor['0.1'][::-1]]),
        fill='toself',
        fillcolor='rgba(0,100,80,0.2)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo='skip',
        showlegend=True,
        name='10-90% quantile'
    ))

    fig.update_layout(
        title=f"Sensor {sensor} - R2={r2_per_sensor.get(sensor,np.nan):.3f}, RMSE={rmse_per_sensor.get(sensor,np.nan):.3f}",
        xaxis_title="Timestamp",
        yaxis_title="PM2.5",
        width=1000,
        height=400
    )

    fig.show()


Aggregated R2: 0.993, RMSE: 4.168
