# GraphCast con datos ocenaográficos
---
Modifica el script ``nan``

Modelo usado:
* Random
* Remplazando todas las variables por la SST

Se tuvo en cuenta:
* Zona de estudio
* Escalas espaciales
* Zona de estudio

No se tuvo en cuenta:
* Escalas temporales
* Orden de los niveles verticales

Parámetros:
* Aleatorios
* latent size: ``2``
* gnn msg steps: ``4``
* random_mesh_size: ``9``
  
> **Conclusiones:** 
> Es necesario definir forcings dado que graphcast lo usa más adelante en el modulo autorregresivo.

In [15]:
import copernicusmarine
# Este módulo proporciona un decorador y funciones para crear clases con campos de datos, 
# similar a collections.namedtuple.
import dataclasses
# Este módulo proporciona clases para manipular fechas y horas.
import datetime
# Este módulo proporciona funciones de orden superior y operaciones de datos en general.
import functools
# Este módulo proporciona funciones matemáticas.
import math
# Este módulo proporciona operaciones de coincidencia de expresiones regulares.
import re
# Este módulo proporciona soporte para tipos opcionales.
from typing import Optional
# Este paquete proporciona funcionalidad para trabajar con sistemas de referencia de coordenadas en Cartopy, 
# una biblioteca de Python para dibujar mapas geoespaciales.
import cartopy.crs as ccrs
# Este paquete proporciona funcionalidad para interactuar con Google Cloud Storage, 
# un servicio de almacenamiento en la nube de Google.
from google.cloud import storage
# Este módulo proporciona implementaciones de modelos autoregresivos para series temporales.
# from graphcast import autoregressive
# Este módulo proporciona funcionalidad para realizar operaciones de casting en datos.
from graphcast import casting
# Este módulo proporciona funcionalidad para el control de puntos de control 
# durante la ejecución de un programa.
from graphcast import checkpoint
# Este módulo proporciona utilidades para el manejo de datos en Graphcast.
# from graphcast import data_utils
# Este módulo proporciona la implementación principal de Graphcast.
# from graphcast import graphcast
# Este módulo proporciona funcionalidad para normalizar datos en Graphcast.
from graphcast import normalization
# Este módulo proporciona funcionalidad para realizar operaciones de rollout en Graphcast.
from graphcast import rollout
# Este módulo proporciona funcionalidad para trabajar con arreglos multidimensionales etiquetados 
# y compatibles con JAX.
from graphcast import xarray_jax
# Este módulo proporciona funcionalidad para trabajar con estructuras de árbol 
# de arreglos multidimensionales etiquetados.
from graphcast import xarray_tree
# Esta clase proporciona funcionalidad para mostrar HTML en el contexto de IPython.
from IPython.display import HTML
# Este paquete proporciona interactividad en el notebook de IPython.
import ipywidgets as widgets
# Este paquete proporciona una biblioteca para construir redes neuronales en JAX.
import haiku as hk
# Este paquete proporciona funcionalidad para realizar cálculos numéricos de alto rendimiento 
# en dispositivos acelerados por GPU y CPU.
import jax
# Este paquete proporciona funcionalidad para la visualización de datos en Python.
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
# Este paquete proporciona funcionalidad para trabajar con arreglos y matrices en Python.
import numpy as np
# Este paquete proporciona funcionalidad para trabajar con arreglos multidimensionales etiquetados.
import xarray as xr


In [16]:
# Importamos paquete local
import sys, os
folders_lis = os.getcwd().split(os.sep)
repo_path = os.sep.join(folders_lis[:folders_lis.index('PhD_repo') + 1])
sys.path.append(repo_path + '\\src')
from art1_tools import graphcast_newvars
from art1_tools import data_utils_newvars
from art1_tools import autoregressive_newvars
from art1_tools import replace
from art1_tools import interpolation

In [198]:
import importlib
# Reload the module
importlib.reload(graphcast_newvars)
importlib.reload(data_utils_newvars)
importlib.reload(autoregressive_newvars)
importlib.reload(replace)

<module 'art1_tools.replace' from 'c:\\Users\\gcuervo\\OneDrive - Universidad de Las Palmas de Gran Canaria\\Documents\\Doctorado\\PhD_repo\\src\\art1_tools\\replace.py'>

In [18]:
def parse_file_parts(file_name):
    """
    Parse a file name into parts separated by underscores and hyphens.

    Args:
        file_name (str): The name of the file to parse.

    Returns:
        dict: A dictionary containing the parsed parts of the file name. The keys are the parts before the first hyphen,
              and the values are the parts after the first hyphen in each section of the file name.

    Example:
        >>> parse_file_parts("example_part1-part2_part3-part4.txt")
        {'example_part1': 'part2', 'part3': 'part4'}
    """
    return dict(part.split("-", 1) for part in file_name.split("_"))

In [19]:
# Crear un cliente de Google Cloud Storage anónimo.
# gcs_client = storage.Client.create_anonymous_client()

# Obtener el bucket "dm_graphcast" del cliente de Google Cloud Storage.
# gcs_bucket = gcs_client.get_bucket("dm_graphcast")

## Cargamos Modelo random
---

In [20]:
def data_valid_for_model(
        file_name: str, model_config: graphcast_newvars.ModelConfig, task_config: graphcast_newvars.TaskConfig):
    """
    Verificar si los datos del archivo son válidos para el modelo y la tarea configurados.

    Args:
        file_name (str): El nombre del archivo de datos.
        model_config (graphcast.ModelConfig): La configuración del modelo.
        task_config (graphcast.TaskConfig): La configuración de la tarea.

    Returns:
        bool: True si los datos son válidos para el modelo y la tarea configurados, False de lo contrario.

    """
    file_parts = parse_file_parts(file_name.removesuffix(".nc"))
    return (
        model_config.resolution in (0, float(file_parts["res"])) and
        len(task_config.pressure_levels) == int(file_parts["levels"]) and
        (
            ("total_precipitation_6hr" in task_config.input_variables and
             file_parts["source"] in ("era5", "fake")) or
            ("total_precipitation_6hr" not in task_config.input_variables and
             file_parts["source"] in ("hres", "fake"))
        )
    )

