In [277]:
from egglog import *
from egglog.egraph import ActionLike, FactLike
from egglog.bindings import Eq
from __future__ import annotations
from collections import defaultdict
from enum import Enum
from dataclasses import dataclass
from typing import Iterator

In [278]:
debug = True
proof_tree = True
dim_errors = True
dim_proof_tree = False

In [279]:
assert (not dim_proof_tree) or proof_tree, "Can't use dim_proofs without proofs"

In [280]:
egraph = EGraph(save_egglog_string=debug)

In [281]:
class Dim(Expr):
    @classmethod
    def lit(cls, value: i64Like) -> Dim: ...

    @classmethod
    def var(cls, name: StringLike) -> Dim:  ...

    def __add__(self, other: Dim) -> Dim: ...

    def __mul__(self, other: Dim) -> Dim: ... 
  
if dim_proof_tree:
  class DimProof(Expr):
    @classmethod
    def refl(cls, dim : Dim) -> DimProof:
      ...
    
    def trans(self, dim : Dim, name : StringLike) -> DimProof:
      ...
      
    @classmethod
    def join(cls, dim : Dim, prfl : DimProof, prfr : DimProof) -> DimProof:
      ...

In [282]:

egraph_consts : dict[int, any] = dict()

def register_const(const : int):
  if const in egraph_consts:
    return egraph_consts[const]
  name = f"___________int_const__{const}"
  egraph.let(name, Dim.lit(const))
  val = var(name, Dim)
  egraph_consts[const] = val
  return val

In [283]:
class ZX(Expr):
    @classmethod
    def symbol(cls, symbol : StringLike, n : Dim, m : Dim) -> ZX:
      ...
    
    def cast(self, n : Dim, m : Dim) -> ZX:
      ...
    
    def compose(self, other : ZX) -> ZX:
      ...
      
    def stack(self, other : ZX) -> ZX:
      ...
      
    @classmethod
    def Z(cls, n : Dim, m : Dim, alpha : Dim) -> ZX:
      ...
      
    @classmethod
    def X(cls, n : Dim, m : Dim, alpha : Dim) -> ZX:
      ...
     
    # TODO
    # def nStack(self, n : Dim) -> ZX:
    #   ...
      
    def nStack1(self, n : Dim) -> ZX:
      ...
    
    @classmethod
    def nWire(cls, n: Dim) -> ZX:
      ...
      
    @property
    def n(self) -> Dim:
      ...
      
    @property
    def m(self) -> Dim:
      ...
      
rw_ruleset = Ruleset(name="rw_rules")

if proof_tree:   
  class ZXProof(Expr):
    @classmethod
    def refl(cls, zx : ZX) -> ZXProof:
      ...
    
    def trans(self, zx : ZX, name : StringLike) -> ZXProof:
      ...
      
    @classmethod
    def join(cls, zx : ZX, prfl : ZXProof, prfr : ZXProof) -> ZXProof:
      ...
    
  proof_ruleset = Ruleset(name="proof_rules")
    # @classmethod
    # def dim_proof(cls, prfn : DimProof, prfm : DimProof) -> ZXProof:
    #   ...

In [284]:
if proof_tree:
  
  @function(merge=lambda old, new : new)
  def zx_edge(zx1 : ZX, zx2: ZX) -> StringLike:
    ...
    
  @function(unextractable=True)
  def zx_path(zx1 : ZX, zx2: ZX) -> ZXProof:
    ... 
    
  goals = relation('goals', ZX, ZX)

    
if dim_proof_tree:
  @function(unextractable=True)
  def dim_path(dim1 : Dim, dim2 : Dim) -> DimProof:
    ...
    
  @function(merge=lambda old, new : new)
  def dim_edge(dim1 : Dim, dim2 : Dim) -> StringLike:
    ...

  

In [285]:
def matchvar(Type,i = 0):
  return var(f'tobematched_______{i}', Type)


In [286]:
def build_dim_rule(match : any, to : any, name : str, *conditions : FactLike):
  if dim_proof_tree:
    conds = (eq(matchvar(Dim)).to(match),) + conditions
    return rule(*conds).then(set_(dim_edge(match, to)).to(name))
  else:
    condto = (to,) + conditions
    return rewrite(match).to(*condto)

