In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import geopandas as gpd
from src.constants import EXTERNAL_PATH

def load_ukraine_admin_polygons(adm_level=4):
    assert adm_level in [1, 2, 3, 4]
    ukraine_regions_path = sorted((EXTERNAL_PATH / 'UKR_admin_boundaries').glob(f'*_adm{adm_level}*.shp'))[0]
    columns = [f'ADM{i}_EN' for i in range(1, adm_level+1)] + ['geometry']
    ukr_regions = gpd.read_file(ukraine_regions_path)[columns]
    ukr_regions.index.name = 'region_id'
    ukr_regions.reset_index(inplace=True)
    return ukr_regions

# Prediction per settlements

In [None]:
import numpy as np
from src.postprocessing.preds_buildings import vectorize_xarray_with_gdf
from src.postprocessing.utils import read_fp_within_geo
from src.data.settlements import MSFT_SETTLEMENTS_PATH

def aggregate_preds_settlement(settlement_id, gds, fp_preds):

    # if settlement_id % 50 == 0:
    #     print(f'Processing settlement {settlement_id}')

    # Load precomputed building footprints for the settlement
    gdf_buildings = gpd.read_file(MSFT_SETTLEMENTS_PATH / f'{settlement_id}.geojson')
    if gdf_buildings.empty:
        return None

    # Load prediction
    preds = read_fp_within_geo(fp_preds, gds.geometry)

    # Vectorize predictions (weighted mean and max)
    preds_vectorized = vectorize_xarray_with_gdf(preds, gdf_buildings, name_id="building_id", verbose=0)
    gdf_buildings_with_preds = gdf_buildings.merge(preds_vectorized, on="building_id")

    # Count number of buildings damaged above a certain threshold
    threhsolds = np.arange(0.5,1,0.05)
    d = {}
    for t in threhsolds:
        d[f'count_mean_{t:.2f}'] = (gdf_buildings_with_preds['weighted_mean'] > 255*t).sum()
        d[f'count_max_{t:.2f}'] = (gdf_buildings_with_preds['max'] > 255*t).sum()
    d['n_buildings'] = gdf_buildings_with_preds.shape[0]
    d['settlement_id'] = settlement_id
    d['geometry'] = gds.geometry

    # Keep track of administrative names
    for k, v in gds.items():
        if k.startswith('ADM'):
            d[k] = v
    return d

In [None]:
import multiprocessing as mp
from src.constants import PREDS_PATH
from src.data.settlements import load_gdf_settlements

def aggregate_preds_all_settlements(run_name, oblasts=None):

    folder_preds = PREDS_PATH / run_name
    fp_preds = folder_preds / "ukraine_padded.tif"

    gdf_settlements = load_gdf_settlements()
    if oblasts is not None:
        gdf_settlements = gdf_settlements[gdf_settlements['ADM1_EN'].isin(oblasts)]

    print(f'Processing {len(gdf_settlements)} settlements...')

    args = [(id_, row, fp_preds) for id_, row in gdf_settlements.iterrows()]

    with mp.Pool(mp.cpu_count()) as pool:
        results = pool.starmap(aggregate_preds_settlement, args)
    results = [r for r in results if r is not None] # remove settlements without buildings
    return gpd.GeoDataFrame(results, crs=gdf_settlements.crs)

In [None]:
run_name = '240224/2022-02-24_2023-02-23'
assert (PREDS_PATH / run_name).exists()

folder = PREDS_PATH / run_name / 'oblasts_with_preds_agg'
folder.mkdir(exist_ok=True, parents=True)

adm1 = load_ukraine_admin_polygons(adm_level=1)
for i, o in enumerate(adm1.ADM1_EN.unique()):

    fp = folder / f"preds_agg_{o}.geojson"
    if fp.exists():
        print(f'Skipping {o}...')
        continue
    print(f'Processing {o} ({i+1}/{len(adm1.ADM1_EN.unique())})...')
    gdf = aggregate_preds_all_settlements(run_name=run_name, oblasts=[o])
    gdf.to_file(folder / f"preds_agg_{o}.geojson", driver='GeoJSON')
    print(f'Saved {gdf.shape[0]} settlements for {o}')

