In [38]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np # type: ignore
import pandas as pd # type: ignore
import torch # type: ignore
from datetime import datetime # type: ignore
import xarray as xr # type: ignore
from xarray import DataArray, Dataset # type: ignore
from aurora import Batch, Metadata# type: ignore

from aurora_benchmark.data import (
    XRAuroraDataset, 
    aurora_batch_collate_fn, 
    aurora_batch_to_xr, 
    xr_to_aurora_batch,
    unpack_aurora_batch
)
    

# create a toy dataset
time = pd.date_range(start="2020-01-01", end="2020-04-30", freq="6h")
num_longitudes = 1440 // 4
num_latitudes = 721 // 4
num_levels = 2

surface_ds = Dataset(
    {
        "10u": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        "10v": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        "2t": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        "msl": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        #"sst": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        #"tp": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
    },
    coords={"time": time, "latitude": np.linspace(-90, 90, num_latitudes), "longitude": np.linspace(-180, 180, num_longitudes)}
)

atmospheric_ds = Dataset(
    {
        "t": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "u": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "v": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "q": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "z": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
    },
    coords={"time": time, 
            "latitude": np.linspace(-90, 90, num_latitudes), 
            "longitude": np.linspace(0, 360, num_longitudes+1)[:-1],
            "level": [1000, 700]}
)

static_ds = Dataset(
    {
        "z": DataArray(np.random.rand(num_latitudes, num_longitudes), dims=["latitude", "longitude"]),
        "lsm": DataArray(np.random.rand(num_latitudes, num_longitudes), dims=["latitude", "longitude"]),
        "slt": DataArray(np.random.rand(num_latitudes, num_longitudes), dims=["latitude", "longitude"]),
    },
    coords={"latitude": np.linspace(-90, 90, num_latitudes), "longitude": np.linspace(-180, 180, num_longitudes)}
)

# import yaml 

# with open("../configs/forecast_era5_wb2_2021-2022-6h-6w-1440x721_original_vars.yaml", "r") as f:
#     config = yaml.safe_load(f)
    
# era5_surface_paths = config['era5_surface_paths']
# era5_atmospheric_paths = config['era5_atmospheric_paths']
# era5_static_paths = config['era5_static_paths']
# interest_variables = config['interest_variables']
# interest_levels = config['interest_levels']
# output_dir = config['output_dir']
# batch_size = config.get('batch_size', 4)
# replacement_variables = config.get('replacement_variables', {})
# era5_base_frequency = config.get('era5_base_frequency', "6h")
# init_frequency = config.get('init_frequency', "1d")
# forecast_horizon = config.get('forecast_horizon', "6W")
# eval_aggregation = config.get('eval_aggregation', None)
# eval_start = config.get('eval_start', "1W")
# aurora_model = config.get('aurora_model', "aurora-0.25-pretrained.ckpt")
# device = config.get('device', "cuda" if torch.cuda.is_available() else "cpu")
# persist = config.get('persist', False)
# rechunk = config.get('rechunk', True)
# drop_timestamps = config.get('drop_timestamps', True)
# verbose = config.get('verbose', True)

# era5_surface_paths = [
#     os.path.join("/projects/prjs0981/ewalt/aurora_benchmark/", p)
#     for p in era5_surface_paths
# ]
# era5_atmospheric_paths = [
#     os.path.join("/projects/prjs0981/ewalt/aurora_benchmark/", p)
#     for p in era5_atmospheric_paths
# ]
# era5_static_paths = [
#     os.path.join("/projects/prjs0981/ewalt/aurora_benchmark/", p)
#     for p in era5_static_paths
# ]

# AURORA_VARAIBLE_RENAMES = {
#     "surface": {
#         "u10": "10u",
#         "v10": "10v",
#         "t2m": "2t",
#     },
#     "atmospheric": {},
#     "static": {},
# }

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [39]:
# surface_ds = xr.merge(
#     [xr.open_dataset(path, engine="netcdf4", 
#                         chunks={"time": 4*batch_size, "latitude": 721, "longitude": 1440}) 
#         for path in era5_surface_paths],
# ).rename(AURORA_VARAIBLE_RENAMES["surface"])
# atmospheric_ds = xr.merge(
#     [xr.open_dataset(path, engine="netcdf4",
#                         chunks={"time": 4*batch_size, "latitude": 721, "longitude": 1440, "level": 7}) 
#         for path in era5_atmospheric_paths],
# ).rename(AURORA_VARAIBLE_RENAMES["atmospheric"])
# static_ds = xr.merge(
#     [xr.open_dataset(path, engine="netcdf4")
#         for path in era5_static_paths],
# ).rename(AURORA_VARAIBLE_RENAMES["static"])

