# Jax-gcm (physics = optional)

In [1]:
import xarray as xr
import numpy as np
import jax.numpy as jnp
import jax_datetime as jdt

from jcm.model import Model
from jcm.physics_interface import PhysicsState, Physics


W0218 19:07:46.085108    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:46.214886    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:46.264314    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


In [2]:
from jcm.geometry import Geometry
from jcm.model import Model
import numpy as np

# Geometry
TERRAIN = "/home/gmathieu/src/jax-gcm/jcm/data/bc/t30/clim/terrain.nc"
geom = Geometry.from_file(TERRAIN, target_resolution=31, num_levels=8)

# Model (sans physics)
m = Model(geometry=geom, physics=None)

# Extraire les axes nodaux horizontaux
lon_rad = np.array(m.coords.horizontal.nodal_axes[0])
lat_rad = np.array(m.coords.horizontal.nodal_axes[1])

lon_deg = lon_rad * 180/np.pi
lat_deg = lat_rad * 180/np.pi

print("lon_deg:", lon_deg.shape, lon_deg.min(), lon_deg.max())
print("lat_deg:", lat_deg.shape, lat_deg.min(), lat_deg.max())


W0218 19:07:46.459844    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:46.715868    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:46.904284    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:46.955438    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:47.011464    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:47.063133    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:07:47.110500    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. 

lon_deg: (96,) 0.0 356.25
lat_deg: (48,) -57.22536341559415 57.22536341559415


In [3]:
import xarray as xr
import numpy as np

ERA5_PATH = "/home/gmathieu/links/scratch/era5_snapshot/era5_19900501T00.nc"
ds_era = xr.open_dataset(ERA5_PATH, chunks=None)

lat_name = "latitude" if "latitude" in ds_era.dims else "lat"
lon_name = "longitude" if "longitude" in ds_era.dims else "lon"

era = ds_era

# mettre lon en 0..360 (pour matcher lon_deg 0..356.25)
if float(era[lon_name].min()) < 0:
    era = era.assign_coords({lon_name: (era[lon_name] % 360)}).sortby(lon_name)

# lat croissantes
if float(era[lat_name][0]) > float(era[lat_name][-1]):
    era = era.sortby(lat_name)

wanted = [
    "u_component_of_wind",
    "v_component_of_wind",
    "temperature",
    "specific_humidity",
    "surface_pressure",
    "geopotential_at_surface",
]

missing = [v for v in wanted if v not in era.data_vars]
print("missing:", missing)

ds_in = era[wanted].interp(
    {
        lon_name: xr.DataArray(lon_deg, dims=(lon_name,)),
        lat_name: xr.DataArray(lat_deg, dims=(lat_name,)),
    },
    method="linear",
)

print("temperature regridded:", ds_in["temperature"].dims, ds_in["temperature"].shape)
print("surface_pressure regridded:", ds_in["surface_pressure"].dims, ds_in["surface_pressure"].shape)
print("geopotential regridded:", ds_in["geopotential_at_surface"].dims, ds_in["geopotential_at_surface"].shape)



missing: []
temperature regridded: ('hybrid', 'latitude', 'longitude') (137, 48, 96)
surface_pressure regridded: ('latitude', 'longitude') (48, 96)
geopotential regridded: ('latitude', 'longitude') (48, 96)


In [4]:
print(ds_era)
print("data_vars:", list(ds_era.data_vars)[:50])
print("coords:", list(ds_era.coords))


<xarray.Dataset> Size: 3GB
Dimensions:                              (latitude: 721, longitude: 1440,
                                          hybrid: 137)
Coordinates:
  * latitude                             (latitude) float32 3kB 90.0 ... -90.0
  * longitude                            (longitude) float32 6kB 0.0 ... 359.8
  * hybrid                               (hybrid) float32 548B 1.0 2.0 ... 137.0
    time                                 datetime64[ns] 8B ...