In [287]:
@rw_ruleset.register
def dim_rules():
  a = var('a', Dim)
  i1 = var('i1', i64)
  i2 = var('i2', i64)
  yield build_dim_rule(Dim.lit(0) + a, a, 'simpl#add_0_l', 
                       ne(a).to(Dim.lit(i1)) # Check it's not a constant as constant fold handles otherwisde
                       )
  yield build_dim_rule(Dim.lit(i1) + Dim.lit(i2), Dim.lit(i1 + i2), 'simpl#c-fold')

In [288]:
@rw_ruleset.register
def zx_fn_rules():
  a, b, x = vars_("a b x", ZX)
  s = var("s", String)
  n, m, alpha = vars_("n m alpha", Dim)
  one = register_const(1)
  constructors = [
    (a.cast(n, m), n, m, []),
    (a.compose(b), a.n, b.m, [eq(a.m).to(b.n)]),
    (a.stack(b), a.n + b.n, a.m + b.m, []),
    (ZX.Z(n, m, alpha), n, m, []),
    (ZX.X(n, m, alpha), n, m, []),
    (ZX.symbol(n=n, m=m, symbol=s), n, m, []),
    (a.nStack1(n), n, n, [eq(a.n).to(one), eq(a.m).to(one)]),
    (ZX.nWire(n), n, n, []),
  ]
  for constr in constructors:
    # force n and m to be calculated
    yield rule(eq(x).to(constr[0])).then(x.n, x.m)
    conds = tuple(constr[3])
    # define n
    yield rewrite(constr[0].n).to(constr[1], *conds)
    # define m
    yield rewrite(constr[0].m).to(constr[2], *conds)

In [289]:
if dim_errors:
  if debug:
    dim_err = relation('dim_err', ZX, ZX, Dim, Dim)
    @proof_ruleset.register
    def dim_err_rule():
      a, b, x = vars_("a b x", ZX)
      yield rule(eq(x).to(a.compose(b)), ne(a.m).to(b.n)).then(dim_err(a, b, a.m, b.n))
  else:
    @proof_ruleset.register
    def dim_err_panic():
      yield rule(eq(x).to(a.compose(b)), ne(a.m).to(b.n)).then(panic('Found illegal composition '))


In [290]:

def proof_rules(edge, path, Type, ProofType):
  x, y, z = vars_("x y z", Type)
  s = var("s", String)
  p = var("p", ProofType)
  return [
    rule(eq(s).to(edge(x,y))) \
      .then(set_(path(x, y)).to(ProofType.refl(y).trans(x, s))),
    rule(eq(s).to(edge(x, y)), \
          eq(p).to(path(y, z)), \
          ne(x).to(z) # This might break (ne is dangerous)
          ) \
      .then(set_(path(x, z)).to(p.trans(x, s))),
    rule(eq(s).to(edge(y, z)), \
              eq(p).to(path(x, y))) \
            .then(set_(path(x, z)).to(p.trans(x, s)))

      ]

def rule_for_bin_op(binop, path, edge, Type, ProofType, EdgeType=String):
  a, x1, x2, y1, y2, z1, z2 = vars_("a x1 x2 y1 y2 z1 z2", Type)
  p1, p2 = vars_("p1 p2", ProofType)
  e1, e2 = vars_("e1 e2", EdgeType)
  base_eq = eq(a).to(binop(x1, x2))
  path1 = eq(p1).to(path(x1, y1))
  path2 = eq(p2).to(path(x2, y2))
  # Work because we create all edges before creating paths
  # # Also no edge -> no (non refl) path
  # no_edge1 = ne(e1).to(edge(x1,z1))
  # no_edge2 = ne(e2).to(edge(x2,z2))
  return [
    rule(base_eq, path1, path2) \
      .then(set_(path(a, binop(y1, y2))).to(ProofType.join(a, p1, p2))),
    rule(base_eq, path1, ) \
      .then(set_(path(a, binop(y1, x2))).to(ProofType.join(a, p1, ProofType.refl(x2)))),
    rule(base_eq, path2, ) \
      .then(set_(path(a, binop(x1, y2))).to(ProofType.join(a, ProofType.refl(x1), p2))),
  ]