## Cargamos un batch de datos
---

In [21]:
dir_data = "C://Users//gcuervo//OneDrive - Universidad de Las Palmas de Gran Canaria//Documents//Doctorado//DB//GraphCast_data//"
atmos_canarias = xr.load_dataset(dir_data + "canarias_atmos_360_181.nc")
atmosvars2drop = set(atmos_canarias.variables.keys()) - set(['2m_temperature', 
                                                        'lat', 'lon', 'level', 
                                                        'time', 'datetime',
                                                        'toa_incident_solar_radiation',
                                                        'land_sea_mask'])
atmos_canarias = atmos_canarias.drop_vars(list(atmosvars2drop))
atmos_canarias



In [22]:
# https://data.marine.copernicus.eu/product/SST_ATL_SST_L4_REP_OBSERVATIONS_010_026/download?dataset=cmems-IFREMER-ATL-SST-L4-REP-OBS_FULL_TIME_SERIE_202012
satelite = "cmems-IFREMER-ATL-SST-L4-REP-OBS_FULL_TIME_SERIE"
min_lon, max_lon = (-20.97, -4.5) # IBI solo va hasta esta lon: -20.97
min_lat, max_lat = (19.55, 34.6)
start_date, end_date = ("2020-12-18", "2020-12-31") 
# Load xarray dataset
SST_SAT = copernicusmarine.open_dataset(dataset_id=satelite,
                                        minimum_longitude=min_lon, 
                                        maximum_longitude=max_lon,
                                        minimum_latitude=min_lat,
                                        maximum_latitude=max_lat,
                                        start_datetime=start_date,
                                        end_datetime=end_date,
                                        )
SST_SAT

Fetching catalog:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching catalog: 100%|██████████| 4/4 [00:19<00:00,  4.94s/it]


INFO - 2024-03-21T16:48:47Z - Dataset version was not specified, the latest one was selected: "202012"
INFO - 2024-03-21T16:48:47Z - Dataset part was not specified, the first one was selected: "default"
INFO - 2024-03-21T16:48:48Z - Service was not specified, the default one was selected: "arco-geo-series"


In [23]:
# Space resolution change
SST_SAT_360x181 = interpolation.resize_lonxlat(SST_SAT, (360, 181))
# Agrega batch
temp_arr = np.expand_dims(SST_SAT_360x181['analysed_sst'].to_numpy(), axis=0)
# Replace NaN values with min value of entire array
# temp_arr[np.isnan(temp_arr)] = np.nanmin(temp_arr)
example_batch = replace.replace_atmos_by_ocean(atmos_canarias, temp_arr)
example_batch = (example_batch.isel(level=0)
                 .rename_vars({"2m_temperature": 'analysed_sst'})
                 )
example_batch.level.data = np.array(0)
example_batch = example_batch.assign_coords(level=0).expand_dims('level')

In [24]:
example_batch

