In [1]:
%%bash
cat > /content/.cdsapirc <<EOF
url: https://cds.climate.copernicus.eu/api/v2
key: 5979816a-75f6-457b-aed1-6a45adb32dad
EOF

In [2]:
!ls -la /content/.cdsapirc


-rw-r--r-- 1 root root 88 Apr 16 05:45 /content/.cdsapirc


# Predictions for ERA5 (one single day)

In this example, we will download ERA5 data for 1 Jan 2023 at 0.25 degrees resolution and run Aurora on this data. The fine-tuned version of Aurora specifically only works with IFS HRES T0, so we use the non-fine-tuned version of Aurora in this example.

Running this notebook requires additional Python packages. You can install these as follows:

```
pip install cdsapi matplotlib
```

## Downloading the Data

To begin with, register an account with the [Climate Data Store](https://cds.climate.copernicus.eu/) and create `$HOME/.cdsapirc` with the following content:

```
url: https://cds.climate.copernicus.eu/api
key: <API key>
```

You can find your API key on your account page.

In order to be able to download ERA5 data, you need to accept the terms of use in the [dataset page](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-single-levels?tab=download).

We now download the ERA5 data.

In [3]:
# 安装所需的库：
!pip install gcsfs xarray netcdf4 cdsapi matplotlib --quiet

import xarray as xr
import gcsfs

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
# Mount Google Drive to access your files.
from google.colab import drive
drive.mount('/content/drive')

# Copy the .cdsapirc file from Google Drive to your home directory.
# Adjust the path below if your .cdsapirc file is stored in a different folder on Drive.
!cp /content/drive/MyDrive/heatwaves/.cdsapirc ~/

from pathlib import Path
import cdsapi
import datetime

# Set the directory on your Google Drive where the downloaded data will be saved.
# You can change "era5_aurora_data" to any folder name or path you prefer.
download_path = Path("/content/drive/MyDrive/heatwaves/era5_aurora_data")
download_path.mkdir(parents=True, exist_ok=True)

# Initialize the CDS API client.
# The client will load credentials from ~/.cdsapirc.
c = cdsapi.Client()

Mounted at /content/drive


2025-04-16 05:45:52,516 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.
INFO:datapi.legacy_api_client:[2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.


In [5]:
# Download the static variables.

if not (download_path / "static.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "geopotential",
                "land_sea_mask",
                "soil_type",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": "00:00",
            "format": "netcdf",
        },
        str(download_path / "static.nc"),
    )
print("Static variables downloaded!")

Static variables downloaded!


In [6]:
# Download the surface-level variables.
if not (download_path / "2023-01-01-surface-level.nc").exists():
    c.retrieve(
        "reanalysis-era5-single-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "2m_temperature",
                "10m_u_component_of_wind",
                "10m_v_component_of_wind",
                "mean_sea_level_pressure",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": ["00:00", "06:00", "12:00", "18:00"],
            "format": "netcdf",
        },
        str(download_path / "2023-01-01-surface-level.nc"),
    )
print("Surface-level variables downloaded!")

# Download the atmospheric variables.
if not (download_path / "2023-01-01-atmospheric.nc").exists():
    c.retrieve(
        "reanalysis-era5-pressure-levels",
        {
            "product_type": "reanalysis",
            "variable": [
                "temperature",
                "u_component_of_wind",
                "v_component_of_wind",
                "specific_humidity",
                "geopotential",
            ],
            "pressure_level": [
                "50",
                "100",
                "150",
                "200",
                "250",
                "300",
                "400",
                "500",
                "600",
                "700",
                "850",
                "925",
                "1000",
            ],
            "year": "2023",
            "month": "01",
            "day": "01",
            "time": ["00:00", "06:00", "12:00", "18:00"],
            "format": "netcdf",
        },
        str(download_path / "2023-01-01-atmospheric.nc"),
    )
print("Atmospheric variables downloaded!")

Surface-level variables downloaded!
Atmospheric variables downloaded!


## Preparing a Batch

We convert the downloaded data to an `aurora.Batch`, which is what the model requires.

In [7]:
pip install microsoft-aurora

Collecting microsoft-aurora
  Downloading microsoft_aurora-1.5.0-py3-none-any.whl.metadata (10 kB)
Collecting azure-storage-blob (from microsoft-aurora)
  Downloading azure_storage_blob-12.25.1-py3-none-any.whl.metadata (26 kB)
Collecting timm==0.6.13 (from microsoft-aurora)
  Downloading timm-0.6.13-py3-none-any.whl.metadata (38 kB)
