# Inference with FCNv2

## Setup

In [13]:
import numpy as np
import datetime
import os
import matplotlib.pyplot as plt
import json

from scipy.signal import periodogram

# Set number of GPUs to use to 1
os.environ["WORLD_SIZE"] = "1"
# Set model registry as a local folder
model_registry = os.path.join(os.path.dirname(os.path.realpath(os.getcwd())), "models")
os.makedirs(model_registry, exist_ok=True)
os.environ["MODEL_REGISTRY"] = model_registry
print(f"{os.environ['MODEL_REGISTRY']}")
# With the enviroment variables set now we import Earth-2 MIP
from earth2mip import registry, inference_ensemble
from earth2mip.initial_conditions import cds
from earth2mip.networks.fcnv2_sm import load as fcnv2_sm_load

/home/workspace/FCN/earth2mip/models


# Run Full Inference

In [2]:
import logging
import os

# Ensure the logs directory exists
os.makedirs("logs", exist_ok=True)

# Configure logging
logging.basicConfig(
    filename=os.path.join("logs", "update_netcdf.log"),
    filemode='w',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.DEBUG
)

# Create a logger object
logger = logging.getLogger()


In [3]:
from earth2mip.schema import EnsembleRun
from earth2mip.inference_ensemble import run_inference
import torch
import numpy as np
import earth2mip.initial_conditions
from earth2mip.schema import Grid, PerturbationStrategy

# Load model(s) from registry
package = registry.get_model("fcnv2_sm")
print("loading FCNv2 small model, this can take a bit")
model = fcnv2_sm_load(package)
perturb = None  # Use default perturbation if not specified
group = None  # Use default torch distributed group if not specified
progress = True  # Show progress bar


loading FCNv2 small model, this can take a bit


In [None]:
model.in_channel_names

In [14]:
from earth2mip.weather_events import WeatherEvent, WeatherEventProperties, Domain, Window, Diagnostic

"""# Define diagnostics
diagnostic = Diagnostic(type="raw", channels=["u100", "u200"], nbins=10)
data_source = cds.DataSource(model.in_channel_names)

# Define domains
window = Window(
    name="global",
    lat_min=-90,
    lat_max=90,
    lon_min=0,
    lon_max=360,
    diagnostics=[diagnostic]
)

# Define weather event properties
weather_event_properties = WeatherEventProperties(
    name="example_event7",
    start_time=datetime.datetime(2023, 5, 21),
    initial_condition_source='era5',
)

# Create WeatherEvent
weather_event = WeatherEvent(
    properties=weather_event_properties,
    domains=[window]
)"""

# EnsembleRun contains settings for the ensemble forecast
config = {
    "ensemble_members": 4,
    "noise_amplitude": 0.05,
    "simulation_length": 10,
    "weather_event": {
        "properties": {
            "name": "Globe",
            "start_time": "2018-06-01 00:00:00",
            "initial_condition_source": "cds",
        },
        "domains": [
            {
                "name": "global",
                "type": "Window",
                "diagnostics": [{"type": "raw", "channels": ["t2m", "u10m"]}],
            }
        ],
    },
    "output_path": "outputs/01_ensemble_notebook",
    "output_frequency": 1,
    "weather_model": "fcnv2_sm",
    "seed": 12345,
    "use_cuda_graphs": False,
    "ensemble_batch_size": 1,
    "autocast_fp16": False,
    "perturbation_strategy": "correlated",
    "noise_reddening": 2.0,
}

In [15]:
config_str = json.dumps(config)
inference_ensemble.main(config_str)

  warn("Distributed manager is already intialized")
2024-05-27 18:27:44,569 INFO Welcome to the CDS
2024-05-27 18:27:44,569 INFO Sending request to https://cds.climate.copernicus.eu/api/v2/resources/reanalysis-era5-single-levels
2024-05-27 18:27:44,656 INFO Request is queued
2024-05-27 18:27:44,765 INFO Welcome to the CDS
2024-05-27 18:27:44,766 INFO Sending request to https://cds.climate.copernicus.eu/api/v2/resources/reanalysis-era5-pressure-levels
2024-05-27 18:27:45,742 INFO Request is running
2024-05-27 18:27:45,925 INFO Request is queued
2024-05-27 18:27:47,323 INFO Request is running
2024-05-27 18:27:47,510 INFO Request is queued
2024-05-27 18:27:49,657 INFO Request is running
2024-05-27 18:27:49,836 INFO Request is queued
2024-05-27 18:27:53,120 INFO Request is running
2024-05-27 18:27:53,292 INFO Request is queued
2024-05-27 18:27:58,263 INFO Request is running
2024-05-27 18:27:58,443 INFO Request is queued
2024-05-27 18:28:05,952 INFO Request is running
2024-05-27 18:28:06,11

# Step 3: Call the run_inference function with the modified config
run_inference(
    model=model,
    config=config,
    perturb=perturb,
    group=group,
    progress=progress,
    data_source=data_source
)

# Process Output Data

In [16]:
import xarray
def open_ensemble(f, domain, chunks={"time": 1}):
    time = xarray.open_dataset(f).time
    root = xarray.open_dataset(f, decode_times=False)
    ds = xarray.open_dataset(f, chunks=chunks, group=domain)
    ds.attrs = root.attrs
    return ds.assign_coords(time=time)


output_path = config["output_path"]
domains = config["weather_event"]["domains"][0]["name"]
ensemble_members = config["ensemble_members"]
ds = open_ensemble(os.path.join(output_path, "ensemble_out_0.nc"), domains)
ds

Unnamed: 0,Array,Chunk
Bytes,348.53 MiB,3.97 MiB
Shape,"(4, 11, 721, 1440)","(2, 1, 361, 720)"
Dask graph,88 chunks in 2 graph layers,88 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 348.53 MiB 3.97 MiB Shape (4, 11, 721, 1440) (2, 1, 361, 720) Dask graph 88 chunks in 2 graph layers Data type float64 numpy.ndarray",4  1  1440  721  11,

Unnamed: 0,Array,Chunk
Bytes,348.53 MiB,3.97 MiB
Shape,"(4, 11, 721, 1440)","(2, 1, 361, 720)"
Dask graph,88 chunks in 2 graph layers,88 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,348.53 MiB,3.97 MiB
Shape,"(4, 11, 721, 1440)","(2, 1, 361, 720)"
Dask graph,88 chunks in 2 graph layers,88 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 348.53 MiB 3.97 MiB Shape (4, 11, 721, 1440) (2, 1, 361, 720) Dask graph 88 chunks in 2 graph layers Data type float64 numpy.ndarray",4  1  1440  721  11,

Unnamed: 0,Array,Chunk
Bytes,348.53 MiB,3.97 MiB
Shape,"(4, 11, 721, 1440)","(2, 1, 361, 720)"
Dask graph,88 chunks in 2 graph layers,88 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
