In [1]:
%load_ext autoreload
%autoreload 2

import os, shutil
import xarray as xr
import torch
import dask
import pandas as pd
from torch.utils.data import DataLoader
import numpy as np
import logging
import dataclasses

from aurora import Batch

from aurora_benchmark.utils import verbose_print, xr_to_netcdf

from aurora_benchmark.parallel import AuroraBatchDataParallel, rollout, ParallelAurora, ParallelAuroraSmall
from aurora_benchmark.data import (
    XRAuroraDataset, 
    XRAuroraBatchedDataset,
    aurora_batch_collate_fn, 
    aurora_batch_to_xr, 
    unpack_aurora_batch
)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Suppress logs from Google libraries
logging.getLogger('google').setLevel(logging.ERROR)
logging.getLogger('google.auth').setLevel(logging.ERROR)
logging.getLogger('google.cloud').setLevel(logging.ERROR)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
era5_surface_paths = [
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/10v_2021-2022-6h-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/2t_2021-2022-6h-1440x721.nc",
#  - data/era5_wb2/2021-2022-6h-1440x721/tp_2021-2022-6h-1440x721.nc # original only
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/10u_2021-2022-6h-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/msl_2021-2022-6h-1440x721.nc",
#  - data/era5_wb2/2021-2022-6h-1440x721/sst_2021-2022-6h-1440x721.nc # original only
]
era5_atmospheric_paths = [
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/q_2021-2022-6h-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/t_2021-2022-6h-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/u_2021-2022-6h-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/v_2021-2022-6h-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/z_2021-2022-6h-1440x721.nc",
]
era5_static_paths = [
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/lsm_static-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/z_static-1440x721.nc",
"/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/slt_static-1440x721.nc",
]

forecast_dir = "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2_forecasts/2021-2022-6h-1d-6w-1440x721_original_variables/"
eval_dir = "../figures/era5_wb2_eval/2021-2022-6h-1d-6w-1440x721_original_variables"

# Load the data into a single dataset with the same coords but multiple variables
surface_ds = xr.merge([
    xr.open_dataset(p, engine="netcdf4", chunks={"time": 50, "latitude": 721, "longitude": 1440})
    for p in era5_surface_paths
])

atmospheric_ds = xr.merge([
    xr.open_dataset(p, engine="netcdf4", chunks={"time": 50, "latitude": 721, "longitude": 1440, "level": 1})
    for p in era5_atmospheric_paths
])

static_ds = xr.merge([
    xr.open_dataset(p, engine="netcdf4", chunks={"latitude": 721, "longitude": 1440})
    for p in era5_static_paths
])

surface_ds.dims, atmospheric_ds.dims, static_ds.dims





In [3]:
# era5_surface_paths = [
#     "../toy_data/era5-1d-360x180/msl-2021-2022-1d-360x180.nc",
#     "../toy_data/era5-1d-360x180/t2m-2021-2022-1d-360x180.nc",
#     "../toy_data/era5-1d-360x180/u10-2021-2022-1d-360x180.nc",
#     "../toy_data/era5-1d-360x180/v10-2021-2022-1d-360x180.nc",
# ]
# era5_atmospheric_paths = [ 
#     "../toy_data/era5-1d-360x180/t-2021-2022-1d-360x180.nc",
#     "../toy_data/era5-1d-360x180/q-2021-2022-1d-360x180.nc",
#     "../toy_data/era5-1d-360x180/u-2021-2022-1d-360x180.nc",
#     "../toy_data/era5-1d-360x180/v-2021-2022-1d-360x180.nc",
#     "../toy_data/era5-1d-360x180/z-2021-2022-1d-360x180.nc",
# ]
# era5_static_paths = [
#     "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/lsm_static-1440x721.nc",
#     "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/z_static-1440x721.nc",
#     "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2/2021-2022-6h-1440x721/slt_static-1440x721.nc",
# ]

