diff --git a/CHANGELOG.md b/CHANGELOG.md index a04d01b741..22a791d11a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `smoothed_projection` for topology optimization of completely binarized designs. - Added more RF-specific mode characteristics to `MicrowaveModeData`, including propagation constants (alpha, beta, gamma), phase/group velocities, wave impedance, and automatic mode classification with configurable polarization thresholds in `MicrowaveModeSpec`. - Introduce `tidy3d.rf` namespace to consolidate all RF classes. +- Added support for custom colormaps in `plot_field`. ### Breaking Changes - Edge singularity correction at PEC and lossy metal edges defaults to `True`. diff --git a/tests/test_components/test_eme.py b/tests/test_components/test_eme.py index 6712d921b6..8e61a5cbd7 100644 --- a/tests/test_components/test_eme.py +++ b/tests/test_components/test_eme.py @@ -1093,6 +1093,19 @@ def test_eme_sim_data(): _ = sim_data.plot_field( "field", "E", eme_port_index=0, val="abs^2", f=td.C_0, mode_index=0, ax=AX ) + _ = sim_data.plot_field( + "field", "Ex", eme_port_index=0, val="real", f=td.C_0, mode_index=0, cmap="plasma", ax=AX + ) + _ = sim_data.plot_field( + "field", + "Ex", + eme_port_index=0, + val="real", + f=td.C_0, + mode_index=0, + cmap=plt.get_cmap("cividis"), + ax=AX, + ) # test smatrix in basis with sweep smatrix = _get_eme_smatrix_dataset(num_modes_1=5, num_modes_2=5, num_sweep=10) diff --git a/tests/test_data/test_sim_data.py b/tests/test_data/test_sim_data.py index dcbe5ab560..8c0383bb09 100644 --- a/tests/test_data/test_sim_data.py +++ b/tests/test_data/test_sim_data.py @@ -6,6 +6,7 @@ import numpy as np import pydantic.v1 as pydantic import pytest +from matplotlib import colors as mcolors import tidy3d as td from tidy3d.components.data.data_array import ScalarFieldTimeDataArray @@ -194,6 +195,22 @@ def test_plot(phase): plt.close() +def test_plot_field_custom_cmap(): + sim_data = make_sim_data() + _ = sim_data.plot_field("field", "Ex", val="real", f=2e14, z=0.10, cmap="viridis") + plt.close() + custom_cmap = mcolors.LinearSegmentedColormap.from_list("two", ["black", "white"]) + _ = sim_data.plot_field( + "field", + "Ez", + val="imag", + f=2e14, + z=0.10, + cmap=custom_cmap, + ) + plt.close() + + def test_plot_field_missing_derived_data(): sim_data = make_sim_data() with pytest.raises(Tidy3dKeyError): diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index d8a0b853de..fdb0c6b552 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -8,7 +8,7 @@ from abc import ABC from collections import defaultdict from os import PathLike -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import h5py import numpy as np @@ -34,6 +34,9 @@ from .data_array import FreqDataArray, TimeDataArray from .monitor_data import AbstractFieldData, FieldTimeData +if TYPE_CHECKING: + from matplotlib.colors import Colormap + DATA_TYPE_MAP = {data.__fields__["monitor"].type_: data for data in MonitorDataTypes} # maps monitor type (string) to the class of the corresponding data @@ -456,6 +459,7 @@ def plot_field_monitor_data( vmax: Optional[float] = None, ax: Ax = None, shading: str = "flat", + cmap: Optional[Union[str, Colormap]] = None, **sel_kwargs: Any, ) -> Ax: """Plot the field data for a monitor with simulation plot overlaid. @@ -492,6 +496,8 @@ def plot_field_monitor_data( matplotlib axes to plot on, if not specified, one is created. shading: str = 'flat' Shading argument for Xarray plot method ('flat','nearest','goraud') + cmap : Optional[Union[str, Colormap]] = None + Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data. These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``), frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable. @@ -656,6 +662,7 @@ def plot_field_monitor_data( cmap_type=cmap_type, ax=ax, shading=shading, + cmap=cmap, infer_intervals=True if shading == "flat" else False, ) @@ -672,6 +679,7 @@ def plot_field( vmax: Optional[float] = None, ax: Ax = None, shading: str = "flat", + cmap: Optional[Union[str, Colormap]] = None, **sel_kwargs: Any, ) -> Ax: """Plot the field data for a monitor with simulation plot overlaid. @@ -709,6 +717,8 @@ def plot_field( matplotlib axes to plot on, if not specified, one is created. shading: str = 'flat' Shading argument for Xarray plot method ('flat','nearest','goraud') + cmap : Optional[Union[str, Colormap]] = None + Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data. These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``), frequency or time dimensions (``f``, ``t``) or ``mode_index``, if applicable. @@ -736,6 +746,7 @@ def plot_field( vmax=vmax, ax=ax, shading=shading, + cmap=cmap, **sel_kwargs, ) @@ -752,6 +763,7 @@ def plot_scalar_array( vmin: Optional[float] = None, vmax: Optional[float] = None, cmap_type: ColormapType = "divergent", + cmap: Optional[Union[str, Colormap]] = None, ax: Ax = None, **kwargs: Any, ) -> Ax: @@ -784,6 +796,8 @@ def plot_scalar_array( inferred from the data and other keyword arguments. cmap_type : Literal["divergent", "sequential", "cyclic"] = "divergent" Type of color map to use for plotting. + cmap : Optional[Union[str, Colormap]] = None + Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. Overrides inferred colormap from `cmap_type`. ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. **kwargs : Extra arguments to ``DataArray.plot``. @@ -798,19 +812,23 @@ def plot_scalar_array( interp_kwarg = {"xyz"[axis]: position} if cmap_type == "divergent": - cmap = "RdBu" + default_cmap = "RdBu" center = 0.0 eps_reverse = False elif cmap_type == "sequential": - cmap = "magma" + default_cmap = "magma" center = False eps_reverse = True elif cmap_type == "cyclic": - cmap = "twilight" + default_cmap = "twilight" vmin = -np.pi vmax = np.pi center = False eps_reverse = False + else: + default_cmap = None + + cmap_to_use = default_cmap if cmap is None else cmap # plot the field xy_coord_labels = list("xyz") @@ -820,7 +838,7 @@ def plot_scalar_array( ax=ax, x=x_coord_label, y=y_coord_label, - cmap=cmap, + cmap=cmap_to_use, vmin=vmin, vmax=vmax, robust=robust, diff --git a/tidy3d/components/mode/data/sim_data.py b/tidy3d/components/mode/data/sim_data.py index 9ba247161a..0aba271bf2 100644 --- a/tidy3d/components/mode/data/sim_data.py +++ b/tidy3d/components/mode/data/sim_data.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import pydantic.v1 as pd @@ -16,6 +16,9 @@ ModeSimulationMonitorDataType = Union[PermittivityData, MediumData] +if TYPE_CHECKING: + from matplotlib.colors import Colormap + class ModeSimulationData(AbstractYeeGridSimulationData): """Data associated with a mode solver simulation.""" @@ -53,6 +56,7 @@ def plot_field( vmin: Optional[float] = None, vmax: Optional[float] = None, ax: Ax = None, + cmap: Optional[Union[str, Colormap]] = None, **sel_kwargs: Any, ) -> Ax: """Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid. @@ -80,6 +84,8 @@ def plot_field( inferred from the data and other keyword arguments. ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. + cmap : Optional[Union[str, Colormap]] = None + Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data. These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``), frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable. @@ -102,6 +108,7 @@ def plot_field( vmin=vmin, vmax=vmax, ax=ax, + cmap=cmap, **sel_kwargs, ) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index f32a56a4cf..5acc5f4e64 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -6,7 +6,7 @@ from functools import wraps from math import isclose -from typing import Any, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args import numpy as np import pydantic.v1 as pydantic @@ -82,6 +82,9 @@ from tidy3d.constants import C_0, fp_eps from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log + +if TYPE_CHECKING: + from matplotlib.colors import Colormap from tidy3d.packaging import supports_local_subpixel, tidy3d_extras # Importing the local solver may not work if e.g. scipy is not installed @@ -2249,6 +2252,7 @@ def plot_field( vmin: Optional[float] = None, vmax: Optional[float] = None, ax: Ax = None, + cmap: Optional[Union[str, Colormap]] = None, **sel_kwargs: Any, ) -> Ax: """Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid. @@ -2276,6 +2280,8 @@ def plot_field( inferred from the data and other keyword arguments. ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. + cmap : Optional[Union[str, Colormap]] = None + Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data. These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``), frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable. @@ -2300,6 +2306,7 @@ def plot_field( vmin=vmin, vmax=vmax, ax=ax, + cmap=cmap, **sel_kwargs, ) diff --git a/tidy3d/components/tcad/data/sim_data.py b/tidy3d/components/tcad/data/sim_data.py index fb62043507..d72d25a931 100644 --- a/tidy3d/components/tcad/data/sim_data.py +++ b/tidy3d/components/tcad/data/sim_data.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import numpy as np import pydantic.v1 as pd @@ -35,6 +35,9 @@ from tidy3d.exceptions import DataError, Tidy3dKeyError from tidy3d.log import log +if TYPE_CHECKING: + from matplotlib.colors import Colormap + class DeviceCharacteristics(Tidy3dBaseModel): """Stores device characteristics. For example, in steady-state it stores @@ -281,6 +284,7 @@ def plot_field( vmin: Optional[float] = None, vmax: Optional[float] = None, ax: Ax = None, + cmap: Optional[Union[str, Colormap]] = None, **sel_kwargs: Any, ) -> Ax: """Plot the data for a monitor with simulation structures overlaid. @@ -310,6 +314,8 @@ def plot_field( inferred from the data and other keyword arguments. ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. + cmap : Optional[Union[str, Colormap]] = None + Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data. These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``), or time dimension (``t``) if applicable. @@ -345,7 +351,7 @@ def plot_field( if scale == "log": field_data = np.log10(np.abs(field_data)) - cmap = "coolwarm" + cmap_to_use = "coolwarm" if cmap is None else cmap # do sel on unstructured data # it could produce either SpatialDataArray or UnstructuredGridDatasetType @@ -361,7 +367,7 @@ def plot_field( if isinstance(field_data, TriangularGridDataset): field_data.plot( ax=ax, - cmap=cmap, + cmap=cmap_to_use, vmin=vmin, vmax=vmax, cbar_kwargs={"label": field_name}, @@ -436,7 +442,7 @@ def plot_field( ax=ax, x=x_coord_label, y=y_coord_label, - cmap=cmap, + cmap=cmap_to_use, vmin=vmin, vmax=vmax, robust=robust, diff --git a/tidy3d/plugins/waveguide/rectangular_dielectric.py b/tidy3d/plugins/waveguide/rectangular_dielectric.py index 647022f3c0..22b01f0d3d 100644 --- a/tidy3d/plugins/waveguide/rectangular_dielectric.py +++ b/tidy3d/plugins/waveguide/rectangular_dielectric.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Annotated, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union import numpy import pydantic.v1 as pydantic @@ -27,6 +27,9 @@ from tidy3d.log import log from tidy3d.plugins.mode.mode_solver import ModeSolver +if TYPE_CHECKING: + from matplotlib.colors import Colormap + AnnotatedMedium = Annotated[MediumType, pydantic.Field(discriminator=TYPE_TAG_STR)] @@ -1099,6 +1102,7 @@ def plot_field( vmax: Optional[float] = None, ax: Ax = None, geometry_edges: Optional[str] = None, + cmap: Optional[Union[str, Colormap]] = None, **sel_kwargs: Any, ) -> Ax: """Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid. @@ -1127,6 +1131,8 @@ def plot_field( ax : matplotlib.axes._subplots.Axes = None matplotlib axes to plot on, if not specified, one is created. geometry_edges : Optional color to use for the geometry edges overlaid on the fields. + cmap : Optional[Union[str, Colormap]] = None + Colormap for visualizing the field values. ``None`` uses the default which infers it from the data. sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data. These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``), frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable. @@ -1147,6 +1153,7 @@ def plot_field( vmin=vmin, vmax=vmax, ax=ax, + cmap=cmap, **sel_kwargs, ) if geometry_edges is not None: