Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and improve speed of flatten and absorb #875

Merged
merged 4 commits into from Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 18 additions & 11 deletions gdsfactory/component.py
Expand Up @@ -1089,7 +1089,7 @@ def align(self, elements="all", alignment="ymax"):
_align(elements, alignment=alignment)
return self

def flatten(self, single_layer: Optional[Tuple[int, int]] = None):
def flatten(self, single_layer: Optional[LayerSpec] = None):
"""Returns a flattened copy of the component.

Flattens the hierarchy of the Component such that there are no longer
Expand All @@ -1102,13 +1102,19 @@ def flatten(self, single_layer: Optional[Tuple[int, int]] = None):
"""
component_flat = Component()

poly_dict = self.get_polygons(by_spec=True, include_paths=False, as_array=False)
for layer, polys in poly_dict.items():
if polys:
component_flat.add_polygon(polys, layer=single_layer or layer)
_cell = self._cell.copy(name=component_flat.name)
_cell = _cell.flatten()
component_flat._cell = _cell
if single_layer is not None:
from gdsfactory import get_layer

for path in self._cell.get_paths():
component_flat.add(path)
layer, datatype = get_layer(single_layer)
for polygon in _cell.polygons:
polygon.layer = layer
polygon.datatype = datatype
for path in _cell.paths:
path.set_layers(layer)
path.set_datatypes(datatype)

component_flat.info = self.info.copy()
component_flat.add_ports(self.ports)
Expand Down Expand Up @@ -1706,11 +1712,12 @@ def absorb(self, reference) -> "Component":
raise ValueError(
"The reference you asked to absorb does not exist in this Component."
)
ref_polygons = reference.get_polygons(by_spec=True, include_paths=False)
for (layer, polys) in ref_polygons.items():
[self.add_polygon(points=p, layer=layer) for p in polys]
ref_polygons = reference.get_polygons(
by_spec=False, include_paths=False, as_array=False
)
self._add_polygons(*ref_polygons)

self.add(reference.parent.labels)
self.add(reference.get_labels())
self.add(reference.get_paths())
self.remove(reference)
return self
Expand Down
7 changes: 5 additions & 2 deletions gdsfactory/component_reference.py
Expand Up @@ -286,7 +286,7 @@ def get_polygons(
layer_to_polygons[layer].append(polygon.points)
return layer_to_polygons

def get_labels(self, depth=None, set_transform=False):
def get_labels(self, depth=None, set_transform=True):
"""Return the list of labels created by this reference.

Args:
Expand All @@ -301,7 +301,10 @@ def get_labels(self, depth=None, set_transform=False):
out : list of `Label`
List containing the labels in this cell and its references.
"""
return self._reference.get_labels(depth=depth, set_transform=set_transform)
if set_transform:
return self._reference.get_labels(depth=depth)
else:
return self.parent.get_labels(depth=depth)

def get_bounding_box(self):
return self._reference.bounding_box()
Expand Down
18 changes: 18 additions & 0 deletions gdsfactory/tests/test_flatten.py
Expand Up @@ -40,6 +40,24 @@ def test_flattened_cell_keeps_ports():
assert len(c2.ports) == 2, len(c2.ports)


def test_flattened_cell_keeps_labels():
c1 = gf.Component()
c1.add_label("hi!")
c2 = c1.flatten()
assert len(c2.labels) == 1


def test_flatten_single_layer():
target_layer = (999, 51)
c1 = gf.components.straight()
c2 = c1.flatten(single_layer=target_layer)
c1_polygons = c1.get_polygons(as_array=False)
c2_polygons = c2.get_polygons(as_array=False)
assert len(c1_polygons) == len(c2_polygons)
for p in c2_polygons:
assert (p.layer, p.datatype) == target_layer


if __name__ == "__main__":
test_flattened_cell_keeps_ports()
# c1 = gf.components.mzi()
Expand Down