# # Load the data into a single dataset with the same coords but multiple variables
# surface_dss = [
#     xr.open_dataset(path, engine="netcdf4").drop_vars("time_bnds")
#     #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).drop_vars("time_bnds")
#     #xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360}).rename({"msl": svar})#.drop_vars("time_bnds")
#     for path in era5_surface_paths
# ]
# surface_ds = xr.merge(surface_dss).rename({"t2m": "2t", "u10": "10u", "v10": "10v", "lat": "latitude", "lon": "longitude"})
# atmospheric_dss = [
#     xr.open_dataset(path, engine="netcdf4").drop_vars("time_bnds")
#     #xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]}).drop_vars("time_bnds")
#     #xr.open_zarr(path, chunks={"time": 50, "latitude": 180, "longitude": 360, "level": 1}).rename({"msl": svar}).expand_dims({"level": [1000, 700, 250]})#.drop_vars("time_bnds")
#     for path in era5_atmospheric_paths
# ]
# atmospheric_ds = xr.merge(atmospheric_dss).rename({"lat": "latitude", "lon": "longitude"})
# static_dss = [
#      xr.open_dataset(path, engine="netcdf4").coarsen(longitude=1440//360, latitude=721//180, boundary="trim").mean()
#     # xr.open_dataset(path, engine="h5netcdf").rename({"msl": svar}).isel(time=0).drop_vars("time_bnds")
#     #xr.open_zarr(path, chunks={"latitude": 180, "longitude": 360}).rename({"msl": svar}).isel(time=0)#.drop_vars("time_bnds")
#     for path in era5_static_paths
# ]
# static_ds = xr.merge(static_dss)


# print("TESTING PARALLEL ON TOY DATA!!!!")
# surface_ds.dims, atmospheric_ds.dims, static_ds.dims

In [4]:
batch_size = 4
num_workers = 2
eval_start ="1w"
era5_base_frequency = "6h"
forecast_horizon = "6w"
use_dataloader = False
eval_aggregation = "1w"
init_frequency = "1w"
verbose = True
drop_timestamps = False
persist = False
rechunk = False
output_dir = "/projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2_forecasts/2021-2022-6h-1w-6w-1440x721_notebook/"

surf_vars = ["2t", "msl", "10u", "10v"]
atmospheric_vars = ["t", "q", "z", "u", "v"]
static_vars = ["z", "lsm", "slt"]

interest_variables = atmospheric_vars + surf_vars
interest_levels = [1000, 500, 250]

os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")

warmup_steps = int(pd.Timedelta(eval_start) / pd.Timedelta(era5_base_frequency)) if eval_start is not None else 0
forecast_steps = int(pd.Timedelta(forecast_horizon) / pd.Timedelta(era5_base_frequency))

assert (forecast_steps-warmup_steps) * pd.Timedelta(era5_base_frequency) >= pd.Timedelta(eval_aggregation), "Evaluation steps must be at least as long as eval_aggregation" 

if use_dataloader:
    dask.config.set(scheduler='synchronous')
else:
    dask.config.set(scheduler='threads')
verbose_print(verbose, f"Using dask scheduler: {dask.config.get('scheduler')}")


if use_dataloader:
    verbose_print(verbose, f"Creating XRAuroraDataset and DataLoader...")
    dataset = XRAuroraDataset(
        surface_ds=surface_ds,
        atmospheric_ds=atmospheric_ds,
        static_ds=static_ds,
        init_frequency=init_frequency,
        forecast_horizon=forecast_horizon,
        num_time_samples=2, # Aurora has fixed history length of 2...
        drop_timestamps=drop_timestamps,
        persist=persist,
        rechunk=rechunk,
        atmospheric_variables=atmospheric_vars,
        surface_variables=surf_vars,
        static_variables=static_vars,
    )
    verbose_print(verbose, f"Loaded dataset of length {len(dataset)} (drop_timestamps={drop_timestamps}, persist={persist}, rechunk={rechunk})")
    
    num_workers = 2 #int(os.getenv('SLURM_CPUS_PER_TASK', 1))+2 if os.getenv('SLURM_CPUS_PER_TASK') is not None else os.cpu_count()+2
    verbose_print(verbose, f"Creating DataLoader with {num_workers} workers ...")
    eval_loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        collate_fn=aurora_batch_collate_fn,
        num_workers=num_workers,
    )
    batch_iterator = eval_loader
