In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import math
from google.cloud import storage
import numpy as np
import xarray
import jax
import jax.numpy as jnp
import haiku as hk
import dataclasses

from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import xarray_tree
from graphcast import gencast
from graphcast import denoiser
from graphcast import nan_cleaning
import graphcast.casting
import graphcast.samplers_utils

from data_loading import load_data

In [3]:
from graphcast.rollout import _get_next_inputs

In [4]:
# anonymous gcs client for the bucket
# bucket contains model, example data and normalization data
gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
dir_prefix = "gencast/"

# Load normalization data
with gcs_bucket.blob(dir_prefix+"stats/diffs_stddev_by_level.nc").open("rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/mean_by_level.nc").open("rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/stddev_by_level.nc").open("rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/min_by_level.nc").open("rb") as f:
  min_by_level = xarray.load_dataset(f).compute()

In [5]:
with gcs_bucket.blob(dir_prefix + f"params/GenCast 1p0deg Mini <2019.npz").open("rb") as f:
    ckpt = checkpoint.load(f, gencast.CheckPoint)
params = ckpt.params
state = {}

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config

# change to triblockdiag_mha for gpu
denoiser_architecture_config.sparse_transformer_config.attention_type = "triblockdiag_mha"
denoiser_architecture_config.sparse_transformer_config.mask_type = "full"

# fix to more appropriate config
noise_config = gencast.NoiseConfig(
    training_max_noise_level=sampler_config.max_noise_level/2,
    training_min_noise_level=sampler_config.min_noise_level,
    training_noise_level_rho=sampler_config.rho
)

In [6]:
def construct_wrapped_gencast():
  """Constructs and wraps the GenCast Predictor."""
  predictor = gencast.GenCast(
      sampler_config=sampler_config,
      task_config=task_config,
      denoiser_architecture_config=denoiser_architecture_config,
      noise_config=noise_config,
      noise_encoder_config=noise_encoder_config,
  )

  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level,
  )

  predictor = nan_cleaning.NaNCleaner(
      predictor=predictor,
      reintroduce_nans=True,
      fill_value=min_by_level,
      var_to_clean='sea_surface_temperature',
  )

  return predictor

In [7]:
def build_static_data_selector(coords, lat_start, lat_end, lon_start, lon_end):
    def find_index(haystack, needle):
        return np.argmin(np.abs(haystack.data - needle))
    lat_start = find_index(coords["lat"], lat_start)
    lat_end = find_index(coords["lat"], lat_end)

    def convert_lon(lon):
        if lon < 0:
            return lon + 360
        return lon

    lon_start = find_index(coords["lon"], convert_lon(lon_start))
    lon_end = find_index(coords["lon"], convert_lon(lon_end))
    # add 1 to include last point
    lat_end += 1
    lon_end += 1

    # print(f"{lat_start}:{lat_end}, {lon_start}:{lon_end}")
    def selector(data):
        data = jnp.roll(data, -lon_start, axis=-1)
        
        inner_lon_end = (lon_end - lon_start) % len(coords["lon"].data)
        return data[..., lat_start:lat_end, 0:inner_lon_end]
    return selector

In [8]:
@hk.transform_with_state
def exact_forward_fn(inputs, targets, forcings):
    predictor = construct_wrapped_gencast()
    denoised_predictions = predictor(
        inputs, targets, forcings
    )

    return denoised_predictions

forward_fn_jitted = jax.jit(
    lambda rng, i, t, f: exact_forward_fn.apply(params, state, rng, i, t, f)[0]
)

