In [1]:
%load_ext autoreload
%autoreload 2

import os, shutil
import xarray as xr
import torch
import dask
import pandas as pd
from torch.utils.data import DataLoader
import numpy as np
import logging

from aurora import Aurora, AuroraSmall, rollout

from aurora_benchmark.utils import verbose_print, xr_to_netcdf

from aurora_benchmark.data import (
    XRAuroraDataset, 
    XRAuroraBatchedDataset,
    aurora_batch_collate_fn, 
    aurora_batch_to_xr, 
    unpack_aurora_batch
)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Suppress logs from Google libraries
logging.getLogger('google').setLevel(logging.ERROR)
logging.getLogger('google.auth').setLevel(logging.ERROR)
logging.getLogger('google.cloud').setLevel(logging.ERROR)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
p = "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.nc"
ph5 = "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.h5"
pzarr = "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.zarr"
if os.path.exists(pzarr):
    shutil.rmtree(pzarr)
d = xr.open_dataset(p, engine="netcdf4").drop_vars("time_bnds")
d.to_zarr(pzarr)

<xarray.backends.zarr.ZarrStore at 0x148e10428fc0>

In [3]:
era5_surface_paths = [
    p, p, p
    #pzarr, pzarr, pzarr
]
era5_atmospheric_paths = [ # we will repeat on level dimension
    p, p, p
    #pzarr, pzarr, pzarr
]
era5_static_paths = [ # we will select the first time step
    p, p, p
    #pzarr, pzarr, pzarr
]

surf_vars = ["2t", "msl", "10u"]
atmospheric_vars = ["t", "q", "u"]
static_vars = ["z", "lsm", "slt"]

# Load the data into a single dataset with the same coords but multiple variables
surface_dss = [
    xr.open_dataset(path, engine="netcdf4").rename({"msl": svar}).drop_vars("time_bnds")
    #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).drop_vars("time_bnds")
    #xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360}).rename({"msl": svar})#.drop_vars("time_bnds")
    for path, svar in zip(era5_surface_paths, surf_vars)
]
surface_ds = xr.merge(surface_dss)
atmospheric_dss = [
    xr.open_dataset(path, engine="netcdf4").rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]}).drop_vars("time_bnds")
    #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]}).drop_vars("time_bnds")
    #xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360, "level": 1}).rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]})#.drop_vars("time_bnds")
    for path, svar in zip(era5_atmospheric_paths, atmospheric_vars)
]
atmospheric_ds = xr.merge(atmospheric_dss)
static_dss = [
     xr.open_dataset(path, engine="netcdf4").rename({"msl": svar}).isel(time=0).drop_vars("time_bnds")
    # xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).isel(time=0).drop_vars("time_bnds")
    #xr.open_zarr(path, chunks={"latitude": 180, "longitude": 360}).rename({"msl": svar}).isel(time=0)#.drop_vars("time_bnds")
    for path, svar in zip(era5_static_paths, static_vars)
]
static_ds = xr.merge(static_dss)

# rename coord lat to latitude and lon to longitude
surface_ds = surface_ds.rename({"lat": "latitude", "lon": "longitude"})
atmospheric_ds = atmospheric_ds.rename({"lat": "latitude", "lon": "longitude"})
static_ds = static_ds.rename({"lat": "latitude", "lon": "longitude"})

surface_ds.dims, atmospheric_ds.dims, static_ds.dims



In [4]:
batch_size = 16
num_workers = 2
eval_start ="1w"
era5_base_frequency = "1d"
forecast_horizon = "10w"
use_dataloader = False
eval_aggregation = "1w"
init_frequency = "1w"
verbose = True
drop_timestamps = False
persist = False
rechunk = False
output_dir = "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2_forecasts/2021-2022-1d-1w-10w-360x180_original_variables/"

os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")

warmup_steps = int(pd.Timedelta(eval_start) / pd.Timedelta(era5_base_frequency)) if eval_start is not None else 0
forecast_steps = int(pd.Timedelta(forecast_horizon) / pd.Timedelta(era5_base_frequency))

assert (forecast_steps-warmup_steps) * pd.Timedelta(era5_base_frequency) >= pd.Timedelta(eval_aggregation), "Evaluation steps must be at least as long as eval_aggregation" 

if use_dataloader:
    dask.config.set(scheduler='synchronous')