In [40]:
# subslice = slice("2021-01-01", "2021-04-30")
# surface_ds = surface_ds\
#         .sel(time=subslice)\
#         .isel(latitude=slice(None, None, 4), longitude=slice(None, None, 4))\
#         .compute() #.coarsen(latitude=4, longitude=4, boundary="trim").mean().compute()
# atmospheric_ds = atmospheric_ds\
#         .sel(time=subslice)\
#         .isel(latitude=slice(None, None, 4), longitude=slice(None, None, 4))\
#         .compute() #.coarsen(latitude=4, longitude=4, boundary="trim").mean().compute()

In [41]:
# surface_ds.to_netcdf("../tmp_data/surface_naive_coarse_20210101-20210430.nc")
# atmospheric_ds.to_netcdf("../tmp_data/atmospheric_naive_coarse_20210101-20210430.nc")

In [42]:
dataset = XRAuroraDataset(
    surface_ds=surface_ds,
    atmospheric_ds=atmospheric_ds,
    static_ds=static_ds,
    init_frequency="1D",
    forecast_horizon="2W",
    num_time_samples=2,
    pressure_levels=[1000, 700]
)

In [43]:
# # Ensure the first batch is equal to the dataset's input tensors on timesteps 1 and 2
# bidx = 0
# batch = dataset[bidx] 
# for subvars, subds in [(batch.surf_vars, surface_ds), 
#                         (batch.atmos_vars, atmospheric_ds)]:
#     for svar, sdata in subvars.items():
#         eqdata = torch.from_numpy(subds[svar].isel(time=slice(bidx, bidx+2)).values).unsqueeze(0)
#         print(svar, eqdata.shape, sdata.shape)
#         assert np.allclose(sdata, subds[svar].isel(time=slice(bidx, bidx+2)).values)

In [44]:
# collated_batch = aurora_batch_collate_fn([
#     dataset[4],
#     dataset[3],
#     dataset[2],
#     dataset[1]
# ])
# for subvars in [collated_batch.surf_vars, collated_batch.atmos_vars, collated_batch.static_vars]:
#     for varname, vartensor in subvars.items():
#         print(varname, vartensor.shape)

In [45]:
# TEST FORWARD PASS
from aurora import Aurora, AuroraSmall, rollout

batch_size = 3

interest_variables = ["2t", "u", "v"]
interest_levels = [1000]

eval_start = "12h" # before that, this is the warmup period and we don't evaluate
eval_aggregation = "10d"

forecast_horizon = "36h"
frequency = "6h"
forecast_steps = pd.Timedelta(forecast_horizon) / pd.Timedelta(frequency)
warmup_steps = pd.Timedelta(eval_start) / pd.Timedelta(frequency) if eval_start is not None else 0.0

assert forecast_steps.is_integer(), f"forecast_horizon not a multiple of frequency"
forecast_steps = int(forecast_steps)

assert warmup_steps.is_integer(), f"eval_start not a multiple of frequency"
warmup_steps = int(warmup_steps)

print(f"Forecast steps: {forecast_steps} (including {warmup_steps} warmup steps)")

eval_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=aurora_batch_collate_fn)

print(f"Creating Aurora model ...")
# model = Aurora(use_lora=False)  # Model is not fine-tuned.
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model = AuroraSmall()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")

if torch.cuda.is_available(): device = "cuda"
else: device = "cpu"
model = model.to(device)
print(f"Computing on {device} ...")

