In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.io import shapereader

from matplotlib.colors import LogNorm

import regionmask
import ipywidgets as widgets
from IPython.display import display

# -----------------------
# User settings
# -----------------------
FILE = "grpi_hemco.nc"
VAR = "emi_ch4"
COUNTRY_LIST = ["India", "Pakistan", "Cambodia", "Myanmar", "Brazil"]

# Visualization settings
COARSEN = 1                 # 1 = no coarsen (highest resolution); 2–4 = faster interaction
PAD_DEG_DEFAULT = 1.0       # padding around country bounds (degrees)
PAD_DEG_SMALL = 2.0         # extra padding for small countries (e.g., Cambodia)
SMALL_COUNTRIES = {"Cambodia"}

# -----------------------
# Load data
# -----------------------
ds = xr.open_dataset(FILE)
da = ds[VAR]  # dims: time, lat, lon

# Ensure lon is -180..180 so masking & zoom behave consistently (esp. Brazil)
if float(da["lon"].max()) > 180:
    da = da.assign_coords(lon=(((da["lon"] + 180) % 360) - 180)).sortby("lon")

# Load into memory for responsive interaction (recommended)
da = da.load()

# -----------------------
# Countries mask (regionmask)
# -----------------------
countries = regionmask.defined_regions.natural_earth_v5_0_0.countries_110
name_to_idx = {n: i for i, n in enumerate(countries.names)}

# Keep only countries that exist in this Natural Earth dataset
COUNTRY_LIST = [c for c in COUNTRY_LIST if c in name_to_idx]
if not COUNTRY_LIST:
    raise ValueError("None of the requested countries were found in regionmask Natural Earth names.")

# Mask is an integer region index inside polygon, NaN outside
mask = countries.mask(da.isel(time=0))  # version-safe (no lat_name/lon_name)

def coarsen_if_needed(x):
    if COARSEN and COARSEN > 1:
        # coarsen needs both dims; mask and da have lat/lon
        return x.coarsen(lat=COARSEN, lon=COARSEN, boundary="trim").mean()
    return x

mask_plot = coarsen_if_needed(mask)

# -----------------------
# Zoom extents (Cartopy Natural Earth shapefile) – robust
# -----------------------
shp = shapereader.natural_earth(resolution="110m", category="cultural", name="admin_0_countries")
records = list(shapereader.Reader(shp).records())

def polygon_extent_cartopy(country_name, pad_deg=1.0):
    """
    Return [lon_min, lon_max, lat_min, lat_max] from Natural Earth polygon bounds.
    Tries multiple attribute keys for robust country matching.
    """
    target = country_name.lower()
    name_fields = ["NAME_LONG", "NAME", "ADMIN", "SOVEREIGNT"]

    def exact_match(rec):
        attrs = rec.attributes
        for f in name_fields:
            v = attrs.get(f)
            if isinstance(v, str) and v.lower() == target:
                return True
        return False

    matches = [r for r in records if exact_match(r)]

    if not matches:
        # Fallback: contains match
        def contains_match(rec):
            attrs = rec.attributes
            for f in name_fields:
                v = attrs.get(f)
                if isinstance(v, str) and target in v.lower():
                    return True
            return False
        matches = [r for r in records if contains_match(r)]

    if not matches:
        raise ValueError(
            f"Could not find '{country_name}' in Natural Earth shapefile attributes. "
            "Try checking record attributes via: records[0].attributes"
        )

    geom = matches[0].geometry
    minx, miny, maxx, maxy = geom.bounds
    return [minx - pad_deg, maxx + pad_deg, miny - pad_deg, maxy + pad_deg]

extent_cache = {}
for c in COUNTRY_LIST:
    pad = PAD_DEG_SMALL if c in SMALL_COUNTRIES else PAD_DEG_DEFAULT
    extent_cache[c] = polygon_extent_cartopy(c, pad_deg=pad)

