In [None]:
import numpy as np
import xarray as xr


def calculate_station_weights(
    data_xr: xr.DataArray, reference_angle: float = 0.75
) -> xr.DataArray:
    """
    Calculate station density weights based on equation 22 in Rodwell (2010).

    Args:
        data_xr: xarray DataArray with dimensions (time, station) and coordinates
                'latitude' and 'longitude' for the station dimension
        reference_angle: Reference angle in degrees (default 0.75° as per paper)

    Returns:
        xarray DataArray with same shape as input containing the weights
    """

    def normalize_longitudes(lons):
        """Convert longitudes to [-180, 180] convention"""
        return np.where(lons > 180, lons - 360, lons)

    def calculate_angular_distance_vectorized(lat1, lon1, lats2, lons2):
        """
        Vectorised version of angular distance calculation for arrays of points.
        """
        # Convert to radians
        lat1, lon1 = np.radians(lat1), np.radians(normalize_longitudes(lon1))
        lats2 = np.radians(lats2)
        lons2 = np.radians(normalize_longitudes(lons2))

        # Calculate absolute difference in longitude, handling wraparound
        dlon = np.abs(lon1 - lons2)
        dlon = np.where(dlon > np.pi, 2 * np.pi - dlon, dlon)

        # Haversine formula vectorized
        dlat = lats2 - lat1
        a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lats2) * np.sin(dlon / 2) ** 2
        c = 2 * np.arcsin(np.sqrt(a))

        return np.degrees(c)

    weights = xr.zeros_like(data_xr)
    lats = data_xr.latitude.values
    lons = data_xr.longitude.values

    for t in data_xr.time:
        # Get data for this timestep
        time_slice = data_xr.sel(time=t)

        # Get stations that have valid (non-NaN) data at this timestep
        valid_mask = ~np.isnan(time_slice)
        valid_stations = time_slice.station[valid_mask]

        # Get valid station coordinates
        valid_lats = lats[valid_mask]
        valid_lons = lons[valid_mask]

        # Calculate density for each valid station
        for i, station in enumerate(valid_stations.values):
            station_lat = valid_lats[i]
            station_lon = valid_lons[i]

            distances = calculate_angular_distance_vectorized(
                station_lat, station_lon, valid_lats, valid_lons
            )
            density = np.sum(np.exp(-((distances / reference_angle) ** 2)))
            weight = 1.0 / density
            weights.loc[dict(time=t, station=station)] = weight

    return weights


def normalised_weights(
    data_xr: xr.DataArray, reference_angle: float = 0.75
) -> xr.DataArray:
    """
    Apply station density weights to input data.

    Note! This isn't strictly necessary with scores which will now automatically
    normalise the weights!

    Args:
        data_xr: xarray DataArray with dimensions (time, station) and coordinates
                'latitude' and 'longitude' for the station dimension
        reference_angle: Reference angle in degrees (default 0.75° as per paper)

    Returns:
        xarray DataArray with same shape as input containing the weighted values
    """
    # Calculate weights
    weights = calculate_station_weights(data_xr, reference_angle)

    weights = weights.where(weights > 0)
    normalisation = weights.count("station") / weights.sum("station")
    normalised_weights = weights * normalisation

    return normalised_weights

In [None]:
RESULTS_PATH = "../data/station_weights/"

In [None]:
da_099 = xr.open_dataarray(f"{RESULTS_PATH}hrrr27_099.nc")
da_0999 = xr.open_dataarray(f"{RESULTS_PATH}hrrr27_0999.nc")

In [None]:
weight_list_099 = []
weight_list_0999 = []
for i_lead_time in np.arange(0, 8):
    da1 = da_099.sel(component="total").isel(lead_time=i_lead_time)
    weights1 = normalised_weights(da1)
    weights1 = weights1.expand_dims("lead_time")
    weight_list_099.append(weights1)
    da2 = da_0999.sel(component="total").isel(lead_time=i_lead_time)
    weights2 = normalised_weights(da2)
    weights2 = weights2.expand_dims("lead_time")
    weight_list_0999.append(weights2)
    print(i_lead_time)

weights_099 = xr.concat(weight_list_099, dim="lead_time")
weights_099.to_netcdf(f"{RESULTS_PATH}weights_099.nc")
weights_0999 = xr.concat(weight_list_0999, dim="lead_time")
weights_0999.to_netcdf(f"{RESULTS_PATH}weights_0999.nc")