In [25]:
def scale(
    data: xr.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:
    """
    Escala los datos para la visualización.

    Args:
        data (xr.Dataset): El conjunto de datos a escalar.
        center (Optional[float], optional): El centro para la escala. Por defecto es None.
        robust (bool, optional): Indica si se debe utilizar una escala robusta. Por defecto es False.

    Returns:
        tuple[xr.Dataset, matplotlib.colors.Normalize, str]: Una tupla que contiene los datos escalados, 
        el objeto de normalización y el mapa de colores.

    """
    vmin = np.nanpercentile(data, (2 if robust else 0))
    vmax = np.nanpercentile(data, (98 if robust else 100))
    if center is not None:
        diff = max(vmax - center, center - vmin)
        vmin = center - diff
        vmax = center + diff
    return (data, matplotlib.colors.Normalize(vmin, vmax),
            ("RdBu_r" if center is not None else "viridis"))

def select(
    data: xr.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
) -> xr.Dataset:
    """
    Seleccione y filtre datos de un conjunto de datos Xr.

    Args:
        data (xr.Dataset): El conjunto de datos Xr del cual seleccionar.
        variable (str): El nombre de la variable a seleccionar.
        level (Optional[int], optional): El nivel específico a seleccionar (si corresponde). Por defecto es None.
        max_steps (Optional[int], optional): El número máximo de pasos de tiempo a seleccionar. Por defecto es None.

    Returns:
        xr.Dataset: El conjunto de datos seleccionado.

    """
    data = data[variable]
    if "batch" in data.dims:
        data = data.isel(batch=0)
    if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
        data = data.isel(time=range(0, max_steps))
    if level is not None and "level" in data.coords:
        data = data.sel(level=level)
    return data

In [26]:
def plot_data(
    data: dict[str, xr.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:
    """
    Visualiza los datos en un gráfico de animación.

    Args:
        data (dict[str, xr.Dataset]): Un diccionario de datos para visualizar.
        fig_title (str): El título de la figura.
        plot_size (float, optional): El tamaño de la trama. Por defecto es 5.
        robust (bool, optional): Indica si se debe utilizar una escala robusta. Por defecto es False.
        cols (int, optional): El número de columnas en el diseño de la trama. Por defecto es 4.

    Returns:
        tuple[xr.Dataset, matplotlib.colors.Normalize, str]: Una tupla que contiene los datos escalados, 
        el objeto de normalización y el mapa de colores.

    """
    first_data = next(iter(data.values()))[0]
    max_steps = first_data.sizes.get("time", 1)
    assert all(max_steps == d.sizes.get("time", 1)
               for d, _, _ in data.values())

    cols = min(cols, len(data))
    rows = math.ceil(len(data) / cols)
    # figure = plt.figure(figsize=(plot_size * 2 * cols,
    #                              plot_size * rows))
    figure = plt.figure(figsize=(15, 7))
    figure.suptitle(fig_title, fontsize=16)
    figure.subplots_adjust(wspace=0.05, hspace=0)
    figure.tight_layout()

    images = []
    for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
        ax = figure.add_subplot(rows, cols, i+1,)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title)
        im = ax.imshow(
            plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
            origin="lower", cmap=cmap, aspect='auto')
        plt.colorbar(
            mappable=im,
            ax=ax,
            orientation="vertical",
            pad=0.02,
            aspect=16,
            shrink=0.75,
            cmap=cmap,
            extend=("both" if robust else "neither"))
        images.append(im)

    def update(frame):
            
        """
        Actualiza los datos mostrados en la animación.

        Args:
            frame (int): El número del frame a mostrar.

        """
        if "time" in first_data.dims:
            td = datetime.timedelta(
                microseconds=first_data["time"][frame].item() / 1000)
            figure.suptitle(f"{fig_title}, {td}", fontsize=16)
        else:
            figure.suptitle(fig_title, fontsize=16)
        for im, (plot_data, norm, cmap) in zip(images, data.values()):
            im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

    ani = animation.FuncAnimation(
        fig=figure, func=update, frames=max_steps, interval=250)
    folder = "C:\\Users\\gcuervo\\OneDrive - Universidad de Las Palmas de Gran Canaria\\Documents\\Doctorado\\images\\"
    ani.save(filename=folder + "test_5.gif", writer="pillow")
    plt.close(figure.number)
    return HTML(ani.to_jshtml())

In [27]:
# Tamaño de la trama para la visualización.
plot_size = 7
plot_example_variable = "analysed_sst"
plot_example_level = 0
# Seleccionar los datos de ejemplo para graficar.
data = {
    " ": scale(select(example_batch, plot_example_variable, plot_example_level, 12),
               robust="Robust"),
}

# Título de la figura basado en la variable y nivel seleccionados.
fig_title = plot_example_variable
if "level" in example_batch[plot_example_variable].coords:
    fig_title += f" a {plot_example_level} m"

# Graficar los datos de ejemplo.
plot_data(data, fig_title, plot_size, "Robust")

## Matrices para normalización
---

In [28]:
# Crear un cliente de Google Cloud Storage anónimo.
gcs_client = storage.Client.create_anonymous_client()

# Obtener el bucket "dm_graphcast" del cliente de Google Cloud Storage.
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

In [29]:
# Cargar los datos de diferencias y desviaciones estándar por nivel.
with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xr.load_dataset(f).compute()

# Cargar los datos de media por nivel.
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xr.load_dataset(f).compute()

# Cargar los datos de desviación estándar por nivel.
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xr.load_dataset(f).compute()

In [30]:
folder = "C:\\Users\\gcuervo\\OneDrive - Universidad de Las Palmas de Gran Canaria\\Documents\\Doctorado\\DB\\GraphCast_data\\"
SST_SAT_mean_1982_2021 = xr.load_dataset(folder + "mean_sat_1982_2021.nc")
SST_SAT_std_1982_2021 = xr.load_dataset(folder + "std_sat_1982_2021.nc")
SST_SAT_diff_std_1982_2021 = xr.load_dataset(folder + "diff_std_sat_1982_2021.nc")

In [31]:
def replace_norm_drops_vars(norm_matrices: xr.Dataset,
                            new_value: xr.DataArray,
                            vars_to_replace: xr.Dataset,
                            new_var_name: str
                            ) -> xr.Dataset:
    norm_matrices = replace.replace_norm_matrices(norm_matrices.rename_vars({"2m_temperature": 'analysed_sst'}),
                                                        new_value,
                                                        vars_to_replace
                                                        )
    normvar2drop = set(norm_matrices.variables.keys()) - set([new_var_name,'level',
                                                              'toa_incident_solar_radiation',
                                                              'year_progress', 'year_progress_sin',
                                                              'year_progress_cos', 'day_progress',
                                                              'day_progress_sin', 'day_progress_cos',])
    norm_matrices = norm_matrices.drop_vars(list(normvar2drop))
    norm_matrices = (norm_matrices
                            .assign_coords(level=0)
                            .expand_dims('level'))
    return norm_matrices

In [32]:
diffs_stddev_by_level = replace_norm_drops_vars(diffs_stddev_by_level,
                                                SST_SAT_diff_std_1982_2021["analysed_sst"],
                                                example_batch,
                                                "analysed_sst")
mean_by_level = replace_norm_drops_vars(mean_by_level,
                                        SST_SAT_mean_1982_2021["analysed_sst"],
                                        example_batch,
                                        "analysed_sst")
stddev_by_level = replace_norm_drops_vars(stddev_by_level,
                                          SST_SAT_std_1982_2021["analysed_sst"],
                                          example_batch,
                                          "analysed_sst")

## Confirguración del modelo
---

In [33]:
example_batch

In [211]:
# Obtener la fuente seleccionada en las pestañas de selección.
# Obtener la fuente seleccionada en las pestañas de selección.
source = "Random"
random_mesh_size = 1 # How many times to split each triangle.
random_latent_size = 2 ** 1 #  MLP output layer size (embedding)
random_gnn_msg_steps = 4
# random_levels = 13
resolution = example_batch.lon.data.ptp() / example_batch.lon.size
# Comprobar la fuente seleccionada y configurar el modelo en consecuencia.
if source == "Random":
    # Configurar parámetros para el modelo aleatorio.
    params = None  # Se llenará a continuación
    state = {}
    model_config = graphcast_newvars.ModelConfig(
        resolution=resolution,
        mesh_size=random_mesh_size,
        latent_size=random_latent_size,
        gnn_msg_steps=random_gnn_msg_steps,
        hidden_layers=1,
        radius_query_fraction_edge_length=0.6)
    task_config = graphcast_newvars.TaskConfig(
        input_variables=graphcast_newvars.TASK_SST.input_variables,
        target_variables=graphcast_newvars.TASK_SST.target_variables,
        forcing_variables=graphcast_newvars.TASK_SST.forcing_variables,
        pressure_levels=graphcast_newvars.PRESSURE_LEVELS_SST,
        input_duration=graphcast_newvars.TASK_SST.input_duration,
    )

# Devolver la configuración del modelo.
print(model_config)

# Extraer datos de entrenamiento y evaluación
train_steps = 12
eval_steps = 12
# Extraer datos de entrada, objetivos y forzamientos para entrenamiento.
train_inputs, train_targets, train_forcings = data_utils_newvars.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{train_steps * 6}h"),
    **dataclasses.asdict(task_config))

# Extraer datos de entrada, objetivos y forzamientos para evaluación.
eval_inputs, eval_targets, eval_forcings = data_utils_newvars.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps * 6}h"),
    **dataclasses.asdict(task_config))

