From 75905818bfdc7a3d63207b8edd37a2c2d60ab5fd Mon Sep 17 00:00:00 2001 From: Sequoia Ploeg Date: Fri, 19 Jan 2024 10:58:02 -0700 Subject: [PATCH 1/2] Switch to using a context manager --- gplugins/common/utils/disable_print.py | 12 +++++--- gplugins/devsim/get_simulation_xsection.py | 7 ++--- gplugins/modes/find_modes.py | 33 +++++++++++----------- gplugins/modes/find_modes_cross_section.py | 33 +++++++++++----------- 4 files changed, 43 insertions(+), 42 deletions(-) diff --git a/gplugins/common/utils/disable_print.py b/gplugins/common/utils/disable_print.py index f4504290..f480d07f 100644 --- a/gplugins/common/utils/disable_print.py +++ b/gplugins/common/utils/disable_print.py @@ -5,9 +5,13 @@ import sys -def disable_print() -> None: - sys.stdout = open(os.devnull, "w") +class DisablePrint: + def __init__(self): + self.output = sys.stdout + def __enter__(self) -> "DisablePrint": + sys.stdout = open(os.devnull, "w") + return self -def enable_print() -> None: - sys.stdout = sys.__stdout__ + def __exit__(self, exc_type, exc_value, exc_tb): + sys.stdout = self.output diff --git a/gplugins/devsim/get_simulation_xsection.py b/gplugins/devsim/get_simulation_xsection.py index 1a48a554..592f7a10 100644 --- a/gplugins/devsim/get_simulation_xsection.py +++ b/gplugins/devsim/get_simulation_xsection.py @@ -23,7 +23,7 @@ from pydantic import BaseModel, ConfigDict from scipy.interpolate import griddata -from gplugins.common.utils.disable_print import disable_print, enable_print +from gplugins.common.utils.disable_print import DisablePrint from gplugins.tidy3d.materials import get_nk from gplugins.tidy3d.modes import Precision, Waveguide @@ -471,9 +471,8 @@ def plot( ) # , scalar_bar_args=sargs) _ = plotter.show_grid() _ = plotter.camera_position = "xy" - disable_print() - _ = plotter.show(jupyter_backend=jupyter_backend) - enable_print() + with DisablePrint(): + _ = plotter.show(jupyter_backend=jupyter_backend) def list_fields(self, tempfile="temp.dat"): """Returns the header of the mesh, which lists all possible fields.""" diff --git a/gplugins/modes/find_modes.py b/gplugins/modes/find_modes.py index b76ea81b..2ce6fc54 100644 --- a/gplugins/modes/find_modes.py +++ b/gplugins/modes/find_modes.py @@ -21,7 +21,7 @@ from gdsfactory.typings import PathType from meep import mpb -from gplugins.common.utils.disable_print import disable_print, enable_print +from gplugins.common.utils.disable_print import DisablePrint from gplugins.common.utils.get_sparameters_path import get_kwargs_hash from gplugins.modes.get_mode_solver_coupler import get_mode_solver_coupler from gplugins.modes.get_mode_solver_rib import get_mode_solver_rib @@ -166,22 +166,21 @@ def find_modes_waveguide( return modes # Output the x component of the Poynting vector for mode_number bands at omega - disable_print() - k = mode_solver.find_k( - parity, - omega, - mode_number, - mode_number + nmodes, - mp.Vector3(1), - tol, - omega * 2.02, - omega * 0.01, - omega * 10, - # mpb.output_poynting_x, - mpb.display_yparities, - mpb.display_group_velocities, - ) - enable_print() + with DisablePrint(): + k = mode_solver.find_k( + parity, + omega, + mode_number, + mode_number + nmodes, + mp.Vector3(1), + tol, + omega * 2.02, + omega * 0.01, + omega * 10, + # mpb.output_poynting_x, + mpb.display_yparities, + mpb.display_group_velocities, + ) neff = np.array(k) * wavelength # vg = mode_solver.compute_group_velocities() diff --git a/gplugins/modes/find_modes_cross_section.py b/gplugins/modes/find_modes_cross_section.py index def258a6..91c89292 100755 --- a/gplugins/modes/find_modes_cross_section.py +++ b/gplugins/modes/find_modes_cross_section.py @@ -21,7 +21,7 @@ from gdsfactory.typings import CrossSectionSpec, PathType from meep import mpb -from gplugins.common.utils.disable_print import disable_print, enable_print +from gplugins.common.utils.disable_print import DisablePrint from gplugins.common.utils.get_sparameters_path import get_kwargs_hash from gplugins.modes.get_mode_solver_cross_section import ( get_mode_solver_cross_section, @@ -108,22 +108,21 @@ def find_modes_cross_section( return modes # Output the x component of the Poynting vector for mode_number bands at omega - disable_print() - k = mode_solver.find_k( - parity, - omega, - mode_number, - mode_number + nmodes, - mp.Vector3(1), - tol, - omega * 2.02, - omega * 0.01, - omega * 10, - # mpb.output_poynting_x, - mpb.display_yparities, - mpb.display_group_velocities, - ) - enable_print() + with DisablePrint(): + k = mode_solver.find_k( + parity, + omega, + mode_number, + mode_number + nmodes, + mp.Vector3(1), + tol, + omega * 2.02, + omega * 0.01, + omega * 10, + # mpb.output_poynting_x, + mpb.display_yparities, + mpb.display_group_velocities, + ) neff = np.array(k) * wavelength # vg = mode_solver.compute_group_velocities() From 42346101e5634e7c128ba73b60c87cd9ac6fb993 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jan 2024 17:59:55 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- gplugins/common/utils/disable_print.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gplugins/common/utils/disable_print.py b/gplugins/common/utils/disable_print.py index f480d07f..591124c0 100644 --- a/gplugins/common/utils/disable_print.py +++ b/gplugins/common/utils/disable_print.py @@ -9,7 +9,7 @@ class DisablePrint: def __init__(self): self.output = sys.stdout - def __enter__(self) -> "DisablePrint": + def __enter__(self) -> DisablePrint: sys.stdout = open(os.devnull, "w") return self