In [None]:
import pandas as pd
run_name = '240224/2022-02-24_2023-02-23'
folder = PREDS_PATH / run_name / 'oblasts_with_preds_agg'
adm1 = load_ukraine_admin_polygons(adm_level=1)
gdf = pd.concat([gpd.read_file(folder / f"preds_agg_{o}.geojson") for o in adm1.ADM1_EN.unique()])
gdf

In [None]:
def aggregate_preds_region(adm_level, threshold=0.5, agg='mean'):

    # Load admin regions
    ukr_regions = load_ukraine_admin_polygons(adm_level)

    # Keep column of interest
    c_to_keep = [c for c in gdf.columns if not c.startswith('count')]
    c_to_keep.append(f'count_{agg}_{threshold:.2f}')
    gdf_ = gdf[c_to_keep].copy()

    # merge with admin regions
    adm_cols = [f'ADM{i}_EN' for i in range(1, adm_level+1)]
    gdf_agg_regions = gdf_.groupby(adm_cols)[[c for c in gdf_.columns if c.startswith(('count','n_buildings'))]].sum().reset_index().merge(ukr_regions, on=adm_cols, how='right')
    gdf_agg_regions = gpd.GeoDataFrame(gdf_agg_regions, crs=ukr_regions.crs)

    # relative values
    gdf_agg_regions[f'count_{agg}_{threshold:.2f}_relative'] = gdf_agg_regions[f'count_{agg}_{threshold:.2f}'] / gdf_agg_regions['n_buildings']
    gdf_agg_regions.fillna(0, inplace=True)

    return gdf_agg_regions

In [None]:
def explore_preds(adm_level, threshold=0.5, agg='mean', relative=True, show_zero=False):

    gdf_agg_regions = aggregate_preds_region(adm_level, threshold=threshold, agg=agg)

    if not show_zero:
        gdf_agg_regions = gdf_agg_regions[gdf_agg_regions[f'count_{agg}_{threshold:.2f}'] > 0]

    col = f'count_{agg}_{threshold:.2f}'
    col_ = col + '_relative'
    col_to_plot = col_ if relative else col
    m = gdf_agg_regions[[col, col_, f'ADM{adm_level}_EN', 'geometry']].explore(col_to_plot,cmap='YlOrRd',  vmin=0, tiles='Esri.WorldGrayCanvas') # tiles='CartoDB.DarkMatterNoLabels' ,vmax=d_vmax[adm_level])
    return m

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output


# New wrapper function to manage the interactive exploration and display the map
def interact_explore_preds():
    adm_level_widget = widgets.Dropdown(options=[1, 2, 3, 4], value=1, description='Admin Level:')
    threshold_widget = widgets.Dropdown(options=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], value=0.65, description='Threshold:')
    agg_widget = widgets.Dropdown(options=['mean', 'max'], value='mean', description='Aggregation:')
    relative_widget = widgets.Checkbox(value=True, description='Relative')
    show_zero_widgets = widgets.Checkbox(value=False, description='Show Zero Values')

    output = widgets.Output()

    def update(*args):
        # Use the output widget to manage the display of the map
        with output:
            clear_output(wait=True)  # Clear the previous map/output
            m = explore_preds(adm_level=adm_level_widget.value, threshold=threshold_widget.value, agg=agg_widget.value, relative=relative_widget.value, show_zero=show_zero_widgets.value)
            display(m)  # Display the new map

    # Button to trigger the update
    button = widgets.Button(description="Update Map")
    button.on_click(lambda b: update())

    ui = widgets.VBox([adm_level_widget, threshold_widget, agg_widget, relative_widget, show_zero_widgets, button, output])
    display(ui)

# Call the interactive wrapper function
interact_explore_preds()



In [None]:
from tqdm import tqdm
n_buildings_damaged = []
thresholds = np.arange(0.5,1,0.05)
for t in tqdm(thresholds):
    gdf_preds = aggregate_preds_region(4, t, 'mean')
    n_buildings_damaged.append(gdf_preds[f'count_mean_{t:.2f}'].sum())
n_buildings = gdf_preds['n_buildings'].sum()

In [None]:
n_buildings

In [None]:
n_buildings22 = n_buildings_damaged