# Construir funciones jitted, y posiblemente inicializar pesos aleatorios

def construct_wrapped_graphcast(
        model_config: graphcast_newvars.ModelConfig,
        task_config: graphcast_newvars.TaskConfig):
    """Construye y envuelve el Predictor de GraphCast_graphcast_newvars."""
    # Predictor más profundo de un paso.
    predictor = graphcast_newvars.GraphCast(model_config, task_config)

    # Modifica las entradas/salidas a `graphcast.GraphCast` para manejar la conversión de/a
    # float32 de/a BFloat16.
    predictor = casting.Bfloat16Cast(predictor)

    # Modifica las entradas/salidas a `casting.Bfloat16Cast` para que la conversión de/a
    # BFloat16 ocurra después de aplicar la normalización a las entradas/objetivos.
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level)

    # Envuelve todo para que el modelo de un paso pueda producir trayectorias.
    predictor = autoregressive_newvars.Predictor(
        predictor, gradient_checkpointing=True)
    return predictor

# Transformación de Haiku que envuelve el predictor para ejecutar hacia adelante.
@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
    """Realiza una ejecución hacia adelante del modelo predictor."""
    predictor = construct_wrapped_graphcast(model_config, task_config)
    return predictor(inputs, targets_template=targets_template, forcings=forcings)


# Transformación de Haiku que envuelve la función de pérdida.
@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
    """Calcula la pérdida del modelo predictor."""
    predictor = construct_wrapped_graphcast(model_config, task_config)
    loss, diagnostics = predictor.loss(inputs, targets, forcings)
    return xarray_tree.map_structure(
        lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
        (loss, diagnostics))


def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
    """Calcula los gradientes de la función de pérdida."""
    def _aux(params, state, i, t, f):
        (loss, diagnostics), next_state = loss_fn.apply(
            params, state, jax.random.PRNGKey(0), model_config, task_config,
            i, t, f)
        return loss, (diagnostics, next_state)
    (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
        _aux, has_aux=True)(params, state, inputs, targets, forcings)
    return loss, diagnostics, next_state, grads


def with_configs(fn):
    """Envuelve una función con las configuraciones del modelo y la tarea."""
    return functools.partial(
        fn, model_config=model_config, task_config=task_config)


def with_params(fn):
    """Envuelve una función con los parámetros y el estado."""
    return functools.partial(fn, params=params, state=state)


def drop_state(fn):
    """Elimina el estado de salida de una función."""
    return lambda **kw: fn(**kw)[0]

# Compilar la inicialización de la transformación run_forward con Haiku y JAX JIT.
init_jitted = jax.jit(with_configs(run_forward.init))

# Si los parámetros no están inicializados, inicializa los parámetros y el estado.
if params is None:
    params, state = init_jitted(
        rng=jax.random.PRNGKey(0),
        inputs=train_inputs,
        targets_template=train_targets,
        forcings=train_forcings)

# Compilar la función de pérdida y los gradientes con Haiku, JAX JIT y las configuraciones del modelo y la tarea.
loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))

# Compilar la transformación run_forward con Haiku, JAX JIT y las configuraciones del modelo y la tarea.
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

# Autoregressive rollout (bucle en Python)

# Verificar que la resolución del modelo coincida con la resolución de los datos.
# assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
assert model_config.resolution in (0, atmos_canarias.lon.data.ptp() / eval_inputs.sizes["lon"]), (
    "La resolución del modelo no coincide con la resolución de los datos. Probablemente desee "
    "volver a filtrar la lista de conjuntos de datos y descargar los datos correctos.")

# Imprimir las dimensiones de las entradas, objetivos y forzamientos de evaluación.

# Realizar la predicción a través de un proceso de despliegue autoregresivo en bloques.
# eval_inputs: 6H antes hasta tiempo actual t_o
# eval_targets: tiempo actual t_o hasta 3d despues
# eval_forcings: radiacion solar tiempo actual t_o hasta 3d despues
predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
predictions

# Calcular la pérdida utilizando la función de pérdida autoregresiva compilada
loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)

# Imprimir la pérdida calculada
print("Pérdida:", float(loss))
# Plot predictions

# Tamaño del gráfico
plot_size = 5
plot_pred_max_steps = 12
plot_pred_variable = "2m_temperature"
plot_pred_level = 50
plot_pred_robust = "Robust"
# Determinar el número máximo de pasos de tiempo a graficar
plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps)

