In [None]:
%load_ext jupyter_black
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys

sys.path.append("../")


In [None]:
from mesoscaler._metadata import get_metadata

In [1]:
from __future__ import annotations

import os
import types

import pyproj
import xarray as xr
import pyresample.kd_tree

from mesoscaler.typing import Any
from mesoscaler.enums import (
    URMA,
    ERA5,
    Dimensions,
    Coordinates,
    X,
    Y,
    Z,
    T,
    LAT,
    LON,
    LVL,
    TIME,
    COORDINATES,
    DIMENSIONS,
)
_test_data = "../tests/data"

urma_store = os.path.join(_test_data, "urma.zarr")
urma_dvars = list(URMA)

era5_store = os.path.join(_test_data, "era5.zarr")
era5_dvars = list(ERA5)

print(era5_dvars, urma_dvars, sep="\n")
# URMAEnum.dvarsfrom 


[geopotential, specific_humidity, temperature, u_component_of_wind, v_component_of_wind, vertical_velocity]
[total_cloud_cover, ceiling, u_wind_component_10m, v_wind_component_10m, wind_speed_10m, wind_speed_gust, wind_direction_10m, temperature_2m, dewpoint_temperature_2m, specific_humidity_2m, surface_pressure, visibility, orography]


In [41]:
import pandas as pd




