diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index 4fb9cb0d91..b5a304f2f7 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -47,7 +47,21 @@ def decode_bytes_array(array_of_bytes: Numpy) -> List[str]: return list_of_str -""" Base Classes """ +""" xarray subclasses """ + + +class Tidy3dDataArray(xr.DataArray): + """Subclass of xarray's DataArray that implements some custom functions.""" + + __slots__ = () + + @property + def abs(self): + """Absolute value of complex-valued data.""" + return abs(self) + + +""" Base data classes """ class Tidy3dData(Tidy3dBaseModel): @@ -63,7 +77,7 @@ class Config: # pylint: disable=too-few-public-methods json_encoders = { # how to write certain types to json files np.ndarray: numpy_encoding, # use custom encoding defined in .types np.int64: lambda x: int(x), # pylint: disable=unnecessary-lambda - xr.DataArray: lambda x: None, # dont write + Tidy3dDataArray: lambda x: None, # dont write xr.Dataset: lambda x: None, # dont write } @@ -106,7 +120,7 @@ class MonitorData(Tidy3dData, ABC): """ @property - def data(self) -> xr.DataArray: + def data(self) -> Tidy3dDataArray: # pylint:disable=line-too-long """Returns an xarray representation of the montitor data. @@ -120,7 +134,7 @@ def data(self) -> xr.DataArray: data_dict = self.dict() coords = {dim: data_dict[dim] for dim in self._dims} - return xr.DataArray(self.values, coords=coords) + return Tidy3dDataArray(self.values, coords=coords) def __eq__(self, other) -> bool: """Check equality against another MonitorData instance. @@ -204,8 +218,8 @@ def data(self) -> xr.Dataset: data_arrays = {name: arr.data for name, arr in self.data_dict.items()} # make an xarray dataset - # return xr.Dataset(data_arrays) # datasets are annoying - return data_arrays + return xr.Dataset(data_arrays) # datasets are annoying + # return data_arrays def __eq__(self, other): """Check for equality against other :class:`CollectionData` object.""" @@ -488,9 +502,11 @@ class SimulationData(Tidy3dBaseModel): @property def log(self): """Prints the server-side log.""" - print(self.log_string if self.log_string else "no log stored") + if not self.log_string: + raise DataError("No log stored in SimulationData.") + return self.log_string - def __getitem__(self, monitor_name: str) -> Union[xr.DataArray, xr.Dataset]: + def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]: """Get the :class:`MonitorData` xarray representation by name (``sim_data[monitor_name]``). Parameters @@ -508,75 +524,111 @@ def __getitem__(self, monitor_name: str) -> Union[xr.DataArray, xr.Dataset]: raise DataError(f"monitor {monitor_name} not found") return monitor_data.data - # @add_ax_if_none - # def plot_field( - # self, - # field_monitor_name: str, - # field_name: str, - # x: float = None, - # y: float = None, - # z: float = None, - # freq: float = None, - # time: float = None, - # eps_alpha: pydantic.confloat(ge=0.0, le=1.0) = 0.5, - # ax: Ax = None, - # **kwargs, - # ) -> Ax: - # """Plot the field data for a monitor with simulation plot overlayed. - - # Parameters - # ---------- - # field_monitor_name : ``str`` - # Name of :class:`FieldMonitor` or :class:`FieldTimeData` to plot. - # field_name : ``str`` - # Name of `field` in monitor to plot (eg. 'Ex'). - # x : ``float``, optional - # Position of plane in x direction. - # y : ``float``, optional - # Position of plane in y direction. - # z : ``float``, optional - # Position of plane in z direction. - # freq: ``float``, optional - # if monitor is a :class:`FieldMonitor`, specifies the frequency (Hz) to plot the field. - # time: ``float``, optional - # if monitor is a :class:`FieldTimeMonitor`, specifies the time (sec) to plot the field. - # cbar: `bool``, optional - # if True (default), will include colorbar - # ax : ``matplotlib.axes._subplots.Axes``, optional - # matplotlib axes to plot on, if not specified, one is created. - # **patch_kwargs - # Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``. - - # Returns - # ------- - # ``matplotlib.axes._subplots.Axes`` - # The supplied or created matplotlib axes. - - # TODO: fully test and finalize arguments. - # """ - - # if field_monitor_name not in self.monitor_data: - # raise DataError(f"field_monitor_name {field_monitor_name} not found in SimulationData.") - - # monitor_data = self.monitor_data.get(field_monitor_name) - - # if not isinstance(monitor_data, FieldData): - # raise DataError(f"field_monitor_name {field_monitor_name} not a FieldData instance.") - - # if field_name not in monitor_data.data_dict: - # raise DataError(f"field_name {field_name} not found in {field_monitor_name}.") - - # xr_data = monitor_data.data_dict.get(field_name) - # if isinstance(monitor_data, FieldData): - # field_data = xr_data.sel(f=freq) - # else: - # field_data = xr_data.sel(t=time) - - # ax = field_data.sel(x=x, y=y, z=z).real.plot.pcolormesh(ax=ax) - # ax = self.simulation.plot_structures_eps( - # freq=freq, cbar=False, x=x, y=y, z=z, alpha=eps_alpha, ax=ax - # ) - # return ax + @add_ax_if_none + def plot_field( + self, + field_monitor_name: str, + field_name: str, + x: float = None, + y: float = None, + z: float = None, + val: Literal["real", "imag", "abs"] = "real", + freq: float = None, + time: float = None, + cbar: bool = None, + eps_alpha: float = 0.2, + ax: Ax = None, + **kwargs, + ) -> Ax: + """Plot the field data for a monitor with simulation plot overlayed. + + Parameters + ---------- + field_monitor_name : str + Name of :class:`FieldMonitor` or :class:`FieldTimeData` to plot. + field_name : str + Name of `field` in monitor to plot (eg. 'Ex'). + x : float = None + Position of plane in x direction. + y : float = None + Position of plane in y direction. + z : float = None + Position of plane in z direction. + val : Literal['real', 'imag', 'abs'] = 'real' + What part of the field to plot (in ) + freq: float = None + If monitor is a :class:`FieldMonitor`, specifies the frequency (Hz) to plot the field. + time: float = None + if monitor is a :class:`FieldTimeMonitor`, specifies the time (sec) to plot the field. + cbar: bool = True + if True (default), will include colorbar + eps_alpha : float = 0.2 + Opacity of the structure permittivity. + Must be between 0 and 1 (inclusive). + ax : matplotlib.axes._subplots.Axes = None + matplotlib axes to plot on, if not specified, one is created. + **patch_kwargs + Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + # get the monitor data + if field_monitor_name not in self.monitor_data: + raise DataError(f"Monitor named '{field_monitor_name}' not found.") + monitor_data = self.monitor_data.get(field_monitor_name) + if not isinstance(monitor_data, FieldData): + raise DataError(f"field_monitor_name '{field_monitor_name}' not a FieldData instance.") + + # get the field data component + if field_name not in monitor_data.data_dict: + raise DataError(f"field_name {field_name} not found in {field_monitor_name}.") + xr_data = monitor_data.data_dict.get(field_name).data + + # select the frequency or time value + if "f" in monitor_data.coords: + if freq is None: + raise DataError("'freq' must be supplied to plot a FieldMonitor.") + field_data = xr_data.interp(f=freq) + elif "t" in monitor_data.coords: + if time is None: + raise DataError("'time' must be supplied to plot a FieldMonitor.") + field_data = xr_data.interp(t=time) + else: + raise DataError("Field data has neither time nor frequency data, something went wrong.") + + # select the cross section data + axis, pos = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z) + axis_label = "xyz"[axis] + sel_kwarg = {axis_label: pos} + try: + field_data = field_data.sel(**sel_kwarg) + except Exception as e: + raise DataError(f"Could not select data at {axis_label}={pos}.") from e + + # select the field value + if val not in ("real", "imag", "abs"): + raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}") + if val == "real": + field_data = field_data.real + elif val == "imag": + field_data = field_data.imag + elif val == "real": + field_data = abs(field_data) + + # plot the field + xy_coords = list("xyz") + xy_coords.pop(axis) + field_data.plot(ax=ax, x=xy_coords[0], y=xy_coords[1]) + + # plot the simulation epsilon + ax = self.simulation.plot_structures_eps( + freq=freq, cbar=cbar, x=x, y=y, z=z, alpha=eps_alpha, ax=ax + ) + return ax def export(self, fname: str) -> None: """Export :class:`SimulationData` to single hdf5 file including monitor data. diff --git a/tidy3d/components/geometry.py b/tidy3d/components/geometry.py index 1d455f93c6..d748256759 100644 --- a/tidy3d/components/geometry.py +++ b/tidy3d/components/geometry.py @@ -208,7 +208,7 @@ def plot( # pylint:disable=line-too-long # find shapes that intersect self at plane - axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z) + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) shapes_intersect = self.intersections(x=x, y=y, z=z) # for each intersection, plot the shape @@ -328,7 +328,7 @@ def unpop_axis(ax_coord: Any, plane_coords: Tuple[Any, Any], axis: int) -> Tuple return tuple(coords) @staticmethod - def _parse_xyz_kwargs(**xyz) -> Tuple[Axis, float]: + def parse_xyz_kwargs(**xyz) -> Tuple[Axis, float]: """Turns x,y,z kwargs into index of the normal axis and position along that axis. Parameters @@ -380,7 +380,7 @@ def intersections(self, x: float = None, y: float = None, z: float = None): For more details refer to `Shapely's Documentaton `_. """ - axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z) + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) if axis == self.axis: z0, _ = self.pop_axis(self.center, axis=self.axis) if (position < z0 - self.length / 2) or (position > z0 + self.length / 2): @@ -529,7 +529,7 @@ def intersections(self, x: float = None, y: float = None, z: float = None): For more details refer to `Shapely's Documentaton `_. """ - axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z) + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) z0, (x0, y0) = self.pop_axis(self.center, axis=axis) Lz, (Lx, Ly) = self.pop_axis(self.size, axis=axis) dz = np.abs(z0 - position) @@ -647,7 +647,7 @@ def intersections(self, x: float = None, y: float = None, z: float = None): For more details refer to `Shapely's Documentaton `_. """ - axis, position = self._parse_xyz_kwargs(x=x, y=y, z=z) + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) z0, (x0, y0) = self.pop_axis(self.center, axis=axis) intersect_dist = self._intersect_dist(position, z0) if not intersect_dist: diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index b0d84a22a8..67078d6359 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -11,7 +11,7 @@ from .viz import add_ax_if_none from .validators import validate_name_str -from ..constants import C_0, inf, pec_val +from ..constants import C_0, pec_val from ..log import log @@ -23,15 +23,16 @@ class AbstractMedium(ABC, Tidy3dBaseModel): Parameters ---------- - frequeuncy_range : Tuple[float, float] = (-inf, inf) + frequeuncy_range : Tuple[float, float] = None Range of validity for the medium in Hz. + If None, then all frequencies are valid. If simulation or plotting functions use frequency out of this range, a warning is thrown. name : str = None Optional name for the medium. """ name: str = None - frequency_range: Tuple[FreqBound, FreqBound] = (-inf, inf) + frequency_range: Tuple[FreqBound, FreqBound] = None _name_validator = validate_name_str() @@ -88,7 +89,7 @@ def _eps_model(self, frequency: float) -> complex: """New eps_model function.""" # if frequency is none, don't check, return original function - if frequency is None: + if frequency is None or self.frequency_range is None: return eps_model(self, frequency) fmin, fmax = self.frequency_range diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index a57fc4cc03..da56e6fcea 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -637,7 +637,7 @@ def plot_grid(self, x: float = None, y: float = None, z: float = None, ax: Ax = The supplied or created matplotlib axes. """ cell_boundaries = self.grid.boundaries - axis, _ = self._parse_xyz_kwargs(x=x, y=y, z=z) + axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) _, (axis_x, axis_y) = self.pop_axis([0, 1, 2], axis=axis) boundaries_x = cell_boundaries.dict()["xyz"[axis_x]] boundaries_y = cell_boundaries.dict()["xyz"[axis_y]] @@ -668,7 +668,7 @@ def _set_plot_bounds(self, ax: Ax, x: float = None, y: float = None, z: float = The axes after setting the boundaries. """ - axis, _ = self._parse_xyz_kwargs(x=x, y=y, z=z) + axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) _, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis) _, (pml_thick_x, pml_thick_y) = self.pop_axis(self.pml_thicknesses, axis=axis) diff --git a/tidy3d/components/viz.py b/tidy3d/components/viz.py index 4a5f3a6548..ea4e56fdee 100644 --- a/tidy3d/components/viz.py +++ b/tidy3d/components/viz.py @@ -81,7 +81,7 @@ class SourceParams(PatchParamSwitcher): def get_plot_params(self) -> PatchParams: """Returns :class:`PatchParams` based on user-supplied args.""" - return PatchParams(alpha=0.7, facecolor="blueviolet", edgecolor="blueviolet") + return PatchParams(alpha=0.4, facecolor="blueviolet", edgecolor="blueviolet") class MonitorParams(PatchParamSwitcher): @@ -89,7 +89,7 @@ class MonitorParams(PatchParamSwitcher): def get_plot_params(self) -> PatchParams: """Returns :class:`PatchParams` based on user-supplied args.""" - return PatchParams(alpha=0.7, facecolor="crimson", edgecolor="crimson") + return PatchParams(alpha=0.4, facecolor="crimson", edgecolor="crimson") class StructMediumParams(PatchParamSwitcher): @@ -140,9 +140,9 @@ class SymParams(PatchParamSwitcher): def get_plot_params(self) -> PatchParams: """Returns :class:`PatchParams` based on user-supplied args.""" if self.sym_value == 1: - return PatchParams(alpha=0.5, facecolor="lightsteelblue", edgecolor="lightsteelblue") + return PatchParams(alpha=0.3, facecolor="lightsteelblue", edgecolor="lightsteelblue") if self.sym_value == -1: - return PatchParams(alpha=0.5, facecolor="lightgreen", edgecolor="lightgreen") + return PatchParams(alpha=0.3, facecolor="lightgreen", edgecolor="lightgreen") return PatchParams() diff --git a/tidy3d/constants.py b/tidy3d/constants.py index 6c9145c590..a4086ffa1e 100644 --- a/tidy3d/constants.py +++ b/tidy3d/constants.py @@ -22,7 +22,7 @@ HBAR = 6.582119569e-16 # infinity (very large) -inf = 1e20 +inf = 1e10 # floating point precisions dp_eps = np.finfo(np.float64).eps diff --git a/tidy3d/convert.py b/tidy3d/convert.py index 152d3990a4..91b776c203 100644 --- a/tidy3d/convert.py +++ b/tidy3d/convert.py @@ -336,15 +336,15 @@ def old_json_monitors(sim: Simulation) -> Dict: mnt.update({"frequency": [f * 1e-12 for f in monitor.freqs]}) elif isinstance(monitor, TimeMonitor): # handle case where stop is None - stop = monitor.stop * 1e12 if monitor.stop else sim.run_time * 1e12 + stop = monitor.stop if monitor.stop else sim.run_time # handle case where stop > sim.run_time - stop = min(stop, sim.run_time * 1e12) + stop = min(stop, sim.run_time) mnt.update( { - "t_start": monitor.start * 1e12, + "t_start": monitor.start, "t_stop": stop, - "t_step": 1e12 * sim.dt * monitor.interval, + "t_step": sim.dt * monitor.interval, } ) @@ -482,9 +482,7 @@ def load_old_monitor_data(simulation: Simulation, data_file: str) -> SolverDataD sampler_label = "f" elif isinstance(monitor, TimeMonitor): - stop = monitor.stop if monitor.stop else simulation.run_time - step = simulation.dt * monitor.interval - sampler_values = list(np.arange(monitor.start, stop, step)) + sampler_values = np.array(f_handle[name]["tmesh"]).ravel() sampler_label = "t" if isinstance(monitor, AbstractFieldMonitor):