# Seleccionar datos para graficar: objetivos, predicciones y diferencia entre objetivos y predicciones
data = {
    "Targets": scale(select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps), robust=plot_pred_robust),
    "Predictions": scale(select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps), robust=plot_pred_robust),
    "Diff": scale((select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps) -
                select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps)),
                robust=plot_pred_robust, center=0),
}

# Título de la figura
fig_title = plot_pred_variable
if "level" in predictions[plot_pred_variable].coords:
    fig_title += f" at {plot_pred_level} hPa"

# Llamar a la función para graficar los datos
msg_steps_dic[n] = [data, fig_title, plot_size, plot_pred_robust]

ModelConfig(resolution=0.08263888888888889, mesh_size=1, latent_size=2, gnn_msg_steps=4, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)




ValueError: Passing a weight that does not correspond to any variable {'10m_u_component_of_wind', 'total_precipitation_6hr', 'mean_sea_level_pressure', '10m_v_component_of_wind', '2m_temperature'}

In [179]:
(targets, target_variables) = data_utils_newvars.mokey1()

In [186]:
target_variables

('analysed_sst',)

In [180]:
targets

In [191]:
list([*target_variables])

['analysed_sst']

In [194]:
tuple("analysed_sst")

('a', 'n', 'a', 'l', 'y', 's', 'e', 'd', '_', 's', 's', 't')

In [196]:
targets[list(target_variables)]

In [117]:
(_grid_lat, _grid_lon, _finest_mesh, _query_radius), (grid_indices, mesh_indices) = graphcast_newvars.mokey1()


In [128]:
_finest_mesh

TriangularMesh(vertices=array([[ 0.49112344,  0.8506508 ,  0.18759246],
       [-0.303531  ,  0.5257311 ,  0.7946544 ],
       [ 0.607062  ,  0.        ,  0.7946544 ],
       ...,
       [-0.74439675,  0.57847726,  0.3335229 ],
       [-0.74606496,  0.5771851 ,  0.33203086],
       [-0.74560386,  0.57638645,  0.33444503]], dtype=float32), faces=array([[      0,  655362,  655364],
       [ 655362,  163842,  655363],
       [ 655364,  655363,  163844],
       ...,
       [2621440,  655361, 2621441],
       [2621439, 2621441,  655359],
       [2621440, 2621441, 2621439]]))

In [127]:
_query_radius

0.001550647895783186

In [118]:
(_grid_lat, _grid_lon, _finest_mesh, _query_radius)