else:
    # This is done to avoid the issue with torch DataLoader and dask
    # when using netcdf files (i.e. netcdf backend is not thread safe)
    verbose_print(verbose, f"Creating XRAuroraBatchedDataset ...")
    dataset = XRAuroraBatchedDataset(
        batch_size=batch_size,
        surface_ds=surface_ds,
        atmospheric_ds=atmospheric_ds,
        static_ds=static_ds,
        init_frequency=init_frequency,
        forecast_horizon=forecast_horizon,
        num_time_samples=2, # Aurora has fixed history length of 2...
        drop_timestamps=drop_timestamps,
        persist=persist,
        rechunk=rechunk,
        atmospheric_variables=atmospheric_vars,
        surface_variables=surf_vars,
        static_variables=static_vars,
    )
    batch_iterator = dataset

verbose_print(verbose, f"interest_vars: {interest_variables}, interest_levels: {interest_levels}")
verbose_print(verbose, f"Dataset length: {dataset.flat_length() if hasattr(dataset, 'flat_length') else len(dataset)}")
verbose_print(verbose, f"Dataloader length: {len(batch_iterator)} (type: {type(batch_iterator)}, batch_size: {batch_size})")

2024-10-17 10:06:13,460 - aurora_benchmark.utils - INFO - Using dask scheduler: threads
2024-10-17 10:06:13,461 - aurora_benchmark.utils - INFO - Creating XRAuroraBatchedDataset ...
2024-10-17 10:06:13,467 - aurora_benchmark.utils - INFO - interest_vars: ['t', 'q', 'z', 'u', 'v', '2t', 'msl', '10u', '10v'], interest_levels: [1000, 500, 250]
2024-10-17 10:06:13,468 - aurora_benchmark.utils - INFO - Dataset length: 12
2024-10-17 10:06:13,468 - aurora_benchmark.utils - INFO - Dataloader length: 12 (type: <class 'aurora_benchmark.data.XRAuroraBatchedDataset'>, batch_size: 4)


Output directory: /projects/prjs0981/ewalt/aurora_benchmark/data/era5_wb2_forecasts/2021-2022-6h-1w-6w-1440x721_notebook/


In [None]:
verbose = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


verbose_print(verbose, "loading model ...")
aurora_model = "aurora-0.25-pretrained.ckpt"
model = ParallelAurora(use_lora=False)
model.load_checkpoint("microsoft/aurora", aurora_model)
if torch.cuda.device_count() > 1:
    verbose_print(verbose, f"Using {torch.cuda.device_count()} GPUs")
    model = AuroraBatchDataParallel(model)
model = model.to(device)