Data variables:
    surface_pressure                     (latitude, longitude) float32 4MB ...
    geopotential_at_surface              (latitude, longitude) float32 4MB ...
    u_component_of_wind                  (hybrid, latitude, longitude) float32 569MB ...
    v_component_of_wind                  (hybrid, latitude, longitude) float32 569MB ...
    temperature                          (hybrid, latitude, longitude) float32 569MB ...
    specific_humidity                    (hybrid, latitude, longitude) float32 569MB ...

In [5]:
import xarray as xr
import pandas as pd
import numpy as np

path_full = "/scratch/gmathieu/pred_dino_full_48h_6h.nc"
ds_full = xr.open_dataset(path_full)

model_times = pd.to_datetime(ds_full["time"].values)
print("Model time[0] =", model_times[0])
print("Model time[-1] =", model_times[-1])
print("Ntimes =", len(model_times))
print(model_times)

# check: 0..48h step 6h
dt = (model_times[1:] - model_times[:-1]).to_series().value_counts()
print("Δt counts:\n", dt)



Model time[0] = 2000-01-01 00:00:00
Model time[-1] = 2000-01-03 00:00:00
Ntimes = 9
DatetimeIndex(['2000-01-01 00:00:00', '2000-01-01 06:00:00',
               '2000-01-01 12:00:00', '2000-01-01 18:00:00',
               '2000-01-02 00:00:00', '2000-01-02 06:00:00',
               '2000-01-02 12:00:00', '2000-01-02 18:00:00',
               '2000-01-03 00:00:00'],
              dtype='datetime64[ns]', freq=None)
Δt counts:
 0 days 06:00:00    8
Name: count, dtype: int64


In [6]:
import glob
import xarray as xr
import pandas as pd

# dossier des snapshots ERA5 physiques (1990)
era_dir = "/scratch/gmathieu/era5_snapshots_48h_6h_1990"

files = sorted(glob.glob(f"{era_dir}/era5_*.nc"))
print("N ERA5 files:", len(files))
for f in files[:3]:
    print(" ", f)

# concaténation temporelle (les fichiers ont un time scalaire)
ds_era = xr.open_mfdataset(
    files,
    combine="nested",
    concat_dim="time",
    engine="scipy",   # important car tu as écrit NETCDF3
)

print(ds_era)


N ERA5 files: 9
  /scratch/gmathieu/era5_snapshots_48h_6h_1990/era5_19900501T00.nc
  /scratch/gmathieu/era5_snapshots_48h_6h_1990/era5_19900501T06.nc
  /scratch/gmathieu/era5_snapshots_48h_6h_1990/era5_19900501T12.nc
<xarray.Dataset> Size: 31GB
Dimensions:                              (time: 9, latitude: 721,
                                          longitude: 1440, hybrid: 137)
Coordinates:
  * time                                 (time) datetime64[ns] 72B 1990-05-01...
  * latitude                             (latitude) float32 3kB 90.0 ... -90.0
  * longitude                            (longitude) float32 6kB 0.0 ... 359.8
  * hybrid                               (hybrid) float32 548B 1.0 2.0 ... 137.0
Data variables:
    surface_pressure                     (time, latitude, longitude) float32 37MB dask.array<chunksize=(1, 721, 1440), meta=np.ndarray>
    geopotential_at_surface              (time, latitude, longitude) float32 37MB dask.array<chunksize=(1, 721, 1440), meta=np.ndarr

In [7]:
import numpy as np
import xarray as xr

ds_full = xr.open_dataset("/scratch/gmathieu/pred_dino_full_48h_6h.nc")
ds_dry  = xr.open_dataset("/scratch/gmathieu/pred_dino_only_48h_6h.nc")

era0 = ds_era.isel(time=0)

# ERA5 latitude est souvent décroissante → interp aime mieux croissante
era0_sorted = era0.sortby("latitude")

# interp ERA5 -> grille modèle (attention aux noms de coords)
era_on_model = era0_sorted.interp(
    latitude=ds_full["lat"],
    longitude=ds_full["lon"],
)