(<xarray.DataArray 'lat' (lat: 181)>
 array([19.75    , 19.831944, 19.913889, 19.995832, 20.077778, 20.159721,
        20.241667, 20.32361 , 20.405556, 20.4875  , 20.569445, 20.651388,
        20.733334, 20.815277, 20.897223, 20.979166, 21.061111, 21.143055,
        21.225   , 21.306944, 21.38889 , 21.470833, 21.552778, 21.634722,
        21.716667, 21.79861 , 21.880556, 21.9625  , 22.044445, 22.126389,
        22.208334, 22.290277, 22.372223, 22.454166, 22.536112, 22.618055,
        22.7     , 22.781944, 22.86389 , 22.945833, 23.027779, 23.109722,
        23.191668, 23.273611, 23.355556, 23.4375  , 23.519444, 23.601389,
        23.683332, 23.765278, 23.847221, 23.929167, 24.01111 , 24.093056,
        24.175   , 24.256945, 24.338888, 24.420834, 24.502777, 24.584723,
        24.666666, 24.748611, 24.830555, 24.9125  , 24.994444, 25.07639 ,
        25.158333, 25.240278, 25.322222, 25.404167, 25.48611 , 25.568056,
        25.65    , 25.731945, 25.813889, 25.895834, 25.977777, 26.059723,
 

In [139]:
grid_nodes_lon, grid_nodes_lat = np.meshgrid(_grid_lon, _grid_lat)
_grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32)
_grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32)

In [119]:
(grid_indices, mesh_indices)

(array([    0,     1,     1, ..., 65158, 65159, 65159]),
 array([2617699, 2617698, 2617699, ..., 1668635, 1668634, 1668635]))

In [88]:
(updated_latent_mesh_nodes, latent_grid_nodes), output_grid_nodes = graphcast_newvars.mokey2()

In [89]:
(updated_latent_mesh_nodes, latent_grid_nodes)

(Traced<ShapedArray(bfloat16[2621442,1,2])>with<DynamicJaxprTrace(level=2/0)>,
 Traced<ShapedArray(bfloat16[65160,1,2])>with<DynamicJaxprTrace(level=2/0)>)

In [133]:
(mesh_nodes_lat, mesh_nodes_lon) = graphcast_newvars.mokey3()


In [134]:
mesh_nodes_lat

array([10.812325, 52.622627, 52.622627, ..., 19.48275 , 19.39209 ,
       19.538795], dtype=float32)

In [135]:
mesh_nodes_lon

array([ 60.     , 120.     ,   0.     , ..., 142.1489 , 142.27306,
       142.29431], dtype=float32)

In [150]:
from graphcast import typed_graph
from graphcast import grid_mesh_connectivity
from art1_tools import model_utils_newvars

_mesh2grid_edge_normalization_factor = (model_config.mesh2grid_edge_normalization_factor)
_spatial_features_kwargs = dict(
        add_node_positions=False,
        add_node_latitude=True,
        add_node_longitude=True,
        add_relative_positions=True,
        relative_longitude_local_coordinates=True,
        relative_latitude_local_coordinates=True,
        )
_num_mesh_nodes = _finest_mesh.vertices.shape[0]
_num_grid_nodes = _grid_lat.shape[0] * _grid_lon.shape[0]
def _init_mesh2grid_graph() -> typed_graph.TypedGraph:
    """Build Mesh2Grid graph."""

    # Create some edges according to how the grid nodes are contained by
    # mesh triangles.
    (grid_indices,
     mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices(
         grid_latitude=_grid_lat,
         grid_longitude=_grid_lon,
         mesh=_finest_mesh)

    # Edges sending info from mesh to grid.
    senders = mesh_indices
    receivers = grid_indices

    # Precompute structural node and edge features according to config options.
    assert mesh_nodes_lat is not None and mesh_nodes_lon is not None
    (senders_node_features, receivers_node_features,
     edge_features) = model_utils_newvars.get_bipartite_graph_spatial_features(
         senders_node_lat=mesh_nodes_lat,
         senders_node_lon=mesh_nodes_lon,
         receivers_node_lat=_grid_nodes_lat,
         receivers_node_lon=_grid_nodes_lon,
         senders=senders,
         receivers=receivers,
         edge_normalization_factor=_mesh2grid_edge_normalization_factor,
         **_spatial_features_kwargs,
     )

    n_grid_node = np.array([_num_grid_nodes])
    n_mesh_node = np.array([_num_mesh_nodes])
    n_edge = np.array([senders.shape[0]])
    grid_node_set = typed_graph.NodeSet(
        n_node=n_grid_node, features=receivers_node_features)
    mesh_node_set = typed_graph.NodeSet(
        n_node=n_mesh_node, features=senders_node_features)
    edge_set = typed_graph.EdgeSet(
        n_edge=n_edge,
        indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
        features=edge_features)
    nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
    edges = {
        typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")):
            edge_set
    }
    mesh2grid_graph = typed_graph.TypedGraph(
        context=typed_graph.Context(n_graph=np.array([1]), features=()),
        nodes=nodes,
        edges=edges)
    return mesh2grid_graph

In [151]:
_mesh2grid_graph_structure = _init_mesh2grid_graph()

In [155]:
def _add_batch_second_axis(data, batch_size):
  # data [leading_dim, trailing_dim]
  assert data.ndim == 2
  ones = jnp.ones([batch_size, 1], dtype=data.dtype)
  return data[:, None] * ones  # [leading_dim, batch, trailing_dim]

In [164]:
(num_surface_vars, task_config, num_atmospheric_vars), num_outputs = graphcast_newvars.mokey4()

In [165]:
(num_surface_vars, task_config, num_atmospheric_vars)

(9,
 TaskConfig(input_variables=('analysed_sst', 'toa_incident_solar_radiation', 'year_progress_sin', 'year_progress_cos', 'day_progress_sin', 'day_progress_cos'), target_variables='analysed_sst', forcing_variables=('toa_incident_solar_radiation', 'year_progress_sin', 'year_progress_cos', 'day_progress_sin', 'day_progress_cos'), pressure_levels=(0,), input_duration='12h'),
 0)

In [None]:
(num_surface_vars + len(task_config.pressure_levels) * num_atmospheric_vars)

In [None]:
num_outputs

In [169]:
ALL_ATMOSPHERIC_VARS = (
    "potential_vorticity",
    "specific_rain_water_content",
    "specific_snow_water_content",
    "geopotential",
    "temperature",
    "u_component_of_wind",
    "v_component_of_wind",
    "specific_humidity",
    "vertical_velocity",
    "vorticity",
    "divergence",
    "relative_humidity",
    "ozone_mass_mixing_ratio",
    "specific_cloud_liquid_water_content",
    "specific_cloud_ice_water_content",
    "fraction_of_cloud_cover",
)

In [166]:
task_config.target_variables

'analysed_sst'

In [172]:
set(task_config.target_variables) - set(ALL_ATMOSPHERIC_VARS)

{'analysed_sst'}

In [173]:
len(set([task_config.target_variables]) - set(ALL_ATMOSPHERIC_VARS))

1

In [161]:
from graphcast import deep_typed_graph_net
_mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet(
        # Require a specific node dimensionaly for the grid node outputs.
        node_output_size=dict(grid_nodes=num_outputs),
        embed_nodes=False,  # Node features already embdded by previous layers.
        embed_edges=True,  # Embed raw features of the mesh2grid edges.
        edge_latent_size=dict(mesh2grid=model_config.latent_size),
        node_latent_size=dict(
            mesh_nodes=model_config.latent_size,
            grid_nodes=model_config.latent_size),
        mlp_hidden_size=model_config.latent_size,
        mlp_num_hidden_layers=model_config.hidden_layers,
        num_message_passing_steps=1,
        use_layer_norm=True,
        include_sent_messages_in_node_update=False,
        activation="swish",
        f32_aggregation=False,
        name="mesh2grid_gnn",
    )

NameError: name 'num_outputs' is not defined

In [158]:
import chex
def _run_mesh2grid_gnn(updated_latent_mesh_nodes: chex.Array,
                         latent_grid_nodes: chex.Array,
                         ) -> chex.Array:
    """Runs the mesh2grid_gnn, extracting the output grid nodes."""

    # Add the structural edge features of this graph. Note we don't need
    # to add the structural node features, because these are already part of
    # the latent state, via the original Grid2Mesh gnn, however, we need
    # the edge ones, because it is the first time we are seeing this particular
    # set of edges.
    batch_size = updated_latent_mesh_nodes.shape[1]

    mesh2grid_graph = _mesh2grid_graph_structure
    global var9
    var9 = mesh2grid_graph
    assert mesh2grid_graph is not None
    mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"]
    grid_nodes = mesh2grid_graph.nodes["grid_nodes"]
    new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes)
    new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes)
    mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid")
    edges = mesh2grid_graph.edges[mesh2grid_key]

    new_edges = edges._replace(
        features=_add_batch_second_axis(
            edges.features.astype(latent_grid_nodes.dtype), batch_size))

    input_graph = mesh2grid_graph._replace(
        edges={mesh2grid_key: new_edges},
        nodes={
            "mesh_nodes": new_mesh_nodes,
            "grid_nodes": new_grid_nodes
        })
    # Run the GNN.
    output_graph = self._mesh2grid_gnn(input_graph)
    global var8
    var8 = output_graph
    output_grid_nodes = output_graph.nodes["grid_nodes"].features

    return output_grid_nodes