else:
    dask.config.set(scheduler='threads')
verbose_print(verbose, f"Using dask scheduler: {dask.config.get('scheduler')}")


if use_dataloader:
    verbose_print(verbose, f"Creating XRAuroraDataset and DataLoader...")
    dataset = XRAuroraDataset(
        surface_ds=surface_ds,
        atmospheric_ds=atmospheric_ds,
        static_ds=static_ds,
        init_frequency=init_frequency,
        forecast_horizon=forecast_horizon,
        num_time_samples=2, # Aurora has fixed history length of 2...
        drop_timestamps=drop_timestamps,
        persist=persist,
        rechunk=rechunk,
        atmospheric_variables=atmospheric_vars,
        surface_variables=surf_vars,
        static_variables=static_vars,
    )
    verbose_print(verbose, f"Loaded dataset of length {len(dataset)} (drop_timestamps={drop_timestamps}, persist={persist}, rechunk={rechunk})")
    
    num_workers = 2 #int(os.getenv('SLURM_CPUS_PER_TASK', 1))+2 if os.getenv('SLURM_CPUS_PER_TASK') is not None else os.cpu_count()+2
    verbose_print(verbose, f"Creating DataLoader with {num_workers} workers ...")
    eval_loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        collate_fn=aurora_batch_collate_fn,
        num_workers=num_workers,
    )
    batch_iterator = eval_loader
else:
    # This is done to avoid the issue with torch DataLoader and dask
    # when using netcdf files (i.e. netcdf backend is not thread safe)
    verbose_print(verbose, f"Creating XRAuroraBatchedDataset ...")
    dataset = XRAuroraBatchedDataset(
        batch_size=batch_size,
        surface_ds=surface_ds,
        atmospheric_ds=atmospheric_ds,
        static_ds=static_ds,
        init_frequency=init_frequency,
        forecast_horizon=forecast_horizon,
        num_time_samples=2, # Aurora has fixed history length of 2...
        drop_timestamps=drop_timestamps,
        persist=persist,
        rechunk=rechunk,
        atmospheric_variables=atmospheric_vars,
        surface_variables=surf_vars,
        static_variables=static_vars,
    )
    batch_iterator = dataset

verbose_print(verbose, f"Dataset length: {dataset.flat_length() if hasattr(dataset, 'flat_length') else len(dataset)}")
verbose_print(verbose, f"Dataloader length: {len(batch_iterator)} (type: {type(batch_iterator)}, batch_size: {batch_size})")

2024-10-16 09:24:52,828 - aurora_benchmark.utils - INFO - Using dask scheduler: threads
2024-10-16 09:24:52,828 - aurora_benchmark.utils - INFO - Creating XRAuroraBatchedDataset ...
2024-10-16 09:24:52,831 - aurora_benchmark.utils - INFO - Dataset length: 42
2024-10-16 09:24:52,832 - aurora_benchmark.utils - INFO - Dataloader length: 3 (type: <class 'aurora_benchmark.data.XRAuroraBatchedDataset'>, batch_size: 20)


Output directory: /projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2_forecasts/2021-2022-1d-1w-10w-360x180_original_variables/


In [5]:
AURORA_VARIABLE_RENAMES = {
    "surface": {
        "u10": "10u",
        "v10": "10v",
        "t2m": "2t",
    },
    "atmospheric": {},
    "static": {},
}
INVERTED_AURORA_VARIABLE_RENAMES = {
    "surface": {v: k for k, v in AURORA_VARIABLE_RENAMES["surface"].items()},
    "atmospheric": {v: k for k, v in AURORA_VARIABLE_RENAMES["atmospheric"].items()},
    "static": {v: k for k, v in AURORA_VARIABLE_RENAMES["static"].items()},
}

interest_variables = atmospheric_vars + surf_vars
interest_levels = [1000, 700, 250]


verbose = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("loading model ...")
aurora_model = "aurora-0.25-pretrained.ckpt"
model = Aurora(use_lora=False)
model.load_checkpoint("microsoft/aurora", aurora_model)
model = model.to(device)

