In [1]:
%load_ext autoreload
%autoreload 2

import os, shutil
import xarray as xr
import torch
import dask
import pandas as pd
from torch.utils.data import DataLoader
import numpy as np
import logging

from aurora import Aurora, AuroraSmall, rollout

from aurora_benchmark.utils import verbose_print, xr_to_netcdf

from aurora_benchmark.data import (
    XRAuroraDataset, 
    XRAuroraBatchedDataset,
    aurora_batch_collate_fn, 
    aurora_batch_to_xr, 
    unpack_aurora_batch
)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Suppress logs from Google libraries
logging.getLogger('google').setLevel(logging.ERROR)
logging.getLogger('google.auth').setLevel(logging.ERROR)
logging.getLogger('google.cloud').setLevel(logging.ERROR)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
era5_surface_paths = [
    "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.nc",
    "../toy_data/era5-1d-360x180/t2m-2021-2022-1d-360x180.nc",
    "../toy_data/era5-1d-360x180/u10-2021-2022-1d-360x180.nc",
    "../toy_data/era5-1d-360x180/v10-2021-2022-1d-360x180.nc",
]
era5_atmospheric_paths = [ 
    "../toy_data/era5-1d-360x180/t-2021-2022-1d-360x180.nc",
    "../toy_data/era5-1d-360x180/q-2021-2022-1d-360x180.nc",
    "../toy_data/era5-1d-360x180/u-2021-2022-1d-360x180.nc",
    "../toy_data/era5-1d-360x180/v-2021-2022-1d-360x180.nc",
    "../toy_data/era5-1d-360x180/z-2021-2022-1d-360x180.nc",
]
era5_static_paths = [
    "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/lsm_static-1440x721.nc",
    "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/z_static-1440x721.nc",
    "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/slt_static-1440x721.nc",
]

# Load the data into a single dataset with the same coords but multiple variables
surface_dss = [
    xr.open_dataset(path, engine="netcdf4").drop_vars("time_bnds")
    #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).drop_vars("time_bnds")
    #xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360}).rename({"msl": svar})#.drop_vars("time_bnds")
    for path in era5_surface_paths
]
surface_ds = xr.merge(surface_dss).rename({"t2m": "2t", "u10": "10u", "v10": "10v", "lat": "latitude", "lon": "longitude"})
atmospheric_dss = [
    xr.open_dataset(path, engine="netcdf4").drop_vars("time_bnds")
    #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]}).drop_vars("time_bnds")
    #xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360, "level": 1}).rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]})#.drop_vars("time_bnds")
    for path in era5_atmospheric_paths
]
atmospheric_ds = xr.merge(atmospheric_dss).rename({"lat": "latitude", "lon": "longitude"})
static_dss = [
     xr.open_dataset(path, engine="netcdf4").coarsen(longitude=1440//360, latitude=721//180, boundary="trim").mean()
    # xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).isel(time=0).drop_vars("time_bnds")
    #xr.open_zarr(path, chunks={"latitude": 180, "longitude": 360}).rename({"msl": svar}).isel(time=0)#.drop_vars("time_bnds")
    for path in era5_static_paths
]
static_ds = xr.merge(static_dss)

surface_ds.dims, atmospheric_ds.dims, static_ds.dims



In [3]:
batch_size = 14
num_workers = 2
eval_start ="1w"
era5_base_frequency = "1d"
forecast_horizon = "10w"
use_dataloader = False
eval_aggregation = "1w"
init_frequency = "1w"
verbose = True
drop_timestamps = False
persist = False
rechunk = False
output_dir = "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2_forecasts/2021-2022-1d-1w-10w-360x180_original_variables/"

surf_vars = ["2t", "msl", "10u", "10v"]
atmospheric_vars = ["t", "q", "z", "u", "v"]
static_vars = ["z", "lsm", "slt"]

interest_variables = atmospheric_vars + surf_vars
interest_levels = [1000, 700, 250]

os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")

warmup_steps = int(pd.Timedelta(eval_start) / pd.Timedelta(era5_base_frequency)) if eval_start is not None else 0
forecast_steps = int(pd.Timedelta(forecast_horizon) / pd.Timedelta(era5_base_frequency))

assert (forecast_steps-warmup_steps) * pd.Timedelta(era5_base_frequency) >= pd.Timedelta(eval_aggregation), "Evaluation steps must be at least as long as eval_aggregation" 

if use_dataloader:
    dask.config.set(scheduler='synchronous')
else:
    dask.config.set(scheduler='threads')
verbose_print(verbose, f"Using dask scheduler: {dask.config.get('scheduler')}")


if use_dataloader:
    verbose_print(verbose, f"Creating XRAuroraDataset and DataLoader...")
    dataset = XRAuroraDataset(
        surface_ds=surface_ds,
        atmospheric_ds=atmospheric_ds,
        static_ds=static_ds,
        init_frequency=init_frequency,
        forecast_horizon=forecast_horizon,
        num_time_samples=2, # Aurora has fixed history length of 2...
        drop_timestamps=drop_timestamps,
        persist=persist,
        rechunk=rechunk,
        atmospheric_variables=atmospheric_vars,
        surface_variables=surf_vars,
        static_variables=static_vars,
    )
    verbose_print(verbose, f"Loaded dataset of length {len(dataset)} (drop_timestamps={drop_timestamps}, persist={persist}, rechunk={rechunk})")
    
    num_workers = 2 #int(os.getenv('SLURM_CPUS_PER_TASK', 1))+2 if os.getenv('SLURM_CPUS_PER_TASK') is not None else os.cpu_count()+2
    verbose_print(verbose, f"Creating DataLoader with {num_workers} workers ...")
    eval_loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        collate_fn=aurora_batch_collate_fn,
        num_workers=num_workers,
    )
    batch_iterator = eval_loader