In [9]:
@hk.transform_with_state
def train_forward_fn(inputs, targets, forcings):
    nan_cleaner = construct_wrapped_gencast()
    normalizer = nan_cleaner._predictor
    predictor = normalizer._predictor

    if nan_cleaner._var_to_clean in inputs.keys():
        inputs = nan_cleaner._clean(inputs)
    if nan_cleaner._var_to_clean in targets.keys():
        targets = nan_cleaner._clean(targets)
    if forcings and nan_cleaner._var_to_clean in forcings.keys():
        forcings = nan_cleaner._clean(forcings)

    norm_inputs = normalization.normalize(inputs, normalizer._scales, normalizer._locations)
    norm_forcings = normalization.normalize(forcings, normalizer._scales, normalizer._locations)
    norm_target_residuals = xarray_tree.map_structure(
        lambda t: normalizer._subtract_input_and_normalize_target(inputs, t),
        targets
    )

    dtype = graphcast.casting.infer_floating_dtype(targets)  # pytype: disable=wrong-arg-types
    key = hk.next_rng_key()
    batch_size = inputs.sizes['batch']
    noise_levels = xarray_jax.DataArray(
        data=graphcast.samplers_utils.rho_inverse_cdf(
            min_value=predictor._noise_config.training_min_noise_level,
            max_value=predictor._noise_config.training_max_noise_level,
            rho=predictor._noise_config.training_noise_level_rho,
            cdf=jax.random.uniform(key, shape=(batch_size,), dtype=dtype)
        ),
        dims=('batch',)
    )

    # Sample noise and apply it to targets:
    noise = (
        graphcast.samplers_utils.spherical_white_noise_like(targets) * noise_levels
    )
    noisy_targets = norm_target_residuals + noise

    # make actual predictions
    denoised_predictions = predictor._preconditioned_denoiser(
        norm_inputs, noisy_targets, noise_levels, norm_forcings
    )

    denoised_predictions = xarray_tree.map_structure(
        lambda pred: normalizer._unnormalize_prediction_and_add_input(inputs, pred),
        denoised_predictions,
    )

    return denoised_predictions

train_forward_fn_jitted = jax.jit(
    lambda rng, i, t, f: train_forward_fn.apply(params, state, rng, i, t, f)[0]
)

In [10]:
@hk.transform_with_state
def approx_forward_fn(inputs, targets, forcings, approximation_steps):
    nan_cleaner = construct_wrapped_gencast()
    normalizer = nan_cleaner._predictor
    predictor = normalizer._predictor

    if nan_cleaner._var_to_clean in inputs.keys():
        inputs = nan_cleaner._clean(inputs)
    if nan_cleaner._var_to_clean in targets.keys():
        targets = nan_cleaner._clean(targets)
    if forcings and nan_cleaner._var_to_clean in forcings.keys():
        forcings = nan_cleaner._clean(forcings)

    norm_inputs = normalization.normalize(inputs, normalizer._scales, normalizer._locations)
    norm_forcings = normalization.normalize(forcings, normalizer._scales, normalizer._locations)
    norm_target_residuals = xarray_tree.map_structure(
        lambda t: normalizer._subtract_input_and_normalize_target(inputs, t),
        targets
    )

    dtype = graphcast.casting.infer_floating_dtype(targets)  # pytype: disable=wrong-arg-types
    batch_size = inputs.sizes['batch']

    key = hk.next_rng_key()
    noise_levels = xarray_jax.DataArray(
        data=graphcast.samplers_utils.rho_inverse_cdf(
            min_value=predictor._noise_config.training_min_noise_level,
            max_value=predictor._noise_config.training_max_noise_level,
            rho=predictor._noise_config.training_noise_level_rho,
            cdf=jax.random.uniform(key, minval=(approximation_steps - 1)/approximation_steps, maxval=1, shape=(batch_size,), dtype=dtype)
        ),
        dims=('batch',)
    )
    noise = (
        graphcast.samplers_utils.spherical_white_noise_like(targets) * noise_levels
    )
    x_current = norm_target_residuals + noise

    for step in range(approximation_steps):
        # make actual predictions
        denoised_predictions = predictor._preconditioned_denoiser(
            norm_inputs, x_current, noise_levels, norm_forcings
        )
        if step == approximation_steps - 1:
            continue # skip rest, as no next sample exists

        new_noise_levels = xarray_jax.DataArray(
            data=graphcast.samplers_utils.rho_inverse_cdf(
                min_value=predictor._noise_config.training_min_noise_level,
                max_value=predictor._noise_config.training_max_noise_level,
                rho=predictor._noise_config.training_noise_level_rho,
                cdf=jax.random.uniform(
                    key,
                    minval=(approximation_steps - step - 2)/approximation_steps,
                    maxval=(approximation_steps - step - 1)/approximation_steps,
                    shape=(batch_size,),
                    dtype=dtype
                )
            ),
            dims=('batch',)
        )
        next_over_current = new_noise_levels / noise_levels
        x_next = (1 - next_over_current) * denoised_predictions + next_over_current * x_current
        
        noise_levels = new_noise_levels
        x_current = x_next

    denoised_predictions = xarray_tree.map_structure(
        lambda pred: normalizer._unnormalize_prediction_and_add_input(inputs, pred),
        denoised_predictions,
    )

    return denoised_predictions

