In [1]:
import re
import json
import statistics

from pathlib import Path

import scipy.stats

import numpy as np
import matplotlib.pyplot as plt

In [2]:
def load_data(target):
    paths = Path("data/results/").glob(f"{target}*.json")
    results = {}
    for path in paths:
        pattern = f"{target}_(\\d+.\\d+)eps_(\\d+)steps"
        epsilon, steps = re.fullmatch(pattern, path.stem).groups()
        epsilon = float(epsilon)
        if epsilon > 0.07:
            continue
        with open(path, 'r') as f:
            contents = json.load(f)
            results[epsilon] = contents
    return results

In [3]:
def plot_results(target):
    results = load_data(target)
    aggregated_results = {}
    for epsilon in sorted(results.keys()):
        unperturbed_results = np.asarray(results[epsilon]["unperturbed"])
        for method, result in results[epsilon].items():
            if method == "unperturbed":
                continue
            # result is list with one list per target
            # each target contains multiple measurements for the single target
            medians = -np.median(np.asarray(result) - unperturbed_results, axis=1)
            mean = statistics.mean(medians)
            if len(medians) > 1:
                ci = 1.644850 * statistics.stdev(medians) / len(medians)**0.5
            else:
                ci = np.nan
            if method not in aggregated_results:
                aggregated_results[method] = []
            aggregated_results[method].append((float(epsilon), mean, ci))

    with plt.xkcd():
        fig, ax = plt.subplots(1, 1, figsize=(10, 8))
        for method, values in aggregated_results.items():
            values = np.asarray(values)
            epsilons = values[:, 0]
            epsilons = np.sqrt(1 + np.square(epsilons)) - 1
            # epsilons = 10 * np.log10(np.sqrt(1 + np.square(epsilons)))
            mean = values[:, 1]
            ax.plot(epsilons, mean, label=method, marker="+")
            ci = values[:, 2]
            ci = np.nan_to_num(ci)
            ax.fill_between(epsilons, mean - ci, mean + ci, alpha=0.2)
        
        # ax_n = ax.twinx()
        # epsilons = np.asarray(list(sorted(results.keys())))
        # epsilons = np.sqrt(1 + np.square(epsilons)) - 1
        # ax_n.plot(epsilons, [len(results[epsilon]["unperturbed"]) for epsilon in sorted(results.keys())], 'k*--')
        # ax_n.set_ylabel("number of samples")
        # ax_n.set_ylim(0)

        ax.legend()
        # ax.set_yscale('log')
        ax.set_xscale('log')
        ax.set_xlabel("Perturbation size")
        ax.set_ylabel(f"{target} deviation")
        ax.xaxis.set_major_formatter(lambda x, _: f"+{round(100*x, 2):.02f}%")
        ax.set_xticks(epsilons)
        None
    plt.show()

In [4]:
plot_results("wind")

In [5]:
plot_results("temperature")

In [6]:
plot_results("precipitation")

In [7]:
def results_to_dats(target):
    results = load_data(target)
    aggregated_results = {}
    for epsilon in sorted(results.keys()):
        unperturbed_results = np.asarray(results[epsilon]["unperturbed"])
        for method, result in results[epsilon].items():
            if method == "unperturbed":
                continue
            # result is list with one list per target
            # each target contains multiple measurements for the single target
            medians = -np.median(np.asarray(result) - unperturbed_results, axis=1)
            mean = statistics.mean(medians)
            if len(medians) > 1:
                ci = 1.644850 * statistics.stdev(medians) / len(medians)**0.5
            else:
                ci = np.nan
            if method not in aggregated_results:
                aggregated_results[method] = []
            aggregated_results[method].append((float(epsilon), mean, ci))

    for method, values in aggregated_results.items():
        values = np.asarray(values)
        epsilons = values[:, 0]
        epsilons = np.sqrt(1 + np.square(epsilons)) - 1
        mean = values[:, 1]
        ci = values[:, 2]
        ci = np.nan_to_num(ci)
        print(method)
        print("\n".join(" ".join(map(str, t)) for t in zip(epsilons,mean,ci)))

In [8]:
results_to_dats("wind")

In [9]:
results_to_dats("temperature")

In [10]:
results_to_dats("precipitation")

In [None]:
import xarray
from tqdm.auto import tqdm

In [93]:
era5 = xarray.open_zarr("gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr")
era5 = era5[["2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind", "total_precipitation_12hr"]]

In [123]:
subsampled = era5.isel(latitude=slice(0,None,4),longitude=slice(0,None,4),time=slice(1,None,2))

In [124]:
hours = [6, 18]
years = list(range(1959, 2023))
out = []
for hour in hours:
    datasets = []
    for year in tqdm(years):
        tmp = subsampled.isel(time=subsampled.time.dt.hour == hour).sel(time=str(year))
        tmp = tmp.assign_coords(dayofyear=tmp.time.dt.dayofyear).swap_dims(
            {'time': 'dayofyear'}
        )
        datasets.append(tmp)
    ds_per_hour = xarray.concat(
        datasets,
        dim=xarray.DataArray(
            np.arange(len(years)), coords={'number': np.arange(len(years))}
        ),
    )
    out.append(ds_per_hour)
out = xarray.concat(out, dim=xarray.DataArray(hours, dims=['hour']))