else:
    # This is done to avoid the issue with torch DataLoader and dask
    # when using netcdf files (i.e. netcdf backend is not thread safe)
    verbose_print(verbose, f"Creating XRAuroraBatchedDataset ...")
    dataset = XRAuroraBatchedDataset(
        batch_size=batch_size,
        surface_ds=surface_ds,
        atmospheric_ds=atmospheric_ds,
        static_ds=static_ds,
        init_frequency=init_frequency,
        forecast_horizon=forecast_horizon,
        num_time_samples=2, # Aurora has fixed history length of 2...
        drop_timestamps=drop_timestamps,
        persist=persist,
        rechunk=rechunk,
        atmospheric_variables=atmospheric_vars,
        surface_variables=surf_vars,
        static_variables=static_vars,
    )
    batch_iterator = dataset

verbose_print(verbose, f"interest_vars: {interest_variables}, interest_levels: {interest_levels}")
verbose_print(verbose, f"Dataset length: {dataset.flat_length() if hasattr(dataset, 'flat_length') else len(dataset)}")
verbose_print(verbose, f"Dataloader length: {len(batch_iterator)} (type: {type(batch_iterator)}, batch_size: {batch_size})")

2024-10-16 16:27:10,940 - aurora_benchmark.utils - INFO - Using dask scheduler: threads
2024-10-16 16:27:10,940 - aurora_benchmark.utils - INFO - Creating XRAuroraBatchedDataset ...
2024-10-16 16:27:10,948 - aurora_benchmark.utils - INFO - interest_vars: ['t', 'q', 'z', 'u', '2t', 'msl', '10u', '10v'], interest_levels: [1000, 700, 250]
2024-10-16 16:27:10,948 - aurora_benchmark.utils - INFO - Dataset length: 3
2024-10-16 16:27:10,948 - aurora_benchmark.utils - INFO - Dataloader length: 3 (type: <class 'aurora_benchmark.data.XRAuroraBatchedDataset'>, batch_size: 14)


Output directory: /projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2_forecasts/2021-2022-1d-1w-10w-360x180_original_variables/