approx_forward_fn_jitted = jax.jit(
    lambda rng, i, t, f, s: approx_forward_fn.apply(params, state, rng, i, t, f, s)[0],
    static_argnums=[4],
)

In [11]:
def multi_step_forward(rng, inputs, targets, forcings, forward_fn):
    forcings_coords = xarray_jax.unwrap_coords(forcings.isel(time=slice(0, 1)))
    targets_coords = xarray_jax.unwrap_coords(targets.isel(time=slice(0, 1)))
    forcings_dims = {k: forcings[k].dims for k in forcings.data_vars}
    targets_dims = {k: targets[k].dims for k in targets.data_vars}
    def body_fn(current_inputs, arrays):
        current_targets, current_forcings = arrays
        current_targets = xarray.Dataset(
            {
                k: (
                    targets_dims[k],
                    xarray_jax.wrap(
                        jnp.expand_dims(jnp.expand_dims(v, 0), 0)
                        )
                    )
                for k, v in current_targets.items()
            },
            targets_coords,
        )
        current_forcings = xarray.Dataset(
            {
                k: (forcings_dims[k], xarray_jax.wrap(jnp.expand_dims(v, [0, 1])))
                for k, v in current_forcings.items()
            },
            forcings_coords,
        )

        predictions = forward_fn(
            rng, current_inputs, current_targets, current_forcings,
        )
        next_frame = xarray.merge([predictions, current_forcings])
        next_inputs = _get_next_inputs(current_inputs, next_frame)
        
        next_inputs = next_inputs.assign_coords(time=current_inputs.coords["time"])
        return next_inputs, xarray_jax.unwrap_vars(predictions)
    
    arrays = (
        xarray_jax.unwrap_vars(targets.isel(batch=0)),
        xarray_jax.unwrap_vars(forcings.isel(batch=0)),
    )
    _, all_predictions = jax.lax.scan(jax.checkpoint(body_fn), inputs, arrays)
    all_predictions = xarray.Dataset(
        {
            k: (targets_dims[k], xarray_jax.wrap(jnp.expand_dims(jnp.squeeze(v), 0)))
            for k, v in all_predictions.items()
        },
        xarray_jax.unwrap_coords(targets)
    )
    return all_predictions
multi_step_forward_jit = jax.jit(multi_step_forward, static_argnums=4)

Load data (required to be read before loss function definition to use coords for target location selection)

In [16]:
one_year = xarray.open_dataset("data/one_year_era5.nc")
inputs, targets, forcings = load_data(one_year, "2022-11-07T06:00:00", task_config)

Loss function and gradient definitions

In [14]:
def general_loss_fn(rng, inputs, targets, forcings, forward_fn):
  denoised_predictions = multi_step_forward_jit(
    rng,
    inputs,
    targets,
    forcings,
    forward_fn
  ).isel(time=-1)

  # sqrt(u^2 + v^2)
  wind_speed = np.sqrt(
    np.square(denoised_predictions["10m_u_component_of_wind"]) +
    np.square(denoised_predictions["10m_v_component_of_wind"])
  )

  # select wind speed in LA/CCS 2022
  wind_speed = xarray_jax.unwrap_data(wind_speed, require_jax=True)
  wind_speed = build_static_data_selector(inputs.coords, 34, 34, -118, -118)(wind_speed)

  # minus so that minimization of loss maximizes wind speed
  loss = -jnp.max(wind_speed)
  return loss

baseline_loss_fn = jax.jit(
  lambda rng, i, t, f: general_loss_fn(rng, i, t, f, train_forward_fn_jitted),
)

improved_loss_fn = jax.jit(
  lambda rng, i, t, f, s: general_loss_fn(rng, i, t, f, lambda rng, i, t, f: approx_forward_fn_jitted(rng, i, t, f, s)),
  static_argnums=(4,),
)

