Skip to content

Commit

Permalink
Bubbles with multiple args (#279)
Browse files Browse the repository at this point in the history
* Bubbles with multiple args

* Drawing multiple args

* Frame drawing

* Align slot boundaries
  • Loading branch information
toumix committed Apr 22, 2024
1 parent 2bac279 commit 006c396
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 59 deletions.
71 changes: 48 additions & 23 deletions discopy/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ def zero(cls, dom, cod):
"""
return cls.sum_factory((), dom, cod)

def bubble(self, **params) -> Bubble:
def bubble(self, *args, **kwargs) -> Bubble:
""" Unary operator on homsets. """
return self.bubble_factory(self, **params)
return self.bubble_factory(self, *args, **kwargs)

@property
def free_symbols(self) -> "set[sympy.Symbol]":
Expand Down Expand Up @@ -686,47 +686,71 @@ class Bubble(Box):
objects :code:`dom` and :code:`cod`.
Parameters:
arg : The arrow inside the bubble.
args : The arrows inside the bubble.
dom : The domain of the bubble, default is that of :code:`other`.
cod : The codomain of the bubble, default is that of :code:`other`.
name (str) : An optional name for the bubble.
method (str) : The method to call when a functor is applied to it.
kwargs : Passed to the `__init__` of :class:`Box`.
"""
def __init__(self, arg: Arrow, dom: Ob = None, cod: Ob = None):
dom = arg.dom if dom is None else dom
cod = arg.cod if cod is None else cod
self.arg = arg
Box.__init__(self, "Bubble", dom, cod)
def __init__(self, *args: Arrow, dom: Ob = None, cod: Ob = None,
name="Bubble", method="bubble", **kwargs):
dom, = set(arg.dom for arg in args) if dom is None else (dom, )
cod, = set(arg.cod for arg in args) if cod is None else (cod, )
self.args, self.method = args, method
Box.__init__(self, name, dom, cod, **kwargs)

@property
def arg(self):
""" The arrow inside the bubble if there is exactly one. """
if len(self.args) == 1:
return self.args[0]
raise ValueError(f"{self} has multiple args.")

@property
def is_id_on_objects(self):
""" Whether the bubble is identity on objects. """
return (self.dom, self.cod) == (self.arg.dom, self.arg.cod)
return len(self.args) == 1 and (
self.dom, self.cod) == (self.arg.dom, self.arg.cod)

def __eq__(self, other):
if isinstance(other, Bubble):
return all(getattr(self, x) == getattr(other, x) for x in (
"args", "dom", "cod", "name", "method"))
return not isinstance(other, Box) and super().__eq__(other)

def __hash__(self):
return hash(tuple(getattr(self, x) for x in [
"args", "dom", "cod", "name", "method"]))

def __str__(self):
str_args = '' if self.is_id_on_objects\
else f'dom={self.dom}, cod={self.cod}'
return f"({self.arg}).bubble({str_args})"
str_args = ",".join(map(str, self.args))
str_dom_cod = '' if self.is_id_on_objects else (
f'dom={self.dom}, cod={self.cod}')
return f"({str_args}).bubble({str_dom_cod})"

def __repr__(self):
str_args = repr(self.arg) if self.is_id_on_objects else\
f"{repr(self.arg)}, dom={repr(self.dom)}, cod={repr(self.cod)}"
return f"{factory_name(type(self))}({str_args})"
repr_args = ", ".join(map(repr, self.args))
repr_dom_cod = "" if self.is_id_on_objects else (
f", dom={repr(self.dom)}, cod={repr(self.cod)}")
return factory_name(type(self)) + (f"({repr_args}{repr_dom_cod})")

@property
def free_symbols(self):
return super().free_symbols.union(self.arg.free_symbols)
return super().free_symbols.union(*[f.free_symbols for f in self.args])

def to_tree(self):
return {
'factory': factory_name(type(self)),
'arg': self.arg.to_tree(),
'args': [f.to_tree() for f in self.args],
'dom': self.dom.to_tree(),
'cod': self.cod.to_tree()}

@classmethod
def from_tree(cls, tree):
dom, cod, arg = map(from_tree, (
tree['dom'], tree['cod'], tree['arg']))
return cls(arg=arg, dom=dom, cod=cod)
args = [tree['arg']] if 'args' not in tree else tree['args']
dom, cod = map(from_tree, (tree['dom'], tree['cod']))
return cls(*map(from_tree, args), dom=dom, cod=cod)


@dataclass
Expand Down Expand Up @@ -868,9 +892,10 @@ def __call__(self, other):
if isinstance(other, Sum):
return sum(map(self, other.terms),
self.cod.ar.zero(self(other.dom), self(other.cod)))
if isinstance(other, Bubble) and hasattr(self.cod.ar, "bubble"):
return self(other.arg).bubble(
dom=self(other.dom), cod=self(other.cod))
if isinstance(other, Bubble) and hasattr(self.cod.ar, other.method):
dom, cod = map(self, (other.dom, other.cod))
return getattr(self.cod.ar, other.method)(
*map(self, other.args), dom=dom, cod=cod)
if isinstance(other, Box) and other.is_dagger:
return self(other.dagger()).dagger()
if isinstance(other, Box):
Expand Down
4 changes: 2 additions & 2 deletions discopy/closed.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,15 @@ class Curry(monoidal.Bubble, Box):
left : Whether to curry on the left or right.
"""
def __init__(self, arg: Diagram, n=1, left=True):
self.arg, self.n, self.left = arg, n, left
self.n, self.left = n, left
name = f"Curry({arg}, {n}, {left})"
if left:
dom = arg.dom[:len(arg.dom) - n]
cod = arg.cod << arg.dom[len(arg.dom) - n:]
else:
dom, cod = arg.dom[n:], arg.dom[:n] >> arg.cod
monoidal.Bubble.__init__(
self, arg, dom, cod, drawing_name="$\\Lambda$")
self, arg, dom=dom, cod=cod, drawing_name="$\\Lambda$")
Box.__init__(self, name, dom, cod)


Expand Down
5 changes: 5 additions & 0 deletions discopy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
"draw_as_discards": lambda _: False,
"draw_as_measures": lambda _: False,
"draw_as_controlled": lambda _: False,
"frame_opening": lambda _: False,
"frame_closing": lambda _: False,
"frame_slot_boundary": lambda _: False,
"frame_slot_opening": lambda box: box.frame_slot_boundary,
"frame_slot_closing": lambda box: box.frame_slot_boundary,
"shape": lambda box:
"circle" if getattr(box, "draw_as_spider", False) else None,
"color": lambda box:
Expand Down
49 changes: 44 additions & 5 deletions discopy/drawing/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,38 @@ def add_node(node, position):
def add_box(scan, box, off, depth, x_pos):
bubble_opening = getattr(box, "bubble_opening", False)
bubble_closing = getattr(box, "bubble_closing", False)
frame_opening = getattr(box, "frame_opening", False)
frame_closing = getattr(box, "frame_closing", False)
frame_slot_boundary = getattr(box, "frame_slot_boundary", False)
bubble = bubble_opening or bubble_closing
node = Node("box", box=box, depth=depth)
add_node(node, (x_pos, len(diagram) - depth - .5))
for i, obj in enumerate(box.dom.inside):
y_pos = len(diagram) - depth - (.75 if frame_opening else .25)
wire, position = Node("dom", obj=obj, i=i, depth=depth), (
pos[scan[off + i]][0], len(diagram) - depth - .25)
pos[scan[off + i]][0], y_pos)
add_node(wire, position)
graph.add_edge(scan[off + i], wire)
if not bubble or bubble_closing and i in [0, len(box.dom) - 1]:
graph.add_edge(wire, node)
for i, obj in enumerate(box.cod.inside):
y_pos = len(diagram) - depth - (.25 if frame_closing else .75)
align_wires = len(box.dom) == len(box.cod) and not frame_closing
position = (
pos[scan[off + i]][0] if len(box.dom) == len(box.cod)
pos[scan[off + i]][0] if align_wires
else pos[scan[off + i + 1]][0] if bubble_closing
else x_pos - len(box.cod[1:]) / 2 + i,
len(diagram) - depth - .75)
else x_pos - len(box.cod[1:]) / 2 + i, y_pos)
if frame_slot_boundary and i == 0:
position = (pos[scan[off]][0], position[1])
if frame_slot_boundary and i == len(box.cod[1:]):
position = (pos[scan[off + len(box.dom[1:])]][0], position[1])
elif frame_opening and i in (0, len(box.cod[1:])):
position = (position[0] + (.25 if i else -.25), position[1])
wire = Node("cod", obj=obj, i=i, depth=depth)
add_node(wire, position)
if not bubble or bubble_opening and i in [0, len(box.cod) - 1]:
graph.add_edge(node, wire)
if bubble_opening or bubble_closing:
if bubble_opening or bubble_closing: # Make wires go through bubbles.
source_ty, target_ty = (box.dom, box.cod[1:-1]) if bubble_opening\
else (box.dom[1:-1], box.cod)
for i, (source_obj, target_obj) in enumerate(zip(
Expand Down Expand Up @@ -490,6 +501,11 @@ def draw(diagram, **params):
diagram = diagram.to_drawing()

drawing_methods = [
("frame_opening", draw_frame_opening),
("frame_closing", draw_frame_closing),
("frame_slot_boundary", draw_frame_boundary),
("frame_slot_opening", draw_frame_opening),
("frame_slot_closing", draw_frame_closing),
("draw_as_brakets", draw_brakets),
("draw_as_controlled", draw_controlled_gate),
("draw_as_discards", draw_discard),
Expand Down Expand Up @@ -773,6 +789,29 @@ def __bool__(self):
return all(term == self.terms[0] for term in self.terms)


def draw_frame_opening(backend, positions, node, **params):
box, depth = node.box, node.depth
obj_left, obj_right = box.cod.inside[0], box.cod.inside[-1]
left = Node("cod", obj=obj_left, depth=depth, i=0)
right = Node("cod", obj=obj_right, depth=depth, i=len(box.cod[1:]))
backend.draw_wire(positions[left], positions[right])
return backend


def draw_frame_closing(backend, positions, node, **params):
box, depth = node.box, node.depth
obj_left, obj_right = box.dom.inside[0], box.dom.inside[-1]
left = Node("dom", obj=obj_left, depth=depth, i=0)
right = Node("dom", obj=obj_right, depth=depth, i=len(box.dom[1:]))
backend.draw_wire(positions[left], positions[right])
return backend


def draw_frame_boundary(backend, positions, node, **params):
backend = draw_frame_closing(backend, positions, node, **params)
return draw_frame_opening(backend, positions, node, **params)


def draw_discard(backend, positions, node, **params):
""" Draws a :class:`discopy.quantum.circuit.Discard` box. """
box, depth = node.box, node.depth
Expand Down
4 changes: 2 additions & 2 deletions discopy/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class Head(monoidal.Bubble, Box):
def __init__(self, arg: Diagram, time_step=0, _attr="head"):
dom, cod = (
getattr(x, _attr).delay(time_step) for x in [arg.dom, arg.cod])
monoidal.Bubble.__init__(self, arg, dom, cod)
monoidal.Bubble.__init__(self, arg, dom=dom, cod=cod)
Box.__init__(self, f"({arg}).{_attr}", self.dom, self.cod, time_step)

delay, reset, __repr__ = HeadOb.delay, HeadOb.reset, HeadOb.__repr__
Expand Down Expand Up @@ -524,7 +524,7 @@ def __init__(self, arg: Diagram, dom=None, cod=None, mem=None, left=False):
if arg.cod != cod @ mem:
raise AxiomError
self.mem, self.left = mem, left
monoidal.Bubble.__init__(self, arg, dom, cod)
monoidal.Bubble.__init__(self, arg, dom=dom, cod=cod)
Box.__init__(self, self.name, dom, cod)
mem_name = "" if len(mem) == 1 else f"mem={mem}"
self.name = f"({self.arg}).feedback({mem_name})"
Expand Down
71 changes: 59 additions & 12 deletions discopy/monoidal.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,40 +957,87 @@ class Bubble(cat.Bubble, Box):
types :code:`dom` and :code:`cod`.
Parameters:
arg : The diagram inside the bubble.
args : The diagrams inside the bubble.
dom : The domain of the bubble, default is that of :code:`other`.
cod : The codomain of the bubble, default is that of :code:`other`.
drawing_name (str) : The name of the bubble when drawing it.
draw_as_bubble (bool) : Whether to draw as a bubble or as a frame.
Examples
--------
>>> x, y = Ty('x'), Ty('y')
>>> f, g = Box('f', x, y ** 3), Box('g', y, y @ y)
>>> f, g, h = Box('f', x, y ** 3), Box('g', y, y @ y), Box('h', x, y)
>>> d = (f.bubble(dom=x @ x, cod=y) >> g).bubble()
>>> d.draw(path='docs/_static/monoidal/bubble-example.png')
.. image:: /_static/monoidal/bubble-example.png
:align: center
>>> b = Bubble(f, g, h, dom=x, cod=y @ y)
>>> b.draw(path='docs/_static/monoidal/bubble-multiple-args.png')
.. image:: /_static/monoidal/bubble-multiple-args.png
:align: center
"""
__ambiguous_inheritance__ = (cat.Bubble, )

def __init__(self, arg: Diagram, dom: Ty = None, cod: Ty = None, **params):
self.drawing_name = params.get("drawing_name", "")
cat.Bubble.__init__(self, arg, dom, cod)
Box.__init__(self, self.name, self.dom, self.cod, data=self.data)

def to_drawing(self):
def __init__(
self, *args: Diagram, dom: Ty = None, cod: Ty = None, **kwargs):
cat.Bubble.__init__(self, *args, dom=dom, cod=cod)
self.drawing_name = kwargs.pop("drawing_name", "")
self.draw_as_bubble = kwargs.pop(
"draw_as_bubble", (len(args) == 1
and len(self.arg.dom) == len(self.dom)
and len(self.arg.cod) == len(self.cod)))
Box.__init__(self, self.name, self.dom, self.cod, **kwargs)

def to_bubble_drawing(self):
dom, cod = self.dom.to_drawing(), self.cod.to_drawing()
argdom, argcod = self.arg.dom.to_drawing(), self.arg.cod.to_drawing()
left, right = Ty(self.drawing_name), Ty("")
left.inside[0].always_draw_label = True
_open = Box("_open", dom, left @ argdom @ right).to_drawing()
_close = Box("_close", left @ argcod @ right, cod).to_drawing()
_open.draw_as_wires = _close.draw_as_wires = True
# Wires can go straight only if types have the same length.
_open.bubble_opening = len(dom) == len(argdom)
_close.bubble_closing = len(cod) == len(argcod)
if len(dom) == len(argdom) and len(cod) == len(argcod):
_open.bubble_opening = _close.bubble_closing = True
_open.draw_as_wires = _close.draw_as_wires = True
else:
_open.frame_slot_opening = _close.frame_slot_closing = True
return _open >> left @ self.arg.to_drawing() @ right >> _close

def to_frame_drawing(self):
dom, cod = self.dom.to_drawing(), self.cod.to_drawing()
if self.args == 1:
inside = self.arg.to_drawing().bubble(
draw_as_bubble=True, dom=Ty(), cod=Ty()).to_drawing()
else:
left = right = Ty('')
first_arg = self.args[0].to_drawing()
last_arg = self.args[-1].to_drawing()
open_first_slot = Box(
"open", Ty(), left @ first_arg.dom @ right).to_drawing()
open_first_slot.frame_slot_opening = True
inside = open_first_slot >> left @ first_arg @ right
for f, g in zip(self.args, self.args[1:]):
b_dom, b_cod = [
left @ x.to_drawing() @ right for x in [f.cod, g.dom]]
b = Box("boundary", b_dom, b_cod)
b.frame_slot_boundary = True
inside >>= b.to_drawing() >> left @ g.to_drawing() @ right
close_last_slot = Box(
"close", left @ last_arg.cod @ right, Ty()).to_drawing()
close_last_slot.frame_slot_closing = True
inside >>= close_last_slot
left, right = Ty(self.drawing_name), Ty("")
_open = Box("_open", dom, left @ right).to_drawing()
_close = Box("_close", left @ right, cod).to_drawing()
_open.frame_opening = _close.frame_closing = True
return _open >> left @ inside @ right >> _close

def to_drawing(self):
return self.to_bubble_drawing(
) if self.draw_as_bubble else self.to_frame_drawing()


class Category(cat.Category):
"""
Expand Down
2 changes: 1 addition & 1 deletion discopy/traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, arg: Diagram, left=False):
name = f"Trace({arg}" + ", left=True)" if left else ")"
dom, cod = (arg.dom[1:], arg.cod[1:]) if left\
else (arg.dom[:-1], arg.cod[:-1])
monoidal.Bubble.__init__(self, arg, dom, cod)
monoidal.Bubble.__init__(self, arg, dom=dom, cod=cod)
Box.__init__(self, name, dom, cod)

def __repr__(self):
Expand Down
Binary file modified docs/_static/monoidal/bubble-example.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/monoidal/bubble-multiple-args.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 10 additions & 14 deletions test/syntax/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,19 @@ def walk(x):
assert F(walk).unroll(9).now()[:10] == (0, -1, -2, -1, 0, 1, 0, 1, 2, 1)
assert F(walk).unroll(9).now()[:10] == (0, -1, 0, 1, 0, 1, 0, -1, 0, -1)

def test_fibonacci():
X = Ty('X')
fby, wait = FollowedBy(X), Swap(X, X.d).feedback()
zero, one = Box('0', Ty(), X), Box('1', Ty(), X)
copy, plus = Copy(X), Box('+', X @ X, X)

X = Ty('X')
fby, wait = FollowedBy(X), Swap(X, X.d).feedback()
zero, one = Box('0', Ty(), X), Box('1', Ty(), X)
copy, plus = Copy(X), Box('+', X @ X, X)


@Diagram.feedback
@Diagram.from_callable(X.d, X @ X)
def fib(x):
y = fby(zero.head(), plus.d(fby.d(one.head.d(), wait.d(x)), x))
return (y, y)

@Diagram.feedback
@Diagram.from_callable(X.d, X @ X)
def fib(x):
y = fby(zero.head(), plus.d(fby.d(one.head.d(), wait.d(x)), x))
return (y, y)

def test_fibonacci_eq():
with Diagram.hypergraph_equality:
assert fib == (
copy.d >> one.head.d @ wait.d @ X.d
Expand All @@ -82,8 +80,6 @@ def test_fibonacci_eq():
>> zero.head @ X.d
>> fby >> copy).feedback()


def test_fibonacci_functor():
F = Functor(
ob={X: (int, )},
ar={zero: lambda: 0,
Expand Down

0 comments on commit 006c396

Please sign in to comment.