In [4]:
verbose = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("loading model ...")
aurora_model = "aurora-0.25-pretrained.ckpt"
model = Aurora(use_lora=False)
model.load_checkpoint("microsoft/aurora", aurora_model)
model = model.to(device)

verbose_print(verbose, f"Evaluating on {device}")
# evaluation loop
with torch.inference_mode() and torch.no_grad():
    for i, batch in enumerate(batch_iterator):
        verbose_print(verbose,f"Rollout prediction on batch {i} ...")
        if batch is None: break
        
        batch = batch.to(device)
        # rollout until for forecast_steps
        trajectories = [[] for _ in range(batch_size)]
        for s, batch_pred in enumerate(rollout(model, batch, steps=forecast_steps)):
            if s < warmup_steps:
                verbose_print(verbose,f" * Rollout step {s+1}: skipping warmup period")
                continue            
            # separate batched batches
            sub_batch_preds = unpack_aurora_batch(batch_pred.to("cpu"))
            verbose_print(verbose,f" * Rollout step {s+1}: unpacked {len(sub_batch_preds)} sub-batches")
            if i != len(batch_iterator) - 1: # the last batch may not be full
                assert len(sub_batch_preds) == batch_size
            # accumulate
            for b, sub_batch_pred in enumerate(sub_batch_preds):
                trajectories[b].append(sub_batch_pred)
        
        # convert to xr and process
        verbose_print(verbose,f"Processing trajectories ...")
        for init_time, trajectory in zip(batch.metadata.time, trajectories):
            verbose_print(verbose,f" * init_time={init_time}: combining {len(trajectory)} steps")
            assert len(trajectory) == forecast_steps-warmup_steps
            # collate trajectory batches
            trajectory = aurora_batch_collate_fn(trajectory)
            # convert to xr.Dataset
            trajectory = aurora_batch_to_xr(trajectory, frequency=era5_base_frequency)
            
            # process individual trajectory elements (i.e. variable types)
            for var_type, vars_ds in trajectory.items():
                # ensure processing is necessary
                if var_type == "static_ds":
                    verbose_print(verbose,f" * Skipping static variables")
                    continue # we do not care about static variables for the forecast
                if not any([var in vars_ds.data_vars for var in interest_variables]):
                    verbose_print(verbose,f" * Skipping {var_type} variables as no interest variables are present")
                    continue # don't bother processing variables we are not interested in
                if var_type == "atmospheric_ds" and (interest_levels is None or len(interest_levels)==0):
                    verbose_print(verbose,f" * Skipping atmospheric variables as no interest levels have been requested")
                    continue # we do not care about atmospheric variables if no levels are of interest
                
                # select interest variables and levels
                vars_interest_variables = [var for var in vars_ds.data_vars if var in interest_variables]
                if var_type == "atmospheric_ds":
                    vars_ds = vars_ds[vars_interest_variables].sel(level=interest_levels)
                else:
                    vars_ds = vars_ds[vars_interest_variables]
                    
                # override time coordinates using the era5_base_frequency
                vars_ds = vars_ds.assign_coords(
                    {"time": pd.date_range(init_time+warmup_steps*pd.Timedelta(era5_base_frequency), 
                                           periods=vars_ds.sizes["time"], 
                                           freq=era5_base_frequency)})
                
                # aggregate at eval_agg frequency
                # use pd.Timedelta to avoid xarray automatically starting the resampling 
                # on Mondays for weekly etc.
                # Note that resulting'time' will be the first timestamp in the aggregated period
                vars_ds = vars_ds.resample(time=pd.Timedelta(eval_aggregation), origin=init_time).mean()
                vars_ds = vars_ds.rename({"time": "lead_time"})
                vars_ds["lead_time"] = vars_ds["lead_time"] - np.datetime64(init_time)
                
                # per-variable processing
                for var in vars_ds.data_vars:
                    # add lead time
                    var_ds = vars_ds[var]
                    
                    # save
                    path = f"forecast_{var}_" + "-".join([
                        init_time.strftime("%Y%m%dT%H%M%S"),
                        str(era5_base_frequency),
                        str(eval_aggregation),
                        str(eval_start),
                        str(forecast_horizon),
                        str(var_ds.sizes["longitude"])+ "x" +str(var_ds.sizes["latitude"]),
                    ]) + ".nc"
                    path = os.path.join(output_dir, path)
                    verbose_print(verbose, f"   * Saving new {var_type} forecast: {path}")
                    xr_to_netcdf(
                        var_ds, path, 
                        precision="float32", 
                        compression_level=1, 
                        sort_time=False, 
                        exist_ok=True
                    )