In [None]:
def adv_grads_fn(rng, inputs, targets, forcings, forward_fn):
  def _aux(rng, i, t, f):
    loss = general_loss_fn(
      rng, i, t, f, forward_fn,
    )
    return loss

  loss, grads = jax.value_and_grad(
    _aux,
    argnums=1,
  )(rng, inputs, targets, forcings)
  return loss, grads

baseline_grads_fn = jax.jit(lambda rng, inputs, targets, forcings: adv_grads_fn(rng, inputs, targets, forcings, train_forward_fn_jitted))
improved_grads_fn = jax.jit(lambda rng, inputs, targets, forcings: adv_grads_fn(
  rng,
  inputs,
  targets,
  forcings,
  lambda r, i, t, f: approx_forward_fn_jitted(r, i, t, f, 2),
))

In [46]:
further_improved_grads_fn = jax.jit(lambda rng, inputs, targets, forcings: adv_grads_fn(
  rng,
  inputs,
  targets,
  forcings,
  lambda r, i, t, f: approx_forward_fn_jitted(r, i, t, f, 3),
))

In [None]:
# baseline predicted wind speed

loss = general_loss_fn(
    jax.random.PRNGKey(1234567890),
    inputs,
    targets,
    forcings,
    forward_fn_jitted,
)
print("Wind speed (without perturbation):", -float(loss))

  self._set_arrayXarray(i, j, x)
  self._set_arrayXarray(i, j, x)


Loss: -7.2243332862854


Define baseline attack and attack with improved gradients

In [18]:
SCALES_PATH = "./data/estimated_error_scales.nc"
with open(SCALES_PATH, "rb") as f:
    STDDEVS = xarray.load_dataset(f).compute()

In [21]:
def add_perturbation(inputs, perturbation):
    return inputs + normalization.unnormalize(perturbation, STDDEVS, None)


def scale_std(data):
    data = data - jnp.mean(data, axis=(-1, -2))[..., np.newaxis, np.newaxis]
    current_std = jnp.std(data, axis=(-1, -2))
    return data / current_std[..., np.newaxis, np.newaxis]


def projection(data, epsilon):
    data = xarray_jax.unwrap(data)
    data = data - jnp.mean(data, axis=(-1, -2))[..., np.newaxis, np.newaxis]
    current_std = jnp.std(data, axis=(-1, -2))
    current_std = current_std[..., np.newaxis, np.newaxis]
    data = data / current_std * jnp.minimum(current_std, epsilon)
    return xarray_jax.wrap(data)

In [31]:
VARS_TO_ATTACK = [
    '10m_u_component_of_wind',
    '10m_v_component_of_wind',
    '2m_temperature',
    'geopotential',
    'mean_sea_level_pressure',
    'sea_surface_temperature',
    'specific_humidity',
    'temperature',
    'u_component_of_wind',
    'v_component_of_wind',
    'vertical_velocity',
]

def baseline_attack(inputs, targets, forcings, epsilon, maxiter=10):
    # refer to Algorithm 1 of https://arxiv.org/abs/2405.19424
    # changed to 'l2' norm instead of l-inf norm
    # changes sign of grad to normalize by std
    alpha = 2 * epsilon / maxiter
    
    # zero init
    perturbation = xarray_tree.map_structure(lambda a: 0*a, inputs)
    for t in range(maxiter):
        perturbed_inputs = add_perturbation(inputs, perturbation)
        loss, grads = baseline_grads_fn(
            rng=jax.random.PRNGKey(t),
            inputs=perturbed_inputs,
            targets=targets,
            forcings=forcings,
        )
        for var in VARS_TO_ATTACK:
            x = xarray_jax.unwrap_data(perturbation[var])
            diff = xarray_jax.unwrap_data(grads[var])
            diff = scale_std(diff) # scale to std = 1
            new_x = x - alpha * diff
            new_x = projection(new_x, epsilon)
            perturbation[var].data = xarray_jax.wrap(new_x)
        print(f"Step {t}: (approx.) Wind Speed: {-loss}")

    return perturbation

In [36]:
perturbation = baseline_attack(
    inputs,
    targets,
    forcings,
    0.05,
    maxiter=50,
)

loss = general_loss_fn(
    jax.random.PRNGKey(1234567890),
    add_perturbation(inputs, perturbation),
    targets,
    forcings,
    forward_fn_jitted,
)
print("Final result:", -float(loss))