# choisir un niveau près-surface ERA5 (IMPORTANT)
# On évite le piège: on détecte si hybrid=1 est surface ou le contraire en regardant la pression moyenne
sp = era_on_model["surface_pressure"]
# approx: le niveau le plus proche surface est celui qui a la PLUS GRANDE pression,
# mais comme on n'a pas p(lev) ici, on fait un choix simple: hybrid=137 est souvent le plus bas sur ERA5.
# => on teste les deux extrêmes et on voit lequel colle aux IC modèle.
lev_candidates = [0, -1]

for lev in lev_candidates:
    T_era = era_on_model["temperature"].isel(hybrid=lev)
    U_era = era_on_model["u_component_of_wind"].isel(hybrid=lev)
    V_era = era_on_model["v_component_of_wind"].isel(hybrid=lev)

    Tm = ds_full["temperature"].isel(time=0, level=0)  # ton choix modèle "near surface"
    Um = ds_full["u_wind"].isel(time=0, level=0)
    Vm = ds_full["v_wind"].isel(time=0, level=0)

    def stats(name, a, b):
        d = (a - b).values
        return name, float(np.nanmean(d)), float(np.nanmax(np.abs(d)))

    print("\n--- ERA5 hybrid index", lev, "---")
    print(stats("T full-ERA", Tm, T_era))
    print(stats("U full-ERA", Um, U_era))
    print(stats("V full-ERA", Vm, V_era))




--- ERA5 hybrid index 0 ---
('T full-ERA', 97.030133873377, 120.07072822037955)
('U full-ERA', -1.8899037084328707, 88.17793343475488)
('V full-ERA', 1.3212530756297551, 43.59059217422228)

--- ERA5 hybrid index -1 ---
('T full-ERA', 7.262376960679911, 58.52748883549745)
('U full-ERA', 0.9234774889629586, 33.24108472937911)
('V full-ERA', -0.45585774967845083, 23.293091631494203)


In [11]:
# interpolation ERA5 -> grille modèle
import xarray as xr

# 1) prends le snapshot ERA5 t0
era0 = ds_era.isel(time=0)

# 2) ERA5 latitude décroissante -> on trie pour interp
if "latitude" in era0.coords:
    era0 = era0.sortby("latitude")

# 3) interp ERA5 -> grille du modèle (lat/lon)
era_interp = era0.interp(
    latitude=ds_full["lat"],
    longitude=ds_full["lon"],
)

print("Done interp:", era_interp.dims)


# exemple temperature lowest level
import numpy as np

mod0 = ds_full["temperature"].isel(time=0, level=0)
era0 = era_interp["temperature"].isel(hybrid=-1)   # near-surface ERA5

diff = (mod0 - era0).values
print("IC diff stats (T, level0 vs ERA5 hybrid=-1):")
print("mean =", float(np.nanmean(diff)))
print("max  =", float(np.nanmax(np.abs(diff))))



IC diff stats (T, level0 vs ERA5 hybrid=-1):
mean = 7.262376960679911
max  = 58.52748883549745


In [12]:
import jax.numpy as jnp

nlev_model = geom.nodal_shape[0]  # 8
nlev_era = ds_in.sizes["hybrid"]  # 137

idx = np.linspace(0, nlev_era - 1, nlev_model).round().astype(int)
print("Selected ERA5 hybrid indices:", idx)

ds8 = ds_in.isel(hybrid=idx)  # garde (8,48,96) en (hybrid,lat,lon)

# Helper: (lev,lat,lon) -> (lev,lon,lat) pour matcher geom.nodal_shape (8,96,48)
def lev_lon_lat(da):
    return jnp.asarray(da.transpose("hybrid", lat_name, lon_name).data).transpose(0,2,1)

u_wind = lev_lon_lat(ds8["u_component_of_wind"])
v_wind = lev_lon_lat(ds8["v_component_of_wind"])
temperature = lev_lon_lat(ds8["temperature"])
specific_humidity = lev_lon_lat(ds8["specific_humidity"])

# 2D champs: (lat,lon) -> (lon,lat)
def lon_lat_2d(da):
    return jnp.asarray(da.transpose(lat_name, lon_name).data).T

