diff --git a/doc/examples_sphinx-gallery/visual_style.py b/doc/examples_sphinx-gallery/visual_style.py index 8273edfd6..232942a11 100644 --- a/doc/examples_sphinx-gallery/visual_style.py +++ b/doc/examples_sphinx-gallery/visual_style.py @@ -47,3 +47,25 @@ # default, you can use igraph's `configuration instance # :class:`igraph.configuration.Configuration`. A quick example on how to use # it can be found here: :ref:`tutorials-configuration`. + +# %% +# In the matplotlib backend, igraph creates a special container +# :class:`igraph.drawing.matplotlib.graph.GraphArtist` which is a matplotlib Artist +# and the first child of the target Axes. That object can be used to customize +# the plot appearance after the initial drawing, e.g.: +g = ig.Graph.Barabasi(n=30, m=1) +fig, ax = plt.subplots() +ig.plot(g, target=ax) +artist = ax.get_children()[0] +# Option 1: +artist.set(vertex_color="blue") +# Option 2: +artist.set_vertex_color("blue") +plt.show() + +# %% +# .. note:: +# The ``artist.set`` method can be used to change multiple properties at +# once and is generally more efficient than multiple calls to specific +# ``artist.set_...` methods. + diff --git a/src/igraph/drawing/__init__.py b/src/igraph/drawing/__init__.py index 5ae708ccc..06bcda90a 100644 --- a/src/igraph/drawing/__init__.py +++ b/src/igraph/drawing/__init__.py @@ -264,13 +264,16 @@ def plot(obj, target=None, bbox=(0, 0, 600, 600), *args, **kwds): warn("%s does not support plotting" % (obj,)) return else: - plotter( + result = plotter( backend, target, palette=palette, *args, **kwds, ) + # NOTE: for matplotlib, result is the container Artist. It would be + # good to return this instead of target, like we do for Cairo. + # However, that breaks API so let's wait for a major release if save_path is not None: if backend == "matplotlib": diff --git a/src/igraph/drawing/graph.py b/src/igraph/drawing/graph.py index ce56fe7ad..aa38d0571 100644 --- a/src/igraph/drawing/graph.py +++ b/src/igraph/drawing/graph.py @@ -555,4 +555,4 @@ def __plot__(self, backend, context, *args, **kwds): "drawer_factory", DrawerDirectory.resolve(self, backend)(context), ) - drawer.draw(self, *args, **kwds) + return drawer.draw(self, *args, **kwds) diff --git a/src/igraph/drawing/matplotlib/edge.py b/src/igraph/drawing/matplotlib/edge.py index adbdfab89..929555ecf 100644 --- a/src/igraph/drawing/matplotlib/edge.py +++ b/src/igraph/drawing/matplotlib/edge.py @@ -51,11 +51,12 @@ class VisualEdgeBuilder(AttributeCollectorBase): width = 2.0 background = None align_label = False + zorder = 1 return VisualEdgeBuilder def draw_directed_edge(self, edge, src_vertex, dest_vertex): - if src_vertex == dest_vertex: # TODO + if src_vertex == dest_vertex: return self.draw_loop_edge(edge, src_vertex) ax = self.context @@ -163,7 +164,7 @@ def draw_directed_edge(self, edge, src_vertex, dest_vertex): path["codes"].append("LINETO") # Draw the edge - stroke = mpl.patches.PathPatch( + arrowshaft = mpl.patches.PathPatch( mpl.path.Path( path["vertices"], codes=[getattr(mpl.path.Path, x) for x in path["codes"]], @@ -171,8 +172,10 @@ def draw_directed_edge(self, edge, src_vertex, dest_vertex): edgecolor=edge.color, facecolor="none", linewidth=edge.width, + zorder=edge.zorder, + transform=ax.transData, + clip_on=True, ) - ax.add_patch(stroke) # Draw the arrow head arrowhead = mpl.patches.Polygon( @@ -184,8 +187,11 @@ def draw_directed_edge(self, edge, src_vertex, dest_vertex): closed=True, facecolor=edge.color, edgecolor="none", + zorder=edge.zorder, + transform=ax.transData, + clip_on=True, ) - ax.add_patch(arrowhead) + return [arrowshaft, arrowhead] def draw_loop_edge(self, edge, vertex): """Draws a loop edge. @@ -201,7 +207,7 @@ def draw_loop_edge(self, edge, vertex): radius = vertex.size * 1.5 center_x = vertex.position[0] + cos(pi / 4) * radius / 2.0 center_y = vertex.position[1] - sin(pi / 4) * radius / 2.0 - stroke = mpl.patches.Arc( + art = mpl.patches.Arc( (center_x, center_y), radius / 2.0, radius / 2.0, @@ -210,9 +216,11 @@ def draw_loop_edge(self, edge, vertex): linewidth=edge.width, facecolor="none", edgecolor=edge.color, + zorder=edge.zorder, + transform=ax.transData, + clip_on=True, ) - # FIXME: make a PathCollection?? - ax.add_patch(stroke) + return [art] def draw_undirected_edge(self, edge, src_vertex, dest_vertex): """Draws an undirected edge. @@ -247,7 +255,7 @@ def draw_undirected_edge(self, edge, src_vertex, dest_vertex): path["vertices"].append(dest_vertex.position) path["codes"].append("LINETO") - stroke = mpl.patches.PathPatch( + art = mpl.patches.PathPatch( mpl.path.Path( path["vertices"], codes=[getattr(mpl.path.Path, x) for x in path["codes"]], @@ -255,6 +263,8 @@ def draw_undirected_edge(self, edge, src_vertex, dest_vertex): edgecolor=edge.color, facecolor="none", linewidth=edge.width, + zorder=edge.zorder, + transform=ax.transData, + clip_on=True, ) - # FIXME: make a PathCollection?? - ax.add_artist(stroke) + return [art] diff --git a/src/igraph/drawing/matplotlib/graph.py b/src/igraph/drawing/matplotlib/graph.py index 6673929ff..4269e5d5e 100644 --- a/src/igraph/drawing/matplotlib/graph.py +++ b/src/igraph/drawing/matplotlib/graph.py @@ -14,10 +14,11 @@ """ from warnings import warn +from functools import wraps, partial from igraph._igraph import convex_hull, VertexSeq from igraph.drawing.baseclasses import AbstractGraphDrawer -from igraph.drawing.utils import Point +from igraph.drawing.utils import Point, FakeModule from .edge import MatplotlibEdgeDrawer from .polygon import MatplotlibPolygonDrawer @@ -26,67 +27,171 @@ __all__ = ("MatplotlibGraphDrawer",) -_, plt = find_matplotlib() +mpl, plt = find_matplotlib() +try: + Artist = mpl.artist.Artist +except AttributeError: + Artist = FakeModule ##################################################################### -class MatplotlibGraphDrawer(AbstractGraphDrawer): - """Graph drawer that uses a pyplot.Axes as context""" - - _shape_dict = { - "rectangle": "s", - "circle": "o", - "hidden": "none", - "triangle-up": "^", - "triangle-down": "v", - } +# NOTE: https://github.com/networkx/grave/blob/main/grave/grave.py +def _stale_wrapper(func): + """Decorator to manage artist state.""" + + @wraps(func) + def inner(self, *args, **kwargs): + try: + func(self, *args, **kwargs) + finally: + self.stale = False + + return inner + + +def _forwarder(forwards, cls=None): + """Decorator to forward specific methods to Artist children.""" + if cls is None: + return partial(_forwarder, forwards) + + def make_forward(name): + def method(self, *args, **kwargs): + ret = getattr(cls.mro()[1], name)(self, *args, **kwargs) + for c in self.get_children(): + getattr(c, name)(*args, **kwargs) + return ret + + return method + + for f in forwards: + method = make_forward(f) + method.__name__ = f + method.__doc__ = "broadcasts {} to children".format(f) + setattr(cls, f, method) + + return cls + + +def _additional_set_methods(attributes, cls=None): + """Decorator to add specific set methods for children properties.""" + if cls is None: + return partial(_additional_set_methods, attributes) + + def make_setter(name): + def method(self, value): + self.set(**{name: value}) + return method + + for attr in attributes: + desc = attr.replace('_', ' ') + method = make_setter(attr) + method.__name__ = f"set_{attr}" + method.__doc__ = f"Set {desc}." + setattr(cls, f"set_{attr}", method) + + return cls + + +@_additional_set_methods( + ( + "vertex_color", + "vertex_size", + "vertex_font", + "vertex_label", + "vertex_label_angle", + "vertex_label_color", + "vertex_label_dist", + "vertex_label_size", + "vertex_order", + "vertex_shape", + "vertex_size", + "edge_color", + "edge_curved", + "edge_font", + "edge_arrow_size", + "edge_arrow_width", + "edge_width", + "edge_label", + "edge_background", + "edge_align_label", + "autocurve", + "layout", + ) +) +@_forwarder( + ( + "set_clip_path", + "set_clip_box", + "set_transform", + "set_snap", + "set_sketch_params", + "set_figure", + "set_animated", + "set_picker", + ) +) +class GraphArtist(Artist, AbstractGraphDrawer): + """Artist for an igraph.Graph object. + + Arguments: + graph: An igraph.Graph object to plot + layout: A layout object or matrix of coordinates to use for plotting. + Each element or row should describes the coordinates for a vertex. + vertex_style: A dictionary specifying style options for vertices. + edge_style: A dictionary specifying style options for edges. + """ def __init__( self, - ax, + graph, vertex_drawer_factory=MatplotlibVertexDrawer, edge_drawer_factory=MatplotlibEdgeDrawer, + mark_groups=None, + layout=None, + palette=None, + **kwds, ): - """Constructs the graph drawer and associates it with the mpl Axes - - @param ax: the matplotlib Axes to draw into. - @param vertex_drawer_factory: a factory method that returns an - L{AbstractVertexDrawer} instance bound to the given Matplotlib axes. - The factory method must take three parameters: the axes and the - palette to be used for drawing colored vertices, and the layout of - the graph. The default vertex drawer is L{MatplotlibVertexDrawer}. - @param edge_drawer_factory: a factory method that returns an - L{AbstractEdgeDrawer} instance bound to a given Matplotlib Axes. - The factory method must take two parameters: the Axes and the palette - to be used for drawing colored edges. The default edge drawer is - L{MatplotlibEdgeDrawer}. - """ - self.ax = ax + super().__init__() + self.graph = graph self.vertex_drawer_factory = vertex_drawer_factory self.edge_drawer_factory = edge_drawer_factory - - def draw(self, graph, *args, **kwds): - # Deferred import to avoid a cycle in the import graph - from igraph.clustering import VertexClustering, VertexCover - - # Positional arguments are not used - if args: - warn( - "Positional arguments to plot functions are ignored " - "and will be deprecated soon.", - DeprecationWarning, - ) - - # Some abbreviations for sake of simplicity - directed = graph.is_directed() - ax = self.ax - - # Palette - palette = kwds.pop("palette", None) - - # Calculate/get the layout of the graph - layout = self.ensure_layout(kwds.get("layout", None), graph) + self.kwds = kwds + self.kwds["mark_groups"] = mark_groups + self.kwds["palette"] = palette + self.kwds["layout"] = layout + + self._kwds_post_update() + + def _kwds_post_update(self): + self.kwds["layout"] = self.ensure_layout(self.kwds["layout"], self.graph) + self.edge_curved = self._set_edge_curve(**self.kwds) + self._clear_state() + self.stale = True + + def _clear_state(self): + self._vertices = [] + self._edges = [] + self._vertex_labels = [] + self._edge_labels = [] + self._group_artists = [] + self._legend_info = {} + + def get_children(self): + artists = sum( + [ + self._group_artists, + self._edges, + self._vertices, + self._edge_labels, + self._vertex_labels, + ], + [], + ) + return tuple(artists) + + def _set_edge_curve(self, **kwds): + graph = self.graph # Decide whether we need to calculate the curvature of edges # automatically -- and calculate them if needed. @@ -94,8 +199,8 @@ def draw(self, graph, *args, **kwds): if autocurve or ( autocurve is None and "edge_curved" not in kwds - and "curved" not in graph.edge_attributes() - and graph.ecount() < 10000 + and "curved" not in self.graph.edge_attributes() + and self.graph.ecount() < 10000 ): from igraph import autocurve @@ -103,137 +208,72 @@ def draw(self, graph, *args, **kwds): if default is True: default = 0.5 default = float(default) - kwds["edge_curved"] = autocurve( + return autocurve( graph, attribute=None, default=default, ) + return None - # Construct the vertex, edge and label drawers - vertex_drawer = self.vertex_drawer_factory(ax, palette, layout) - edge_drawer = self.edge_drawer_factory(ax, palette) + def get_vertices(self): + """Get vertex artists.""" + return self._vertices - # Construct the visual vertex/edge builders based on the specifications - # provided by the vertex_drawer and the edge_drawer - vertex_builder = vertex_drawer.VisualVertexBuilder(graph.vs, kwds) - edge_builder = edge_drawer.VisualEdgeBuilder(graph.es, kwds) - - # Draw the highlighted groups (if any) - if "mark_groups" in kwds: - mark_groups = kwds["mark_groups"] - - # Deferred import to avoid a cycle in the import graph - from igraph.clustering import VertexClustering, VertexCover - - # Figure out what to do with mark_groups in order to be able to - # iterate over it and get memberlist-color pairs - if isinstance(mark_groups, dict): - # Dictionary mapping vertex indices or tuples of vertex - # indices to colors - group_iter = iter(mark_groups.items()) - elif isinstance(mark_groups, (VertexClustering, VertexCover)): - # Vertex clustering - group_iter = ((group, color) for color, group in enumerate(mark_groups)) - elif hasattr(mark_groups, "__iter__"): - # Lists, tuples, iterators etc - group_iter = iter(mark_groups) - else: - # False - group_iter = iter({}.items()) + def get_edges(self): + """Get edge artists. - if kwds.get("legend", False): - legend_info = { - "handles": [], - "labels": [], - } - - # Iterate over color-memberlist pairs - for group, color_id in group_iter: - if not group or color_id is None: - continue - - color = palette.get(color_id) - - if isinstance(group, VertexSeq): - group = [vertex.index for vertex in group] - if not hasattr(group, "__iter__"): - raise TypeError("group membership list must be iterable") - - # Get the vertex indices that constitute the convex hull - hull = [group[i] for i in convex_hull([layout[idx] for idx in group])] - - # Calculate the preferred rounding radius for the corners - corner_radius = 1.25 * max(vertex_builder[idx].size for idx in hull) - - # Construct the polygon - polygon = [layout[idx] for idx in hull] - - if len(polygon) == 2: - # Expand the polygon (which is a flat line otherwise) - a, b = Point(*polygon[0]), Point(*polygon[1]) - c = corner_radius * (a - b).normalized() - n = Point(-c[1], c[0]) - polygon = [a + n, b + n, b - c, b - n, a - n, a + c] - else: - # Expand the polygon around its center of mass - center = Point( - *[sum(coords) / float(len(coords)) for coords in zip(*polygon)] - ) - polygon = [ - Point(*point).towards(center, -corner_radius) - for point in polygon - ] - - # Draw the hull - facecolor = (color[0], color[1], color[2], 0.25 * color[3]) - drawer = MatplotlibPolygonDrawer(ax) - drawer.draw( - polygon, - corner_radius=corner_radius, - facecolor=facecolor, - edgecolor=color, - ) + Note that for directed edges, an edge might have more than one + artist, e.g. arrow shaft and arrowhead. + """ + return self._edges - if kwds.get("legend", False): - legend_info["handles"].append( - plt.Rectangle( - (0, 0), - 0, - 0, - facecolor=facecolor, - edgecolor=color, - ) - ) - legend_info["labels"].append(str(color_id)) + def get_groups(self): + """Get group/cluster/cover artists.""" + return self._group_artists - if kwds.get("legend", False): - ax.legend( - legend_info["handles"], - legend_info["labels"], - ) + def get_vertex_labels(self): + """Get vertex label artists.""" + return self._vertex_labels - # Determine the order in which we will draw the vertices and edges - vertex_order = self._determine_vertex_order(graph, kwds) - edge_order = self._determine_edge_order(graph, kwds) + def get_edge_labels(self): + """Get edge label artists.""" + return self._edge_labels - # Construct the iterator that we will use to draw the vertices - vs = graph.vs - if vertex_order is None: - # Default vertex order - vertex_coord_iter = zip(vs, vertex_builder, layout) + def get_datalim(self): + """Get limits on x/y axes based on the graph layout data. + + There is a small padding based on the size of the vertex marker to + ensure it fits into the canvas. + """ + import numpy as np + + vertex_builder = self.vertex_builder + layout = self.kwds["layout"] + + mins = np.min(layout, axis=0) + maxs = np.max(layout, axis=0) + + # Pad by vertex size, to ensure they fit + if vertex_builder.size is not None: + mins -= vertex_builder.size * 1.1 + maxs += vertex_builder.size * 1.1 else: - # Specified vertex order - vertex_coord_iter = ( - (vs[i], vertex_builder[i], layout[i]) for i in vertex_order - ) + mins[0] -= vertex_builder.width * 0.55 + mins[1] -= vertex_builder.height * 0.55 + maxs[0] += vertex_builder.width * 0.55 + maxs[1] += vertex_builder.height * 0.55 - # Draw the vertices - drawer_method = vertex_drawer.draw - for vertex, visual_vertex, coords in vertex_coord_iter: - drawer_method(visual_vertex, vertex, coords) + return (mins, maxs) + + def _draw_vertex_labels(self): + import numpy as np + + kwds = self.kwds + layout = self.kwds["layout"] + vertex_builder = self.vertex_builder + vertex_order = self.vertex_order # Construct the iterator that we will use to draw the vertex labels - vs = graph.vs if vertex_order is None: # Default vertex order vertex_coord_iter = zip(vertex_builder, layout) @@ -251,16 +291,226 @@ def draw(self, graph, *args, **kwds): vertex.label_size, ) - ax.text( - *coords, + # Locate text relative to vertex in data units. This is consistent + # with the vertex size being in data units, but might be not fully + # satisfactory when zooming in/out. In that case, revisit this + # section + dist = vertex.label_dist + angle = vertex.label_angle + if vertex.size is not None: + vertex_width = vertex.size + vertex_height = vertex.size + else: + vertex_width = vertex.width + vertex_height = vertex.height + xtext = coords[0] + dist * vertex_width * np.cos(angle) + ytext = coords[1] + dist * vertex_height * np.sin(angle) + xytext = (xtext, ytext) + textcoords = "data" + + art = mpl.text.Annotation( vertex.label, + coords, + xytext=xytext, + textcoords=textcoords, fontsize=label_size, - ha='center', - va='center', - # TODO: overlap, offset, etc. + ha="center", + va="center", + transform=self.axes.transData, + clip_on=True, + zorder=3, + ) + self._vertex_labels.append(art) + + def _draw_edge_labels(self): + graph = self.graph + kwds = self.kwds + vertex_builder = self.vertex_builder + edge_builder = self.edge_builder + edge_drawer = self.edge_drawer + edge_order = self.edge_order or range(self.graph.ecount()) + + labels = kwds.get("edge_label", None) + if labels is None: + return + + edge_label_iter = ( + (labels[i], edge_builder[i], graph.es[i]) for i in edge_order + ) + for label, visual_edge, edge in edge_label_iter: + # Ask the edge drawer to propose an anchor point for the label + src, dest = edge.tuple + src_vertex, dest_vertex = vertex_builder[src], vertex_builder[dest] + (x, y), (halign, valign) = edge_drawer.get_label_position( + visual_edge, + src_vertex, + dest_vertex, + ) + + text_kwds = {} + text_kwds["ha"] = halign.value + text_kwds["va"] = valign.value + + if visual_edge.background is not None: + text_kwds["bbox"] = dict( + facecolor=visual_edge.background, + edgecolor="none", + ) + text_kwds["ha"] = "center" + text_kwds["va"] = "center" + + if visual_edge.align_label: + # Rotate the text to align with the edge + rotation = edge_drawer.get_label_rotation( + visual_edge, + src_vertex, + dest_vertex, + ) + text_kwds["rotation"] = rotation + + art = mpl.text.Annotation( + label, + (x, y), + fontsize=visual_edge.label_size, + color=visual_edge.label_color, + transform=self.axes.transData, + clip_on=True, + zorder=3, + **text_kwds, + ) + self._vertex_labels.append(art) + + def _draw_groups(self): + """Draw the highlighted vertex groups, if requested""" + # Deferred import to avoid a cycle in the import graph + from igraph.clustering import VertexClustering, VertexCover + + kwds = self.kwds + palette = self.kwds["palette"] + layout = self.kwds["layout"] + mark_groups = self.kwds["mark_groups"] + vertex_builder = self.vertex_builder + + if not mark_groups: + return + + # Figure out what to do with mark_groups in order to be able to + # iterate over it and get memberlist-color pairs + if isinstance(mark_groups, dict): + # Dictionary mapping vertex indices or tuples of vertex + # indices to colors + group_iter = iter(mark_groups.items()) + elif isinstance(mark_groups, (VertexClustering, VertexCover)): + # Vertex clustering + group_iter = ((group, color) for color, group in enumerate(mark_groups)) + elif hasattr(mark_groups, "__iter__"): + # One-off generators: we need to store the actual list for future + # calls (e.g. resizing, recoloring, etc.). If we don't do this, + # the generator is exhausted: we cannot rewind it. + self.mark_groups = mark_groups = list(mark_groups) + # Lists, tuples, iterators etc + group_iter = iter(mark_groups) + else: + # False + group_iter = iter({}.items()) + + if kwds.get("legend", False): + legend_info = { + "handles": [], + "labels": [], + } + + # Iterate over color-memberlist pairs + for group, color_id in group_iter: + if not group or color_id is None: + continue + + color = palette.get(color_id) + + if isinstance(group, VertexSeq): + group = [vertex.index for vertex in group] + if not hasattr(group, "__iter__"): + raise TypeError("group membership list must be iterable") + + # Get the vertex indices that constitute the convex hull + hull = [group[i] for i in convex_hull([layout[idx] for idx in group])] + + # Calculate the preferred rounding radius for the corners + corner_radius = 1.25 * max(vertex_builder[idx].size for idx in hull) + + # Construct the polygon + polygon = [layout[idx] for idx in hull] + + if len(polygon) == 2: + # Expand the polygon (which is a flat line otherwise) + a, b = Point(*polygon[0]), Point(*polygon[1]) + c = corner_radius * (a - b).normalized() + n = Point(-c[1], c[0]) + polygon = [a + n, b + n, b - c, b - n, a - n, a + c] + else: + # Expand the polygon around its center of mass + center = Point( + *[sum(coords) / float(len(coords)) for coords in zip(*polygon)] + ) + polygon = [ + Point(*point).towards(center, -corner_radius) for point in polygon + ] + + # Draw the hull + facecolor = (color[0], color[1], color[2], 0.25 * color[3]) + drawer = MatplotlibPolygonDrawer(self.axes) + art = drawer.draw( + polygon, + corner_radius=corner_radius, + facecolor=facecolor, + edgecolor=color, ) + self._group_artists.append(art) + + if kwds.get("legend", False): + legend_info["handles"].append( + plt.Rectangle( + (0, 0), + 0, + 0, + facecolor=facecolor, + edgecolor=color, + ) + ) + legend_info["labels"].append(str(color_id)) + + if kwds.get("legend", False): + self.legend_info = legend_info + + def _draw_vertices(self): + """Draw the vertices""" + graph = self.graph + layout = self.kwds["layout"] + vertex_drawer = self.vertex_drawer + vertex_builder = self.vertex_builder + vertex_order = self.vertex_order + + vs = graph.vs + if vertex_order is None: + # Default vertex order + vertex_coord_iter = zip(vs, vertex_builder, layout) + else: + # Specified vertex order + vertex_coord_iter = ( + (vs[i], vertex_builder[i], layout[i]) for i in vertex_order + ) + for vertex, visual_vertex, coords in vertex_coord_iter: + art = vertex_drawer.draw(visual_vertex, vertex, coords) + self._vertices.append(art) + + def _draw_edges(self): + """Draw the edges""" + graph = self.graph + vertex_builder = self.vertex_builder + edge_drawer = self.edge_drawer + edge_builder = self.edge_builder + edge_order = self.edge_order - # Construct the iterator that we will use to draw the edges es = graph.es if edge_order is None: # Default edge order @@ -269,60 +519,201 @@ def draw(self, graph, *args, **kwds): # Specified edge order edge_coord_iter = ((es[i], edge_builder[i]) for i in edge_order) - # Draw the edges + directed = graph.is_directed() if directed: + # Arrows and the likes drawer_method = edge_drawer.draw_directed_edge else: + # Lines drawer_method = edge_drawer.draw_undirected_edge for edge, visual_edge in edge_coord_iter: src, dest = edge.tuple src_vertex, dest_vertex = vertex_builder[src], vertex_builder[dest] - drawer_method(visual_edge, src_vertex, dest_vertex) + arts = drawer_method(visual_edge, src_vertex, dest_vertex) + self._edges.extend(arts) - # Draw the edge labels - labels = kwds.get("edge_label", None) - if labels is not None: - edge_label_iter = ( - (labels[i], edge_builder[i], graph.es[i]) for i in range(graph.ecount()) + def _reprocess(self): + """Prepare artist and children for the actual drawing. + + Children are not drawn here, but the dictionaries of properties are + marshalled to their specific artists. + """ + # clear state and mark as stale + # since all children artists are part of the state, clearing it + # will trigger a deletion by the backend at the next draw cycle + self._clear_state() + self.stale = True + + # get local refs to everything (just for less typing) + graph = self.graph + palette = self.kwds["palette"] + layout = self.kwds["layout"] + kwds = self.kwds + + # Construct the vertex, edge and label drawers + self.vertex_drawer = self.vertex_drawer_factory(self.axes, palette, layout) + self.edge_drawer = self.edge_drawer_factory(self.axes, palette) + + # Construct the visual vertex/edge builders based on the specifications + # provided by the vertex_drawer and the edge_drawer + self.vertex_builder = self.vertex_drawer.VisualVertexBuilder(graph.vs, kwds) + self.edge_builder = self.edge_drawer.VisualEdgeBuilder(graph.es, kwds) + + # Determine the order in which we will draw the vertices and edges + # These methods come from AbstractGraphDrawer + self.vertex_order = self._determine_vertex_order(graph, kwds) + self.edge_order = self._determine_edge_order(graph, kwds) + + self._draw_groups() + self._draw_edges() + self._draw_vertices() + self._draw_vertex_labels() + self._draw_edge_labels() + + # Forward mpl properties to children + # TODO sort out all of the things that need to be forwarded + for child in self.get_children(): + # set the figure / axes on child, this ensures each primitive + # knows where to draw + child.set_figure(self.figure) + child.axes = self.axes + + # forward the clippath/box to the children need this logic + # because mpl exposes some fast-path logic + clip_path = self.get_clip_path() + if clip_path is None: + clip_box = self.get_clip_box() + child.set_clip_box(clip_box) + else: + child.set_clip_path(clip_path) + + @_stale_wrapper + def draw(self, renderer, *args, **kwds): + """Draw each of the children, with some buffering mechanism.""" + if not self.get_visible(): + return + + if not self.get_children(): + self._reprocess() + + # NOTE: looks like we have to manage the zorder ourselves + children = list(self.get_children()) + children.sort(key=lambda x: x.zorder) + for art in children: + art.draw(renderer, *args, **kwds) + + def set( + self, + **kwds, + ): + """Set multiple parameters at once. + + The same options can be used as in the igraph.plot function. + """ + if len(kwds) == 0: + return + + self.kwds.update(kwds) + self._kwds_post_update() + + def contains(self, mouseevent): + """Track 'contains' event for mouse interactions.""" + props = {"vertices": [], "edges": []} + hit = False + for i, art in enumerate(self._edges): + edge_hit = art.contains(mouseevent)[0] + hit |= edge_hit + props["edges"].append(i) + + for i, art in enumerate(self._vertices): + vertex_hit = art.contains(mouseevent)[0] + hit |= vertex_hit + props["vertices"].append(i) + + return hit, props + + def pick(self, mouseevent): + """Track 'pick' event for mouse interactions.""" + if self.pickable(): + picker = self.get_picker() + if callable(picker): + inside, prop = picker(self, mouseevent) + else: + inside, prop = self.contains(mouseevent) + if inside: + self.figure.canvas.pick_event(mouseevent, self, **prop) + + +class MatplotlibGraphDrawer(AbstractGraphDrawer): + """Graph drawer that uses a pyplot.Axes as context""" + + _shape_dict = { + "rectangle": "s", + "circle": "o", + "hidden": "none", + "triangle-up": "^", + "triangle-down": "v", + } + + def __init__( + self, + ax, + vertex_drawer_factory=MatplotlibVertexDrawer, + edge_drawer_factory=MatplotlibEdgeDrawer, + ): + """Constructs the graph drawer and associates it with the mpl Axes + + @param ax: the matplotlib Axes to draw into. + @param vertex_drawer_factory: a factory method that returns an + L{AbstractVertexDrawer} instance bound to the given Matplotlib axes. + The factory method must take three parameters: the axes and the + palette to be used for drawing colored vertices, and the layout of + the graph. The default vertex drawer is L{MatplotlibVertexDrawer}. + @param edge_drawer_factory: a factory method that returns an + L{AbstractEdgeDrawer} instance bound to a given Matplotlib Axes. + The factory method must take two parameters: the Axes and the palette + to be used for drawing colored edges. The default edge drawer is + L{MatplotlibEdgeDrawer}. + """ + self.ax = ax + self.vertex_drawer_factory = vertex_drawer_factory + self.edge_drawer_factory = edge_drawer_factory + + def draw(self, graph, *args, **kwds): + if args: + warn( + "Positional arguments to plot functions are ignored " + "and will be deprecated soon.", + DeprecationWarning, ) - for label, visual_edge, edge in edge_label_iter: - # Ask the edge drawer to propose an anchor point for the label - src, dest = edge.tuple - src_vertex, dest_vertex = vertex_builder[src], vertex_builder[dest] - (x, y), (halign, valign) = edge_drawer.get_label_position( - visual_edge, - src_vertex, - dest_vertex, - ) - text_kwargs = {} - text_kwargs['ha'] = halign.value - text_kwargs['va'] = valign.value + # Some abbreviations for sake of simplicity + ax = self.ax - if visual_edge.background is not None: - text_kwargs['bbox'] = dict( - facecolor=visual_edge.background, - edgecolor='none', - ) - text_kwargs['ha'] = 'center' - text_kwargs['va'] = 'center' + # Create artist + art = GraphArtist( + graph, + vertex_drawer_factory=self.vertex_drawer_factory, + edge_drawer_factory=self.edge_drawer_factory, + *args, + **kwds, + ) + + # Bind artist to axes + ax.add_artist(art) + + # Create children artists (this also binds them to the axes) + art._reprocess() + + # Legend for groups + if ("mark_groups" in kwds) and kwds.get("legend", False): + ax.legend( + art._legend_info["handles"], + art._legend_info["labels"], + ) - if visual_edge.align_label: - # Rotate the text to align with the edge - rotation = edge_drawer.get_label_rotation( - visual_edge, src_vertex, dest_vertex, - ) - text_kwargs['rotation'] = rotation - - ax.text( - x, - y, - label, - fontsize=visual_edge.label_size, - color=visual_edge.label_color, - **text_kwargs, - # TODO: offset, etc. - ) + # Set new data limits + ax.update_datalim(art.get_datalim()) # Despine ax.spines["right"].set_visible(False) @@ -339,3 +730,5 @@ def draw(self, graph, *args, **kwds): # Autoscale for x/y axis limits ax.autoscale_view() + + return art diff --git a/src/igraph/drawing/matplotlib/polygon.py b/src/igraph/drawing/matplotlib/polygon.py index 9cbbf6546..8e429a9b8 100644 --- a/src/igraph/drawing/matplotlib/polygon.py +++ b/src/igraph/drawing/matplotlib/polygon.py @@ -80,8 +80,11 @@ def draw(self, points, corner_radius=0, **kwds): codes.extend([mpl.path.Path.CURVE4] * 3) u = v - stroke = mpl.patches.PathPatch( + art = mpl.patches.PathPatch( mpl.path.Path(path, codes=codes, closed=True), + transform=ax.transData, + clip_on=True, + zorder=4, **kwds, ) - ax.add_patch(stroke) + return art diff --git a/src/igraph/drawing/matplotlib/vertex.py b/src/igraph/drawing/matplotlib/vertex.py index 3a7db0d90..00270a010 100644 --- a/src/igraph/drawing/matplotlib/vertex.py +++ b/src/igraph/drawing/matplotlib/vertex.py @@ -44,6 +44,7 @@ class VisualVertexBuilder(AttributeCollectorBase): return VisualVertexBuilder def draw(self, visual_vertex, vertex, coords): + """Build the Artist for a vertex and return it.""" ax = self.context width = ( @@ -57,7 +58,7 @@ def draw(self, visual_vertex, vertex, coords): else visual_vertex.size ) - stroke = visual_vertex.shape.draw_path( + art = visual_vertex.shape.draw_path( ax, coords[0], coords[1], @@ -68,4 +69,4 @@ def draw(self, visual_vertex, vertex, coords): linewidth=visual_vertex.frame_width, zorder=visual_vertex.zorder, ) - ax.add_patch(stroke) + return art diff --git a/src/igraph/drawing/shapes.py b/src/igraph/drawing/shapes.py index d4ca9fc49..69d39bb4d 100644 --- a/src/igraph/drawing/shapes.py +++ b/src/igraph/drawing/shapes.py @@ -96,7 +96,12 @@ def draw_path(ctx, center_x, center_y, width, height=None, **kwargs): height = height or width if hasattr(plt, "Axes") and isinstance(ctx, plt.Axes): return mpl.patches.Rectangle( - (center_x - width / 2, center_y - height / 2), width, height, **kwargs + (center_x - width / 2, center_y - height / 2), + width, + height, + transform=ctx.transData, + clip_on=True, + **kwargs, ) else: ctx.rectangle(center_x - width / 2, center_y - height / 2, width, height) @@ -163,7 +168,13 @@ def draw_path(ctx, center_x, center_y, width, height=None, **kwargs): @see: ShapeDrawer.draw_path""" if hasattr(plt, "Axes") and isinstance(ctx, plt.Axes): - return mpl.patches.Circle((center_x, center_y), width / 2, **kwargs) + return mpl.patches.Circle( + (center_x, center_y), + width / 2, + transform=ctx.transData, + clip_on=True, + **kwargs, + ) else: ctx.arc(center_x, center_y, width / 2, 0, 2 * pi) @@ -197,7 +208,13 @@ def draw_path(ctx, center_x, center_y, width, height=None, **kwargs): [center_x + 0.5 * width, center_y - 0.333 * height], [center_x, center_x + 0.667 * height], ] - return mpl.patches.Polygon(vertices, closed=True, **kwargs) + return mpl.patches.Polygon( + vertices, + closed=True, + transform=ctx.transData, + clip_on=True, + **kwargs, + ) else: ctx.move_to(center_x - width / 2, center_y + height / 2) ctx.line_to(center_x, center_y - height / 2) @@ -234,7 +251,13 @@ def draw_path(ctx, center_x, center_y, width, height=None, **kwargs): [center_x + 0.5 * width, center_y + 0.333 * height], [center_x, center_y - 0.667 * height], ] - return mpl.patches.Polygon(vertices, closed=True, **kwargs) + return mpl.patches.Polygon( + vertices, + closed=True, + transform=ctx.transData, + clip_on=True, + **kwargs, + ) else: ctx.move_to(center_x - width / 2, center_y - height / 2) @@ -273,7 +296,13 @@ def draw_path(ctx, center_x, center_y, width, height=None, **kwargs): [center_x + 0.5 * width, center_y], [center_x, center_y + 0.5 * height], ] - return mpl.patches.Polygon(vertices, closed=True, **kwargs) + return mpl.patches.Polygon( + vertices, + closed=True, + transform=ctx.transData, + clip_on=True, + **kwargs, + ) else: ctx.move_to(center_x - width / 2, center_y) ctx.line_to(center_x, center_y + height / 2) diff --git a/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed.png b/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed.png index bd2235aa5..8a8179e05 100644 Binary files a/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed.png and b/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed.png differ diff --git a/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed_large.png b/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed_large.png index 6d7f5f038..a6ea8972f 100644 Binary files a/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed_large.png and b/tests/drawing/matplotlib/baseline_images/test_graph/clustering_directed_large.png differ diff --git a/tests/drawing/matplotlib/baseline_images/test_graph/graph_basic.png b/tests/drawing/matplotlib/baseline_images/test_graph/graph_basic.png index c9883a651..259ec8d48 100644 Binary files a/tests/drawing/matplotlib/baseline_images/test_graph/graph_basic.png and b/tests/drawing/matplotlib/baseline_images/test_graph/graph_basic.png differ diff --git a/tests/drawing/matplotlib/baseline_images/test_graph/graph_directed.png b/tests/drawing/matplotlib/baseline_images/test_graph/graph_directed.png index fafa8fe0b..19f703fde 100644 Binary files a/tests/drawing/matplotlib/baseline_images/test_graph/graph_directed.png and b/tests/drawing/matplotlib/baseline_images/test_graph/graph_directed.png differ diff --git a/tests/drawing/matplotlib/baseline_images/test_graph/graph_edit_children.png b/tests/drawing/matplotlib/baseline_images/test_graph/graph_edit_children.png index 33fa772b2..14a9122d8 100644 Binary files a/tests/drawing/matplotlib/baseline_images/test_graph/graph_edit_children.png and b/tests/drawing/matplotlib/baseline_images/test_graph/graph_edit_children.png differ diff --git a/tests/drawing/matplotlib/baseline_images/test_graph/graph_mark_groups_directed.png b/tests/drawing/matplotlib/baseline_images/test_graph/graph_mark_groups_directed.png index fafa8fe0b..19f703fde 100644 Binary files a/tests/drawing/matplotlib/baseline_images/test_graph/graph_mark_groups_directed.png and b/tests/drawing/matplotlib/baseline_images/test_graph/graph_mark_groups_directed.png differ diff --git a/tests/drawing/matplotlib/test_graph.py b/tests/drawing/matplotlib/test_graph.py index 0cec58830..bbc7554d2 100644 --- a/tests/drawing/matplotlib/test_graph.py +++ b/tests/drawing/matplotlib/test_graph.py @@ -15,6 +15,7 @@ try: import matplotlib as mpl + mpl.use("agg") import matplotlib.pyplot as plt except ImportError: @@ -68,17 +69,22 @@ def test_mark_groups_squares(self): plt.close("all") g = Graph.Ring(5, directed=True) fig, ax = plt.subplots() - plot(g, target=ax, mark_groups=True, vertex_shape="s", - layout=self.layout_small_ring) + plot( + g, + target=ax, + mark_groups=True, + vertex_shape="s", + layout=self.layout_small_ring, + ) @image_comparison(baseline_images=["graph_edit_children"], remove_text=True) def test_mark_groups_squares(self): plt.close("all") g = Graph.Ring(5) fig, ax = plt.subplots() - plot(g, target=ax, vertex_shape="o", - layout=self.layout_small_ring) - dot = ax.get_children()[0] + plot(g, target=ax, vertex_shape="o", layout=self.layout_small_ring) + graph_artist = ax.get_children()[0] + dot = graph_artist.get_vertices()[0] dot.set_facecolor("blue") dot.radius *= 0.5 @@ -86,7 +92,7 @@ def test_mark_groups_squares(self): def test_gh_587(self): plt.close("all") g = Graph.Ring(5) - with overridden_configuration('plotting.backend', 'matplotlib'): + with overridden_configuration("plotting.backend", "matplotlib"): plot(g, target="graph_basic.png", layout=self.layout_small_ring) @@ -168,8 +174,7 @@ def test_clustering_directed_small(self): g = Graph.Ring(5, directed=True) clu = VertexClustering(g, [0] * 5) fig, ax = plt.subplots() - plot(clu, target=ax, mark_groups=True, - layout=self.layout_small_ring) + plot(clu, target=ax, mark_groups=True, layout=self.layout_small_ring) @image_comparison(baseline_images=["clustering_directed_large"], remove_text=True) def test_clustering_directed_large(self):