Skip to content

Commit

Permalink
refactor CQMap and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
toumix committed Nov 3, 2020
1 parent 40a77e6 commit f26d868
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 98 deletions.
8 changes: 4 additions & 4 deletions discopy/biclosed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class Ty(monoidal.Ty):
((y << x) >> y) @ x
"""
@staticmethod
def upgrade(ty):
if len(ty) == 1 and isinstance(ty[0], (Over, Under)):
return ty[0]
return Ty(*ty.objects)
def upgrade(typ):
if len(typ) == 1 and isinstance(typ[0], (Over, Under)):
return typ[0]
return Ty(*typ.objects)

def __init__(self, *objects, left=None, right=None):
self.left, self.right = left, right
Expand Down
17 changes: 8 additions & 9 deletions discopy/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,18 @@

from discopy.cat import AxiomError
from discopy import messages, monoidal, rigid
from discopy.cat import Quiver
from discopy.monoidal import Sum
from discopy.rigid import PRO


def tuplify(xs):
def tuplify(stuff):
""" Returns :code:`xs` if it is already a tuple else :code:`(xs, )`. """
return xs if isinstance(xs, tuple) else (xs, )
return stuff if isinstance(stuff, tuple) else (stuff, )


def untuplify(*xs):
def untuplify(*stuff):
""" Returns either the tuple :code:`xs` or its only element. """
return xs[0] if len(xs) == 1 else xs
return stuff[0] if len(stuff) == 1 else stuff


class Function(rigid.Box):
Expand Down Expand Up @@ -198,10 +197,10 @@ def __call__(self, *values):
>>> assert SWAP(1, 2) == (2, 1)
>>> assert (COPY @ COPY >> Id(1) @ SWAP @ Id(1))(1, 2) == (1, 2, 1, 2)
"""
ob = Quiver(lambda t: PRO(len(t)))
ar = Quiver(lambda f:
Function(len(f.dom), len(f.cod), f.function))
return PythonFunctor(ob, ar)(self)(*values)
return PythonFunctor(
ob=lambda t: PRO(len(t)),
ar=lambda f: Function(len(f.dom), len(f.cod), f.function))(
self)(*values)


class Id(Diagram):
Expand Down
6 changes: 3 additions & 3 deletions discopy/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,13 @@ def __str__(self):
return " + ".join("({})".format(arrow) for arrow in self.terms)

def __add__(self, other):
if other == 0:
return self
other = other if isinstance(other, Sum) else Sum(other)
return self.upgrade(Sum(*(self.terms + other.terms)))

def __radd__(self, other):
if isinstance(other, Arrow):
return self + Sum(other)
return self if 0 == other else other + self
return self.__add__(other)