Step 0: (approx.) Wind Speed: 1.4433245658874512
Step 1: (approx.) Wind Speed: 1.3948262929916382
Step 2: (approx.) Wind Speed: 4.288384437561035
Step 3: (approx.) Wind Speed: 1.462218165397644
Step 4: (approx.) Wind Speed: 4.538511276245117
Step 5: (approx.) Wind Speed: 1.5979453325271606
Step 6: (approx.) Wind Speed: 1.426227331161499
Step 7: (approx.) Wind Speed: 4.000718116760254
Step 8: (approx.) Wind Speed: 5.748746395111084
Step 9: (approx.) Wind Speed: 1.740755558013916
Step 10: (approx.) Wind Speed: 1.5389586687088013
Step 11: (approx.) Wind Speed: 2.7514212131500244
Step 12: (approx.) Wind Speed: 1.2495098114013672
Step 13: (approx.) Wind Speed: 2.6105642318725586
Step 14: (approx.) Wind Speed: 5.588534832000732
Step 15: (approx.) Wind Speed: 7.0757856369018555
Step 16: (approx.) Wind Speed: 7.806009769439697
Step 17: (approx.) Wind Speed: 2.105501651763916
Step 18: (approx.) Wind Speed: 8.645541191101074
Step 19: (approx.) Wind Speed: 5.073997497558594
Step 20: (approx.) Win

In [29]:
def improved_attack(inputs, targets, forcings, epsilon, maxiter=10):
    alpha = 2 * epsilon / maxiter
    
    # zero init
    perturbation = xarray_tree.map_structure(lambda a: 0*a, inputs)
    for t in range(maxiter):
        perturbed_inputs = add_perturbation(inputs, perturbation)
        loss, grads = improved_grads_fn(
            rng=jax.random.PRNGKey(t),
            inputs=perturbed_inputs,
            targets=targets,
            forcings=forcings,
        )
        for var in VARS_TO_ATTACK:
            x = xarray_jax.unwrap_data(perturbation[var])
            diff = xarray_jax.unwrap_data(grads[var])
            diff = scale_std(diff) # scale to std = 1
            new_x = x - alpha * diff
            new_x = projection(new_x, epsilon)
            perturbation[var].data = xarray_jax.wrap(new_x)
        print(f"Step {t}: (approx.) Wind Speed: {-loss}")

    return perturbation

In [None]:
perturbation = improved_attack(
    inputs,
    targets,
    forcings,
    0.05,
    maxiter=50,
)

loss = general_loss_fn(
    jax.random.PRNGKey(1234567890),
    add_perturbation(inputs, perturbation),
    targets,
    forcings,
    forward_fn_jitted,
)
print("Final result:", -float(loss))

Step 0: (approx.) Wind Speed: 1.0412932634353638
Step 1: (approx.) Wind Speed: 4.2267632484436035
Step 2: (approx.) Wind Speed: 3.977543830871582
Step 3: (approx.) Wind Speed: 1.2554935216903687
Step 4: (approx.) Wind Speed: 5.2124104499816895
Step 5: (approx.) Wind Speed: 1.2360069751739502
Step 6: (approx.) Wind Speed: 2.2564589977264404
Step 7: (approx.) Wind Speed: 6.575007915496826
Step 8: (approx.) Wind Speed: 7.477817535400391
Step 9: (approx.) Wind Speed: 1.8586806058883667
Step 10: (approx.) Wind Speed: 1.6908849477767944
Step 11: (approx.) Wind Speed: 9.239643096923828
Step 12: (approx.) Wind Speed: 8.010008811950684
Step 13: (approx.) Wind Speed: 5.974884986877441
Step 14: (approx.) Wind Speed: 8.77587604522705
Step 15: (approx.) Wind Speed: 10.150940895080566
Step 16: (approx.) Wind Speed: 10.392120361328125
Step 17: (approx.) Wind Speed: 10.731352806091309
Step 18: (approx.) Wind Speed: 12.22215461730957
Step 19: (approx.) Wind Speed: 11.645843505859375
Step 20: (approx.) 