In [None]:
n_buildings21 = n_buildings_damaged

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 10))
ax.bar(thresholds, n_buildings_damaged, width=0.04)
ax.set_xlabel('Threshold')
ax.set_ylabel('Number of buildings damaged over entire coutnry')
ax.set_title(f'Total number of buildings damaged for different thresholds (total: {n_buildings:.2e}) - 2022')
# add values
for t, n in zip(thresholds, n_buildings_damaged):
    ax.text(t,n, f'{100*n/n_buildings:.2f}%', ha='center', va='bottom')
plt.show()

## Building predictions

In [None]:
from src.data.settlements import load_gdf_settlements
import geopandas as gpd
from src.postprocessing.utils import read_fp_within_geo
from src.postprocessing.preds_buildings import vectorize_xarray_with_gdf
from src.data.settlements import MSFT_SETTLEMENTS_PATH
from src.constants import PREDS_PATH


def create_buildings_with_preds(settlement_id, row, dates):

    try:
        geo = row.geometry
        gdf_buildings = gpd.read_file(MSFT_SETTLEMENTS_PATH / f'{settlement_id}.geojson')
        if gdf_buildings.empty:
            return None
        gdf_buildings_with_preds = gdf_buildings.copy()

        for date_ in dates:
            fp_preds =  PREDS_PATH / run_name / date_ / "ukraine_padded.tif"
            preds = read_fp_within_geo(fp_preds, geo)
            preds_vectorized = vectorize_xarray_with_gdf(preds, gdf_buildings, name_id="building_id", verbose=0)
            d_rename = {c: c+'_' + date_ for c in preds_vectorized.columns if c not in ['building_id']}
            preds_vectorized.rename(columns=d_rename, inplace=True)
            gdf_buildings_with_preds = gdf_buildings_with_preds.merge(preds_vectorized, on="building_id")

        for k, v in row.items():
            if k.startswith('ADM'):
                gdf_buildings_with_preds[k] = v
        gdf_buildings_with_preds['settlement_id'] = settlement_id

        folder_to_save = PREDS_PATH / run_name / 'buildings_with_preds'
        folder_to_save.mkdir(exist_ok=True, parents=True)
        gdf_buildings_with_preds.to_file(folder_to_save / f"{settlement_id}.geojson", driver='GeoJSON')
    except Exception as e:
        print(f'Error processing {settlement_id}: {e}')

In [None]:
import multiprocessing as mp

run_name = '240224'
dates = ['2021-02-24_2022-02-23', '2022-02-24_2023-02-23']
gdf_settlements = load_gdf_settlements()

args = [(id_, row, dates) for id_, row in gdf_settlements.iterrows()]
print(len(args))

folder_to_save = PREDS_PATH / run_name / 'buildings_with_preds'
args = [a for a in args if a[0] in bad_settlements]
print(len(args))
with mp.Pool(mp.cpu_count()) as pool:
    pool.starmap(create_buildings_with_preds, args)


In [None]:
import multiprocessing as mp
from src.constants import PREDS_PATH
from src.data.settlements import load_gdf_settlements

import numpy as np
from src.postprocessing.preds_buildings import vectorize_xarray_with_gdf
from src.postprocessing.utils import read_fp_within_geo


def aggregate_precomputed_preds_settlement(settlement_id, gds, run_name, dates = ['2021-02-24_2022-02-23', '2022-02-24_2023-02-23']):

    # Load precomputed building footprints for the settlement
    folder = PREDS_PATH / run_name / "buildings_with_preds"
    fp = folder / f"{settlement_id}.geojson"
    if not fp.exists():
        return None
    gdf_buildings_with_preds = gpd.read_file(fp)
    if gdf_buildings_with_preds.empty:
        return None

    # Count number of buildings damaged above a certain threshold
    thresholds = np.arange(0.5, 1, 0.05)
    d = {}
    for dates_ in dates:
        for t in thresholds:
            d[f"count_mean_{t:.2f}_{dates_}"] = (gdf_buildings_with_preds[f"weighted_mean_{dates_}"] > 255 * t).sum()
            d[f"count_max_{t:.2f}_{dates_}"] = (gdf_buildings_with_preds[f"max_{dates_}"] > 255 * t).sum()

    for t in thresholds:
        d[f"count_mean_{t:.2f}"] =  gdf_buildings_with_preds[
            (gdf_buildings_with_preds[f"weighted_mean_{dates[1]}"] > 255 * t)
            & (gdf_buildings_with_preds[f"weighted_mean_{dates[0]}"] < 255 * t)
        ].shape[0]
        d[f"count_max_{t:.2f}"] = gdf_buildings_with_preds[
            (gdf_buildings_with_preds[f"max_{dates[1]}"] > 255 * t)
            & (gdf_buildings_with_preds[f"max_{dates[0]}"] < 255 * t)
        ].shape[0]
    d["n_buildings"] = gdf_buildings_with_preds.shape[0]
    d["settlement_id"] = settlement_id
    d["geometry"] = gds.geometry

    # Keep track of administrative names
    for k, v in gds.items():
        if k.startswith("ADM"):
            d[k] = v
    return d


