In [1]:
import sys
sys.path.append('/home/bernatj/jbayesevt')

import numpy as np
from pathlib import Path
import xarray as xr
import os
import datetime
from earth2mip import inference_ensemble, registry
from earth2mip.networks import get_model
from earth2mip.initial_conditions import cds
from earth2mip.inference_ensemble import run_basic_inference
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from bayesevt._src.data.ics import LocalDataSourceXArray

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
model_name = "pangu_6" # "fcnv2_sm"
model = f"e2mip://{model_name}"
device = "cuda:1"

In [3]:
time_loop  = get_model(
    model=model,
    device=device,
)
channel_names = time_loop.in_channel_names

In [4]:
#config 

file_format = 'grib'
outputdir="/home/bernatj/Data/ai-forecasts/fcst/"

#we want to run our model for different dates
t0_i = datetime.datetime(2022,7,1,0)
t0_f = datetime.datetime(2022,7,31,18)
delta_h = 6
# number of forecast steps
num_steps = 4 * 15 # 6h intervals
#vars_to_save = ['t2m','t850','msl','tcwv','z500','u100m','v100m','u10m','v10m']
vars_to_save = ['t2m','z500']

init_times = []
current_time = t0_i
while current_time <= t0_f:
    init_times.append(current_time)
    current_time += datetime.timedelta(hours=delta_h)

In [5]:
for t0 in init_times:
    print(t0)

2022-07-01 00:00:00
2022-07-01 06:00:00
2022-07-01 12:00:00
2022-07-01 18:00:00
2022-07-02 00:00:00
2022-07-02 06:00:00
2022-07-02 12:00:00
2022-07-02 18:00:00
2022-07-03 00:00:00
2022-07-03 06:00:00
2022-07-03 12:00:00
2022-07-03 18:00:00
2022-07-04 00:00:00
2022-07-04 06:00:00
2022-07-04 12:00:00
2022-07-04 18:00:00
2022-07-05 00:00:00
2022-07-05 06:00:00
2022-07-05 12:00:00
2022-07-05 18:00:00
2022-07-06 00:00:00
2022-07-06 06:00:00
2022-07-06 12:00:00
2022-07-06 18:00:00
2022-07-07 00:00:00
2022-07-07 06:00:00
2022-07-07 12:00:00
2022-07-07 18:00:00
2022-07-08 00:00:00
2022-07-08 06:00:00
2022-07-08 12:00:00
2022-07-08 18:00:00
2022-07-09 00:00:00
2022-07-09 06:00:00
2022-07-09 12:00:00
2022-07-09 18:00:00
2022-07-10 00:00:00
2022-07-10 06:00:00
2022-07-10 12:00:00
2022-07-10 18:00:00
2022-07-11 00:00:00
2022-07-11 06:00:00
2022-07-11 12:00:00
2022-07-11 18:00:00
2022-07-12 00:00:00
2022-07-12 06:00:00
2022-07-12 12:00:00
2022-07-12 18:00:00
2022-07-13 00:00:00
2022-07-13 06:00:00


In [6]:
def do_forecast(channel_names, file_paths,  pressure_name='isobaricInhPa', engine='netcdf4'):
    
    #get the init data
    data_source_xr = LocalDataSourceXArray(
    channel_names=channel_names,
    file_paths=file_paths,
    pressure_name=pressure_name, 
    name_convention="short_name",
    engine=engine, 
    )

    #run the model
    forecast = run_basic_inference(
    time_loop, 
    n=num_steps, 
    data_source=data_source_xr, 
    time=t0
    )

    return forecast

In [7]:
if file_format == 'grib':
    ending='grib'
    engine='cfgrib'
elif file_format == 'netcdf':
    ending='nc'
    engine='netcdf4'