# -----------------------
# Shared log-scale limits (optional)
# Compute once across selected countries and all months for stable comparison.
# -----------------------
def compute_shared_limits(da, mask, country_list):
    vmins, vmaxs = [], []
    for cname in country_list:
        idx = name_to_idx[cname]
        da_c = da.where(mask == idx)
        da_pos = da_c.where(da_c > 0)
        if bool(da_pos.notnull().any()):
            vmins.append(float(da_pos.quantile(0.02)))
            vmaxs.append(float(da_pos.quantile(0.98)))
    if not vmins:
        return 1e-12, 1.0
    return min(vmins), max(vmaxs)

shared_vmin, shared_vmax = compute_shared_limits(da, mask, COUNTRY_LIST)

# -----------------------
# Widgets
# -----------------------
time_labels = [str(t)[:10] for t in da["time"].values]

country_w = widgets.Dropdown(
    options=COUNTRY_LIST,
    value=COUNTRY_LIST[0],
    description="Country:",
    layout=widgets.Layout(width="260px")
)

month_w = widgets.IntSlider(
    value=0,
    min=0,
    max=len(time_labels) - 1,
    step=1,
    description="Month:",
    continuous_update=False,
    layout=widgets.Layout(width="520px")
)

play_w = widgets.Play(
    value=0,
    min=0,
    max=len(time_labels) - 1,
    step=1,
    interval=600
)
widgets.jslink((play_w, "value"), (month_w, "value"))

use_shared_scale_w = widgets.Checkbox(
    value=True,
    description="Shared color scale",
    indent=False
)

date_label = widgets.HTML(value=f"<b>Date:</b> {time_labels[0]}")
out = widgets.Output()

# -----------------------
# Plot function
# -----------------------
def render(country, t_index, shared_scale):
    date_str = time_labels[t_index]
    date_label.value = f"<b>Date:</b> {date_str}"

    idx = name_to_idx[country]

    da_t = da.isel(time=t_index)
    da_country = da_t.where(mask == idx)  # masked to country polygon

    da_plot = coarsen_if_needed(da_country)
    da_plot_pos = da_plot.where(da_plot > 0)  # LogNorm requires positive

    if shared_scale:
        vmin, vmax = shared_vmin, shared_vmax
    else:
        if bool(da_plot_pos.notnull().any()):
            vmin = float(da_plot_pos.quantile(0.02))
            vmax = float(da_plot_pos.quantile(0.98))
        else:
            vmin, vmax = 1e-12, 1.0

    # Safety for LogNorm
    vmin = max(float(vmin), 1e-30)
    vmax = max(float(vmax), vmin * 10)

    units = da.attrs.get("units", "")
    label = "CH₄ emissions" + (f" [{units}]" if units else "")

    with out:
        out.clear_output(wait=True)

        fig = plt.figure(figsize=(7.8, 7.8))
        ax = plt.axes(projection=ccrs.PlateCarree())

        # Zoom to country bounds (robust)
        ax.set_extent(extent_cache[country], crs=ccrs.PlateCarree())

        # Basemap
        ax.add_feature(cfeature.LAND, facecolor="0.95", zorder=0)
        ax.add_feature(cfeature.COASTLINE, linewidth=0.8)
        ax.add_feature(cfeature.BORDERS, linewidth=0.6, alpha=0.8)

        # Data layer
        mesh = ax.pcolormesh(
            da_plot["lon"], da_plot["lat"], da_plot_pos,
            transform=ccrs.PlateCarree(),
            shading="auto",
            cmap="viridis",
            norm=LogNorm(vmin=vmin, vmax=vmax),
        )

        cbar = plt.colorbar(mesh, ax=ax, shrink=0.82, pad=0.02)
        cbar.set_label(label)

        ax.set_title(f"{country} — {date_str} (masked to country, log scale)")
        plt.tight_layout()
        plt.show()
        plt.close(fig)  # important for stability

# -----------------------
# UI wiring (stable)
# -----------------------
ui = widgets.VBox([
    widgets.HBox([country_w, use_shared_scale_w]),
    widgets.HBox([month_w, play_w]),
    date_label,
    out
])

display(ui)

widgets.interactive_output(
    render,
    {"country": country_w, "t_index": month_w, "shared_scale": use_shared_scale_w}
)

# Initial draw
render(country_w.value, month_w.value, use_shared_scale_w.value)




VBox(children=(HBox(children=(Dropdown(description='Country:', layout=Layout(width='260px'), options=('India',…