In [159]:
_run_mesh2grid_gnn(updated_latent_mesh_nodes=updated_latent_mesh_nodes,
                             latent_grid_nodes=latent_grid_nodes)

In [91]:
output_grid_nodes

Traced<ShapedArray(bfloat16[65160,1,9])>with<DynamicJaxprTrace(level=2/0)>

In [140]:
from typing import Mapping, Optional, Tuple

def stacked_to_dataset(
    stacked_array: xr.Variable,
    template_dataset: xr.Dataset,
    preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
    ) -> xr.Dataset:

  unstack_from_channels_sizes = {}
  var_names = sorted(template_dataset.keys())
  for name in var_names:
    template_var = template_dataset[name]
    if not all(dim in template_var.dims for dim in preserved_dims):
      raise ValueError(
          f"stacked_to_dataset requires all Variables to have {preserved_dims} "
          f"dimensions, but found only {template_var.dims}.")
    unstack_from_channels_sizes[name] = {
        dim: size for dim, size in template_var.sizes.items()
        if dim not in preserved_dims}

  print("unstack_from_channels_sizes: ", unstack_from_channels_sizes)
  channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64)
              for name, unstack_sizes in unstack_from_channels_sizes.items()}
  print("channels: ", channels)
  total_expected_channels = sum(channels.values())
  print("total_expected_channels: ", total_expected_channels)
  found_channels = stacked_array.sizes["channels"]
  print("found_channels: ", found_channels)
  print("stacked_array: ", stacked_array)
  if total_expected_channels != found_channels:
    raise ValueError(
        f"Expected {total_expected_channels} channels but found "
        f"{found_channels}, when trying to convert a stacked array of shape "
        f"{stacked_array.sizes} to a dataset of shape {template_dataset}.")

  data_vars = {}
  index = 0
  for name in var_names:
    template_var = template_dataset[name]
    var = stacked_array.isel({"channels": slice(index, index + channels[name])})
    index += channels[name]
    var = var.unstack({"channels": unstack_from_channels_sizes[name]})
    var = var.transpose(*template_var.dims)
    data_vars[name] = xr.DataArray(
        data=var,
        coords=template_var.coords,
        # This might not always be the same as the name it's keyed under; it
        # will refer to the original variable name, whereas the key might be
        # some alias e.g. temperature_850 under which it should be logged:
        name=template_var.name,
    )
  return type(template_dataset)(data_vars)  # pytype:disable=not-callable,wrong-arg-count

In [141]:
stacked_to_dataset(grid_xarray.variable, targets_template)

unstack_from_channels_sizes:  {'analysed_sst': {'level': 1, 'time': 1}}
channels:  {'analysed_sst': 1}
total_expected_channels:  1
found_channels:  9
stacked_array:  <xarray.Variable (batch: 1, lat: 181, lon: 360, channels: 9)>
xarray_jax.JaxArrayWrapper(Traced<ShapedArray(bfloat16[1,181,360,9])>with<DynamicJaxprTrace(level=2/0)>)


ValueError: Expected 1 channels but found 9, when trying to convert a stacked array of shape Frozen({'batch': 1, 'lat': 181, 'lon': 360, 'channels': 9}) to a dataset of shape <xarray.Dataset>
Dimensions:       (level: 1, batch: 1, time: 1, lat: 181, lon: 360)
Coordinates:
  * time          (time) timedelta64[ns] 06:00:00
  * level         (level) int32 0
  * lon           (lon) float64 145.8 145.8 145.9 146.0 ... 175.3 175.4 175.5
  * lat           (lat) float64 19.75 19.83 19.91 20.0 ... 34.34 34.42 34.5
Dimensions without coordinates: batch
Data variables:
    analysed_sst  (level, batch, time, lat, lon) bfloat16 xarray_jax.JaxArray....

In [171]:
grid_shape, grid_node_outputs = graphcast_newvars.mokey2()

In [172]:
grid_shape

(181, 360)

In [173]:
grid_node_outputs

Traced<ShapedArray(bfloat16[65160,1,9])>with<DynamicJaxprTrace(level=2/0)>

In [177]:
updated_latent_mesh_nodes, latent_grid_nodes = graphcast_newvars.mokey3()

In [178]:
updated_latent_mesh_nodes

Traced<ShapedArray(bfloat16[2621442,1,2])>with<DynamicJaxprTrace(level=2/0)>

In [185]:
updated_latent_mesh_nodes.shape

(2621442, 1, 2)

In [179]:
latent_grid_nodes

Traced<ShapedArray(bfloat16[65160,1,2])>with<DynamicJaxprTrace(level=2/0)>

In [188]:
output_grid_nodes = graphcast_newvars.mokey4()

In [189]:
output_grid_nodes

Traced<ShapedArray(bfloat16[65160,1,9])>with<DynamicJaxprTrace(level=2/0)>

In [192]:
output_graph = graphcast_newvars.mokey5()

In [193]:
output_graph

TypedGraph(context=Context(n_graph=array([1]), features=()), nodes={'mesh_nodes': NodeSet(n_node=array([2621442]), features=Traced<ShapedArray(bfloat16[2621442,1,2])>with<DynamicJaxprTrace(level=2/0)>), 'grid_nodes': NodeSet(n_node=array([65160]), features=Traced<ShapedArray(bfloat16[65160,1,9])>with<DynamicJaxprTrace(level=2/0)>)}, edges={EdgeSetKey(name='mesh2grid', node_sets=('mesh_nodes', 'grid_nodes')): EdgeSet(n_edge=array([195480]), indices=EdgesIndices(senders=array([2617699,  163623, 2617703, ..., 1668634, 1668635, 1668568]), receivers=array([    0,     0,     0, ..., 65159, 65159, 65159])), features=Traced<ShapedArray(bfloat16[195480,1,2])>with<DynamicJaxprTrace(level=2/0)>)})