In [38]:
def improved_attack_with_momentum(inputs, targets, forcings, epsilon, maxiter=10):
    alpha = 2 * epsilon / maxiter
    beta = 0.9
    
    # zero init
    perturbation = xarray_tree.map_structure(lambda a: 0*a, inputs)
    first_moment = xarray_tree.map_structure(lambda a: 0*a, inputs)

    for t in range(maxiter):
        perturbed_inputs = add_perturbation(inputs, perturbation)
        loss, grads = improved_grads_fn(
            rng=jax.random.PRNGKey(t),
            inputs=perturbed_inputs,
            targets=targets,
            forcings=forcings,
        )
        for var in VARS_TO_ATTACK:
            x = xarray_jax.unwrap_data(perturbation[var])
            diff = xarray_jax.unwrap_data(grads[var])
            diff = scale_std(diff) # scale to std = 1
            diff = beta * xarray_jax.unwrap_data(first_moment[var]) + (1 - beta) * diff
            first_moment[var].data = xarray_jax.wrap(diff)
            new_x = x - alpha * diff
            new_x = projection(new_x, epsilon)
            perturbation[var].data = xarray_jax.wrap(new_x)
        print(f"Step {t}: (approx.) Wind Speed: {-loss}")

    return perturbation

In [39]:
perturbation = improved_attack_with_momentum(
    inputs,
    targets,
    forcings,
    0.05,
    maxiter=50,
)

loss = general_loss_fn(
    jax.random.PRNGKey(1234567890),
    add_perturbation(inputs, perturbation),
    targets,
    forcings,
    forward_fn_jitted,
)
print("Final result:", -float(loss))

Step 0: (approx.) Wind Speed: 1.0463693141937256
Step 1: (approx.) Wind Speed: 4.152469635009766
Step 2: (approx.) Wind Speed: 3.469343900680542
Step 3: (approx.) Wind Speed: 1.0064433813095093
Step 4: (approx.) Wind Speed: 3.9934256076812744
Step 5: (approx.) Wind Speed: 0.7225298285484314
Step 6: (approx.) Wind Speed: 1.849095106124878
Step 7: (approx.) Wind Speed: 4.607650279998779
Step 8: (approx.) Wind Speed: 4.955946922302246
Step 9: (approx.) Wind Speed: 1.0756417512893677
Step 10: (approx.) Wind Speed: 1.050492525100708
Step 11: (approx.) Wind Speed: 7.091588973999023
Step 12: (approx.) Wind Speed: 5.749095439910889
Step 13: (approx.) Wind Speed: 4.868059158325195
Step 14: (approx.) Wind Speed: 6.107431411743164
Step 15: (approx.) Wind Speed: 6.791924476623535
Step 16: (approx.) Wind Speed: 6.914328575134277
Step 17: (approx.) Wind Speed: 7.753853797912598
Step 18: (approx.) Wind Speed: 8.287782669067383
Step 19: (approx.) Wind Speed: 7.800251483917236
Step 20: (approx.) Wind S

In [42]:
def final_attack(
        inputs,
        targets,
        forcings,
        epsilon,
        maxiter=10
    ):
    beta = 0.9
    def _cos_anneal(eta_0, eta_min, t):
        return eta_min + 0.5 * (eta_0 - eta_min) * (1 + np.cos(t * np.pi / maxiter))
    _learning_rate = lambda t: _cos_anneal(2*epsilon, epsilon/maxiter, t)

    # zero init
    perturbation = xarray_tree.map_structure(lambda a: 0*a, inputs)
    first_moment = xarray_tree.map_structure(lambda a: 0*a, inputs)

    for t in range(maxiter):
        perturbed_inputs = add_perturbation(inputs, perturbation)
        loss, grads = improved_grads_fn(
            rng=jax.random.PRNGKey(t),
            inputs=perturbed_inputs,
            targets=targets,
            forcings=forcings,
        )

        for var in VARS_TO_ATTACK:
            x = xarray_jax.unwrap_data(perturbation[var])
            diff = xarray_jax.unwrap_data(grads[var])
            diff = scale_std(diff) # scale to std = 1
            diff = beta * xarray_jax.unwrap_data(first_moment[var]) + (1 - beta) * diff
            first_moment[var].data = xarray_jax.wrap(diff)
            learning_rate = _learning_rate(t) / (1 - beta**(t+1))
            new_x = x - learning_rate * diff
            new_x = projection(new_x, epsilon)
            perturbation[var].data = xarray_jax.wrap(new_x)
        print(f"Step {t}: (approx.) Wind Speed: {-loss}")

    return perturbation