surface_pressure = lon_lat_2d(ds8["surface_pressure"])
geopotential = lon_lat_2d(ds8["geopotential_at_surface"])

print("u_wind", u_wind.shape)
print("temperature", temperature.shape)
print("surface_pressure", surface_pressure.shape)
print("geopotential", geopotential.shape)
print("geom.nodal_shape", geom.nodal_shape)


Selected ERA5 hybrid indices: [  0  19  39  58  78  97 117 136]


W0218 19:15:12.148374    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


u_wind (8, 96, 48)
temperature (8, 96, 48)
surface_pressure (96, 48)
geopotential (96, 48)
geom.nodal_shape (8, 96, 48)


W0218 19:15:12.484204    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


In [13]:
import jcm.physics_interface as pi

P0 = 1e5
normalized_surface_pressure = (surface_pressure / P0)[None, ...]  # (1,96,48)

print("geopotential:", geopotential.shape)
print("normalized_surface_pressure:", normalized_surface_pressure.shape)

initial_physics_state = pi.PhysicsState(
    u_wind=u_wind,
    v_wind=v_wind,
    temperature=temperature,
    specific_humidity=specific_humidity,
    geopotential=geopotential,
    normalized_surface_pressure=normalized_surface_pressure,
)
print("PhysicsState OK (with nsp 3D).")


geopotential: (96, 48)
normalized_surface_pressure: (1, 96, 48)
PhysicsState OK (with nsp 3D).


W0218 19:15:13.597850    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:13.650037    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


In [14]:
import xarray as xr
import numpy as np

IC_OUT = "IC_USED.nc"

ds_ic = xr.Dataset(
    data_vars=dict(
        u_wind=(("level","longitude","latitude"), np.asarray(initial_physics_state.u_wind)),
        v_wind=(("level","longitude","latitude"), np.asarray(initial_physics_state.v_wind)),
        temperature=(("level","longitude","latitude"), np.asarray(initial_physics_state.temperature)),
        specific_humidity=(("level","longitude","latitude"), np.asarray(initial_physics_state.specific_humidity)),
        geopotential=(("longitude","latitude"), np.asarray(initial_physics_state.geopotential)),
        normalized_surface_pressure=(("one","longitude","latitude"),
                                     np.asarray(initial_physics_state.normalized_surface_pressure)),
    ),
    coords=dict(
        level=np.arange(geom.nodal_shape[0], dtype=np.int32),
        one=np.array([0], dtype=np.int32),
        longitude=lon_deg,
        latitude=lat_deg,
    ),
    attrs=dict(
        source_era5=ERA5_PATH,
        era5_vertical_indices=str(idx.tolist()),
        P0_used_Pa=100000.0,
        note="IC used for both dino-only (NoPhysics) and dino-full (SpeedyPhysics). Shapes follow model order.",
    )
)

ds_ic.to_netcdf(IC_OUT)
print("WROTE:", IC_OUT)
print(ds_ic["normalized_surface_pressure"].dims, ds_ic["normalized_surface_pressure"].shape)



WROTE: IC_USED.nc
('one', 'longitude', 'latitude') (1, 96, 48)


In [15]:
import jcm.physics_interface as pi
import jax.numpy as jnp
import jcm.physics_interface as pi

class NoPhysics(pi.Physics):
    def compute_tendencies(self, state: pi.PhysicsState, forcing, geometry, date):
        # Tendances nulles (mêmes shapes/dtypes)
        tend = pi.PhysicsTendency(
            u_wind=jnp.zeros_like(state.u_wind),
            v_wind=jnp.zeros_like(state.v_wind),
            temperature=jnp.zeros_like(state.temperature),
            specific_humidity=jnp.zeros_like(state.specific_humidity),
        )
        data = self.get_empty_data(geometry)  # None par défaut
        return tend, data



In [16]:
tend0, data0 = NoPhysics().compute_tendencies(initial_physics_state, forcing=None, geometry=geom, date=None)
print("tend u", tend0.u_wind.shape, tend0.u_wind.dtype, float(tend0.u_wind.max()))
print("data", data0)