def aggregate_preds_all_settlements(run_name, dates= ['2021-02-24_2022-02-23', '2022-02-24_2023-02-23'], oblasts=None):

    gdf_settlements = load_gdf_settlements()
    if oblasts is not None:
        gdf_settlements = gdf_settlements[gdf_settlements['ADM1_EN'].isin(oblasts)]

    print(f'Processing {len(gdf_settlements)} settlements...')

    args = [(id_, row, run_name, dates) for id_, row in gdf_settlements.iterrows()]

    with mp.Pool(mp.cpu_count()) as pool:
        results = pool.starmap(aggregate_precomputed_preds_settlement, args)
    results = [r for r in results if r is not None] # remove settlements without buildings
    return gpd.GeoDataFrame(results, crs=gdf_settlements.crs)

run_name = "240224"
assert (PREDS_PATH / run_name).exists()

folder = PREDS_PATH / run_name / "oblasts_with_preds_agg"
folder.mkdir(exist_ok=True, parents=True)

adm1 = load_ukraine_admin_polygons(adm_level=1)
for i, o in enumerate(adm1.ADM1_EN.unique()):

    fp = folder / f"preds_agg_{o}.geojson"
    if fp.exists():
        print(f"Skipping {o}...")
        continue
    print(f"Processing {o} ({i+1}/{len(adm1.ADM1_EN.unique())})...")
    gdf = aggregate_preds_all_settlements(run_name=run_name, oblasts=[o])
    gdf.to_file(folder / f"preds_agg_{o}.geojson", driver="GeoJSON")
    print(f"Saved {gdf.shape[0]} settlements for {o}")

In [None]:
settlement_id = 1
folder = PREDS_PATH / run_name / "oblasts_with_preds"
fp = folder / f"{settlement_id}.geojson"
gdf_buildings_with_preds = gpd.read_file(fp)
gdf_buildings_with_preds

In [None]:
import pandas as pd

run_name = "240224"
folder = PREDS_PATH / run_name / "oblasts_with_preds_agg"
adm1 = load_ukraine_admin_polygons(adm_level=1)
gdf = pd.concat([gpd.read_file(folder / f"preds_agg_{o}.geojson") for o in adm1.ADM1_EN.unique()])
gdf.shape

In [None]:
from tqdm import tqdm

n_buildings_damaged = []
thresholds = np.arange(0.5, 1, 0.05)
for t in tqdm(thresholds):
    gdf_preds = aggregate_preds_region(4, t, "mean")
    n_buildings_damaged.append(gdf_preds[f"count_mean_{t:.2f}"].sum())
n_buildings = gdf_preds["n_buildings"].sum()

In [None]:
n_buildings

In [None]:
n_buildings = 12106657.0

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 6))
width = 0.01
x_adjustments = [-width, 0, width]
for i, dmgs in enumerate([n_buildings22, n_buildings21, n_buildings_damaged]):
    label = f"202{i+1}" if i < 2 else "2022 - 2021 (building-wise)"
    ax.bar(thresholds + x_adjustments[i], dmgs, width=width, label=label)

    # add values
    if i != 1:
        for t, n in zip(thresholds, dmgs):
            if t == 0.5 or round(t,2) == 0.65:
                ax.text(t + x_adjustments[i], n, f"{100*n/n_buildings:.2f}%", ha="center", va="bottom")