[{'hash': 5931780838539,
  'name': 'Dimensions',
  '__mesometa_cls_data__': {},
  '__mesometa_loc__': <mesoscaler.generic.Loc at 0x7fe1e2ecefe0>,
  '__mesometa_series__': {'T': T, 'Z': Z, 'Y': Y, 'X': X},
  '__mesometa_member_data__': {'T': {}, 'Z': {}, 'Y': {}, 'X': {}},
  '__mesometa_member_aliases__': [{'T': 'time',
    'Z': 'z',
    'Y': 'grid_latitude',
    'X': 'longitude'},
   {'T': 't', 'Z': 'altitude', 'Y': 'y', 'X': 'grid_longitude'},
   {'T': None, 'Z': 'height', 'Y': 'latitude', 'X': 'x'},
   {'T': None, 'Z': 'level', 'Y': None, 'X': None}]},
 {'hash': 5931780838778,
  'name': 'Coordinates',
  '__mesometa_cls_data__': {},
  '__mesometa_loc__': <mesoscaler.generic.Loc at 0x7fe1e2ecf670>,
  '__mesometa_series__': {'time': time,
   'vertical': vertical,
   'latitude': latitude,
   'longitude': longitude},
  '__mesometa_member_data__': {'time': {'axis': (T,)},
   'vertical': {'axis': (Z,)},
   'latitude': {'axis': (Y, X)},
   'longitude': {'axis': (Y, X)},
   'axis': {}},
  '__

In [None]:
_URMA_DATASET = xr.open_zarr(urma_store)
_URMA_DATASET

In [None]:
if not os.path.exists(era5_store):
    _google_store = "gs://weatherbench2/datasets/era5/1959-2022-full_37-1h-0p25deg-chunk-1.zarr-v2"
    xr.open_zarr(_google_store)[era5_dvars].sel(time=_URMA_DATASET.time).to_zarr(era5_store, mode="w")

_ERA5_DATASET = xr.open_zarr(era5_store)
_ERA5_DATASET

In [None]:
from mesoscaler.core import make_independent


def get_urma() -> xr.Dataset:
    return make_independent(_URMA_DATASET.copy())


def get_era5() -> xr.Dataset:
    return make_independent(_ERA5_DATASET.copy())


urma = get_urma()
urma

In [None]:
era5 = get_era5()
era5

In [None]:
from mesoscaler.core import GriddedDataset

ds = GriddedDataset.from_zarr(urma_store, [URMA.CEIL, URMA.VIS])
print(ds.crs.is_geocentric, ds.crs.is_geographic, ds.crs.is_projected)
ds = GriddedDataset.from_zarr(era5_store, [ERA5.T])
print(ds.crs.is_geocentric, ds.crs.is_geographic, ds.crs.is_projected)

In [None]:
import numpy as np

urma = GriddedDataset.from_zarr(urma_store, [URMA.CEIL, URMA.VIS])
era5 = GriddedDataset.from_zarr(era5_store, [ERA5.T])

In [None]:
# import functools


# class ReSampler:
#     def __init__(
#         self,
#         width=80,
#         height=80,
#         *,
#         ids: IndependentDataset,
#     ) -> None:
#         md = ids.metadata
#         self._source_definition = pyresample.geometry.AreaDefinition(
#             area_id=f"{md.title} Area Definition",
#             description=md.comment,
#             proj_id=md.crs.name,
#             projection=md.crs,
#             height=ids.y.size,
#             width=ids.y.size,
#             area_extent=ids.area_extent,
#             # lons=lons,
#             # lats=lats,
#         )

#         self._target_definition = functools.partial(
#             pyresample.geometry.AreaDefinition,
#             area_id="lambert_azimuthal_equal_area",
#             description="lambert_azimuthal_equal_area",
#             proj_id="lambert_azimuthal_equal_area",
#             width=width,
#             height=height,
#         )

#     def get_target_definition(
#         self, latitude: float, longitude: float, area_extent: list[float]
#     ) -> pyresample.geometry.AreaDefinition:
#         crs = pyproj.CRS.from_cf(
#             {
#                 "grid_mapping_name": "lambert_azimuthal_equal_area",
#                 "latitude_of_projection_origin": latitude,
#                 "longitude_of_projection_origin": longitude,
#                 "units": "m",
#             }
#         )

#         return self._target_definition(projection=crs, area_extent=area_extent)


# def get_grid_definition(self: IndependentDataset) -> pyresample.geometry.GridDefinition:
#     lons = self.lons.to_numpy()
#     if is_greenwich(ds.lons.to_numpy()):
#         lons = (lons - 180) % 360 - 180

#     return pyresample.geometry.GridDefinition(lons, self.lats)


# ds = IndependentDataset.from_zarr(urma_store, URMAEnum("t2m", "d2m"))
# lons = (ds.lons.to_numpy() - 180) % 360 - 180

# from src.mesoformer.typing import Boolean


# def is_greenwich(lons: np.ndarray) -> Boolean:
#     x = lons - (lons - 180) % 180
#     return np.all(x == 180.0)


# # print(
# #     ds.lons.to_numpy(),
# #     is_greenwich(ds.lons.to_numpy()),
# #     is_greenwich(360),
# #     # ds.lons.to_numpy(),
# #     # # ds.lons.to_numpy()[0],
# #     # lons,
# #     # np.unique(lons - (lons - 180) % 360 - 180),
# #     sep="\n==\n",
# #     # ds.lats.shape,
# # )
# # # get_grid_definition(IndependentDataset.from_zarr(urma_store, URMAEnum("t2m", "d2m")))


# # print(
# #     f"""\
# # {is_greenwich(0)=}
# # {is_greenwich(180)=}
# # {is_greenwich(360)=}
# # {is_greenwich(40)=}
# # """
# # )
# def is_geodetic(crs: pyproj.CRS) -> Boolean:
#     return crs.is_geographic


# print(
#     "\n=========\n".join(
#         f"""
# {crs.name=}
# area_of_use:\n{crs.area_of_use or 'undefined'}
# {is_geodetic(crs)=}
# """
#         for crs in (ERA5Enum.crs, URMAEnum.crs)
#     )
# )

In [None]:
# from src.mesoformer.typing import StrPath, Sequence, DictStrAny, Mapping, Hashable, Iterable
# from src.mesoformer.datasets.metadata import MetadataMixin
# from src.mesoformer.datasets.metadata import T, Z, X, Y, LON, LAT, LVL, TIME, DatasetMetadata
# import functools

# VariableLike = type[CFDatasetEnum] | CFDatasetEnum | Sequence[CFDatasetEnum]


# class PartialAreaDefinitions:
#     def __init__(
#         self,
#         width=256,
#         height=256,
#         *,
#         metadata: DatasetMetadata,
#         latitude: float = 25.0,
#         longitude: float = 265.0,
#     ) -> None:
#         self._source = None

#         self._target = functools.partial(
#             pyresample.geometry.AreaDefinition,
#             area_id="lambert_azimuthal_equal_area",
#             description="lambert_azimuthal_equal_area",
#             proj_id="lambert_azimuthal_equal_area",
#             projection=pyproj.CRS.from_cf(
#                 {
#                     "grid_mapping_name": "lambert_azimuthal_equal_area",
#                     "latitude_of_projection_origin": latitude,
#                     "longitude_of_projection_origin": longitude,
#                 }
#             ),
#             width=width,
#             height=height,
#         )

#     def __repr__(self) -> str:
#         return f"{self.__class__.__name__}({self._target})"

#     def __call__(self, area_extent: list[float]):
#         tgt = self._target(area_extent=area_extent)
#         return tgt, tgt

#     def iter_definitions(
#         self, area_extents: Iterable[list[float]]
#     ) -> Iterable[tuple[pyresample.geometry.AreaDefinition, pyresample.geometry.AreaDefinition]]:
#         for area_extent in area_extents:
#             yield self(area_extent)


# area_defs = PartialAreaDefinitions(
#     80,
#     80,
#     metadata=DatasetMetadata.from_title("ERA5"),
#     latitude=25.0,
#     longitude=265.0,
# )


# for src, tgt in area_defs.iter_definitions(extents):
#     print(src, tgt)

In [None]:
# class ReSampler:
#     def __init__(
#         self,
#         area_defs: PartialAreaDefinitions,
#     ) -> None:
#         pass

In [None]:
# class MesoDataset(IndependentDataset):
#     cf = types.MappingProxyType(
#         {
#             "geographic_crs_name": "NDFD CONUS 2.5km Lambert Conformal Conic",
#             "projected_crs_name": "NDFD",
#             "semi_major_axis": 6378137.0,
#             "semi_minor_axis": 6356752.31424518,
#             "inverse_flattening": 298.25722356301,
#             "reference_ellipsoid_name": "WGS 84",
#             "longitude_of_prime_meridian": 0.0,
#             "prime_meridian_name": "Greenwich",
#             "horizontal_datum_name": "WGS84",
#             "latitude_of_projection_origin": 20.191999,
#             "longitude_of_projection_origin": 238.445999,
#             "false_easting": 0.0,
#             "false_northing": 0.0,
#         }
#     )

#     def __init__(self, ds: xr.Dataset, dvars: VariableLike) -> None:
#         super().__init__(ds, dvars)

#     #     lons, lats = self.lons.to_numpy(), self.lats.to_numpy()
#     #     area_extent = [lons.min(), lats.min(), lons.max(), lats.max()]
#     #     self._source_definition = pyresample.geometry.AreaDefinition(
#     #         self.cf["geographic_crs_name"],
#     #         "National Digital Forecast Database Grid",
#     #         self.cf["projected_crs_name"],
#     #         self.get_crs("lambert_conformal_conic", standard_parallel=25),
#     #         self.x.size,
#     #         self.y.size,
#     #         area_extent=area_extent,
#     #         lons=lons,
#     #         lats=lats,
#     #     )

#     # def get_source_definition(self) -> pyresample.geometry.AreaDefinition:
#     #     return self._source_definition

#     # def get_target_definition(
#     #     self, latitude: float, longitude: float, width: float, height: float, area_extent: list[float]
#     # ) -> pyresample.geometry.AreaDefinition:
#     #     crs = self.get_crs("lambert_azimuthal_equal_area", latitude=latitude, longitude=longitude)
#     #     return pyresample.geometry.AreaDefinition(
#     #         "target_projection",
#     #         "description",
#     #         None,
#     #         crs,
#     #         width=width,
#     #         height=height,
#     #         area_extent=area_extent,
#     #     )

#     # def get_crs(
#     #     self, grid_mapping_name: str, *, latitude: float | None = None, longitude: float | None = None, **kwargs
#     # ) -> pyproj.CRS:
#     #     origin: dict[str, Any] = {"grid_mapping_name": grid_mapping_name, "units": "m"}
#     #     if latitude is not None:
#     #         origin["latitude_of_projection_origin"] = latitude

#     #     if longitude is not None:
#     #         origin["longitude_of_projection_origin"] = longitude

#     #     return pyproj.CRS.from_cf(self.cf | origin | kwargs)

#     # # @property
#     # # def area_extent(self) -> list[float]:
#     # #     return [
#     # #         self.lons.min(),
#     # #         self.lats.min(),
#     # #         self.lons.max(),
#     # #         self.lats.max(),
#     # #     ]
#     # def _resample_on_center(self, target: pyresample.geometry.AreaDefinition) -> np.ndarray:
#     #     return pyresample.kd_tree.resample_nearest(self.area_definition, self.to_numpy(), target, radius_of_influence=50000)

#     # def resample_on_center(
#     #     self,
#     #     longitude: float,
#     #     latitude: float,
#     #     *,
#     #     width=256,
#     #     height=256,
#     #     dx=100,
#     #     dy=100,
#     #     scale_x=1,
#     #     scale_y=1,
#     #     units="km",
#     # ):
#     #     if units == "km":
#     #         dx *= 1000
#     #         dy *= 1000

#     #     dx /= 2
#     #     dy /= 2

#     #     height *= scale_y
#     #     width *= scale_x

#     #     source = self.get_source_definition()
#     #     data = self.da.to_numpy()

#     #     area_extent = [-dx * scale_x, -dy * scale_y, dx * scale_x, dy * scale_y]
#     #     target = self.get_target_definition(latitude, longitude, width, height, area_extent)

#     #     return pyresample.kd_tree.resample_nearest(source, data, target, radius_of_influence=50000)


# cf = MesoDataset.from_zarr(urma_store, URMAEnum("ceiling", "vis"))
# cf

In [None]:
# ceil = cf.dvars[0]
# print(ceil)
# ceil.crs

In [None]:
# import matplotlib.pyplot as plt


# class Dataset:
#     """
#     # CONUS and Northern Hemisphere Grids

#     https://graphical.weather.gov/docs/ndfdSRS.htm#:~:text=The%20NDFD%20uses%20the%20World%20Geodetic%20System,1984%20%28WGS84%29%20ellipsoid%20for%20its%20horizontal%20datum.

#     Grid Parameter	    CONUS 2.5km
#     Number of Points	2953665
#     Projection Type	    Lambert Conformal
#     Shape of Earth      Sphere
#     Earth Radius	    6371.2 km
#     Number of Points on the parallel	2145
#     Number of Points on the Meridian	1377
#     Latitude1:	20.191999
#     Longitude1:	238.445999
#     u/v vectors relative to:	easterly/northerly
#     Dx	2539.703 m
#     Dy	2539.703 m
#     GRIB2 grid, scan mode	64 (0100)
#     Scan i/x direction	positive
#     Scan j/y direction	positive
#     Consecutive points in	i/x direction
#     Adjacent rows scan in	same direction
#     Mesh Latitude	25
#     Orientation Longitude	265
#     Which Pole is on the Plane	north
#     Is Projection Bi-polar	no
#     Tangent Latitude1	25
#     Tangent Latitude2	25
#     Southern Latitude	-90
#     Southern Longitude	0
#     """

#     cf = types.MappingProxyType(
#         {
#             "geographic_crs_name": "NDFD CONUS 2.5km Lambert Conformal Conic",
#             "projected_crs_name": "NDFD",
#             "semi_major_axis": 6378137.0,
#             "semi_minor_axis": 6356752.31424518,
#             "inverse_flattening": 298.25722356301,
#             "reference_ellipsoid_name": "WGS 84",
#             "longitude_of_prime_meridian": 0.0,
#             "prime_meridian_name": "Greenwich",
#             "horizontal_datum_name": "WGS84",
#             "latitude_of_projection_origin": 20.191999,
#             "longitude_of_projection_origin": 238.445999,
#             "false_easting": 0.0,
#             "false_northing": 0.0,
#         }
#     )

#     def __init__(self, ds: xr.Dataset):
#         self.da = da = ds.to_array().transpose(X, Y, ...)
#         self.lons = lons = (da["longitude"].to_numpy() + 180) % 360 - 180
#         self.lats = lats = da["latitude"].to_numpy()

#         self._source_definition = pyresample.geometry.AreaDefinition(
#             self.cf["geographic_crs_name"],
#             "National Digital Forecast Database Grid",
#             self.cf["projected_crs_name"],
#             self.get_crs("lambert_conformal_conic", standard_parallel=25),
#             da[Y].size,
#             da[X].size,
#             area_extent=self.area_extent,
#             lons=lons,
#             lats=lats,
#         )

#     def get_source_definition(self) -> pyresample.geometry.AreaDefinition:
#         return self._source_definition

#     def get_target_definition(
#         self, latitude: float, longitude: float, width: float, height: float, area_extent: list[float]
#     ) -> pyresample.geometry.AreaDefinition:
#         crs = self.get_crs("lambert_azimuthal_equal_area", latitude=latitude, longitude=longitude)
#         return pyresample.geometry.AreaDefinition(
#             "target_projection",
#             "description",
#             None,
#             crs,
#             width=width,
#             height=height,
#             area_extent=area_extent,
#         )

#     def get_crs(
#         self, grid_mapping_name: str, *, latitude: float | None = None, longitude: float | None = None, **kwargs
#     ) -> pyproj.CRS:
#         origin: dict[str, Any] = {"grid_mapping_name": grid_mapping_name, "units": "m"}
#         if latitude is not None:
#             origin["latitude_of_projection_origin"] = latitude

#         if longitude is not None:
#             origin["longitude_of_projection_origin"] = longitude

#         return pyproj.CRS.from_cf(self.cf | origin | kwargs)

#     @property
#     def area_extent(self) -> list[float]:
#         return [
#             self.lons.min(),
#             self.lats.min(),
#             self.lons.max(),
#             self.lats.max(),
#         ]

#     def resample_on_center(
#         self,
#         longitude: float,
#         latitude: float,
#         *,
#         width=256,
#         height=256,
#         dx=100,
#         dy=100,
#         scale_x=1,
#         scale_y=1,
#         units="km",
#     ):
#         if units == "km":
#             dx *= 1000
#             dy *= 1000

#         dx /= 2
#         dy /= 2

#         height *= scale_y
#         width *= scale_x

#         source = self.get_source_definition()
#         data = self.da.to_numpy()

#         area_extent = [-dx * scale_x, -dy * scale_y, dx * scale_x, dy * scale_y]
#         target = self.get_target_definition(latitude, longitude, width, height, area_extent)

#         return pyresample.kd_tree.resample_nearest(source, data, target, radius_of_influence=50000)


# dataset = Dataset(get_era5().isel(T=0, Z=0))
# data = dataset.resample_on_center(longitude=-89.835, latitude=38.54)
# H, W, C = data.shape

# fig, axes = plt.subplots(1, C, figsize=(20, 5))

# for i in range(C):
#     ax = axes[i]
#     ax.imshow(data[:, :, i], origin="upper", cmap="terrain")
#     ax.set_xticks([])
#     ax.set_yticks([])