In [1]:
%load_ext autoreload
%autoreload 2

import os, shutil
import xarray as xr
import torch
import dask
import pandas as pd
from torch.utils.data import DataLoader
import numpy as np
import logging

from aurora import Aurora, AuroraSmall, rollout

from aurora_benchmark.utils import verbose_print, xr_to_netcdf

dask.config.set(scheduler='threads')

from aurora_benchmark.data import (
    XRAuroraDataset, 
    aurora_batch_collate_fn, 
    aurora_batch_to_xr, 
    unpack_aurora_batch
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
p = "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.nc"
ph5 = "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.h5"
pzarr = "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.zarr"
if os.path.exists(pzarr):
    shutil.rmtree(pzarr)
d = xr.open_dataset(p, engine="netcdf4").drop_vars("time_bnds")
d.to_zarr(pzarr)

<xarray.backends.zarr.ZarrStore at 0x150f3849d240>

In [3]:
era5_surface_paths = [
    pzarr, pzarr, pzarr
]
era5_atmospheric_paths = [ # we will repeat on level dimension
    pzarr, pzarr, pzarr
]
era5_static_paths = [ # we will select the first time step
    pzarr, pzarr, pzarr
]

surf_vars = ["2t", "msl", "10u"]
atmospheric_vars = ["t", "q", "u"]
static_vars = ["z", "lsm", "slt"]

# Load the data into a single dataset with the same coords but multiple variables
surface_dss = [
    #xr.open_dataset(path, engine="netcdf4").rename({"msl": svar}).drop_vars("time_bnds")
    #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).drop_vars("time_bnds")
    xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360}).rename({"msl": svar})#.drop_vars("time_bnds")
    for path, svar in zip(era5_surface_paths, surf_vars)
]
surface_ds = xr.merge(surface_dss)
atmospheric_dss = [
    #xr.open_dataset(path, engine="netcdf4").rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]}).drop_vars("time_bnds")
    #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]}).drop_vars("time_bnds")
    xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360, "level": 1}).rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]})#.drop_vars("time_bnds")
    for path, svar in zip(era5_atmospheric_paths, atmospheric_vars)
]
atmospheric_ds = xr.merge(atmospheric_dss)
static_dss = [
    #
    # xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).isel(time=0).drop_vars("time_bnds")
    xr.open_zarr(path, chunks={"latitude": 180, "longitude": 360}).rename({"msl": svar}).isel(time=0)#.drop_vars("time_bnds")
    for path, svar in zip(era5_static_paths, static_vars)
]
static_ds = xr.merge(static_dss)

# rename coord lat to latitude and lon to longitude
surface_ds = surface_ds.rename({"lat": "latitude", "lon": "longitude"})
atmospheric_ds = atmospheric_ds.rename({"lat": "latitude", "lon": "longitude"})
static_ds = static_ds.rename({"lat": "latitude", "lon": "longitude"})

surface_ds.dims, atmospheric_ds.dims, static_ds.dims





In [24]:
batch_size = 32
num_workers = 2
eval_start ="3d"
era5_base_frequency = "1d"
forecast_horizon = "1w"

warmup_steps = int(pd.Timedelta(eval_start) / pd.Timedelta(era5_base_frequency)) if eval_start is not None else 0
forecast_steps = int(pd.Timedelta(forecast_horizon) / pd.Timedelta(era5_base_frequency))

dataset = XRAuroraDataset(
    surface_ds=surface_ds,
    atmospheric_ds=atmospheric_ds,
    static_ds=static_ds,
    init_frequency="1d",
    forecast_horizon=forecast_horizon,
    num_time_samples=2, # Aurora has fixed history length of 2...
    drop_timestamps=True,
    persist=False,
    rechunk=False,
    atmospheric_variables=atmospheric_vars,
    surface_variables=surf_vars,
    static_variables=static_vars,
)

eval_loader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    collate_fn=aurora_batch_collate_fn,
    num_workers=num_workers,
    pin_memory=False
)

print("Dataset length:", len(dataset))
print("Dataloader length:", len(eval_loader))

Dataset length: 356
Dataloader length: 12




In [None]:
AURORA_VARIABLE_RENAMES = {
    "surface": {
        "u10": "10u",
        "v10": "10v",
        "t2m": "2t",
    },
    "atmospheric": {},
    "static": {},
}
INVERTED_AURORA_VARIABLE_RENAMES = {
    "surface": {v: k for k, v in AURORA_VARIABLE_RENAMES["surface"].items()},
    "atmospheric": {v: k for k, v in AURORA_VARIABLE_RENAMES["atmospheric"].items()},
    "static": {v: k for k, v in AURORA_VARIABLE_RENAMES["static"].items()},
}

interest_variables = atmospheric_vars + surf_vars
interest_levels = [1000, 700, 250]