def __iter__(self):
for arrow in self.terms:
Expand Down
6 changes: 3 additions & 3 deletions discopy/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def generate(self, start, max_sentences, max_depth, max_iter=100,
if seed is not None:
random.seed(seed)
prods, cache = list(self.productions), set()
n, i = 0, 0
while (not max_sentences or n < max_sentences) and i < max_iter:
n_sentences, i = 1, 0
while n_sentences <= (max_sentences or n_sentences) and i < max_iter:
i += 1
sentence = Id(start)
depth = 0
Expand All @@ -109,7 +109,7 @@ def generate(self, start, max_sentences, max_depth, max_iter=100,
yield sentence
if remove_duplicates:
cache.add(sentence)
n += 1
n_sentences += 1
break
tag = sentence.dom[0]
random.shuffle(prods)
Expand Down
22 changes: 11 additions & 11 deletions discopy/monoidal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def tensor(self, *others):
objects = self.objects + [x for t in others for x in t.objects]
return self.upgrade(Ty(*objects))

def count(self, ob):
def count(self, obj):
"""
Counts the occurrence of a given object.
Parameters
----------
ob : :class:`Ty` or :class:`Ob`
obj : :class:`Ty` or :class:`Ob`
either a type of length 1 or an object
Returns
Expand All @@ -142,13 +142,13 @@ def count(self, ob):
>>> xs = x ** 5
>>> assert xs.count(x) == xs.count(x[0]) == xs.objects.count(Ob('x'))
"""
ob, = ob if isinstance(ob, Ty) else (ob, )
return self.objects.count(ob)
obj, = obj if isinstance(obj, Ty) else (obj, )
return self.objects.count(obj)

@staticmethod
def upgrade(ty):
def upgrade(typ):
""" Allows class inheritance for tensor and __getitem__ """
return ty
return typ

def __init__(self, *objects):
self._objects = tuple(
Expand Down Expand Up @@ -203,11 +203,11 @@ class PRO(Ty):
>>> assert PRO(1) == PRO(Ob(1))
"""
@staticmethod
def upgrade(ty):
for x in ty:
if x.name != 1:
raise TypeError(messages.type_err(int, x.name))
return PRO(len(ty))
def upgrade(typ):
for obj in typ:
if obj.name != 1:
raise TypeError(messages.type_err(int, obj.name))
return PRO(len(typ))

def __init__(self, n=0):
if isinstance(n, PRO):
Expand Down
92 changes: 46 additions & 46 deletions discopy/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(self, dim=Dim(1)):
super().__init__(Dim(1), dim)


class CQMap(rigid.Box):
class CQMap(Tensor):
"""
Implements classical-quantum maps.
Expand All @@ -139,31 +139,30 @@ class CQMap(rigid.Box):
cod : :class:`CQ`
Codomain.
array : list, optional
Array of size :code:`product(data.dom @ data.cod)`.
data : :class:`discopy.tensor.Tensor`, optional
with domain :code:`dom.classical @ dom.quantum ** 2` and codomain
Array of size :code:`product(utensor.dom @ utensor.cod)`.
utensor : :class:`discopy.tensor.Tensor`, optional
Underlying tensor with domain
:code:`dom.classical @ dom.quantum ** 2` and codomain
:code:`cod.classical @ cod.quantum ** 2``.
"""
def __init__(self, dom, cod, array=None, data=None):
if array is None and data is None:
raise ValueError("One of array or data must be given.")
if data is None:
data = Tensor(dom.classical @ dom.quantum @ dom.quantum,
cod.classical @ cod.quantum @ cod.quantum, array)
self.array = data.array
super().__init__("CQMap", dom, cod, data=data)

def __eq__(self, other):
return isinstance(other, CQMap)\
and (self.dom, self.cod) == (other.dom, other.cod)\
and self.data == other.data
@property
def utensor(self):
""" Underlying tensor. """
return Tensor(self._udom, self._ucod, self.array)

def __init__(self, dom, cod, array=None, utensor=None):
if array is None and utensor is None:
raise ValueError("One of array or utensor must be given.")
if utensor is None:
udom = dom.classical @ dom.quantum @ dom.quantum
ucod = cod.classical @ cod.quantum @ cod.quantum
else:
udom, ucod = utensor.dom, utensor.cod
super().__init__(udom, ucod, utensor.array if array is None else array)
self._dom, self._cod, self._udom, self._ucod = dom, cod, udom, ucod

def __repr__(self):
return "CQMap(dom={}, cod={}, array={})".format(
self.dom, self.cod, np.array2string(self.array.flatten()))

def __str__(self):
return repr(self)
return super().__repr__().replace("Tensor", "CQMap")

def __add__(self, other):
if other == 0:
Expand All @@ -177,49 +176,50 @@ def __radd__(self, other):

@staticmethod
def id(dom):
data = Tensor.id(dom.classical @ dom.quantum @ dom.quantum)
return CQMap(dom, dom, data.array)
utensor = Tensor.id(dom.classical @ dom.quantum @ dom.quantum)
return CQMap(dom, dom, utensor=utensor)

def then(self, *others):
if len(others) != 1 or any(isinstance(other, Sum) for other in others):
if len(others) != 1:
return monoidal.Diagram.then(self, *others)
data = self.data >> others[0].data
return CQMap(self.dom, others[0].cod, data.array)
other, = others
return CQMap(
self.dom, other.cod, utensor=self.utensor >> other.utensor)

def dagger(self):
return CQMap(self.cod, self.dom, self.data.dagger().array)
return CQMap(self.cod, self.dom, utensor=self.utensor.dagger())

def tensor(self, *others):
if len(others) != 1 or any(isinstance(other, Sum) for other in others):
if len(others) != 1:
return monoidal.Diagram.tensor(self, *others)
other, = others
f = rigid.Box('f', Ty('c00', 'q00', 'q00'), Ty('c10', 'q10', 'q10'))
g = rigid.Box('g', Ty('c01', 'q01', 'q01'), Ty('c11', 'q11', 'q11'))
ob = {Ty("{}{}{}".format(a, b, c)):
z.__getattribute__(y).__getattribute__(x)
for a, x in zip(['c', 'q'], ['classical', 'quantum'])
for b, y in zip([0, 1], ['dom', 'cod'])
for c, z in zip([0, 1], [self, others[0]])}
ar = {f: self.array, g: others[0].array}
permute_above = Diagram.id(f.dom[:1] @ g.dom[:1] @ f.dom[1:2])\
above = Diagram.id(f.dom[:1] @ g.dom[:1] @ f.dom[1:2])\
@ Diagram.swap(g.dom[1:2], f.dom[2:]) @ Diagram.id(g.dom[2:])\
>> Diagram.id(f.dom[:1]) @ Diagram.swap(g.dom[:1], f.dom[1:])\
@ Diagram.id(g.dom[1:])
permute_below =\
below =\
Diagram.id(f.cod[:1]) @ Diagram.swap(f.cod[1:], g.cod[:1])\
@ Diagram.id(g.cod[1:])\
>> Diagram.id(f.cod[:1] @ g.cod[:1] @ f.cod[1:2])\
@ Diagram.swap(f.cod[2:], g.cod[1:2]) @ Diagram.id(g.cod[2:])
F = TensorFunctor(ob, ar)
array = F(permute_above >> f @ g >> permute_below).array
dom, cod = self.dom @ others[0].dom, self.cod @ others[0].cod
return CQMap(dom, cod, array)
diagram2tensor = TensorFunctor(
ob={Ty("{}{}{}".format(a, b, c)):
z.__getattribute__(y).__getattribute__(x)
for a, x in zip(['c', 'q'], ['classical', 'quantum'])
for b, y in zip([0, 1], ['dom', 'cod'])
for c, z in zip([0, 1], [self, other])},
ar={f: self.utensor.array, g: other.utensor.array})
return CQMap(self.dom @ other.dom, self.cod @ other.cod,
utensor=diagram2tensor(above >> f @ g >> below))

@staticmethod
def swap(left, right):
data = Tensor.swap(left.classical, right.classical)\
utensor = Tensor.swap(left.classical, right.classical)\
@ Tensor.swap(left.quantum, right.quantum)\
@ Tensor.swap(left.quantum, right.quantum)
return CQMap(left @ right, right @ left, data.array)
return CQMap(left @ right, right @ left, utensor=utensor)

@staticmethod
def measure(dim, destructive=True):
Expand Down Expand Up @@ -279,7 +279,7 @@ def caps(left, right):

def round(self, decimals=0):
""" Rounds the entries of a CQMap up to a number of decimals. """
return CQMap(self.dom, self.cod, data=self.data.round(decimals))
return CQMap(self.dom, self.cod, utensor=self.utensor.round(decimals))


class CQMapFunctor(rigid.Functor):
Expand Down Expand Up @@ -897,12 +897,12 @@ class Bits(ClassicalGate):
... == Tensor(dom=Dim(1), cod=Dim(2, 2), array=[0, 0, 1, 0])
"""
def __init__(self, *bitstring, _dagger=False):
data = Tensor.id(Dim(1)).tensor(*(
utensor = Tensor.id(Dim(1)).tensor(*(
Tensor(Dim(1), Dim(2), [0, 1] if bit else [1, 0])
for bit in bitstring))
name = "Bits({})".format(', '.join(map(str, bitstring)))
dom, cod = (len(bitstring), 0) if _dagger else (0, len(bitstring))
super().__init__(name, dom, cod, array=data.array, _dagger=_dagger)
super().__init__(name, dom, cod, array=utensor.array, _dagger=_dagger)
self.bitstring = bitstring

def __repr__(self):
Expand Down
26 changes: 13 additions & 13 deletions discopy/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class Ty(monoidal.Ty, Ob):
>>> assert (s @ n).l == n.l @ s.l and (s @ n).r == n.r @ s.r
"""
@staticmethod
def upgrade(ty):
return Ty(*ty.objects)
def upgrade(typ):
return Ty(*typ.objects)

@property
def l(self):
Expand Down Expand Up @@ -112,8 +112,8 @@ class PRO(monoidal.PRO, Ty):
Objects of the free rigid monoidal category generated by 1.
"""
@staticmethod
def upgrade(ty):
return PRO(len(monoidal.PRO.upgrade(ty)))
def upgrade(typ):
return PRO(len(monoidal.PRO.upgrade(typ)))

@property
def l(self):
Expand Down Expand Up @@ -483,13 +483,13 @@ def __init__(self, ob, ar, ob_factory=Ty, ar_factory=Diagram):

def __call__(self, diagram):
if isinstance(diagram, monoidal.Ty):
def adjoint(ob):
result = self.ob[type(diagram)(type(ob)(ob.name, z=0))]
if ob.z < 0:
for _ in range(-ob.z):
def adjoint(obj):
result = self.ob[type(diagram)(type(obj)(obj.name, z=0))]
if obj.z < 0:
for _ in range(-obj.z):
result = result.l
elif ob.z > 0:
for _ in range(ob.z):
elif obj.z > 0:
for _ in range(obj.z):
result = result.r
return result
return self.ob_factory().tensor(*map(adjoint, diagram.objects))
Expand All @@ -506,9 +506,9 @@ def adjoint(ob):

def cups(left, right, ar_factory=Diagram, cup_factory=Cup, reverse=False):
""" Constructs a diagram of nested cups. """
for ty in left, right:
if not isinstance(ty, Ty):
raise TypeError(messages.type_err(Ty, ty))
for typ in left, right:
if not isinstance(typ, Ty):
raise TypeError(messages.type_err(Ty, typ))
if left.r != right and right.r != left:
raise AxiomError(messages.are_not_adjoints(left, right))
result = ar_factory.id(left @ right)
Expand Down
16 changes: 8 additions & 8 deletions discopy/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

def array2string(array, max_length=messages.NUMPY_THRESHOLD):
""" array2string is not implemented in jax.numpy """
ls = list(array)
if len(ls) > max_length:
ls = ls[:max_length // 2] + ["..."] + ls[1 - max_length // 2:]
return "[{}]".format(", ".join(map(str, ls)))
flat = list(array)
flat = flat if len(flat) <= max_length else\
flat[:max_length // 2] + ["..."] + flat[1 - max_length // 2:]
return "[{}]".format(", ".join(map(str, flat)))
np.array2string = array2string
except ImportError: # pragma: no cover
import numpy as np
Expand Down Expand Up @@ -139,7 +139,7 @@ def __eq__(self, other):
def then(self, *others):
if len(others) != 1 or any(isinstance(other, Sum) for other in others):
return monoidal.Diagram.then(self, *others)
other = others[0]
other, = others
if not isinstance(other, Tensor):
raise TypeError(messages.type_err(Tensor, other))
if self.cod != other.dom:
Expand Down Expand Up @@ -262,10 +262,10 @@ def __call__(self, diagram):
dom, cod = self(diagram.dom), self(diagram.cod)
return sum(map(self, diagram), Tensor.zeros(dom, cod))
if isinstance(diagram, monoidal.Ty):
def ob(x):
result = self.ob[type(diagram)(x.name)]
def obj_to_dim(obj):
result = self.ob[type(diagram)(obj.name)]
return result if isinstance(result, Dim) else Dim(result)
return Dim(1).tensor(*map(ob, diagram.objects))
return Dim(1).tensor(*map(obj_to_dim, diagram.objects))
if isinstance(diagram, Cup):
return Tensor.cups(self(diagram.dom[:1]), self(diagram.dom[1:]))
if isinstance(diagram, Cap):
Expand Down
Loading

0 comments on commit f26d868

Please sign in to comment.