Collecting azure-core>=1.30.0 (from azure-storage-blob->microsoft-aurora)
  Downloading azure_core-1.33.0-py3-none-any.whl.metadata (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.6/42.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting isodate>=0.6.1 (from azure-storage-blob->microsoft-aurora)
  Downloading isodate-0.7.2-py3-none-any.whl.metadata (11 kB)
Downloading microsoft_aurora-1.5.0-py3-none-any.whl (200 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.4/200.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading timm-0.6.13-py3-none-any.whl (549 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [8]:
import torch
import xarray as xr

from aurora import Batch, Metadata

# Open the downloaded netCDF files (stored on Google Drive).
static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
surf_vars_ds   = xr.open_dataset(download_path / "2023-01-01-surface-level.nc", engine="netcdf4")
atmos_vars_ds  = xr.open_dataset(download_path / "2023-01-01-atmospheric.nc", engine="netcdf4")

# Select the time index from the downloaded data. Here, we use index i=1.
i = 1

# Build the batch using the defined variables and metadata.
batch = Batch(
    surf_vars={
        # Use consecutive time steps (i-1 and i). The [None] adds a batch dimension.
        "2t": torch.from_numpy(surf_vars_ds["t2m"].values[[i - 1, i]][None]),
        "10u": torch.from_numpy(surf_vars_ds["u10"].values[[i - 1, i]][None]),
        "10v": torch.from_numpy(surf_vars_ds["v10"].values[[i - 1, i]][None]),
        "msl": torch.from_numpy(surf_vars_ds["msl"].values[[i - 1, i]][None]),
    },
    static_vars={
        # Since static variables don’t change with time, we use only the first time slice.
        "z": torch.from_numpy(static_vars_ds["z"].values[0]),
        "slt": torch.from_numpy(static_vars_ds["slt"].values[0]),
        "lsm": torch.from_numpy(static_vars_ds["lsm"].values[0]),
    },
    atmos_vars={
        # Atmospheric data also uses consecutive time steps.
        "t": torch.from_numpy(atmos_vars_ds["t"].values[[i - 1, i]][None]),
        "u": torch.from_numpy(atmos_vars_ds["u"].values[[i - 1, i]][None]),
        "v": torch.from_numpy(atmos_vars_ds["v"].values[[i - 1, i]][None]),
        "q": torch.from_numpy(atmos_vars_ds["q"].values[[i - 1, i]][None]),
        "z": torch.from_numpy(atmos_vars_ds["z"].values[[i - 1, i]][None]),
    },
    metadata=Metadata(
        # Metadata includes latitude, longitude, time, and atmospheric levels.
        lat=torch.from_numpy(surf_vars_ds.latitude.values),
        lon=torch.from_numpy(surf_vars_ds.longitude.values),
        # The time is converted from datetime64 to Python datetime objects.
        time=(surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[i],),
        atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
    ),
)


## Loading and Running the Model

Finally, we are ready to load and run the model and visualise the predictions. We perform a roll-out for two steps, which produces predictions for hours 12:00 and 18:00.

In [9]:
import torch
print(torch.cuda.is_available())
print(torch.version.cuda)

False
None


In [12]:
#TPU- CUDA
import torch_xla.core.xla_model as xm
device = xm.xla_device()  # 获取 TPU 设备
model = model.to(device)

In [13]:
# If GPU, run
from aurora import Aurora, rollout


model = Aurora(use_lora=False)  # The pretrained version does not use LoRA.
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")

model.eval()
model = model.to("cuda")

with torch.inference_mode():
    preds = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]

model = model.to("cpu")

AssertionError: Torch not compiled with CUDA enabled

In [16]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.2"  # 尝试只预留 20% 的 TPU 内存

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()


from aurora import Aurora, rollout

model = Aurora(use_lora=False)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model.eval()
model = model.to(device)


RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 4.00M. That was not possible. There are 16.0K free.; (0x0x0_HBM0)

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 2, figsize=(12, 6.5))

for i in range(ax.shape[0]):
    pred = preds[i]

    ax[i, 0].imshow(pred.surf_vars["2t"][0, 0].numpy() - 273.15, vmin=-50, vmax=50)
    ax[i, 0].set_ylabel(str(pred.metadata.time[0]))
    if i == 0:
        ax[i, 0].set_title("Aurora Prediction")
    ax[i, 0].set_xticks([])
    ax[i, 0].set_yticks([])

    ax[i, 1].imshow(surf_vars_ds["t2m"][2 + i].values - 273.15, vmin=-50, vmax=50)
    if i == 0:
        ax[i, 1].set_title("ERA5")
    ax[i, 1].set_xticks([])
    ax[i, 1].set_yticks([])

plt.tight_layout()