if dim_proof_tree:      
  @proof_ruleset.register
  def dim_proof_rules():
    for rule in proof_rules(dim_edge, dim_path, Dim, DimProof):
      yield rule
  @proof_ruleset.register
  def dim_binop_rules():
    for rule in rule_for_bin_op(lambda x, y  : x + y, dim_path, dim_edge, Dim, DimProof):
      yield rule
    
if proof_tree:      
  @proof_ruleset.register
  def zx_binop_rules():
    binops = [lambda x, y  : x.stack(y), lambda x, y  : x.compose(y)]
    for op in binops:
      for rule in rule_for_bin_op(op, zx_path, zx_edge, ZX, ZXProof):
        yield rule
  @proof_ruleset.register
  def zx_proof_rules():
    for rule in proof_rules(zx_edge, zx_path, ZX, ZXProof):
      yield rule
  # @egraph.register
  # def zx_cast_rules():
  #   zx = var("zx", ZX)
  #   n1, n2, m1, m2 = vars_("n1 n2 m1 m2", Dim)
  #   p1, p2 = vars_("p1 p2", DimProof)
  #   path1 = eq(p1).to(dim_path(n1, m1))
  #   path2 = eq(p2).to(dim_path(n2, m2))
  #   base_eq = eq(matchvar(ZX)).to(zx.cast(n1, n2))    
  #   yield rule(base_eq, path1, path2) \
  #           .then(set_(zx_path(zx, zx.cast(m1, m2))).to(ZXProof.dim_proof(p1, p2)))
  #   yield rule(base_eq, path1) \
  #            .then(set_(zx_path(zx, zx.cast(m1, n2))).to(ZXProof.dim_proof(p1, DimProof.refl(n2))))
  #   yield rule(base_eq, path2) \
  #             .then(set_(zx_path(zx, zx.cast(n1, m2))).to(ZXProof.dim_proof(DimProof.refl(n1), p2)))
    

    

In [291]:
@dataclass(frozen=True)
class EvalDim:
  def __add__(self, other : EvalDim):
    return EvalAdd(self, other)
  
@dataclass(frozen=True)
class EvalSymbol(EvalDim):
  s : str

@dataclass(frozen=True)
class EvalInt(EvalDim):
  i : int
    
@dataclass(frozen=True)
class EvalAdd(EvalDim):
  a : EvalDim
  b : EvalDim

class ZXParam:
  name : str
  n : EvalDim 
  m : EvalDim
  
  def __init__(self, name : str, n : Dim, m : Dim) -> None:
    self.name = name
    self.n = n
    self.m = m
    
  def __eq__(self, other : ZXParam) -> bool:
    return self.name == other.name
  
  def __hash__(self) -> int:
    return self.name.__hash__()
  
  def get_var(self) -> ZX:
    return var(self.name, ZX)
  
class ParamConstrType(Enum):
  n = 0
  m = 1

def flatten_constr_type_to_property(zx : ZX, typ : ParamConstrType):
  return zx.n if typ == ParamConstrType.n else zx.m

In [292]:
def gen_constraints(params : list[ZXParam], *constraints : FactLike) -> Iterator[Fact]:
  dim_to_param : dict[EvalDim, set[tuple[ZXParam, ParamConstrType]]] = defaultdict(set)
  for param in params:
    dim_to_param[param.n].add((param, ParamConstrType.n))
    dim_to_param[param.m].add((param, ParamConstrType.m))
  for dimparams in dim_to_param.values():
    listparams = list(dimparams) # Need ordering
    if len(listparams) > 1:
      base_param = listparams[0]
      for dimparam in listparams[1:]:
        yield eq(flatten_constr_type_to_property(base_param[0].get_var(), base_param[1])) \
              .to(flatten_constr_type_to_property(dimparam[0].get_var(), dimparam[1]))
              
  for aux_constr in constraints:
    yield aux_constr
        
        

