In [1]:
import sys
sys.path.append('/home/bernatj/jbayesevt')

import numpy as np
from pathlib import Path
import xarray as xr
import os
import datetime
from earth2mip import inference_ensemble, registry
from earth2mip.networks import get_model
from earth2mip.initial_conditions import cds
from earth2mip.inference_ensemble import run_basic_inference
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from bayesevt._src.data.ics import LocalDataSourceXArray

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
model_name = "fcnv2_sm" # "fcn" # 
model = f"e2mip://{model_name}"
device = "cuda:1"

In [3]:
time_loop  = get_model(
    model=model,
    device=device,
)
channel_names = time_loop.in_channel_names

In [4]:
#config 

file_format = 'netcdf'
outputdir="/home/bernatj/Data/ai-forecasts/fcst/"

#we want to run our model for different dates
t0_i = datetime.datetime(2018,9,9,0)
t0_f = datetime.datetime(2018,9,12,18)
delta_h = 6
# number of forecast steps
num_steps = 4 * 15 # 6h intervals

init_times = []
current_time = t0_i
while current_time <= t0_f:
    init_times.append(current_time)
    current_time += datetime.timedelta(hours=delta_h)

#vars_to_save = ['t2m', 'tcwv','msl', 'v100m', 'u100m', 'u850', 'v850', 't850', 'z500']
vars_to_save = ['tcwv']

models = ['awi-cm-1-1-mr', 'bcc-csm2-mr', 'cams-csm1-0', 'canesm5-1', 'cas-esm2-0', 'cmcc-cm2-hr4', \
          'cmcc-cm2-sr5', 'cmcc-esm2', 'ec-earth3-cc', 'ec-earth3-veg-lr', 'ec-earth3-veg','ec-earth3']

In [5]:
for t0 in init_times:
    print(t0)

2018-09-09 00:00:00
2018-09-09 06:00:00
2018-09-09 12:00:00
2018-09-09 18:00:00
2018-09-10 00:00:00
2018-09-10 06:00:00
2018-09-10 12:00:00
2018-09-10 18:00:00
2018-09-11 00:00:00
2018-09-11 06:00:00
2018-09-11 12:00:00
2018-09-11 18:00:00
2018-09-12 00:00:00
2018-09-12 06:00:00
2018-09-12 12:00:00
2018-09-12 18:00:00


In [6]:
def do_forecast(channel_names, file_paths,  pressure_name='isobaricInhPa', engine='netcdf4'):
    
    #get the init data
    data_source_xr = LocalDataSourceXArray(
    channel_names=channel_names,
    file_paths=file_paths,
    pressure_name=pressure_name, 
    name_convention="short_name",
    engine=engine, 
    )

    #run the model
    forecast = run_basic_inference(
    time_loop, 
    n=num_steps, 
    data_source=data_source_xr, 
    time=t0
    )

    return forecast

In [7]:
if file_format == 'grib':
    ending='grib'
    engine='cfgrib'
elif file_format == 'netcdf':
    ending='nc'
    engine='netcdf4'

In [8]:
for model in models:
    for t0 in init_times:

        yyyymmddhh = t0.strftime('%Y%m%d%H')

        file_paths = [  
        f"/home/bernatj/Data/ai-forecasts/input/{file_format}/{yyyymmddhh}/fcnv2_sl_PGW_{model}_{yyyymmddhh}.{ending}",
        f"/home/bernatj/Data/ai-forecasts/input/{file_format}/{yyyymmddhh}/fcnv2_pl_PGW_{model}_{yyyymmddhh}.{ending}"
        ]

        #run one forecast
        forecast = do_forecast(channel_names, file_paths, engine=engine)

        #store the data
        os.makedirs(outputdir+'/'+yyyymmddhh, exist_ok=True)
        for var in vars_to_save:
            forecast.sel(channel=var).squeeze().drop_vars('channel').to_dataset(name=var).to_netcdf(outputdir+f'{yyyymmddhh}/{var}_fcnv2_PGW_{model}_{yyyymmddhh}.nc')

        print(f'finished forecast for init {t0}')

<xarray.DataArray 'u10' (channel: 73, latitude: 721, longitude: 1440)> Size: 303MB
dask.array<concatenate, shape=(73, 721, 1440), dtype=float32, chunksize=(1, 721, 1440), chunktype=numpy.ndarray>
Coordinates:
    number      int64 8B 0
    time        datetime64[ns] 8B 2018-09-09
    step        timedelta64[ns] 8B 00:00:00
    surface     float64 8B ...
  * latitude    (latitude) float64 6kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
    valid_time  datetime64[ns] 8B 2018-09-09
  * longitude   (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
    height      float64 8B 2.0
    dayofyear   int64 8B 252
  * channel     (channel) <U5 1kB 'u10m' 'v10m' 'u100m' ... 'r925' 'r1000'
Attributes: (12/30)
    GRIB_paramId:                             165
    GRIB_dataType:                            an
    GRIB_numberOfPoints:                      1038240
    GRIB_typeOfLevel:                         surface
    GRIB_stepUnits:                           1
    GRIB_stepType:            