Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to using a context manager for disable_print #310

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 8 additions & 4 deletions gplugins/common/utils/disable_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import sys


def disable_print() -> None:
sys.stdout = open(os.devnull, "w")
class DisablePrint:
def __init__(self):
self.output = sys.stdout

def __enter__(self) -> DisablePrint:
sys.stdout = open(os.devnull, "w")
return self

def enable_print() -> None:
sys.stdout = sys.__stdout__
def __exit__(self, exc_type, exc_value, exc_tb):
sys.stdout = self.output
7 changes: 3 additions & 4 deletions gplugins/devsim/get_simulation_xsection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pydantic import BaseModel, ConfigDict
from scipy.interpolate import griddata

from gplugins.common.utils.disable_print import disable_print, enable_print
from gplugins.common.utils.disable_print import DisablePrint
from gplugins.tidy3d.materials import get_nk
from gplugins.tidy3d.modes import Precision, Waveguide

Expand Down Expand Up @@ -471,9 +471,8 @@ def plot(
) # , scalar_bar_args=sargs)
_ = plotter.show_grid()
_ = plotter.camera_position = "xy"
disable_print()
_ = plotter.show(jupyter_backend=jupyter_backend)
enable_print()
with DisablePrint():
_ = plotter.show(jupyter_backend=jupyter_backend)

def list_fields(self, tempfile="temp.dat"):
"""Returns the header of the mesh, which lists all possible fields."""
Expand Down
33 changes: 16 additions & 17 deletions gplugins/modes/find_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gdsfactory.typings import PathType
from meep import mpb

from gplugins.common.utils.disable_print import disable_print, enable_print
from gplugins.common.utils.disable_print import DisablePrint
from gplugins.common.utils.get_sparameters_path import get_kwargs_hash
from gplugins.modes.get_mode_solver_coupler import get_mode_solver_coupler
from gplugins.modes.get_mode_solver_rib import get_mode_solver_rib
Expand Down Expand Up @@ -166,22 +166,21 @@ def find_modes_waveguide(
return modes

# Output the x component of the Poynting vector for mode_number bands at omega
disable_print()
k = mode_solver.find_k(
parity,
omega,
mode_number,
mode_number + nmodes,
mp.Vector3(1),
tol,
omega * 2.02,
omega * 0.01,
omega * 10,
# mpb.output_poynting_x,
mpb.display_yparities,
mpb.display_group_velocities,
)
enable_print()
with DisablePrint():
k = mode_solver.find_k(
parity,
omega,
mode_number,
mode_number + nmodes,
mp.Vector3(1),
tol,
omega * 2.02,
omega * 0.01,
omega * 10,
# mpb.output_poynting_x,
mpb.display_yparities,
mpb.display_group_velocities,
)
neff = np.array(k) * wavelength

# vg = mode_solver.compute_group_velocities()
Expand Down
33 changes: 16 additions & 17 deletions gplugins/modes/find_modes_cross_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gdsfactory.typings import CrossSectionSpec, PathType
from meep import mpb

from gplugins.common.utils.disable_print import disable_print, enable_print
from gplugins.common.utils.disable_print import DisablePrint
from gplugins.common.utils.get_sparameters_path import get_kwargs_hash
from gplugins.modes.get_mode_solver_cross_section import (
get_mode_solver_cross_section,
Expand Down Expand Up @@ -108,22 +108,21 @@ def find_modes_cross_section(
return modes

# Output the x component of the Poynting vector for mode_number bands at omega
disable_print()
k = mode_solver.find_k(
parity,
omega,
mode_number,
mode_number + nmodes,
mp.Vector3(1),
tol,
omega * 2.02,
omega * 0.01,
omega * 10,
# mpb.output_poynting_x,
mpb.display_yparities,
mpb.display_group_velocities,
)
enable_print()
with DisablePrint():
k = mode_solver.find_k(
parity,
omega,
mode_number,
mode_number + nmodes,
mp.Vector3(1),
tol,
omega * 2.02,
omega * 0.01,
omega * 10,
# mpb.output_poynting_x,
mpb.display_yparities,
mpb.display_group_velocities,
)
neff = np.array(k) * wavelength

# vg = mode_solver.compute_group_velocities()
Expand Down