verbose_print(verbose, f"Evaluating on {device}")
# evaluation loop
with torch.inference_mode() and torch.no_grad():
    for i, batch in enumerate(batch_iterator):
        verbose_print(verbose,f"Rollout prediction on batch {i} ...")
        if batch is None: break
        
        batch = batch.to(device)
        # rollout until for forecast_steps
        trajectories = [[] for _ in range(batch_size)]
        for s, batch_pred in enumerate(rollout(model, batch, steps=forecast_steps)):
            if s < warmup_steps:
                verbose_print(verbose,f" * Rollout step {s+1}: skipping warmup period")
                continue            
            # separate batched batches
            sub_batch_preds = unpack_aurora_batch(batch_pred.to("cpu"))
            verbose_print(verbose,f" * Rollout step {s+1}: unpacked {len(sub_batch_preds)} sub-batches")
            if i != len(batch_iterator) - 1: # the last batch may not be full
                assert len(sub_batch_preds) == batch_size
            # accumulate
            for b, sub_batch_pred in enumerate(sub_batch_preds):
                trajectories[b].append(sub_batch_pred)
        
        # convert to xr and process
        verbose_print(verbose,f"Processing trajectories ...")
        for init_time, trajectory in zip(batch.metadata.time, trajectories):
            verbose_print(verbose,f" * init_time={init_time}: combining {len(trajectory)} steps")
            assert len(trajectory) == forecast_steps-warmup_steps
            # collate trajectory batches
            trajectory = aurora_batch_collate_fn(trajectory)
            # convert to xr.Dataset
            trajectory = aurora_batch_to_xr(trajectory, frequency=era5_base_frequency)
            
            # process individual trajectory elements (i.e. variable types)
            for var_type, vars_ds in trajectory.items():
                # ensure processing is necessary
                if var_type == "static_ds":
                    verbose_print(verbose,f" * Skipping static variables")
                    continue # we do not care about static variables for the forecast
                if not any([var in vars_ds.data_vars for var in interest_variables]):
                    verbose_print(verbose,f" * Skipping {var_type} variables as no interest variables are present")
                    continue # don't bother processing variables we are not interested in
                if var_type == "atmospheric_ds" and (interest_levels is None or len(interest_levels)==0):
                    verbose_print(verbose,f" * Skipping atmospheric variables as no interest levels have been requested")
                    continue # we do not care about atmospheric variables if no levels are of interest
                
                # select interest variables and levels
                vars_interest_variables = [var for var in vars_ds.data_vars if var in interest_variables]
                if var_type == "atmospheric_ds":
                    vars_ds = vars_ds[vars_interest_variables].sel(level=interest_levels)
                else:
                    vars_ds = vars_ds[vars_interest_variables]
                    
                # override time coordinates using the era5_base_frequency
                vars_ds = vars_ds.assign_coords(
                    {"time": pd.date_range(init_time+warmup_steps*pd.Timedelta(era5_base_frequency), 
                                           periods=vars_ds.sizes["time"], 
                                           freq=era5_base_frequency)})
                
                # aggregate at eval_agg frequency
                # use pd.Timedelta to avoid xarray automatically starting the resampling 
                # on Mondays for weekly etc.
                # Note that resulting'time' will be the first timestamp in the aggregated period
                vars_ds = vars_ds.resample(time=pd.Timedelta(eval_aggregation), origin=init_time).mean()
                vars_ds = vars_ds.rename({"time": "lead_time"})
                vars_ds["lead_time"] = vars_ds["lead_time"] - np.datetime64(init_time)
                
                # per-variable processing
                for var in vars_ds.data_vars:
                    # add lead time
                    var_ds = vars_ds[var]
                    
                    # save
                    path = f"forecast_{var}_" + "-".join([
                        init_time.strftime("%Y%m%dT%H%M%S"),
                        str(era5_base_frequency),
                        str(eval_aggregation),
                        str(eval_start),
                        str(forecast_horizon),
                        str(var_ds.sizes["longitude"])+ "x" +str(var_ds.sizes["latitude"]),
                    ]) + ".nc"
                    path = os.path.join(output_dir, path)
                    verbose_print(verbose, f"   * Saving new {var_type} forecast: {path}")
                    xr_to_netcdf(
                        var_ds, path, 
                        precision="float32", 
                        compression_level=1, 
                        sort_time=False, 
                        exist_ok=True
                    )

2024-10-17 10:06:21,107 - aurora_benchmark.utils - INFO - loading model ...
2024-10-17 10:06:47,303 - aurora_benchmark.utils - INFO - Using 4 GPUs
2024-10-17 10:06:47,849 - aurora_benchmark.utils - INFO - Evaluating on cuda


In [3]:
aurora_model = "aurora-0.25-pretrained.ckpt"
model = ParallelAurora(use_lora=False)
model.load_checkpoint("microsoft/aurora", aurora_model)
if torch.cuda.device_count() > 1:
    model = AuroraBatchDataParallel(model)