In [8]:
for t0 in init_times:

    yyyymmddhh = t0.strftime('%Y%m%d%H')
    
    file_paths = [  
    f"/home/bernatj/Data/ai-forecasts/input/{file_format}/{yyyymmddhh}/fcnv2_sl_PGW_multimodel_{yyyymmddhh}.{ending}",
    f"/home/bernatj/Data/ai-forecasts/input/{file_format}/{yyyymmddhh}/fcnv2_pl_PGW_multimodel_{yyyymmddhh}.{ending}"
    ]
    
    #run one forecast
    forecast = do_forecast(channel_names, file_paths, engine=engine)

    #store the data
    os.makedirs(outputdir+'/'+yyyymmddhh, exist_ok=True)
    for var in vars_to_save:
        forecast.sel(channel=var).squeeze().drop_vars('channel').to_dataset(name=var).to_netcdf(outputdir+f'{yyyymmddhh}/{var}_fcnv2_PGW_multimodel_{yyyymmddhh}.nc')

    print(f'finished forecast for init {t0}')

<xarray.DataArray 'u10' (channel: 73, latitude: 721, longitude: 1440)> Size: 303MB
dask.array<concatenate, shape=(73, 721, 1440), dtype=float32, chunksize=(1, 721, 1440), chunktype=numpy.ndarray>
Coordinates:
    number      int64 8B 0
    time        datetime64[ns] 8B 2022-07-01
    step        timedelta64[ns] 8B 00:00:00
    surface     float64 8B ...
  * latitude    (latitude) float64 6kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
    valid_time  datetime64[ns] 8B 2022-07-01
  * longitude   (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
    height      float64 8B 2.0
    dayofyear   int64 8B 182
  * channel     (channel) <U5 1kB 'u10m' 'v10m' 'u100m' ... 'r925' 'r1000'
Attributes: (12/30)
    GRIB_paramId:                             165
    GRIB_dataType:                            an
    GRIB_numberOfPoints:                      1038240
    GRIB_typeOfLevel:                         surface
    GRIB_stepUnits:                           1
    GRIB_stepType:            

FileNotFoundError: [Errno 2] No such file or directory: '/home/bernatj/Data/ai-forecasts/input/netcdf/2022073100/fcnv2_sl_PGW_multimodel_2022073100.nc'

In [17]:
for t0 in init_times:

    yyyymmddhh = t0.strftime('%Y%m%d%H')
    
    file_paths = [  
    f"/home/bernatj/Data/ai-forecasts/input/{file_format}/{yyyymmddhh}/fcnv2_sl_{yyyymmddhh}.{ending}",
    f"/home/bernatj/Data/ai-forecasts/input/{file_format}/{yyyymmddhh}/fcnv2_pl_{yyyymmddhh}.{ending}"
    ]
    
    #run one forecast
    forecast = do_forecast(channel_names, file_paths, engine=engine)

    #store the data
    os.makedirs(outputdir+'/'+yyyymmddhh, exist_ok=True)
    for var in vars_to_save:
        forecast.sel(channel=var).squeeze().drop_vars('channel').to_dataset(name=var).to_netcdf(outputdir+f'{yyyymmddhh}/{var}_fcnv2_{yyyymmddhh}.nc')

    print(f'finished forecast for init {t0}')

<xarray.DataArray 'u10' (channel: 73, latitude: 721, longitude: 1440)> Size: 303MB
dask.array<concatenate, shape=(73, 721, 1440), dtype=float32, chunksize=(1, 721, 1440), chunktype=numpy.ndarray>
Coordinates:
    number      int64 8B 0
    time        datetime64[ns] 8B 2023-10-25
    step        timedelta64[ns] 8B 00:00:00
    surface     float64 8B ...
  * latitude    (latitude) float64 6kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
  * longitude   (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
    valid_time  datetime64[ns] 8B 2023-10-25
  * channel     (channel) <U5 1kB 'u10m' 'v10m' 'u100m' ... 'r925' 'r1000'
Attributes: (12/30)
    GRIB_paramId:                             165
    GRIB_dataType:                            an
    GRIB_numberOfPoints:                      1038240
    GRIB_typeOfLevel:                         surface
    GRIB_stepUnits:                           1
    GRIB_stepType:                            instant
    ...                             