ax.set_xlabel("Threshold")
ax.set_ylabel("Number of buildings damaged over entire coutnry")
ax.set_title(f"Total number of buildings damaged (total: {n_buildings:.2e}) - (2022 - 2021)")

ax.legend()
plt.show()

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output


def aggregate_preds_region(adm_level, threshold=0.5, agg="mean"):

    # Load admin regions
    ukr_regions = load_ukraine_admin_polygons(adm_level)

    # Keep column of interest
    c_to_keep = [c for c in gdf.columns if not c.startswith("count")]
    c_to_keep.append(f"count_{agg}_{threshold:.2f}")
    gdf_ = gdf[c_to_keep].copy()

    # merge with admin regions
    adm_cols = [f"ADM{i}_EN" for i in range(1, adm_level + 1)]
    gdf_agg_regions = (
        gdf_.groupby(adm_cols)[[c for c in gdf_.columns if c.startswith(("count", "n_buildings"))]]
        .sum()
        .reset_index()
        .merge(ukr_regions, on=adm_cols, how="right")
    )
    gdf_agg_regions = gpd.GeoDataFrame(gdf_agg_regions, crs=ukr_regions.crs)

    # relative values
    gdf_agg_regions[f"count_{agg}_{threshold:.2f}_relative"] = (
        gdf_agg_regions[f"count_{agg}_{threshold:.2f}"] / gdf_agg_regions["n_buildings"]
    )
    gdf_agg_regions.fillna(0, inplace=True)

    return gdf_agg_regions


def explore_preds(adm_level, threshold=0.5, agg="mean", relative=True, show_zero=False):

    gdf_agg_regions = aggregate_preds_region(adm_level, threshold=threshold, agg=agg)

    if not show_zero:
        gdf_agg_regions = gdf_agg_regions[gdf_agg_regions[f"count_{agg}_{threshold:.2f}"] > 0]

    col = f"count_{agg}_{threshold:.2f}"
    col_ = col + "_relative"
    col_to_plot = col_ if relative else col
    m = gdf_agg_regions[[col, col_, f"ADM{adm_level}_EN", "geometry"]].explore(
        col_to_plot, cmap="YlOrRd", vmin=0, tiles="Esri.WorldGrayCanvas"
    )  # tiles='CartoDB.DarkMatterNoLabels' ,vmax=d_vmax[adm_level])
    return m


# New wrapper function to manage the interactive exploration and display the map
def interact_explore_preds():
    adm_level_widget = widgets.Dropdown(options=[1, 2, 3, 4], value=1, description='Admin Level:')
    threshold_widget = widgets.Dropdown(options=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], value=0.65, description='Threshold:')
    agg_widget = widgets.Dropdown(options=['mean', 'max'], value='mean', description='Aggregation:')
    relative_widget = widgets.Checkbox(value=True, description='Relative')
    show_zero_widgets = widgets.Checkbox(value=False, description='Show Zero Values')

    output = widgets.Output()

    def update(*args):
        # Use the output widget to manage the display of the map
        with output:
            clear_output(wait=True)  # Clear the previous map/output
            m = explore_preds(adm_level=adm_level_widget.value, threshold=threshold_widget.value, agg=agg_widget.value, relative=relative_widget.value, show_zero=show_zero_widgets.value)
            display(m)  # Display the new map

    # Button to trigger the update
    button = widgets.Button(description="Update Map")
    button.on_click(lambda b: update())

    ui = widgets.VBox([adm_level_widget, threshold_widget, agg_widget, relative_widget, show_zero_widgets, button, output])
    display(ui)

# Call the interactive wrapper function
interact_explore_preds()

In [None]:
import pandas as pd


adm1 = load_ukraine_admin_polygons(adm_level=1)

run_name = "240224/2022-02-24_2023-02-23"
gdf22 = pd.concat([gpd.read_file(PREDS_PATH / run_name / "oblasts_with_preds_agg" / f"preds_agg_{o}.geojson") for o in adm1.ADM1_EN.unique()])

run_name = "240224/2021-02-24_2022-02-23"
gdf21 = pd.concat([gpd.read_file(PREDS_PATH / run_name / "oblasts_with_preds_agg" / f"preds_agg_{o}.geojson") for o in adm1.ADM1_EN.unique()])