verbose_print(verbose, f"Evaluating on {device}")
# evaluation loop
with torch.inference_mode() and torch.no_grad():
    for i, batch in enumerate(batch_iterator):
        batch = batch.to(device)
        # rollout until for forecast_steps
        verbose_print(verbose,f"Rollout prediction on batch {i} ...")
        trajectories = [[] for _ in range(batch_size)]
        for s, batch_pred in enumerate(rollout(model, batch, steps=forecast_steps)):
            if s < warmup_steps:
                verbose_print(verbose,f" * Rollout step {s+1}: skipping warmup period")
                continue            
            # separate batched batches
            sub_batch_preds = unpack_aurora_batch(batch_pred.to("cpu"))
            verbose_print(verbose,f" * Rollout step {s+1}: unpacked {len(sub_batch_preds)} sub-batches")
            if i != len(batch_iterator) - 1: # the last batch may not be full
                assert len(sub_batch_preds) == batch_size
            # accumulate
            for b, sub_batch_pred in enumerate(sub_batch_preds):
                trajectories[b].append(sub_batch_pred)
        
        # convert to xr and process
        verbose_print(verbose,f"Processing trajectories ...")
        for init_time, trajectory in zip(batch.metadata.time, trajectories):
            verbose_print(verbose,f" * init_time={init_time}: combining {len(trajectory)} steps")
            assert len(trajectory) == forecast_steps-warmup_steps
            # collate trajectory batches
            trajectory = aurora_batch_collate_fn(trajectory)
            # convert to xr.Dataset
            trajectory = aurora_batch_to_xr(trajectory, frequency=era5_base_frequency)
            
            # process individual trajectory elements (i.e. variable types)
            for var_type, vars_ds in trajectory.items():
                # ensure processing is necessary
                if var_type == "static_ds":
                    verbose_print(verbose,f" * Skipping static variables")
                    continue # we do not care about static variables for the forecast
                if not any([var in vars_ds.data_vars for var in interest_variables]):
                    verbose_print(verbose,f" * Skipping {var_type} variables as no interest variables are present")
                    continue # don't bother processing variables we are not interested in
                if var_type == "atmospheric_ds" and (interest_levels is None or len(interest_levels)==0):
                    verbose_print(verbose,f" * Skipping atmospheric variables as no interest levels have been requested")
                    continue # we do not care about atmospheric variables if no levels are of interest
                
                # select interest variables and levels
                vars_interest_variables = [var for var in vars_ds.data_vars if var in interest_variables]
                if var_type == "atmospheric_ds":
                    vars_ds = vars_ds[vars_interest_variables].sel(level=interest_levels)
                else:
                    vars_ds = vars_ds[vars_interest_variables]
                    
                # override time coordinates using the era5_base_frequency
                vars_ds = vars_ds.assign_coords(
                    {"time": pd.date_range(init_time+warmup_steps*pd.Timedelta(era5_base_frequency), 
                                           periods=vars_ds.sizes["time"], 
                                           freq=era5_base_frequency)})
                
                # aggregate at eval_agg frequency
                # use pd.Timedelta to avoid xarray automatically starting the resampling 
                # on Mondays for weekly etc.
                # Note that resulting'time' will be the first timestamp in the aggregated period
                vars_ds = vars_ds.resample(time=pd.Timedelta(eval_aggregation), origin=init_time).mean()
                vars_ds = vars_ds.rename({"time": "lead_time"})
                vars_ds["lead_time"] = vars_ds["lead_time"] - np.datetime64(init_time)
                
                # per-variable processing
                for var in vars_ds.data_vars:
                    # add lead time
                    var_ds = vars_ds[var]
                    
                    # save
                    path = f"forecast_{var}_" + "-".join([
                        init_time.strftime("%Y%m%dT%H%M%S"),
                        str(era5_base_frequency),
                        str(eval_aggregation),
                        str(eval_start),
                        str(forecast_horizon),
                        str(var_ds.sizes["longitude"])+ "x" +str(var_ds.sizes["latitude"]),
                    ]) + ".nc"
                    path = os.path.join(output_dir, path)
                    verbose_print(verbose, f"   * Saving new {var_type} forecast: {path}")
                    xr_to_netcdf(
                        var_ds, path, 
                        precision="float32", 
                        compression_level=1, 
                        sort_time=False, 
                        exist_ok=True
                    )

loading model ...


2024-10-16 09:25:08,324 - aurora_benchmark.utils - INFO - Evaluating on cuda
2024-10-16 09:25:08,506 - aurora_benchmark.utils - INFO - Rollout prediction on batch 0 ...


RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