In [11]:
def compute_extremes(values):
    mean = values.mean(dim="number")
    maximums = (values - mean).max(dim=["dayofyear", "hour"])
    quantiles = maximums.quantile(dim="number",q=[0.99,0.999,1.0])
    return quantiles

In [12]:
compute_extremes(out["2m_temperature"]).mean(dim=["latitude", "longitude"]).compute()

In [147]:
compute_extremes(out["total_precipitation_12hr"]).mean(dim=["latitude", "longitude"]).compute()

In [148]:
wind_speed = (out["10m_u_component_of_wind"]**2 + out["10m_v_component_of_wind"]**2)**0.5
compute_extremes(wind_speed).mean(dim=["latitude", "longitude"]).compute()

In [65]:
from scipy import stats

var_count = 181 * 360 * (5 + 13*6)
def compute_intersection(target, value):
    results = load_data(target)
    aggregated_results = {}
    for epsilon in sorted(results.keys()):
        unperturbed_results = np.asarray(results[epsilon]["unperturbed"])
        for method, result in results[epsilon].items():
            if method == "unperturbed":
                continue
            # result is list with one list per target
            # each target contains multiple measurements for the single target
            medians = -np.median(np.asarray(result) - unperturbed_results, axis=1)
            mean = statistics.mean(medians)
            if len(medians) > 1:
                ci = 1.644850 * statistics.stdev(medians) / len(medians)**0.5
            else:
                ci = np.nan
            if method not in aggregated_results:
                aggregated_results[method] = []
            aggregated_results[method].append((float(epsilon), mean, ci))

    for method, values in aggregated_results.items():
        values = np.asarray(values)
        epsilons = values[:, 0]
        epsilons = np.sqrt(1 + np.square(epsilons)) - 1
        mean = values[:, 1]
        
        # find intersection
        # 1. find segment in which it will lie
        i = 0
        while i < len(mean) and mean[i] < value:
            i += 1
        # 2. linearly interpolate
        # print(method, epsilons[i], mean[i], epsilons[i-1], mean[i-1], value)
        if i < len(mean):
            m = (mean[max(i, 1)] - mean[max(i-1, 0)]) / (epsilons[max(i, 1)] - epsilons[max(i-1, 0)])
            intersection_epsilon = epsilons[i] + (value - mean[i]) / m
        else:
            m = (mean[-1] - mean[-2]) / (epsilons[-1] - epsilons[-2])
            intersection_epsilon = epsilons[-1] + (value - mean[-1]) / m
        p = 1 - stats.chi2.cdf(stats.chi2.ppf(0.99, var_count-1, var_count-1), var_count - 1, var_count - 1, scale=intersection_epsilon + 1)
        print(method, f"{intersection_epsilon*100:.3f}%", p**2)

In [66]:
compute_intersection("temperature", 11.75064901)

In [67]:
compute_intersection("wind", 12.56660729)

In [68]:
compute_intersection("precipitation", 0.06293304)

In [None]:
with open("data/weather_evaluation_targets.json", "r") as f:
    targets = json.load(f)
ATTACKS = ["Ours", "DP-Attacker", "AdvDM"]

In [246]:
def get_per_attack(target, target_value):
    results = load_data(target)
    results_per_attack = {"Ours": [], "DP-Attacker": [], "AdvDM": []}
    for i, target in enumerate(targets):
        unperturbed = -np.asarray([results[epsilon]["unperturbed"][i] for epsilon in results.keys()])
        lat, lon = target["location"]["latitude"], target["location"]["longitude"]
        lat = round(lat)
        lon = round(lon)
        for attack in ATTACKS:
            epsilons = list(results.keys())
            epsilons = np.sqrt(1 + np.square(epsilons)) - 1
            values = -np.asarray([results[epsilon][attack][i] for epsilon in results.keys()])
            mean = np.median(values - unperturbed, axis=-1)
            order = np.argsort(epsilons)
            mean = mean[order]
            epsilons = epsilons[order]
            
            # find intersection
            # 1. find segment in which it will lie
            i = 0
            while i < len(mean) and mean[i] < target_value:
                i += 1
            # 2. linearly interpolate
            if i < len(mean):
                m = (mean[max(i, 1)] - mean[max(i-1, 0)]) / (epsilons[max(i, 1)] - epsilons[max(i-1, 0)])
                intersection_epsilon = epsilons[i] + (target_value - mean[i]) / m
            else:
                m = (mean[-1] - mean[-2]) / (epsilons[-1] - epsilons[-2])
                intersection_epsilon = epsilons[-1] + (target_value - mean[-1]) / m
            results_per_attack[attack].append((lon, lat, intersection_epsilon))
    return results_per_attack

temperature_per_attack = get_per_attack("temperature", 11.75064901)
wind_per_attack = get_per_attack("wind", 12.56660729)
precipitation_per_attack = get_per_attack("precipitation", 0.06293304)
mean_ours = (
    np.asarray(temperature_per_attack["Ours"])[:, 2] +
    np.asarray(wind_per_attack["Ours"])[:, 2] +
    np.asarray(precipitation_per_attack["Ours"])[:, 2]
    ) / 3
lats = np.asarray(temperature_per_attack["Ours"])[:, 1]

plot(None, lats, mean_ours)
plt.show()

In [247]:
print("\n".join(f"{x} {y}" for x,y in zip(np.abs(lats), mean_ours)))