tend u (8, 96, 48) float32 0.0
data None


W0218 19:15:15.564777    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:15.624383    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


In [17]:
from jcm.model import Model

model_dry = Model(geometry=geom, physics=NoPhysics())
print("model_dry built:", model_dry)


model_dry built: <jcm.model.Model object at 0x152e509dbbd0>


In [18]:
import numpy as np
print("T0 min/max:", float(np.asarray(initial_physics_state.temperature).min()),
                 float(np.asarray(initial_physics_state.temperature).max()))


T0 min/max: 177.0670623779297 308.121337890625


In [19]:
from jcm.physics.speedy.speedy_physics import SpeedyPhysics

physics_full = SpeedyPhysics()   # paramètres par défaut (ok pour commencer)
model_full = Model(geometry=geom, physics=physics_full)

print("model_full built:", model_full)



model_full built: <jcm.model.Model object at 0x153d02f03dd0>


In [20]:
import os, jcm
from jcm.forcing import ForcingData

JCM_ROOT = os.path.dirname(jcm.__file__)
FORCING = os.path.join(JCM_ROOT, "data/bc/t30/clim/forcing.nc")

forcing = ForcingData.from_file(FORCING, target_resolution=31)
print("forcing:", type(forcing))


  return xr.merge([daily_time_vars, ds_monthly[non_time_vars]])
W0218 19:15:19.184322    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.223506    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.275076    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.325157    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.376174    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


forcing: <class 'jcm.forcing.ForcingData'>


In [21]:


pred_dry = model_dry.run(initial_state=initial_physics_state, forcing=forcing, save_interval=30.0, total_time=30.0)
print("dry run OK")


W0218 19:15:19.524596    3313 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.542248    3313 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.568175    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.694875    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.742538    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.789512    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:19.836400    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. 

dry run OK


In [22]:
pred_full = model_full.run(
    initial_state=initial_physics_state,
    forcing=forcing,
    save_interval=30.0,
    total_time=30.0,
)
print("full run OK")


W0218 19:15:29.397547    3313 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:29.413007    3313 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:29.427675    3313 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:29.442401    3313 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:35.119920    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


full run OK


In [23]:
import numpy as np

ds_full = pred_full.to_xarray()
ds_dry  = pred_dry.to_xarray()

print("ds_full dims:", ds_full.dims)
print("ds_dry  dims:", ds_dry.dims)

time_dim = "time" if "time" in ds_full.dims else None
print("time_dim:", time_dim)

if time_dim:
    for v in ["temperature", "u_wind", "v_wind", "specific_humidity"]:
        if v in ds_full and v in ds_dry:
            a = ds_full[v].isel(time=0).values
            b = ds_dry[v].isel(time=0).values
            print(v, "maxabs diff t0:", float(np.max(np.abs(a-b))))


W0218 19:15:57.550284    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:57.897579    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:57.970212    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:58.004428    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:58.074686    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:58.150515    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:58.185314    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. 

time_dim: time
temperature maxabs diff t0: 0.0
u_wind maxabs diff t0: 0.0
v_wind maxabs diff t0: 0.0
specific_humidity maxabs diff t0: 0.0


W0218 19:15:58.456796    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:15:58.489117    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


In [24]:
import xarray as xr
import numpy as np

ds_ic = xr.Dataset(
    data_vars=dict(
        u_wind=(("level","longitude","latitude"), np.asarray(initial_physics_state.u_wind)),
        v_wind=(("level","longitude","latitude"), np.asarray(initial_physics_state.v_wind)),
        temperature=(("level","longitude","latitude"), np.asarray(initial_physics_state.temperature)),
        specific_humidity=(("level","longitude","latitude"), np.asarray(initial_physics_state.specific_humidity)),
        geopotential=(("longitude","latitude"), np.asarray(initial_physics_state.geopotential)),
        normalized_surface_pressure=(("one","longitude","latitude"), np.asarray(initial_physics_state.normalized_surface_pressure)),
    ),
    coords=dict(
        level=np.arange(geom.nodal_shape[0], dtype=np.int32),
        one=np.array([0], dtype=np.int32),
        longitude=lon_deg,
        latitude=lat_deg,
    ),
    attrs=dict(
        source_era5=ERA5_PATH,
        era5_vertical_indices=str(idx.tolist()),
        P0_used_Pa=100000.0,
        note="IC used for both dino-only (NoPhysics) and dino-full (SpeedyPhysics). Shapes follow model order.",
    )
)
ds_ic.to_netcdf("IC_USED.nc")
print("RE-WROTE: IC_USED.nc")