loading model ...


2024-10-16 16:27:28,855 - aurora_benchmark.utils - INFO - Evaluating on cuda
2024-10-16 16:27:29,540 - aurora_benchmark.utils - INFO - Rollout prediction on batch 0 ...
2024-10-16 16:27:32,614 - aurora_benchmark.utils - INFO -  * Rollout step 1: skipping warmup period
2024-10-16 16:27:35,381 - aurora_benchmark.utils - INFO -  * Rollout step 2: skipping warmup period
2024-10-16 16:27:38,146 - aurora_benchmark.utils - INFO -  * Rollout step 3: skipping warmup period
2024-10-16 16:27:40,909 - aurora_benchmark.utils - INFO -  * Rollout step 4: skipping warmup period
2024-10-16 16:27:43,675 - aurora_benchmark.utils - INFO -  * Rollout step 5: skipping warmup period
2024-10-16 16:27:46,455 - aurora_benchmark.utils - INFO -  * Rollout step 6: skipping warmup period
2024-10-16 16:27:49,220 - aurora_benchmark.utils - INFO -  * Rollout step 7: skipping warmup period
2024-10-16 16:27:52,012 - aurora_benchmark.utils - INFO -  * Rollout step 8: unpacked 14 sub-batches
2024-10-16 16:27:54,803 - auro

In [6]:
for vars in [batch.surf_vars, batch.atmos_vars, batch.static_vars]:
    for var, data in vars.items():
        print(var, data.shape)

2t torch.Size([14, 2, 180, 360])
msl torch.Size([14, 2, 180, 360])
10u torch.Size([14, 2, 180, 360])
10v torch.Size([14, 2, 180, 360])
t torch.Size([14, 2, 13, 180, 360])
q torch.Size([14, 2, 13, 180, 360])
z torch.Size([14, 2, 13, 180, 360])
u torch.Size([14, 2, 13, 180, 360])
z torch.Size([180, 360])
lsm torch.Size([180, 360])
slt torch.Size([180, 360])


In [8]:
batch.metadata.lat

tensor([ 89.5000,  88.5000,  87.5000,  86.5000,  85.5000,  84.5000,  83.5000,
         82.5000,  81.5000,  80.5000,  79.5000,  78.5000,  77.5000,  76.5000,
         75.5000,  74.5000,  73.5000,  72.5000,  71.5000,  70.5000,  69.5000,
         68.5000,  67.5000,  66.5000,  65.5000,  64.5000,  63.5000,  62.5000,
         61.5000,  60.5000,  59.5000,  58.5000,  57.5000,  56.5000,  55.5000,
         54.5000,  53.5000,  52.5000,  51.5000,  50.5000,  49.5000,  48.5000,
         47.5000,  46.5000,  45.5000,  44.5000,  43.5000,  42.5000,  41.5000,
         40.5000,  39.5000,  38.5000,  37.5000,  36.5000,  35.5000,  34.5000,
         33.5000,  32.5000,  31.5000,  30.5000,  29.5000,  28.5000,  27.5000,
         26.5000,  25.5000,  24.5000,  23.5000,  22.5000,  21.5000,  20.5000,
         19.5000,  18.5000,  17.5000,  16.5000,  15.5000,  14.5000,  13.5000,
         12.5000,  11.5000,  10.5000,   9.5000,   8.5000,   7.5000,   6.5000,
          5.5000,   4.5000,   3.5000,   2.5000,   1.5000,   0.50