diff --git a/tests/test_components/test_scene.py b/tests/test_components/test_scene.py index 6f7e5741c3..ff5178d182 100644 --- a/tests/test_components/test_scene.py +++ b/tests/test_components/test_scene.py @@ -2,6 +2,7 @@ from __future__ import annotations +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pydantic.v1 as pd @@ -9,6 +10,7 @@ import tidy3d as td from tidy3d.components.scene import MAX_GEOMETRY_COUNT, MAX_NUM_MEDIUMS +from tidy3d.components.viz import STRUCTURE_EPS_CMAP, STRUCTURE_EPS_CMAP_R from tidy3d.exceptions import SetupError from ..utils import SIM_FULL, cartesian_to_unstructured @@ -142,11 +144,62 @@ def test_get_structure_plot_params(): pp = SCENE_FULL._get_structure_eps_plot_params( medium=SCENE_FULL.medium, freq=1, eps_min=1, eps_max=2 ) - assert float(pp.facecolor) == 1.0 + expected_color = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP)(0.0) + assert np.allclose(pp.facecolor, expected_color) pp = SCENE_FULL._get_structure_eps_plot_params(medium=td.PEC, freq=1, eps_min=1, eps_max=2) assert pp.facecolor == "gold" +def test_structure_eps_color_mapping(): + medium_min = td.Medium(permittivity=1.0) + medium_max = td.Medium(permittivity=5.0) + norm = mpl.colors.Normalize(vmin=1.0, vmax=5.0) + + pp_min = SCENE_FULL._get_structure_eps_plot_params( + medium=medium_min, + freq=1, + eps_min=1.0, + eps_max=5.0, + norm=norm, + reverse=False, + ) + expected_min = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP)(norm(1.0)) + assert np.allclose(pp_min.facecolor, expected_min) + + pp_max = SCENE_FULL._get_structure_eps_plot_params( + medium=medium_max, + freq=1, + eps_min=1.0, + eps_max=5.0, + norm=norm, + reverse=False, + ) + expected_max = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP)(norm(5.0)) + assert np.allclose(pp_max.facecolor, expected_max) + + pp_min_reverse = SCENE_FULL._get_structure_eps_plot_params( + medium=medium_min, + freq=1, + eps_min=1.0, + eps_max=5.0, + norm=norm, + reverse=True, + ) + expected_min_reverse = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP_R)(norm(1.0)) + assert np.allclose(pp_min_reverse.facecolor, expected_min_reverse) + + pp_max_reverse = SCENE_FULL._get_structure_eps_plot_params( + medium=medium_max, + freq=1, + eps_min=1.0, + eps_max=5.0, + norm=norm, + reverse=True, + ) + expected_max_reverse = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP_R)(norm(5.0)) + assert np.allclose(pp_max_reverse.facecolor, expected_max_reverse) + + def test_num_mediums(): """Make sure we error if too many mediums supplied.""" diff --git a/tidy3d/components/scene.py b/tidy3d/components/scene.py index f2f4d99087..85b643188c 100644 --- a/tidy3d/components/scene.py +++ b/tidy3d/components/scene.py @@ -90,6 +90,10 @@ MAX_STRUCTURES_PER_MEDIUM = 1_000 +def _get_colormap(reverse: bool = False): + return STRUCTURE_EPS_CMAP_R if reverse else STRUCTURE_EPS_CMAP + + class Scene(Tidy3dBaseModel): """Contains generic information about the geometry and medium properties common to all types of simulations. @@ -1200,7 +1204,7 @@ def _add_cbar_eps( vmin=eps_min, vmax=eps_max, label=r"$\epsilon_r$", - cmap=STRUCTURE_EPS_CMAP if not reverse else STRUCTURE_EPS_CMAP_R, + cmap=_get_colormap(reverse=reverse), ax=ax, norm=norm, ) @@ -1314,16 +1318,14 @@ def _pcolormesh_shape_custom_medium_structure_eps( # extract slice if volumetric unstructured data eps = eps.plane_slice(axis=normal_axis_ind, pos=normal_position) - if reverse: - eps = eps_min + eps_max - eps - # at this point eps_mean is TriangularGridDataset and we just plot it directly # with applying shape mask + cmap_name = _get_colormap(reverse=reverse) eps.plot( grid=False, ax=ax, cbar=False, - cmap=STRUCTURE_EPS_CMAP, + cmap=cmap_name, vmin=eps_min, vmax=eps_max, pcolor_kwargs={ @@ -1395,18 +1397,15 @@ def _pcolormesh_shape_custom_medium_structure_eps( # remove the normal_axis and take real part eps_shape = eps_shape.real.mean(axis=normal_axis_ind) - # reverse - if reverse: - eps_shape = eps_min + eps_max - eps_shape - # pcolormesh plane_xp, plane_yp = np.meshgrid(plane_coord[0], plane_coord[1], indexing="ij") + cmap_name = _get_colormap(reverse=reverse) ax.pcolormesh( plane_xp, plane_yp, eps_shape, clip_path=(polygon_path(shape), ax.transData), - cmap=STRUCTURE_EPS_CMAP, + cmap=cmap_name, alpha=alpha, clip_box=ax.bbox, norm=norm, @@ -1447,23 +1446,15 @@ def _get_structure_eps_plot_params( plot_params = plot_params.copy(update={"edgecolor": "k", "linewidth": 1}) else: eps_medium = medium._eps_plot(frequency=freq, eps_component=eps_component) - if norm is not None: - # Use the same normalization as the colorbar for consistency - color = norm(eps_medium) - # TODO: This is a hack to ensure color consistency with the colorbar. - # It should be removed once we establish a proper color mapping where - # eps_min maps to 0 and eps_max maps to 1 for 'reverse=False'. - if not reverse: - color = 1 - color - color = min(1, max(color, 0)) # clip in case of custom eps limits - else: - # Fallback to linear mapping for backward compatibility - delta_eps = eps_medium - eps_min - delta_eps_max = eps_max - eps_min + 1e-5 - eps_fraction = delta_eps / delta_eps_max - color = eps_fraction if reverse else 1 - eps_fraction - color = min(1, max(color, 0)) # clip in case of custom eps limits - plot_params = plot_params.copy(update={"facecolor": str(color)}) + active_norm = ( + norm if norm is not None else mpl.colors.Normalize(vmin=eps_min, vmax=eps_max) + ) + color_value = float(active_norm(eps_medium)) + color_value = min(1.0, max(0.0, color_value)) + cmap_name = _get_colormap(reverse=reverse) + cmap = mpl.cm.get_cmap(cmap_name) + rgba = tuple(float(component) for component in cmap(color_value)) + plot_params = plot_params.copy(update={"facecolor": rgba}) return plot_params