In [293]:
def gen_rule(fromExpr : ZX, toExpr : ZX, name : str, *constraints : FactLike):
  x = matchvar(ZX)
  if proof_rules:
    return rule(eq(x).to(fromExpr), *constraints) \
            .then(set_(zx_edge(fromExpr, toExpr)).to(name))
  else:
    return rewrite(fromExpr).to(toExpr, *constraints)

In [294]:
def gen_zx_rule(params : list[ZXParam], fromExpr : ZX, toExpr : ZX, name : str, *constraints : FactLike):
  constraints = tuple(gen_constraints(params, *constraints))
  return gen_rule(fromExpr, toExpr, name, *constraints)

In [295]:
@rw_ruleset.register
def cast_rules():
  param = ZXParam('zx', EvalSymbol('n'), EvalSymbol('m'))
  n, m = vars_("n m", Dim)
  o, p = vars_("o p", Dim)
  if dim_proof_tree:
    p1, p2 = vars_("p1 p2", DimProof)
  facts_n = [eq(param.get_var().n).to(n)]
  facts_m = [eq(param.get_var().m).to(m)]
  if dim_proof_tree:
    facts_n += [eq(p1).to(dim_path(param.get_var().n, n)), eq(p1).to(dim_path(n, param.get_var().n))]
    facts_m += [eq(p2).to(dim_path(param.get_var().m, m)), eq(p2).to(dim_path(m, param.get_var().m))]
  for (i, fact_n) in enumerate(facts_n):
    for (j, fact_m) in enumerate(facts_m):
      yield gen_rule(param.get_var().cast(n, m), param.get_var(), f'cast_id#{i}{j}', fact_n, fact_m)
  
  yield gen_rule(param.get_var().cast(n, m).cast(o, p), param.get_var().cast(o, p), 'cast_contract')

In [296]:
@rw_ruleset.register
def nwire_rules():
  param = ZXParam('zx', EvalSymbol('n'), EvalSymbol('m'))
  n, m = vars_("n m", Dim)
  yield gen_rule(param.get_var().compose(ZX.nWire(m)), param.get_var(), 'nwire_removal_r', eq(m).to(param.get_var().m))
  yield gen_rule(ZX.nWire(n).compose(param.get_var()), param.get_var(), 'nwire_removal_r', eq(n).to(param.get_var().n))
  

In [297]:
@rw_ruleset.register
def gen_compose_assoc():
  params : list[ZXParam] = [ 
            ZXParam('zx0', EvalSymbol('n'), EvalSymbol('m')),
            ZXParam('zx1', EvalSymbol('m'), EvalSymbol('o')),
            ZXParam('zx2', EvalSymbol('o'), EvalSymbol('p'))
          ]
  yield gen_zx_rule(params, 
           params[0].get_var().compose(params[1].get_var().compose(params[2].get_var())), 
           params[0].get_var().compose(params[1].get_var()).compose(params[2].get_var()),
           'ComposeRules.compose_assoc')
  
@rw_ruleset.register
def gen_stack_assoc():
  n, m = vars_("n m", Dim)
  params : list[ZXParam] = [ 
            ZXParam('zx0', EvalSymbol('n0'), EvalSymbol('m0')),
            ZXParam('zx1', EvalSymbol('n1'), EvalSymbol('m1')),
            ZXParam('zx2', EvalSymbol('n2'), EvalSymbol('m2'))
          ]
  yield gen_zx_rule(params, 
           params[0].get_var().stack(params[1].get_var().stack(params[2].get_var())), 
           (params[0].get_var().stack(params[1].get_var()).stack(params[2].get_var()))
           .cast(
             params[0].get_var().n + (params[1].get_var().n + params[2].get_var().n),
             params[0].get_var().m + (params[1].get_var().m + params[2].get_var().m)
            ),
           'StackRules.stack_assoc', 
          #  ne(matchvar(ZX, 1)).to(
          #    matchvar(ZX)
          #                  .cast(
          #    #(params[0].get_var().n + params[1].get_var().n) + params[2].get_var().n,
          #    #(params[0].get_var().m + params[1].get_var().m) + params[2].get_var().m)
          #    n, m
          #   ))
          )
  
  yield gen_zx_rule(params, 
           params[0].get_var().stack(params[1].get_var()).stack(params[2].get_var()),
            (params[0].get_var().stack(params[1].get_var().stack(params[2].get_var())))
           .cast(
             (params[0].get_var().n + params[1].get_var().n) + params[2].get_var().n,
             (params[0].get_var().m + params[1].get_var().m) + params[2].get_var().m,
            ),
           'StackRules.stack_assoc_back',
            # ne(matchvar(ZX, 1)).to(
            #  matchvar(ZX).cast(
            #   #params[0].get_var().n + (params[1].get_var().n + params[2].get_var().n),
            #  #params[0].get_var().m + (params[1].get_var().m + params[2].get_var().m)
            #  n, m
            # ))
            )
  
