# Sandbox for Support Distribution Aggregation #41

In [14]:
import polars as pl
import pandas as pd
import numpy as np
from functools import reduce

import importlib

import AggregationManager with force reload

In [47]:
import aggregation
importlib.reload(aggregation)
from aggregation import AggregationManager

Test of basic pipeline:

In [58]:
manager = AggregationManager(
    index_cols=['time', 'entity_id'],
    target_cols=['conflict_prob', 'death_count'],
    weights=[0.3, 0.8]
)

pdf = pl.DataFrame({
    "time": [1, 1],
    "entity_id": [1, 2],
    "conflict_prob": pl.Series("conflict_prob", [[0.1, 0.2], [0.3, 0.4]], dtype=pl.List(pl.Float64)),
    "death_count": pl.Series("death_count", [[0.5, 0.6], [0.7, 0.8]], dtype=pl.List(pl.Float64))
})

pdf_pd = pd.DataFrame({
    "time": [1, 1],
    "entity_id": [1, 2],
    "conflict_prob": [[0.1, 0.2], [0.3, 0.4]],
    "death_count": [[0.5, 0.6], [0.7, 0.8]]
})

manager.add_model(pdf)
manager.add_model(pdf_pd)

print(manager.models)

ensemble_distributions = manager.aggregate_distributions("weighted", n_samples=100)

print(ensemble_distributions)

[shape: (2, 4)
┌──────┬───────────┬───────────────┬─────────────┐
│ time ┆ entity_id ┆ conflict_prob ┆ death_count │
│ ---  ┆ ---       ┆ ---           ┆ ---         │
│ i64  ┆ i64       ┆ list[f64]     ┆ list[f64]   │
╞══════╪═══════════╪═══════════════╪═════════════╡
│ 1    ┆ 1         ┆ [0.1, 0.2]    ┆ [0.5, 0.6]  │
│ 1    ┆ 2         ┆ [0.3, 0.4]    ┆ [0.7, 0.8]  │
└──────┴───────────┴───────────────┴─────────────┘, shape: (2, 4)
┌──────┬───────────┬───────────────┬─────────────┐
│ time ┆ entity_id ┆ conflict_prob ┆ death_count │
│ ---  ┆ ---       ┆ ---           ┆ ---         │
│ i64  ┆ i64       ┆ list[f64]     ┆ list[f64]   │
╞══════╪═══════════╪═══════════════╪═════════════╡
│ 1    ┆ 1         ┆ [0.1, 0.2]    ┆ [0.5, 0.6]  │
│ 1    ┆ 2         ┆ [0.3, 0.4]    ┆ [0.7, 0.8]  │
└──────┴───────────┴───────────────┴─────────────┘]
shape: (2, 4)
┌──────┬───────────┬───────────────────┬───────────────────┐
│ time ┆ entity_id ┆ conflict_prob     ┆ death_count       │
│ ---  ┆ ---     

## Other tests

Test with path to parquet file of actual prediction output:

In [56]:
test_parquet_path = "data/predictions_forecasting_20250807.parquet"
test_parquet = pl.read_parquet(test_parquet_path)

pq_manager = AggregationManager(
    index_cols=['month_id', 'country_id'],
    target_cols=['pred_ln_ged_sb_dep']
)

pq_manager.add_model(test_parquet)

pq_manager.models

[shape: (6_876, 3)
 ┌──────────┬────────────┬────────────────────┐
 │ month_id ┆ country_id ┆ pred_ln_ged_sb_dep │
 │ ---      ┆ ---        ┆ ---                │
 │ i64      ┆ i64        ┆ list[f64]          │
 ╞══════════╪════════════╪════════════════════╡
 │ 547      ┆ 1          ┆ [0.000706]         │
 │ 547      ┆ 2          ┆ [0.00158]          │
 │ 547      ┆ 3          ┆ [0.072077]         │
 │ 547      ┆ 4          ┆ [0.011363]         │
 │ 547      ┆ 5          ┆ [0.001575]         │
 │ …        ┆ …          ┆ …                  │
 │ 582      ┆ 242        ┆ [0.327009]         │
 │ 582      ┆ 243        ┆ [0.229127]         │
 │ 582      ┆ 244        ┆ [0.16187]          │
 │ 582      ┆ 245        ┆ [4.861599]         │
 │ 582      ┆ 246        ┆ [2.111845]         │
 └──────────┴────────────┴────────────────────┘]

Test with invalid dataframe:

In [57]:
pdf_test = pd.DataFrame({
    "time": [1, 1],
    "entity_id": ["dsf", "sdf"],
    "conflict_prob": [[0.5, 0.6], [0.7, 0.8]],
    "death_count": [[0.5, 0.6], [0.7, 0.8]]
})

manager.add_model(pdf_test)

TypeError: Index column 'entity_id' must be integer, got String

Test with invalid aggregation method

In [34]:
manager.aggregate_point_predictions("hello")

ValueError: Unsupported aggregation function: "hello", must be one of 'mean', 'median', 'min', 'max' or custum aggregation function of form Callable[[pl.Series], float]