In [341]:
from egglog import *
from egglog.egraph import ActionLike, FactLike
from egglog.bindings import Eq
from __future__ import annotations


In [342]:
debug = True
proof_tree = True
dim_errors = True

In [343]:
egraph = EGraph(save_egglog_string=debug)

In [344]:
class Dim(Expr):
    @classmethod
    def lit(cls, value: i64Like) -> Dim: ...

    @classmethod
    def var(cls, name: StringLike) -> Dim:  ...

    def __add__(self, other: Dim) -> Dim: ...

    def __mul__(self, other: Dim) -> Dim: ... 
    
  
if proof_tree:
  class DimProof(Expr):
    @classmethod
    def refl(cls, dim : Dim) -> DimProof:
      ...
    
    def trans(self, dim : Dim, name : StringLike) -> DimProof:
      ...
      
    @classmethod
    def join(cls, dim : Dim, prfl : DimProof, prfr : DimProof) -> DimProof:
      ...

In [345]:

egraph_consts : dict[int, any] = dict()

def register_const(const : int):
  if const in egraph_consts:
    return egraph_consts[const]
  name = f"___________int_const__{const}"
  egraph.let(name, Dim.lit(const))
  val = var(name, Dim)
  egraph_consts[const] = val
  return val

In [346]:
class ZX(Expr):
    @classmethod
    def symbol(cls, symbol : StringLike, n : Dim, m : Dim) -> ZX:
      ...
    
    def cast(self, n : Dim, m : Dim) -> ZX:
      ...
    
    def compose(self, other : ZX) -> ZX:
      ...
      
    def stack(self, other : ZX) -> ZX:
      ...
      
    @classmethod
    def Z(cls, n : Dim, m : Dim, alpha : Dim) -> ZX:
      ...
      
    @classmethod
    def X(cls, n : Dim, m : Dim, alpha : Dim) -> ZX:
      ...
     
    # TODO
    # def nStack(self, n : Dim) -> ZX:
    #   ...
      
    def nStack1(self, n : Dim) -> ZX:
      ...
    
    @classmethod
    def nWire(cls, n: Dim) -> ZX:
      ...
      
    @property
    def n(self) -> Dim:
      ...
      
    @property
    def m(self) -> Dim:
      ...

if proof_tree:   
  class ZXProof(Expr):
    @classmethod
    def refl(cls, zx : ZX) -> ZXProof:
      ...
    
    def trans(self, zx : ZX, name : StringLike) -> ZXProof:
      ...
      
    @classmethod
    def join(cls, zx : ZX, prfl : ZXProof, prfr : ZXProof) -> ZXProof:
      ...
    
    @classmethod
    def dim_proof(cls, prfn : DimProof, prfm : DimProof) -> ZXProof:
      ...
    


In [347]:
if proof_tree:
  dim_edge = relation('dim_edge', Dim, Dim, StringLike)
  zx_edge = relation('zx_edge', ZX, ZX, StringLike)
  dim_path = relation('dim_path', Dim, Dim, DimProof, i64)
  zx_path = relation('zx_path', ZX, ZX, ZXProof, i64)
  goals = relation('goals', ZX, ZX)
  

In [348]:
def build_dim_rule(match : any, to : any, name : str, *conditions : FactLike):
  if proof_tree:
    x = var('tobematched', Dim)
    conds = (eq(x).to(match),) + conditions
    return rule(*conds).then(dim_edge(match, to, name))
  else:
    condto = (to,) + conditions
    return rewrite(match).to(*condto)

In [349]:
@egraph.register
def dim_rules():
  a = var('a', Dim)
  i1 = var('i1', i64)
  i2 = var('i2', i64)
  yield build_dim_rule(Dim.lit(0) + a, a, 'simpl#add_0_l')
  yield build_dim_rule(Dim.lit(i1) + Dim.lit(i2), Dim.lit(i1 + i2), 'simpl#c-fold')

In [350]:
@egraph.register
def zx_fn_rules():
  a, b, x = vars_("a b x", ZX)
  s = var("s", String)
  n, m, alpha = vars_("n m alpha", Dim)
  one = register_const(1)
  constructors = [
    (a.cast(n, m), n, m, []),
    (a.compose(b), a.n, b.m, [eq(a.m).to(b.n)]),
    (a.stack(b), a.n + b.n, a.m + b.m, []),
    (ZX.Z(n, m, alpha), n, m, []),
    (ZX.X(n, m, alpha), n, m, []),
    (ZX.symbol(n=n, m=m, symbol=s), n, m, []),
    (a.nStack1(n), n, n, [eq(a.n).to(one), eq(a.m).to(one)]),
    (ZX.nWire(n), n, n, []),
  ]
  for constr in constructors:
    # force n and m to be calculated
    yield rule(eq(x).to(constr[0])).then(x.n, x.m)
    conds = tuple(constr[3])
    # define n
    yield rewrite(constr[0].n).to(constr[1], *conds)
    # define m
    yield rewrite(constr[0].m).to(constr[2], *conds)

In [351]:
if dim_errors:
  dim_err = relation('dim_err', ZX)
  @egraph.register
  def dim_err_rule():
    a, b, x = vars_("a b x", ZX)
    yield rule(eq(x).to(a.compose(b)), ne(a.m).to(b.n)).then(dim_err(a.compose(b)))

In [352]:
if proof_tree:
  def proof_rules(edge, path, Type, ProofType):
    x, y, z = vars_("x y z", Type)
    s = var("s", String)
    p = var("p", ProofType)
    d = var("d", i64)
    return [rule(edge(x,y,s)).then(path(x, y, ProofType.refl(y).trans(x, s), 1)),
     rule(edge(x, y, s), path(y, z, p, d)).then(path(x, z, p.trans(x, s), 1 + d))]
  
  def rule_for_bin_op(binop, path, Type, ProofType):
    a, x1, x2, y1, y2 = vars_("a x1 x2 y1 y2", Type)
    p1, p2 = vars_("p1 p2", ProofType)
    d1, d2 = vars_("d1 d2", i64)
    base_eq = eq(a).to(binop(x1, x2))
    path1 = path(x1, y1, p1, d1)
    path2 = path(x2, y2, p2, d2)
    return [
      rule(base_eq, path1, path2) \
        .then(path(a, binop(y2, y2), ProofType.join(a, p1, p2), (1 + (d1 + d2)))),
      rule(base_eq, path1) \
        .then(path(a, binop(y1, x2), ProofType.join(a, p1, ProofType.refl(x2)), 1 + d1)),
      rule(base_eq, path2) \
        .then(path(a, binop(x1, y2), ProofType.join(a, ProofType.refl(x1), p2), 1 + d2)),
    ]
  
  @egraph.register
  def dim_proof_rules():
    for rule in proof_rules(dim_edge, dim_path, Dim, DimProof):
      yield rule
  
  @egraph.register
  def zx_proof_rules():
    for rule in proof_rules(zx_edge, zx_path, ZX, ZXProof):
      yield rule

  @egraph.register
  def dim_binop_rules():
    for rule in rule_for_bin_op(lambda x, y  : x + y, dim_path, Dim, DimProof):
      yield rule
      
  @egraph.register
  def zx_binop_rules():
    binops = [lambda x, y  : x.stack(y), lambda x, y  : x.compose(y)]
    for op in binops:
      for rule in rule_for_bin_op(op, zx_path, ZX, ZXProof):
        yield rule
    

In [353]:
with open('zxgen.egg', 'w') as f:
 f.write(egraph.as_egglog_string if debug else "Only on debug")