verbose = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("loading model ...")
# aurora_model = "aurora-0.25-pretrained.ckpt"
# model = Aurora(use_lora=False)
# model.load_checkpoint("microsoft/aurora", aurora_model)
# model = model.to(device)

print(f"Evaluating on {device}")
# evaluation loop
xr_preds = {"surface_ds": [], "atmospheric_ds": []}
with torch.inference_mode() and torch.no_grad():
    
    ## MANUAL DATALOADING
    
    # print(f"DEBUGGING: Manual dataloading")
    # batch_size = 1
    # for i, batch in enumerate(dataset):
        
    
    for i, batch in enumerate(dataset):
        batch = batch.to(device)
        # rollout until for forecast_steps
        print(f"Rollout prediction on batch {i} ...")
        trajectories = [[] for _ in range(batch_size)]
        for s, batch_pred in enumerate(rollout(model, batch, steps=forecast_steps)):
            if s < warmup_steps:
                print(f" * Rollout step {s+1}: skipping warmup period")
                continue            
            # separate batched batches
            sub_batch_preds = unpack_aurora_batch(batch_pred.to("cpu"))
            print(f" * Rollout step {s+1}: unpacked {len(sub_batch_preds)} sub-batches")
            assert len(sub_batch_preds) == batch_size
            # accumulate
            for b, sub_batch_pred in enumerate(sub_batch_preds):
                trajectories[b].append(sub_batch_pred)
        print(f"Processing trajectories ...")
        # convert to xr 
        for init_time, trajectory in zip(batch.metadata.time, trajectories):
            print(f" * init_time={init_time}: combining {len(trajectory)} steps")
            assert len(trajectory) == forecast_steps-warmup_steps
            # collate trajectory batches
            trajectory = aurora_batch_collate_fn(trajectory)
            # convert to xr.Dataset
            trajectory = aurora_batch_to_xr(trajectory, frequency=era5_base_frequency)
            # add lead time
            for var_type, vars_ds in trajectory.items():
                # ensure processing is necessary
                if var_type == "static_ds":
                    print(f" * Skipping static variables")
                    continue # we do not care about static variables for the forecast
                if not any([var in vars_ds.data_vars for var in interest_variables]):
                    print(f" * Skipping {var_type} variables as no interest variables are present")
                    continue # don't bother processing variables we are not interested in
                if var_type == "atmospheric_ds" and (interest_levels is None or len(interest_levels)==0):
                    print(f" * Skipping atmospheric variables as no interest levels have been requested")
                    continue # we do not care about atmospheric variables if no levels are of interest
                
                # select interest variables and levels
                vars_interest_variables = [var for var in vars_ds.data_vars if var in interest_variables]
                if var_type == "atmospheric_ds":
                    vars_ds = vars_ds[vars_interest_variables].sel(level=interest_levels)
                else:
                    vars_ds = vars_ds[vars_interest_variables]
                
                # add lead time
                vars_ds = vars_ds.assign_coords({"lead_time": vars_ds.time.values - np.datetime64(init_time)})
                vars_ds = vars_ds.set_index({"lead_time": "lead_time"})
                
                # TODO: aggregate desired timesteps to agg freq  
                # vars_ds = vars_ds.resample(time=eval_aggregation).mean()  # not enough as it messes with "lead time"                    
                
                # append to predictions
                xr_preds[var_type].append(vars_ds)
        
# merge predictions and save
for var_type, var_ds_list in xr_preds.items():
    ds = xr.concat(var_ds_list, dim="time").rename(INVERTED_AURORA_VARIABLE_RENAMES[var_type])

loading model ...
Evaluating on cuda
Rollout prediction on batch 0 ...
 * Rollout step 1: skipping warmup period
 * Rollout step 2: skipping warmup period
 * Rollout step 3: skipping warmup period
 * Rollout step 4: unpacked 1 sub-batches
 * Rollout step 5: unpacked 1 sub-batches
 * Rollout step 6: unpacked 1 sub-batches
 * Rollout step 7: unpacked 1 sub-batches
Processing trajectories ...
 * init_time=2021-01-02 09:00:00: combining 4 steps
 * Skipping static variables
Rollout prediction on batch 1 ...
 * Rollout step 1: skipping warmup period
 * Rollout step 2: skipping warmup period
 * Rollout step 3: skipping warmup period
 * Rollout step 4: unpacked 1 sub-batches
 * Rollout step 5: unpacked 1 sub-batches
 * Rollout step 6: unpacked 1 sub-batches
 * Rollout step 7: unpacked 1 sub-batches
Processing trajectories ...
 * init_time=2021-01-03 09:00:00: combining 4 steps
 * Skipping static variables
Rollout prediction on batch 2 ...
 * Rollout step 1: skipping warmup period
 * Rollout st

In [23]:
batches = [batch, batch]
batches[0].metadata.time[0]

datetime.datetime(2021, 1, 2, 9, 0)

In [32]:
sub_batch_preds[0].metadata.time

datetime.datetime(2021, 1, 4, 3, 0)