In [43]:
perturbation = final_attack(
    inputs,
    targets,
    forcings,
    0.05,
    maxiter=50,
)

loss = general_loss_fn(
    jax.random.PRNGKey(1234567890),
    add_perturbation(inputs, perturbation),
    targets,
    forcings,
    forward_fn_jitted,
)
print("Final result:", -float(loss))

Step 0: (approx.) Wind Speed: 1.0449819564819336
Step 1: (approx.) Wind Speed: 4.042561054229736
Step 2: (approx.) Wind Speed: 8.909462928771973
Step 3: (approx.) Wind Speed: 6.065245151519775
Step 4: (approx.) Wind Speed: 11.684289932250977
Step 5: (approx.) Wind Speed: 11.777127265930176
Step 6: (approx.) Wind Speed: 13.052958488464355
Step 7: (approx.) Wind Speed: 19.54697608947754
Step 8: (approx.) Wind Speed: 28.92591094970703
Step 9: (approx.) Wind Speed: 18.248559951782227
Step 10: (approx.) Wind Speed: 15.459378242492676
Step 11: (approx.) Wind Speed: 23.38280487060547
Step 12: (approx.) Wind Speed: 27.882600784301758
Step 13: (approx.) Wind Speed: 21.43470573425293
Step 14: (approx.) Wind Speed: 23.156517028808594
Step 15: (approx.) Wind Speed: 29.779376983642578
Step 16: (approx.) Wind Speed: 27.42388916015625
Step 17: (approx.) Wind Speed: 30.237730026245117
Step 18: (approx.) Wind Speed: 39.65716552734375
Step 19: (approx.) Wind Speed: 27.62385368347168
Step 20: (approx.) W

In [44]:
def baseline_attack_with_larger_steps(
        inputs,
        targets,
        forcings,
        epsilon,
        maxiter=10
    ):
    beta = 0.9
    def _cos_anneal(eta_0, eta_min, t):
        return eta_min + 0.5 * (eta_0 - eta_min) * (1 + np.cos(t * np.pi / maxiter))
    _learning_rate = lambda t: _cos_anneal(2*epsilon, epsilon/maxiter, t)

    # zero init
    perturbation = xarray_tree.map_structure(lambda a: 0*a, inputs)
    first_moment = xarray_tree.map_structure(lambda a: 0*a, inputs)

    for t in range(maxiter):
        perturbed_inputs = add_perturbation(inputs, perturbation)
        loss, grads = baseline_grads_fn(
            rng=jax.random.PRNGKey(t),
            inputs=perturbed_inputs,
            targets=targets,
            forcings=forcings,
        )

        for var in VARS_TO_ATTACK:
            x = xarray_jax.unwrap_data(perturbation[var])
            diff = xarray_jax.unwrap_data(grads[var])
            diff = scale_std(diff) # scale to std = 1
            diff = beta * xarray_jax.unwrap_data(first_moment[var]) + (1 - beta) * diff
            first_moment[var].data = xarray_jax.wrap(diff)
            learning_rate = _learning_rate(t) / (1 - beta**(t+1))
            new_x = x - learning_rate * diff
            new_x = projection(new_x, epsilon)
            perturbation[var].data = xarray_jax.wrap(new_x)
        print(f"Step {t}: (approx.) Wind Speed: {-loss}")

    return perturbation

In [45]:
perturbation = baseline_attack_with_larger_steps(
    inputs,
    targets,
    forcings,
    0.05,
    maxiter=50,
)

loss = general_loss_fn(
    jax.random.PRNGKey(1234567890),
    add_perturbation(inputs, perturbation),
    targets,
    forcings,
    forward_fn_jitted,
)
print("Final result:", -float(loss))