In [194]:
mesh2grid_graph = graphcast_newvars.mokey6()

In [231]:
mesh2grid_graph._asdict().keys()

dict_keys(['context', 'nodes', 'edges'])

In [230]:
mesh2grid_graph._asdict()["context"]

Context(n_graph=array([1]), features=())

In [276]:
mesh2grid_graph._asdict()["edges"]

{EdgeSetKey(name='mesh2grid', node_sets=('mesh_nodes', 'grid_nodes')): EdgeSet(n_edge=array([195480]), indices=EdgesIndices(senders=array([2617699,  163623, 2617703, ..., 1668634, 1668635, 1668568]), receivers=array([    0,     0,     0, ..., 65159, 65159, 65159])), features=array([[ 1.23402465e-01, -3.09035756e-05,  2.86594667e-02,
         -1.20028340e-01],
        [ 8.86492544e-01, -1.03664892e-03,  1.83647836e-01,
          8.67260877e-01],
        [ 8.95416134e-01, -1.03317351e-03, -7.39653860e-01,
          5.04659442e-01],
        ...,
        [ 5.30097526e-01, -3.77520079e-04,  3.59106283e-01,
         -3.89930663e-01],
        [ 3.29538995e-01, -1.43771732e-04, -1.32130086e-01,
          3.01889995e-01],
        [ 7.36959053e-01, -7.02337771e-04, -5.93913063e-01,
         -4.36320325e-01]]))}

In [277]:
mesh2grid_graph._asdict()["nodes"]

{'grid_nodes': NodeSet(n_node=array([65160]), features=array([[ 0.33791676, -0.8265897 ,  0.56280506],
        [ 0.33791676, -0.8274029 ,  0.5616088 ],
        [ 0.33791676, -0.8282143 ,  0.5604116 ],
        ...,
        [ 0.56640625, -0.9966862 ,  0.08134252],
        [ 0.56640625, -0.9968028 ,  0.07990098],
        [ 0.56640625, -0.99691737,  0.07845904]], dtype=float32)),
 'mesh_nodes': NodeSet(n_node=array([2621442]), features=array([[ 0.18759258,  0.49999997,  0.86602545],
        [ 0.7946544 , -0.50000006,  0.8660254 ],
        [ 0.7946544 ,  1.        ,  0.        ],
        ...,
        [ 0.33352306, -0.78960806,  0.6136115 ],
        [ 0.33203098, -0.7909359 ,  0.611899  ],
        [ 0.33444512, -0.79116285,  0.6116056 ]], dtype=float32))}

In [233]:
mesh2grid_graph.nodes.keys()

dict_keys(['grid_nodes', 'mesh_nodes'])

In [213]:
mesh2grid_graph.nodes['grid_nodes']

NodeSet(n_node=array([65160]), features=array([[ 0.33791676, -0.8265897 ,  0.56280506],
       [ 0.33791676, -0.8274029 ,  0.5616088 ],
       [ 0.33791676, -0.8282143 ,  0.5604116 ],
       ...,
       [ 0.56640625, -0.9966862 ,  0.08134252],
       [ 0.56640625, -0.9968028 ,  0.07990098],
       [ 0.56640625, -0.99691737,  0.07845904]], dtype=float32))

In [235]:
mesh2grid_graph.nodes['grid_nodes'].features.shape

(65160, 3)

In [216]:
mesh2grid_graph.nodes['mesh_nodes']

NodeSet(n_node=array([2621442]), features=array([[ 0.18759258,  0.49999997,  0.86602545],
       [ 0.7946544 , -0.50000006,  0.8660254 ],
       [ 0.7946544 ,  1.        ,  0.        ],
       ...,
       [ 0.33352306, -0.78960806,  0.6136115 ],
       [ 0.33203098, -0.7909359 ,  0.611899  ],
       [ 0.33444512, -0.79116285,  0.6116056 ]], dtype=float32))

In [236]:
mesh2grid_graph.nodes['mesh_nodes'].features.shape

(2621442, 3)

In [259]:
list(mesh2grid_graph.edges.keys())[0]

EdgeSetKey(name='mesh2grid', node_sets=('mesh_nodes', 'grid_nodes'))

In [267]:
mesh2grid_graph.edges[list(mesh2grid_graph.edges.keys())[0]]._asdict().keys()

dict_keys(['n_edge', 'indices', 'features'])

In [266]:
mesh2grid_graph.edges[list(mesh2grid_graph.edges.keys())[0]]._asdict()["features"].shape

(195480, 4)

In [272]:
mesh2grid_graph.edges[list(mesh2grid_graph.edges.keys())[0]]._asdict()["indices"]._asdict().keys()

dict_keys(['senders', 'receivers'])

In [274]:
mesh2grid_graph.edges[list(mesh2grid_graph.edges.keys())[0]]._asdict()["indices"]._asdict()["senders"].shape

(195480,)

In [275]:
mesh2grid_graph.edges[list(mesh2grid_graph.edges.keys())[0]]._asdict()["indices"]._asdict()["receivers"].shape

(195480,)

In [279]:
mesh2grid_graph.nodes["grid_nodes"].features.shape

(65160, 3)

In [222]:
node_ico0 = 12
edges_ico0 = 30
for i in range(1, 10):

    new_edges = edges_ico0 * 2
    new_nodes = 12 + edges_ico0

    node_ico0 = new_nodes
    print(node_ico0)
    edges_ico0 = new_edges

42
72
132
252
492
972
1932
3852
7692


In [225]:
7692 * 37

284604

In [219]:
node_ico0

42