Skip to content

Commit

Permalink
Merge pull request #172 from neutrinoceros/resolve_unchecked_annotations
Browse files Browse the repository at this point in the history
TYP: resolve unchecked annotations
  • Loading branch information
neutrinoceros committed Jan 8, 2023
2 parents b21d5b3 + 900fca0 commit 0a84916
Showing 1 changed file with 53 additions and 22 deletions.
75 changes: 53 additions & 22 deletions yt_idefix/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import ABC, abstractmethod
from functools import cached_property
from pathlib import Path
from typing import Literal
from typing import TYPE_CHECKING, Literal

import inifix
import numpy as np
Expand All @@ -29,6 +29,16 @@
# even though we don't call them directly
from .io import IdefixDmpIO, IdefixVtkIO, PlutoVtkIO # noqa

if TYPE_CHECKING:
# these should really be unyt_array,
# but mypy doesn't recognize it as a valid type as of unyt 2.9.3 and mypy 0.991
XSpans = np.ndarray
YSpans = np.ndarray
ZSpans = np.ndarray
XCoords = np.ndarray
YCoords = np.ndarray
ZCoords = np.ndarray

ytLogger = logging.getLogger("yt")


Expand Down Expand Up @@ -114,14 +124,14 @@ def _get_field_offset_index(self) -> dict[str, int]:

@abstractmethod
@cached_property
def _cell_widths(self):
def _cell_widths(self) -> tuple[XSpans, YSpans, ZSpans]:
# must return a 3-tuple of 1D unyt_array
# with unit "code_length" and dtype float64
...

@abstractmethod
@cached_property
def _cell_centers(self):
def _cell_centers(self) -> tuple[XCoords, YCoords, ZCoords]:
# must return a 3-tuple of 1D unyt_array
# with unit "code_length" and dtype float64
...
Expand Down Expand Up @@ -152,35 +162,48 @@ def _get_field_offset_index(self) -> dict[str, int]:
return self.ds._field_offset_index

@cached_property
def _cell_widths(self):
def _cell_widths(self) -> tuple[XSpans, YSpans, ZSpans]:
with open(self.index_filename, "rb") as fh:
cell_edges = vtk_io.read_grid_coordinates(fh, geometry=self.ds.geometry)

cell_widths: list[np.ndarray] = []
dims = self.ds.domain_dimensions
length_unit = self.ds.quan(1, "code_length")

cell_widths: tuple[XSpans, YSpans, ZSpans]
cell_widths = (
np.empty(max(dims[0], 2), dtype="float64") * length_unit,
np.empty(max(dims[1], 2), dtype="float64") * length_unit,
np.empty(max(dims[2], 2), dtype="float64") * length_unit,
)

for idir, edges in enumerate(cell_edges[:3]):
ncells = self.ds.domain_dimensions[idir]
if ncells > 1:
cell_widths.append(np.ediff1d(edges).astype("float64") * length_unit)
if dims[idir] > 1:
cell_widths[idir][:] = np.ediff1d(edges)
else:
cell_widths.append(np.array([self.ds.domain_width[idir]]) * length_unit)
return tuple(cell_widths)
cell_widths[idir][:] = self.ds.domain_width[idir]
return cell_widths

@cached_property
def _cell_centers(self):
def _cell_centers(self) -> tuple[XCoords, YCoords, ZCoords]:
with open(self.index_filename, "rb") as fh:
cell_edges = vtk_io.read_grid_coordinates(fh, geometry=self.ds.geometry)

cell_centers: list[np.ndarray] = []
dims = self.ds.domain_dimensions
length_unit = self.ds.quan(1, "code_length")

cell_centers: tuple[XCoords, YCoords, ZCoords]
cell_centers = (
np.empty(max(dims[0], 2), dtype="float64") * length_unit,
np.empty(max(dims[1], 2), dtype="float64") * length_unit,
np.empty(max(dims[2], 2), dtype="float64") * length_unit,
)

for idir, edges in enumerate(cell_edges[:3]):
ncells = self.ds.domain_dimensions[idir]
if ncells > 1:
e64 = edges.astype("float64")
cell_centers.append(0.5 * (e64[1:] + e64[:-1]) * length_unit)
if dims[idir] > 1:
cell_centers[idir][:] = 0.5 * (edges[1:] + edges[:-1])
else:
cell_centers.append(np.array([edges[0]]) * length_unit)
return tuple(cell_centers)
cell_centers[idir][:] = edges[0]
return cell_centers


class IdefixDmpHierarchy(IdefixHierarchy):
Expand All @@ -189,16 +212,24 @@ def _get_field_offset_index(self) -> dict[str, int]:
return dmp_io.get_field_offset_index(fh)

@cached_property
def _cell_widths(self):
def _cell_widths(self) -> tuple[XSpans, YSpans, ZSpans]:
_fprops, fdata = dmp_io.read_idefix_dmpfile(self.index_filename, skip_data=True)
length_unit = self.ds.quan(1, "code_length")
return tuple((fdata[f"xr{d}"] - fdata[f"xl{d}"]) * length_unit for d in "123")
return (
(fdata["xr1"] - fdata["xl1"]) * length_unit,
(fdata["xr2"] - fdata["xl2"]) * length_unit,
(fdata["xr3"] - fdata["xl3"]) * length_unit,
)

@cached_property
def _cell_centers(self):
def _cell_centers(self) -> tuple[XCoords, YCoords, ZCoords]:
_fprops, fdata = dmp_io.read_idefix_dmpfile(self.index_filename, skip_data=True)
length_unit = self.ds.quan(1, "code_length")
return tuple(fdata[f"x{d}"] * length_unit for d in "123")
return (
fdata["x1"] * length_unit,
fdata["x2"] * length_unit,
fdata["x3"] * length_unit,
)


class IdefixDataset(Dataset, ABC):
Expand Down

0 comments on commit 0a84916

Please sign in to comment.