Step 0: (approx.) Wind Speed: 1.4433022737503052
Step 1: (approx.) Wind Speed: 2.93444561958313
Step 2: (approx.) Wind Speed: 12.437716484069824
Step 3: (approx.) Wind Speed: 1.5361865758895874
Step 4: (approx.) Wind Speed: 14.537372589111328
Step 5: (approx.) Wind Speed: 1.7392650842666626
Step 6: (approx.) Wind Speed: 1.6761924028396606
Step 7: (approx.) Wind Speed: 16.622295379638672
Step 8: (approx.) Wind Speed: 21.19961166381836
Step 9: (approx.) Wind Speed: 1.9306564331054688
Step 10: (approx.) Wind Speed: 1.614388108253479
Step 11: (approx.) Wind Speed: 11.763879776000977
Step 12: (approx.) Wind Speed: 4.666250705718994
Step 13: (approx.) Wind Speed: 2.9787096977233887
Step 14: (approx.) Wind Speed: 18.135950088500977
Step 15: (approx.) Wind Speed: 25.312488555908203
Step 16: (approx.) Wind Speed: 27.207372665405273
Step 17: (approx.) Wind Speed: 9.047948837280273
Step 18: (approx.) Wind Speed: 31.62239646911621
Step 19: (approx.) Wind Speed: 17.52513885498047
Step 20: (approx.)

In [47]:
def final_further_attack(
        inputs,
        targets,
        forcings,
        epsilon,
        maxiter=10
    ):
    beta = 0.9
    def _cos_anneal(eta_0, eta_min, t):
        return eta_min + 0.5 * (eta_0 - eta_min) * (1 + np.cos(t * np.pi / maxiter))
    _learning_rate = lambda t: _cos_anneal(2*epsilon, epsilon/maxiter, t)

    # zero init
    perturbation = xarray_tree.map_structure(lambda a: 0*a, inputs)
    first_moment = xarray_tree.map_structure(lambda a: 0*a, inputs)

    for t in range(maxiter):
        perturbed_inputs = add_perturbation(inputs, perturbation)
        loss, grads = further_improved_grads_fn(
            rng=jax.random.PRNGKey(t),
            inputs=perturbed_inputs,
            targets=targets,
            forcings=forcings,
        )

        for var in VARS_TO_ATTACK:
            x = xarray_jax.unwrap_data(perturbation[var])
            diff = xarray_jax.unwrap_data(grads[var])
            diff = scale_std(diff) # scale to std = 1
            diff = beta * xarray_jax.unwrap_data(first_moment[var]) + (1 - beta) * diff
            first_moment[var].data = xarray_jax.wrap(diff)
            learning_rate = _learning_rate(t) / (1 - beta**(t+1))
            new_x = x - learning_rate * diff
            new_x = projection(new_x, epsilon)
            perturbation[var].data = xarray_jax.wrap(new_x)
        print(f"Step {t}: (approx.) Wind Speed: {-loss}")

    return perturbation

In [48]:
perturbation = final_further_attack(
    inputs,
    targets,
    forcings,
    0.05,
    maxiter=50,
)

loss = general_loss_fn(
    jax.random.PRNGKey(1234567890),
    add_perturbation(inputs, perturbation),
    targets,
    forcings,
    forward_fn_jitted,
)
print("Final result:", -float(loss))

  self._set_arrayXarray(i, j, x)
  self._set_arrayXarray(i, j, x)


Step 0: (approx.) Wind Speed: 3.1596426963806152
Step 1: (approx.) Wind Speed: 6.499915599822998
Step 2: (approx.) Wind Speed: 8.01211929321289
Step 3: (approx.) Wind Speed: 8.993512153625488
Step 4: (approx.) Wind Speed: 13.690526962280273
Step 5: (approx.) Wind Speed: 10.712691307067871
Step 6: (approx.) Wind Speed: 13.40321159362793
Step 7: (approx.) Wind Speed: 18.278032302856445
Step 8: (approx.) Wind Speed: 20.84583854675293
Step 9: (approx.) Wind Speed: 14.807703018188477
Step 10: (approx.) Wind Speed: 15.317180633544922
Step 11: (approx.) Wind Speed: 24.900348663330078
Step 12: (approx.) Wind Speed: 26.896087646484375
Step 13: (approx.) Wind Speed: 22.970056533813477
Step 14: (approx.) Wind Speed: 24.6019287109375
Step 15: (approx.) Wind Speed: 24.435827255249023
Step 16: (approx.) Wind Speed: 27.702247619628906
Step 17: (approx.) Wind Speed: 33.88011932373047
Step 18: (approx.) Wind Speed: 29.742326736450195
Step 19: (approx.) Wind Speed: 30.660430908203125
Step 20: (approx.) 