From d22399dbf5ab90b7b88ec478b2981c4c2771488e Mon Sep 17 00:00:00 2001 From: Niko Savola Date: Mon, 27 Nov 2023 15:07:39 -0800 Subject: [PATCH] Add pins in `taper_cross_section`, closes #2329 --- gdsfactory/add_tapers_cross_section.py | 26 +++++++++++++++----- gdsfactory/components/taper_cross_section.py | 15 ++++++++++- gdsfactory/cross_section.py | 7 ++++-- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/gdsfactory/add_tapers_cross_section.py b/gdsfactory/add_tapers_cross_section.py index 6853170f17..e5ac8d7649 100644 --- a/gdsfactory/add_tapers_cross_section.py +++ b/gdsfactory/add_tapers_cross_section.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import gdsfactory as gf from gdsfactory.cell import cell @@ -13,7 +14,7 @@ @cell def add_tapers( - component: Component, + component: ComponentSpec, taper: ComponentSpec = taper_cross_section, select_ports: Callable | None = select_ports_optical, taper_port_name1: str = "o1", @@ -37,18 +38,31 @@ def add_tapers( npoints: number of points. linear: shape of the transition, sine when False. kwargs: cross_section settings for section2. + + Note: + If ``taper`` is a partial function and ``cross_section2`` is None, then + ``cross_section2`` is inferred from the partial keywords. """ c = gf.Component() + component = gf.get_component(component) + ports_to_taper = select_ports(component.ports) if select_ports else component.ports ports_to_taper_names = [p.name for p in ports_to_taper.values()] for port_name, port in component.ports.items(): if port.name in ports_to_taper_names: - taper_ref = c << taper( - cross_section1=port.cross_section, - cross_section2=cross_section2, - **kwargs, - ) + if isinstance(taper, partial) and cross_section2 is None: + taper_ref = c << taper( + cross_section2=partial( + taper.keywords["cross_section2"], width=port.width + ), + ) + else: + taper_ref = c << taper( + cross_section1=port.cross_section, + cross_section2=cross_section2, + **kwargs, + ) taper_ref.connect(taper_ref.ports[taper_port_name1].name, port) c.add_port(name=port_name, port=taper_ref.ports[taper_port_name2]) else: diff --git a/gdsfactory/components/taper_cross_section.py b/gdsfactory/components/taper_cross_section.py index 009d79efda..3ef2f406a3 100644 --- a/gdsfactory/components/taper_cross_section.py +++ b/gdsfactory/components/taper_cross_section.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import partial +from itertools import islice import gdsfactory as gf from gdsfactory.cell import cell @@ -57,6 +58,19 @@ def taper_cross_section( ref = c << gf.path.extrude_transition(taper_path, transition=transition) c.add_ports(ref.ports) c.absorb(ref) + + # set one pin for each cross section + x1.add_pins( + c, + select_ports=lambda ports: {(port_name := next(iter(ports))): ports[port_name]}, + ) + x2.add_pins( + c, + select_ports=lambda ports: { + (port_name := next(islice(iter(ports), 1, None))): ports[port_name], + }, + ) + if "type" in x1.info and x1.info["type"] == x2.info.get("type"): c.add_route_info(cross_section=x1, length=length, taper=True) return c @@ -68,7 +82,6 @@ def taper_cross_section( taper_cross_section, linear=False, width_type="parabolic", npoints=101 ) - if __name__ == "__main__": # x1 = partial(strip, width=0.5) # x2 = partial(strip, width=2.5) diff --git a/gdsfactory/cross_section.py b/gdsfactory/cross_section.py index ab8286cf34..946e8370ed 100644 --- a/gdsfactory/cross_section.py +++ b/gdsfactory/cross_section.py @@ -289,7 +289,10 @@ def mirror(self) -> CrossSection: sections = [s.model_copy(update=dict(offset=-s.offset)) for s in self.sections] return self.model_copy(update={"sections": tuple(sections)}) - def add_pins(self, component: Component) -> Component: + def add_pins(self, component: Component, *args, **kwargs) -> Component: + """Add pins to a target component according to :class:`CrossSection`. + Args and kwargs are passed to the function defined by the `add_pins_function_name`. + """ if self.add_pins_function_name is None: return component @@ -300,7 +303,7 @@ def add_pins(self, component: Component) -> Component: f"add_pins_function_module = {self.add_pins_function_module}" ) function = getattr(add_pins, self.add_pins_function_name) - return function(component=component) + return function(*args, component=component, **kwargs) def add_bbox( self,