diff --git a/gplugins/common/utils/disable_print.py b/gplugins/common/utils/disable_print.py index f4504290..591124c0 100644 --- a/gplugins/common/utils/disable_print.py +++ b/gplugins/common/utils/disable_print.py @@ -5,9 +5,13 @@ import sys -def disable_print() -> None: - sys.stdout = open(os.devnull, "w") +class DisablePrint: + def __init__(self): + self.output = sys.stdout + def __enter__(self) -> DisablePrint: + sys.stdout = open(os.devnull, "w") + return self -def enable_print() -> None: - sys.stdout = sys.__stdout__ + def __exit__(self, exc_type, exc_value, exc_tb): + sys.stdout = self.output diff --git a/gplugins/devsim/get_simulation_xsection.py b/gplugins/devsim/get_simulation_xsection.py index 1a48a554..592f7a10 100644 --- a/gplugins/devsim/get_simulation_xsection.py +++ b/gplugins/devsim/get_simulation_xsection.py @@ -23,7 +23,7 @@ from pydantic import BaseModel, ConfigDict from scipy.interpolate import griddata -from gplugins.common.utils.disable_print import disable_print, enable_print +from gplugins.common.utils.disable_print import DisablePrint from gplugins.tidy3d.materials import get_nk from gplugins.tidy3d.modes import Precision, Waveguide @@ -471,9 +471,8 @@ def plot( ) # , scalar_bar_args=sargs) _ = plotter.show_grid() _ = plotter.camera_position = "xy" - disable_print() - _ = plotter.show(jupyter_backend=jupyter_backend) - enable_print() + with DisablePrint(): + _ = plotter.show(jupyter_backend=jupyter_backend) def list_fields(self, tempfile="temp.dat"): """Returns the header of the mesh, which lists all possible fields.""" diff --git a/gplugins/modes/find_modes.py b/gplugins/modes/find_modes.py index b76ea81b..2ce6fc54 100644 --- a/gplugins/modes/find_modes.py +++ b/gplugins/modes/find_modes.py @@ -21,7 +21,7 @@ from gdsfactory.typings import PathType from meep import mpb -from gplugins.common.utils.disable_print import disable_print, enable_print +from gplugins.common.utils.disable_print import DisablePrint from gplugins.common.utils.get_sparameters_path import get_kwargs_hash from gplugins.modes.get_mode_solver_coupler import get_mode_solver_coupler from gplugins.modes.get_mode_solver_rib import get_mode_solver_rib @@ -166,22 +166,21 @@ def find_modes_waveguide( return modes # Output the x component of the Poynting vector for mode_number bands at omega - disable_print() - k = mode_solver.find_k( - parity, - omega, - mode_number, - mode_number + nmodes, - mp.Vector3(1), - tol, - omega * 2.02, - omega * 0.01, - omega * 10, - # mpb.output_poynting_x, - mpb.display_yparities, - mpb.display_group_velocities, - ) - enable_print() + with DisablePrint(): + k = mode_solver.find_k( + parity, + omega, + mode_number, + mode_number + nmodes, + mp.Vector3(1), + tol, + omega * 2.02, + omega * 0.01, + omega * 10, + # mpb.output_poynting_x, + mpb.display_yparities, + mpb.display_group_velocities, + ) neff = np.array(k) * wavelength # vg = mode_solver.compute_group_velocities() diff --git a/gplugins/modes/find_modes_cross_section.py b/gplugins/modes/find_modes_cross_section.py index def258a6..91c89292 100755 --- a/gplugins/modes/find_modes_cross_section.py +++ b/gplugins/modes/find_modes_cross_section.py @@ -21,7 +21,7 @@ from gdsfactory.typings import CrossSectionSpec, PathType from meep import mpb -from gplugins.common.utils.disable_print import disable_print, enable_print +from gplugins.common.utils.disable_print import DisablePrint from gplugins.common.utils.get_sparameters_path import get_kwargs_hash from gplugins.modes.get_mode_solver_cross_section import ( get_mode_solver_cross_section, @@ -108,22 +108,21 @@ def find_modes_cross_section( return modes # Output the x component of the Poynting vector for mode_number bands at omega - disable_print() - k = mode_solver.find_k( - parity, - omega, - mode_number, - mode_number + nmodes, - mp.Vector3(1), - tol, - omega * 2.02, - omega * 0.01, - omega * 10, - # mpb.output_poynting_x, - mpb.display_yparities, - mpb.display_group_velocities, - ) - enable_print() + with DisablePrint(): + k = mode_solver.find_k( + parity, + omega, + mode_number, + mode_number + nmodes, + mp.Vector3(1), + tol, + omega * 2.02, + omega * 0.01, + omega * 10, + # mpb.output_poynting_x, + mpb.display_yparities, + mpb.display_group_velocities, + ) neff = np.array(k) * wavelength # vg = mode_solver.compute_group_velocities()