RE-WROTE: IC_USED.nc


In [26]:
import os, gc
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

out = "/scratch/gmathieu/pred_dino_full_48h_6h.nc"
if os.path.exists(out):
    os.remove(out)

gc.collect()


import os, numpy as np
SCRATCH = os.environ["SCRATCH"]

SAVE = 0.25
TOTAL = 2.25  # pour inclure 48h

pred_full = model_full.run(initial_state=initial_physics_state, forcing=forcing,
                           save_interval=SAVE, total_time=TOTAL)
ds_full = pred_full.to_xarray()

t0 = ds_full["time"].values[0]
t_end = t0 + np.timedelta64(48, "h")
ds_full_48 = ds_full.sel(time=slice(t0, t_end))

print("full times:", ds_full_48["time"].values)
print("delta hours:", float((ds_full_48["time"].values[-1]-t0)/np.timedelta64(1,"h")))
print("unique step h:", np.unique((ds_full_48["time"].values[1:]-ds_full_48["time"].values[:-1]) / np.timedelta64(1,"h")))

out_full = f"{SCRATCH}/pred_dino_full_48h_6h.nc"
enc_full = {v: {"zlib": True, "complevel": 1} for v in ds_full_48.data_vars}
ds_full_48.to_netcdf(out_full, encoding=enc_full)
print("WROTE:", out_full)



full times: ['2000-01-01T00:00:00.000000000' '2000-01-01T06:00:00.000000000'
 '2000-01-01T12:00:00.000000000' '2000-01-01T18:00:00.000000000'
 '2000-01-02T00:00:00.000000000' '2000-01-02T06:00:00.000000000'
 '2000-01-02T12:00:00.000000000' '2000-01-02T18:00:00.000000000'
 '2000-01-03T00:00:00.000000000']
delta hours: 48.0
unique step h: [6.]


W0218 19:26:27.223601    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:26:27.509342    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


WROTE: /scratch/gmathieu/pred_dino_full_48h_6h.nc


In [27]:
pred_dry = model_dry.run(initial_state=initial_physics_state, forcing=forcing,
                         save_interval=SAVE, total_time=TOTAL)
ds_dry = pred_dry.to_xarray()

t0 = ds_dry["time"].values[0]
t_end = t0 + np.timedelta64(48, "h")
ds_dry_48 = ds_dry.sel(time=slice(t0, t_end))

out_dry = f"{SCRATCH}/pred_dino_only_48h_6h.nc"
enc_dry = {v: {"zlib": True, "complevel": 1} for v in ds_dry_48.data_vars}
ds_dry_48.to_netcdf(out_dry, encoding=enc_dry)
print("WROTE:", out_dry)



W0218 19:26:38.569265    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


WROTE: /scratch/gmathieu/pred_dino_only_48h_6h.nc


In [28]:
import numpy as np

print("FULL vs DRY at t=0 (level=0) maxabs:")
for v in ["temperature", "u_wind", "v_wind"]:
    d = (ds_full[v].isel(time=0, level=0) - ds_dry[v].isel(time=0, level=0)).values
    print(v, float(np.nanmax(np.abs(d))))


FULL vs DRY at t=0 (level=0) maxabs:


W0218 19:26:54.390965    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.


temperature 0.0
u_wind 0.0
v_wind 0.0


W0218 19:26:54.719089    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
W0218 19:26:54.764813    3202 sol_gpu_cost_model.cc:102] No SoL config found for device: NVIDIA H100 80GB HBM3 MIG 2g.20gb. Using default config.
