In [25]:
%load_ext autoreload
%autoreload 2

import numpy as np # type: ignore
import pandas as pd # type: ignore
import torch # type: ignore
from datetime import datetime # type: ignore
import xarray as xr # type: ignore
from xarray import DataArray, Dataset # type: ignore
from aurora import Batch, Metadata# type: ignore

from aurora_benchmark.data import (
    XRAuroraDataset, 
    aurora_batch_collate_fn, 
    aurora_batch_to_xr, 
    xr_to_aurora_batch,
    unpack_aurora_batch
)
    

# create a toy dataset
time = pd.date_range(start="2020-01-01", end="2020-04-30", freq="6h")
num_longitudes = 1440 // 4
num_latitudes = 721 // 4
num_levels = 2

surface_ds = Dataset(
    {
        "10u": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        "10v": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        "2t": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        "msl": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        #"sst": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
        #"tp": DataArray(np.random.rand(len(time), num_latitudes, num_longitudes), dims=["time", "latitude", "longitude"]),
    },
    coords={"time": time, "latitude": np.linspace(-90, 90, num_latitudes), "longitude": np.linspace(-180, 180, num_longitudes)}
)

atmospheric_ds = Dataset(
    {
        "t": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "u": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "v": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "q": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
        "z": DataArray(np.random.rand(len(time), num_levels, num_latitudes, num_longitudes), dims=["time", "level", "latitude", "longitude"]),
    },
    coords={"time": time, 
            "latitude": np.linspace(-90, 90, num_latitudes), 
            "longitude": np.linspace(0, 360, num_longitudes+1)[:-1],
            "level": [1000, 700]}
)