xr_preds = {"surface_ds": [], "atmospheric_ds": []}
with torch.inference_mode() and torch.no_grad():
    for i, batch in enumerate(eval_loader):
        batch = batch.to(device)
        # rollout until for forecast_steps
        print(f"Rollout prediction on batch {i} ...")
        trajectories = [[] for _ in range(batch_size)]
        for s, batch_pred in enumerate(rollout(model, batch, steps=forecast_steps)):
            if s < warmup_steps:
                print(f" * Rollout step {s+1}: skipping warmup period")
                continue            
            # separate batched batches
            sub_batch_preds = unpack_aurora_batch(batch_pred.to("cpu"))
            print(f" * Rollout step {s+1}: unpacked {len(sub_batch_preds)} sub-batches")
            assert len(sub_batch_preds) == batch_size
            # accumulate
            for b, sub_batch_pred in enumerate(sub_batch_preds):
                trajectories[b].append(sub_batch_pred)
        print(f"Processing trajectories ...")
        # convert to xr 
        for init_time, trajectory in zip(batch.metadata.time, trajectories):
            print(f" * init_time={init_time}: combining {len(trajectory)} steps")
            assert len(trajectory) == forecast_steps-warmup_steps
            # collate trajectory batches
            trajectory = aurora_batch_collate_fn(trajectory)
            # convert to xr.Dataset
            trajectory = aurora_batch_to_xr(trajectory, frequency=frequency)
            # add lead time
            for var_type, vars_ds in trajectory.items():
                # ensure processing is necessary
                if var_type == "static_ds":
                    print(f" * Skipping static variables")
                    continue # we do not care about static variables for the forecast
                if not any([var in vars_ds.data_vars for var in interest_variables]):
                    print(f" * Skipping {var_type} variables as no interest variables are present")
                    continue # don't bother processing variables we are not interested in
                if var_type == "atmospheric_ds" and (interest_levels is None or len(interest_levels)==0):
                    print(f" * Skipping atmospheric variables as no interest levels have been requested")
                    continue # we do not care about atmospheric variables if no levels are of interest
                
                # select interest variables and levels
                vars_interest_variables = [var for var in vars_ds.data_vars if var in interest_variables]
                if var_type == "atmospheric_ds":
                    vars_ds = vars_ds[vars_interest_variables].sel(level=interest_levels)
                else:
                    vars_ds = vars_ds[vars_interest_variables]
                
                # add lead time
                vars_ds = vars_ds.assign_coords({"lead_time": vars_ds.time.values - np.datetime64(init_time)})
                vars_ds = vars_ds.set_index({"lead_time": "lead_time"})
        
                
                # TODO: aggregate desired timesteps to agg freq  
                # vars_ds = vars_ds.resample(time=eval_aggregation).mean()  # not enough as it messes with "lead time"
                
                # append to predictions
                xr_preds[var_type].append(vars_ds)
            
        if i==1: break
        
for var_type, var_ds_list in xr_preds.items():
    xr_preds[var_type] = xr.concat(var_ds_list, dim="time")

Forecast steps: 6 (including 2 warmup steps)
Creating Aurora model ...
Computing on cuda ...
Rollout prediction on batch 0 ...
 * Rollout step 1: skipping warmup period
 * Rollout step 2: skipping warmup period
 * Rollout step 3: unpacked 3 sub-batches
 * Rollout step 4: unpacked 3 sub-batches
 * Rollout step 5: unpacked 3 sub-batches
 * Rollout step 6: unpacked 3 sub-batches
Processing trajectories ...
 * init_time=2020-01-01 06:00:00: combining 4 steps
 * Skipping static variables
 * init_time=2020-01-02 06:00:00: combining 4 steps
 * Skipping static variables
 * init_time=2020-01-03 06:00:00: combining 4 steps
 * Skipping static variables
Rollout prediction on batch 1 ...
 * Rollout step 1: skipping warmup period
 * Rollout step 2: skipping warmup period
 * Rollout step 3: unpacked 3 sub-batches
 * Rollout step 4: unpacked 3 sub-batches
 * Rollout step 5: unpacked 3 sub-batches
 * Rollout step 6: unpacked 3 sub-batches
Processing trajectories ...
 * init_time=2020-01-04 06:00:00: co

In [46]:
lts = np.unique(xr_preds["surface_ds"].lead_time.values).astype("timedelta64[h]")
for lt in lts:
    ltds = xr_preds["surface_ds"].sel(lead_time=lt)
    print(ltds.dims, ltds.data_vars)
    for var in ltds.data_vars:
        print(var, ltds[var].shape)
        

    2t       (time, latitude, longitude) float32 6MB 42.37 50.72 ... 76.05 75.5
2t (24, 180, 360)
    2t       (time, latitude, longitude) float32 6MB 42.37 50.72 ... 76.05 75.5
2t (24, 180, 360)
    2t       (time, latitude, longitude) float32 6MB 42.37 50.72 ... 76.05 75.5
2t (24, 180, 360)
    2t       (time, latitude, longitude) float32 6MB 42.37 50.72 ... 76.05 75.5
2t (24, 180, 360)


In [257]:
ds = xr_preds["surface_ds"]
ds.lead_time

In [47]:
era5_surface_paths

['/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1444x721/v10_2021-2022-6h-1440x721.nc',
 '/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1444x721/t2m_2021-2022-6h-1440x721.nc',
 '/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1444x721/u10_2021-2022-6h-1440x721.nc',
 '/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1444x721/msl_2021-2022-6h-1440x721.nc']

: 