diff --git a/yt_idefix/data_structures.py b/yt_idefix/data_structures.py index b583829f..8c1cc6a8 100644 --- a/yt_idefix/data_structures.py +++ b/yt_idefix/data_structures.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from functools import cached_property from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING, Literal import inifix import numpy as np @@ -29,6 +29,16 @@ # even though we don't call them directly from .io import IdefixDmpIO, IdefixVtkIO, PlutoVtkIO # noqa +if TYPE_CHECKING: + # these should really be unyt_array, + # but mypy doesn't recognize it as a valid type as of unyt 2.9.3 and mypy 0.991 + XSpans = np.ndarray + YSpans = np.ndarray + ZSpans = np.ndarray + XCoords = np.ndarray + YCoords = np.ndarray + ZCoords = np.ndarray + ytLogger = logging.getLogger("yt") @@ -114,14 +124,14 @@ def _get_field_offset_index(self) -> dict[str, int]: @abstractmethod @cached_property - def _cell_widths(self): + def _cell_widths(self) -> tuple[XSpans, YSpans, ZSpans]: # must return a 3-tuple of 1D unyt_array # with unit "code_length" and dtype float64 ... @abstractmethod @cached_property - def _cell_centers(self): + def _cell_centers(self) -> tuple[XCoords, YCoords, ZCoords]: # must return a 3-tuple of 1D unyt_array # with unit "code_length" and dtype float64 ... @@ -152,35 +162,48 @@ def _get_field_offset_index(self) -> dict[str, int]: return self.ds._field_offset_index @cached_property - def _cell_widths(self): + def _cell_widths(self) -> tuple[XSpans, YSpans, ZSpans]: with open(self.index_filename, "rb") as fh: cell_edges = vtk_io.read_grid_coordinates(fh, geometry=self.ds.geometry) - cell_widths: list[np.ndarray] = [] + dims = self.ds.domain_dimensions length_unit = self.ds.quan(1, "code_length") + + cell_widths: tuple[XSpans, YSpans, ZSpans] + cell_widths = ( + np.empty(max(dims[0], 2), dtype="float64") * length_unit, + np.empty(max(dims[1], 2), dtype="float64") * length_unit, + np.empty(max(dims[2], 2), dtype="float64") * length_unit, + ) + for idir, edges in enumerate(cell_edges[:3]): - ncells = self.ds.domain_dimensions[idir] - if ncells > 1: - cell_widths.append(np.ediff1d(edges).astype("float64") * length_unit) + if dims[idir] > 1: + cell_widths[idir][:] = np.ediff1d(edges) else: - cell_widths.append(np.array([self.ds.domain_width[idir]]) * length_unit) - return tuple(cell_widths) + cell_widths[idir][:] = self.ds.domain_width[idir] + return cell_widths @cached_property - def _cell_centers(self): + def _cell_centers(self) -> tuple[XCoords, YCoords, ZCoords]: with open(self.index_filename, "rb") as fh: cell_edges = vtk_io.read_grid_coordinates(fh, geometry=self.ds.geometry) - cell_centers: list[np.ndarray] = [] + dims = self.ds.domain_dimensions length_unit = self.ds.quan(1, "code_length") + + cell_centers: tuple[XCoords, YCoords, ZCoords] + cell_centers = ( + np.empty(max(dims[0], 2), dtype="float64") * length_unit, + np.empty(max(dims[1], 2), dtype="float64") * length_unit, + np.empty(max(dims[2], 2), dtype="float64") * length_unit, + ) + for idir, edges in enumerate(cell_edges[:3]): - ncells = self.ds.domain_dimensions[idir] - if ncells > 1: - e64 = edges.astype("float64") - cell_centers.append(0.5 * (e64[1:] + e64[:-1]) * length_unit) + if dims[idir] > 1: + cell_centers[idir][:] = 0.5 * (edges[1:] + edges[:-1]) else: - cell_centers.append(np.array([edges[0]]) * length_unit) - return tuple(cell_centers) + cell_centers[idir][:] = edges[0] + return cell_centers class IdefixDmpHierarchy(IdefixHierarchy): @@ -189,16 +212,24 @@ def _get_field_offset_index(self) -> dict[str, int]: return dmp_io.get_field_offset_index(fh) @cached_property - def _cell_widths(self): + def _cell_widths(self) -> tuple[XSpans, YSpans, ZSpans]: _fprops, fdata = dmp_io.read_idefix_dmpfile(self.index_filename, skip_data=True) length_unit = self.ds.quan(1, "code_length") - return tuple((fdata[f"xr{d}"] - fdata[f"xl{d}"]) * length_unit for d in "123") + return ( + (fdata["xr1"] - fdata["xl1"]) * length_unit, + (fdata["xr2"] - fdata["xl2"]) * length_unit, + (fdata["xr3"] - fdata["xl3"]) * length_unit, + ) @cached_property - def _cell_centers(self): + def _cell_centers(self) -> tuple[XCoords, YCoords, ZCoords]: _fprops, fdata = dmp_io.read_idefix_dmpfile(self.index_filename, skip_data=True) length_unit = self.ds.quan(1, "code_length") - return tuple(fdata[f"x{d}"] * length_unit for d in "123") + return ( + fdata["x1"] * length_unit, + fdata["x2"] * length_unit, + fdata["x3"] * length_unit, + ) class IdefixDataset(Dataset, ABC):