static_ds = Dataset(
    {
        "z": DataArray(np.random.rand(num_latitudes, num_longitudes), dims=["latitude", "longitude"]),
        "lsm": DataArray(np.random.rand(num_latitudes, num_longitudes), dims=["latitude", "longitude"]),
        "slt": DataArray(np.random.rand(num_latitudes, num_longitudes), dims=["latitude", "longitude"]),
    },
    coords={"latitude": np.linspace(-90, 90, num_latitudes), "longitude": np.linspace(-180, 180, num_longitudes)}
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
dataset = XRAuroraDataset(
    surface_ds=surface_ds,
    atmospheric_ds=atmospheric_ds,
    static_ds=static_ds,
    init_frequency="1D",
    forecast_horizon="2W",
    num_time_samples=2,
    pressure_levels=[1000, 700]
)

In [3]:
len(dataset)

106

In [4]:
# Ensure the first batch is equal to the dataset's input tensors on timesteps 1 and 2
bidx = 0
batch = dataset[bidx] 
for subvars, subds in [(batch.surf_vars, surface_ds), 
                        (batch.atmos_vars, atmospheric_ds)]:
    for svar, sdata in subvars.items():
        eqdata = torch.from_numpy(subds[svar].isel(time=slice(bidx, bidx+2)).values).unsqueeze(0)
        print(svar, eqdata.shape, sdata.shape)
        assert np.allclose(sdata, subds[svar].isel(time=slice(bidx, bidx+2)).values)

10u torch.Size([1, 2, 180, 360]) torch.Size([1, 2, 180, 360])
10v torch.Size([1, 2, 180, 360]) torch.Size([1, 2, 180, 360])
2t torch.Size([1, 2, 180, 360]) torch.Size([1, 2, 180, 360])
msl torch.Size([1, 2, 180, 360]) torch.Size([1, 2, 180, 360])
t torch.Size([1, 2, 2, 180, 360]) torch.Size([1, 2, 2, 180, 360])
u torch.Size([1, 2, 2, 180, 360]) torch.Size([1, 2, 2, 180, 360])
v torch.Size([1, 2, 2, 180, 360]) torch.Size([1, 2, 2, 180, 360])
q torch.Size([1, 2, 2, 180, 360]) torch.Size([1, 2, 2, 180, 360])
z torch.Size([1, 2, 2, 180, 360]) torch.Size([1, 2, 2, 180, 360])


In [5]:
collated_batch = aurora_batch_collate_fn([
    dataset[4],
    dataset[3],
    dataset[2],
    dataset[1]
])
for subvars in [collated_batch.surf_vars, collated_batch.atmos_vars, collated_batch.static_vars]:
    for varname, vartensor in subvars.items():
        print(varname, vartensor.shape)

10u torch.Size([4, 2, 180, 360])
10v torch.Size([4, 2, 180, 360])
2t torch.Size([4, 2, 180, 360])
msl torch.Size([4, 2, 180, 360])
t torch.Size([4, 2, 2, 180, 360])
u torch.Size([4, 2, 2, 180, 360])
v torch.Size([4, 2, 2, 180, 360])
q torch.Size([4, 2, 2, 180, 360])
z torch.Size([4, 2, 2, 180, 360])
z torch.Size([180, 360])
lsm torch.Size([180, 360])
slt torch.Size([180, 360])


In [6]:
collated_batch.metadata.time

(datetime.datetime(2020, 1, 5, 6, 0),
 datetime.datetime(2020, 1, 4, 6, 0),
 datetime.datetime(2020, 1, 3, 6, 0),
 datetime.datetime(2020, 1, 2, 6, 0))

In [7]:
# TEST DATALOADERS
eval_loader = torch.utils.data.DataLoader(dataset, batch_size=53, collate_fn=aurora_batch_collate_fn)

for i, batch in enumerate(eval_loader):
    print(i, batch.metadata.time)

0 (datetime.datetime(2020, 1, 1, 6, 0), datetime.datetime(2020, 1, 2, 6, 0), datetime.datetime(2020, 1, 3, 6, 0), datetime.datetime(2020, 1, 4, 6, 0), datetime.datetime(2020, 1, 5, 6, 0), datetime.datetime(2020, 1, 6, 6, 0), datetime.datetime(2020, 1, 7, 6, 0), datetime.datetime(2020, 1, 8, 6, 0), datetime.datetime(2020, 1, 9, 6, 0), datetime.datetime(2020, 1, 10, 6, 0), datetime.datetime(2020, 1, 11, 6, 0), datetime.datetime(2020, 1, 12, 6, 0), datetime.datetime(2020, 1, 13, 6, 0), datetime.datetime(2020, 1, 14, 6, 0), datetime.datetime(2020, 1, 15, 6, 0), datetime.datetime(2020, 1, 16, 6, 0), datetime.datetime(2020, 1, 17, 6, 0), datetime.datetime(2020, 1, 18, 6, 0), datetime.datetime(2020, 1, 19, 6, 0), datetime.datetime(2020, 1, 20, 6, 0), datetime.datetime(2020, 1, 21, 6, 0), datetime.datetime(2020, 1, 22, 6, 0), datetime.datetime(2020, 1, 23, 6, 0), datetime.datetime(2020, 1, 24, 6, 0), datetime.datetime(2020, 1, 25, 6, 0), datetime.datetime(2020, 1, 26, 6, 0), datetime.datetime(

In [265]:
# TEST FORWARD PASS
from aurora import Aurora, AuroraSmall, rollout

batch_size = 3

interest_variables = ["2t", "u", "v"]
interest_levels = [1000]

eval_start = "12h" # before that, this is the warmup period and we don't evaluate
eval_aggregation = "10d"

forecast_horizon = "36h"
frequency = "6h"
forecast_steps = pd.Timedelta(forecast_horizon) / pd.Timedelta(frequency)
warmup_steps = pd.Timedelta(eval_start) / pd.Timedelta(frequency) if eval_start is not None else 0.0

assert forecast_steps.is_integer(), f"forecast_horizon not a multiple of frequency"
forecast_steps = int(forecast_steps)

assert warmup_steps.is_integer(), f"eval_start not a multiple of frequency"
warmup_steps = int(warmup_steps)

print(f"Forecast steps: {forecast_steps} (including {warmup_steps} warmup steps)")

eval_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=aurora_batch_collate_fn)

print(f"Creating Aurora model ...")
# model = Aurora(use_lora=False)  # Model is not fine-tuned.
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model = AuroraSmall()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")

if torch.cuda.is_available(): device = "cuda"
else: device = "cpu"
model = model.to(device)
print(f"Computing on {device} ...")

xr_preds = {"surface_ds": [], "atmospheric_ds": []}
with torch.inference_mode() and torch.no_grad():
    for i, batch in enumerate(eval_loader):
        batch = batch.to(device)
        # rollout until for forecast_steps
        print(f"Rollout prediction on batch {i} ...")
        trajectories = [[] for _ in range(batch_size)]
        for s, batch_pred in enumerate(rollout(model, batch, steps=forecast_steps)):
            if s < warmup_steps:
                print(f" * Rollout step {s+1}: skipping warmup period")
                continue            
            # separate batched batches
            sub_batch_preds = unpack_aurora_batch(batch_pred.to("cpu"))
            print(f" * Rollout step {s+1}: unpacked {len(sub_batch_preds)} sub-batches")
            assert len(sub_batch_preds) == batch_size
            # accumulate
            for b, sub_batch_pred in enumerate(sub_batch_preds):
                trajectories[b].append(sub_batch_pred)
        print(f"Processing trajectories ...")
        # convert to xr 
        for init_time, trajectory in zip(batch.metadata.time, trajectories):
            print(f" * init_time={init_time}: combining {len(trajectory)} steps")
            assert len(trajectory) == forecast_steps-warmup_steps
            # collate trajectory batches
            trajectory = aurora_batch_collate_fn(trajectory)
            # convert to xr.Dataset
            trajectory = aurora_batch_to_xr(trajectory, frequency=frequency)
            # add lead time
            for var_type, vars_ds in trajectory.items():
                # ensure processing is necessary
                if var_type == "static_ds":
                    print(f" * Skipping static variables")
                    continue # we do not care about static variables for the forecast
                if not any([var in vars_ds.data_vars for var in interest_variables]):
                    print(f" * Skipping {var_type} variables as no interest variables are present")
                    continue # don't bother processing variables we are not interested in
                if var_type == "atmospheric_ds" and (interest_levels is None or len(interest_levels)==0):
                    print(f" * Skipping atmospheric variables as no interest levels have been requested")
                    continue # we do not care about atmospheric variables if no levels are of interest
                
                # select interest variables and levels
                vars_interest_variables = [var for var in vars_ds.data_vars if var in interest_variables]
                if var_type == "atmospheric_ds":
                    vars_ds = vars_ds[vars_interest_variables].sel(level=interest_levels)
                else:
                    vars_ds = vars_ds[vars_interest_variables]
                
                # add lead time
                vars_ds = vars_ds.assign_coords({"lead_time": vars_ds.time.values - np.datetime64(init_time)})
                vars_ds = vars_ds.set_index({"lead_time": "lead_time"})
                
                # add lead time
                vars_ds = vars_ds.assign_coords({"init_time": np.array([init_time for _ in vars_ds.time.values])})
                vars_ds = vars_ds.set_index({"init_time": "init_time"})
                
                # TODO: aggregate desired timesteps to agg freq  
                # vars_ds = vars_ds.resample(time=eval_aggregation).mean()  # not enough as it messes with "lead time"
                
                # append to predictions
                xr_preds[var_type].append(vars_ds)
            
        if i==1: break
        
for var_type, var_ds_list in xr_preds.items():
    xr_preds[var_type] = xr.concat(var_ds_list, dim="time")

Forecast steps: 6 (including 2 warmup steps)
Creating Aurora model ...
Computing on cpu ...
Rollout prediction on batch 0 ...
 * Rollout step 1: skipping warmup period
 * Rollout step 2: skipping warmup period
 * Rollout step 3: unpacked 3 sub-batches
 * Rollout step 4: unpacked 3 sub-batches


In [253]:
lts = np.unique(xr_preds["surface_ds"].lead_time.values).astype("timedelta64[h]")
for lt in lts:
    ltds = xr_preds["surface_ds"].sel(lead_time=lt)
    print(ltds.dims, ltds.data_vars)
    for var in ltds.data_vars:
        print(var, ltds[var].shape)
        

    2t       (time, latitude, longitude) float32 6MB 52.28 50.79 ... 69.17 74.33
2t (24, 180, 360)
    2t       (time, latitude, longitude) float32 6MB 52.28 50.79 ... 69.17 74.33
2t (24, 180, 360)
    2t       (time, latitude, longitude) float32 6MB 52.28 50.79 ... 69.17 74.33
2t (24, 180, 360)
    2t       (time, latitude, longitude) float32 6MB 52.28 50.79 ... 69.17 74.33
2t (24, 180, 360)


In [257]:
ds = xr_preds["surface_ds"]
ds.lead_time

In [264]:
# resample to 1 day frequency (make sure to always use the same init_time (i.e. some kind of groupby))


AttributeError: 'DatetimeAccessor' object has no attribute 'dt.weekofyear'