run_name = "240224"
gdf22_21 = pd.concat([gpd.read_file(PREDS_PATH / run_name / "oblasts_with_preds_agg" / f"preds_agg_{o}.geojson") for o in adm1.ADM1_EN.unique()])

In [None]:
gdf22.shape, gdf21.shape, gdf22_21.shape

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output


def aggregate_preds_region(adm_level, threshold=0.5, agg="mean", year=2022):

    # Load admin regions
    ukr_regions = load_ukraine_admin_polygons(adm_level)

    # Choose gdf
    if year == 2022:
        gdf = gdf22
    elif year == 2021:
        gdf = gdf21
    else:
        gdf = gdf22_21

    # Keep column of interest
    c_to_keep = [c for c in gdf.columns if not c.startswith("count")]
    c_to_keep.append(f"count_{agg}_{threshold:.2f}")
    gdf_ = gdf[c_to_keep].copy()

    # merge with admin regions
    adm_cols = [f"ADM{i}_EN" for i in range(1, adm_level + 1)]
    gdf_agg_regions = (
        gdf_.groupby(adm_cols)[[c for c in gdf_.columns if c.startswith(("count", "n_buildings"))]]
        .sum()
        .reset_index()
        .merge(ukr_regions, on=adm_cols, how="right")
    )
    gdf_agg_regions = gpd.GeoDataFrame(gdf_agg_regions, crs=ukr_regions.crs)

    # relative values
    gdf_agg_regions[f"count_{agg}_{threshold:.2f}_relative"] = (
        gdf_agg_regions[f"count_{agg}_{threshold:.2f}"] / gdf_agg_regions["n_buildings"]
    )
    gdf_agg_regions.fillna(0, inplace=True)

    return gdf_agg_regions


def explore_preds(adm_level, threshold=0.5, agg="mean", year=2022, relative=True, show_zero=False):

    gdf_agg_regions = aggregate_preds_region(adm_level, threshold=threshold, agg=agg, year=year)

    if not show_zero:
        gdf_agg_regions = gdf_agg_regions[gdf_agg_regions[f"count_{agg}_{threshold:.2f}"] > 0]

    col = f"count_{agg}_{threshold:.2f}"
    col_ = col + "_relative"
    col_to_plot = col_ if relative else col
    m = gdf_agg_regions[[col, col_, f"ADM{adm_level}_EN", "geometry"]].explore(
        col_to_plot, cmap="YlOrRd", vmin=0, tiles="Esri.WorldGrayCanvas"
    )  # tiles='CartoDB.DarkMatterNoLabels' ,vmax=d_vmax[adm_level])
    return m


# New wrapper function to manage the interactive exploration and display the map
def interact_explore_preds():
    adm_level_widget = widgets.Dropdown(options=[1, 2, 3, 4], value=1, description="Admin Level:")
    threshold_widget = widgets.Dropdown(
        options=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], value=0.65, description="Threshold:"
    )
    agg_widget = widgets.Dropdown(options=["mean", "max"], value="mean", description="Aggregation:")
    year_widget = widgets.Dropdown(options=[2021, 2022, '2022 - 2021'], value=2022, description="Year:")
    relative_widget = widgets.Checkbox(value=True, description="Relative")
    show_zero_widgets = widgets.Checkbox(value=False, description="Show Zero Values")

    output = widgets.Output()

    def update(*args):
        # Use the output widget to manage the display of the map
        with output:
            clear_output(wait=True)  # Clear the previous map/output
            m = explore_preds(
                adm_level=adm_level_widget.value,
                threshold=threshold_widget.value,
                agg=agg_widget.value,
                year=year_widget.value,
                relative=relative_widget.value,
                show_zero=show_zero_widgets.value,
            )
            display(m)  # Display the new map

    # Button to trigger the update
    button = widgets.Button(description="Update Map")
    button.on_click(lambda b: update())

    ui = widgets.VBox(
        [widgets.HBox(
            [adm_level_widget, threshold_widget, agg_widget, year_widget, relative_widget, show_zero_widgets, button],
        ),
        output]
    )
    display(ui)


# Call the interactive wrapper function
interact_explore_preds()