model = model.to(device)

out = model.forward(batch)

NameError: name 'device' is not defined

In [34]:
type(out), type(out.surf_vars), out.surf_vars["2t"].shape, out.static_vars["z"].shape

(aurora.batch.Batch,
 dict,
 torch.Size([4, 1, 180, 360]),
 torch.Size([180, 360]))

In [35]:
out.metadata.rollout_step

1

In [21]:
surf_vars["2t"][0].shape, len(surf_vars["2t"])

(torch.Size([4, 2, 180, 360]), 3)

In [51]:
from functools import partial
from datetime import timedelta

class ParallelAurora(Aurora):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def forward(self, batch: Batch) -> Batch:
        """Forward pass.

        Args:
            batch (:class:`Batch`): Batch to run the model on.

        Raises:
            ValueError: If no metric is provided.

        Returns:
            :class:`Batch`: Prediction for the batch.
        """
        
        # Ensure everything is on the same device and in the right format.
        device = next(iter(batch.surf_vars.values())).device
        batch = batch.type(torch.float32)
        batch = batch.normalise()
        batch = batch.crop(patch_size=self.patch_size)
        batch = batch.to(device)
        
        H, W = batch.spatial_shape
        patch_res = (
            self.encoder.latent_levels,
            H // self.encoder.patch_size,
            W // self.encoder.patch_size,
        )

        # Insert batch and history dimension for static variables.
        B, T = next(iter(batch.surf_vars.values())).shape[:2]
        batch = dataclasses.replace(
            batch,
            static_vars={k: v[None, None].repeat(B, T, 1, 1) for k, v in batch.static_vars.items()},
        )

        x = self.encoder(
            batch,
            lead_time=timedelta(hours=6),
        )
        x = self.backbone(
            x,
            lead_time=timedelta(hours=6),
            patch_res=patch_res,
            rollout_step=batch.metadata.rollout_step,
        )
        pred = self.decoder(
            x,
            batch,
            lead_time=timedelta(hours=6),
            patch_res=patch_res,
        )

        # Remove batch and history dimension from static variables.
        B, T = next(iter(batch.surf_vars.values()))[0]
        pred = dataclasses.replace(
            pred,
            static_vars={k: v[0, 0] for k, v in batch.static_vars.items()},
        )

        # Insert history dimension in prediction. The time should already be right.
        pred = dataclasses.replace(
            pred,
            surf_vars={k: v[:, None] for k, v in pred.surf_vars.items()},
            atmos_vars={k: v[:, None] for k, v in pred.atmos_vars.items()},
        )

        pred = pred.unnormalise()

        return pred
    
ParallelAuroraSmall = partial(
    ParallelAurora,
    encoder_depths=(2, 6, 2),
    encoder_num_heads=(4, 8, 16),
    decoder_depths=(2, 6, 2),
    decoder_num_heads=(16, 8, 4),
    embed_dim=256,
    num_heads=8,
    use_lora=False,
)

In [52]:
model = ParallelAuroraSmall()#use_lora=False)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-small-pretrained.ckpt")#)
model = AuroraBatchDataParallel(model)
model = model.to("cuda")
model.forward(batch)

OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/ewalt/.local/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_1699939/3131655674.py", line 52, in forward
    pred = self.decoder(
           ^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/aurora/model/decoder.py", line 162, in forward
    x_atmos = self.deaggregate_levels(levels_embed, x[..., 1:, :])  # (B, L, C_A, D)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/aurora/model/decoder.py", line 104, in deaggregate_levels
    x = self.level_decoder(level_embed, x)  # (BxL, C, D)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ewalt/.local/lib/python3.11/site-packages/aurora/model/perceiver.py", line 213, in forward
    latents = attn_out + latents if self.residual_latent else attn_out
              ~~~~~~~~~^~~~~~~~~
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.61 GiB. GPU 