@rw_ruleset.register
def gen_stack_compose_distr():
  params : list[ZXParam] = [ 
            ZXParam('zx0', EvalSymbol('n0'), EvalSymbol('m0')),
            ZXParam('zx1', EvalSymbol('n1'), EvalSymbol('m1')),
            ZXParam('zx2', EvalSymbol('m0'), EvalSymbol('o0')),
            ZXParam('zx3', EvalSymbol('m1'), EvalSymbol('o1')),
          ]
  yield gen_zx_rule(params,
              (params[0].get_var().stack(params[1].get_var())).compose(params[2].get_var().stack(params[3].get_var())),
              (params[0].get_var().compose(params[2].get_var())).stack(params[1].get_var().compose(params[3].get_var())),
              'stack_compose_distr'
              )


In [298]:
rw_ruleset.rules

[Rewrite(ruleset=None, _lhs=Dim.lit(0) + a, _rhs=a, _conditions=(ExprFact(_expr=ne(a).to(Dim.lit(i1))),)),
 Rewrite(ruleset=None, _lhs=Dim.lit(i1) + Dim.lit(i2), _rhs=Dim.lit(i1 + i2), _conditions=()),
 Rule(head=(ExprAction(_expr=x.n), ExprAction(_expr=x.m)), body=(Eq(_exprs=[x, a.cast(n, m)]),), name='', ruleset=None),
 Rewrite(ruleset=None, _lhs=a.cast(n, m).n, _rhs=n, _conditions=()),
 Rewrite(ruleset=None, _lhs=a.cast(n, m).m, _rhs=m, _conditions=()),
 Rule(head=(ExprAction(_expr=x.n), ExprAction(_expr=x.m)), body=(Eq(_exprs=[x, a.compose(b)]),), name='', ruleset=None),
 Rewrite(ruleset=None, _lhs=a.compose(b).n, _rhs=a.n, _conditions=(Eq(_exprs=[a.m, b.n]),)),
 Rewrite(ruleset=None, _lhs=a.compose(b).m, _rhs=b.m, _conditions=(Eq(_exprs=[a.m, b.n]),)),
 Rule(head=(ExprAction(_expr=x.n), ExprAction(_expr=x.m)), body=(Eq(_exprs=[x, a.stack(b)]),), name='', ruleset=None),
 Rewrite(ruleset=None, _lhs=a.stack(b).n, _rhs=a.n + b.n, _conditions=()),
 Rewrite(ruleset=None, _lhs=a.stack(b)

In [299]:
proof_ruleset.rules



[Rule(head=(ExprAction(_expr=dim_err(a, b, a.m, b.n)),), body=(Eq(_exprs=[x, a.compose(b)]), ExprFact(_expr=ne(a.m).to(b.n))), name='', ruleset=None),
 Rule(head=(Set(_call=zx_path(a, y1.stack(y2)), _rhs=ZXProof.join(a, p1, p2)),), body=(Eq(_exprs=[a, x1.stack(x2)]), Eq(_exprs=[p1, zx_path(x1, y1)]), Eq(_exprs=[p2, zx_path(x2, y2)])), name='', ruleset=None),
 Rule(head=(Set(_call=zx_path(a, y1.stack(x2)), _rhs=ZXProof.join(a, p1, ZXProof.refl(x2))),), body=(Eq(_exprs=[a, x1.stack(x2)]), Eq(_exprs=[p1, zx_path(x1, y1)])), name='', ruleset=None),
 Rule(head=(Set(_call=zx_path(a, x1.stack(y2)), _rhs=ZXProof.join(a, ZXProof.refl(x1), p2)),), body=(Eq(_exprs=[a, x1.stack(x2)]), Eq(_exprs=[p2, zx_path(x2, y2)])), name='', ruleset=None),
 Rule(head=(Set(_call=zx_path(a, y1.compose(y2)), _rhs=ZXProof.join(a, p1, p2)),), body=(Eq(_exprs=[a, x1.compose(x2)]), Eq(_exprs=[p1, zx_path(x1, y1)]), Eq(_exprs=[p2, zx_path(x2, y2)])), name='', ruleset=None),
 Rule(head=(Set(_call=zx_path(a, y1.compose(x

In [300]:
egraph.let("Wire", ZX.symbol('-', register_const(1), register_const(1)))
egraph.let("Cap", ZX.symbol('Cap', register_const(2), register_const(0)))
egraph.let("Cup", ZX.symbol('Cup', register_const(0), register_const(2)))

wire, cap, cup = vars_("Wire Cap Cup", ZX)

# egraph.let("lhs", wire.stack(wire.compose(wire).stack(cup.compose(cap))).compose(ZX.nWire(register_const(2))))
# egraph.let("rhs", wire.cast(register_const(1) + register_const(0), register_const(1) + register_const(0)).stack(wire).stack(cup).compose(wire.stack(wire).stack(cap)))
# egraph.let("lhs", wire.stack(wire.stack(wire)))
# egraph.let("rhs", wire.stack(wire).stack(wire))
egraph.let("lhs", wire.compose(wire).compose(wire))
egraph.let("rhs", wire.compose(wire.compose(wire)))
# egraph.let("lhs", cup.cast(register_const(0), register_const(1) + register_const(1)))
# egraph.let("rhs", cup)
egraph.run(25, ruleset=rw_ruleset)  

@egraph.register
def create_goals():
  lhs, rhs = vars_("lhs rhs", ZX)
  yield goals(lhs, rhs)

egraph.run(5, ruleset=proof_ruleset)
# if proof_tree:
#   egraph.run(3, ruleset=proof_ruleset)


RunReport(True, {'(rule ((= a (ZX_compose x1 x2))\n       (= p2 (zx_path x2 y2))\n       (= v5___ (ZXProof_join a (ZXProof_refl x1) p2)))\n      ((set (zx_path a (ZX_compose x1 y2)) v5___))\n         )': datetime.timedelta(0), '(rule ((= s (zx_edge x y))\n       (= p (zx_path y z))\n       (!= x z)\n       (= v7___ (ZXProof_trans p x s)))\n      ((set (zx_path x z) v7___))\n         )': datetime.timedelta(0), '(rule ((= a (ZX_stack x1 x2))\n       (= p1 (zx_path x1 y1))\n       (= v1___ (ZXProof_join a p1 (ZXProof_refl x2))))\n      ((set (zx_path a (ZX_stack y1 x2)) v1___))\n         )': datetime.timedelta(0), '(rule ((= s (zx_edge x y))\n       (= p (zx_path y z))\n       (!= x z))\n      ((set (zx_path x z) (ZXProof_trans p x s)))\n         )': datetime.timedelta(0), '(rule ((= a (ZX_compose x1 x2))\n       (= p2 (zx_path x2 y2)))\n      ((set (zx_path a (ZX_compose x1 y2)) (ZXProof_join a (ZXProof_refl x1) p2)))\n         )': datetime.timedelta(0), '(rule ((= s (zx_edge y z))\n    

In [301]:
test_code = \
"""
(print-function zx_edge 100)
(print-function dim_err 100)
(extract (zx_path lhs rhs))
"""

In [302]:
with open('zxgen.egg', 'w') as f:
 f.write(egraph.as_egglog_string + test_code if debug else "Only on debug")