In [928]:
from egglog import *
from egglog.egraph import ActionLike, FactLike
from egglog.bindings import Eq
from __future__ import annotations
from collections import defaultdict
from enum import Enum
from dataclasses import dataclass
from typing import Iterator

In [929]:
debug = True
proof_tree = True
dim_errors = True

In [930]:
egraph = EGraph(save_egglog_string=debug)

In [931]:
class Dim(Expr):
    @classmethod
    def lit(cls, value: i64Like) -> Dim: ...

    @classmethod
    def var(cls, name: StringLike) -> Dim:  ...

    def __add__(self, other: Dim) -> Dim: ...

    def __mul__(self, other: Dim) -> Dim: ... 

    
  
if proof_tree:
  class DimProof(Expr):
    @classmethod
    def refl(cls, dim : Dim) -> DimProof:
      ...
    
    def trans(self, dim : Dim, name : StringLike) -> DimProof:
      ...
      
    @classmethod
    def join(cls, dim : Dim, prfl : DimProof, prfr : DimProof) -> DimProof:
      ...

In [932]:

egraph_consts : dict[int, any] = dict()

def register_const(const : int):
  if const in egraph_consts:
    return egraph_consts[const]
  name = f"___________int_const__{const}"
  egraph.let(name, Dim.lit(const))
  val = var(name, Dim)
  egraph_consts[const] = val
  return val

In [933]:
class ZX(Expr):
    @classmethod
    def symbol(cls, symbol : StringLike, n : Dim, m : Dim) -> ZX:
      ...
    
    def cast(self, n : Dim, m : Dim) -> ZX:
      ...
    
    def compose(self, other : ZX) -> ZX:
      ...
      
    def stack(self, other : ZX) -> ZX:
      ...
      
    @classmethod
    def Z(cls, n : Dim, m : Dim, alpha : Dim) -> ZX:
      ...
      
    @classmethod
    def X(cls, n : Dim, m : Dim, alpha : Dim) -> ZX:
      ...
     
    # TODO
    # def nStack(self, n : Dim) -> ZX:
    #   ...
      
    def nStack1(self, n : Dim) -> ZX:
      ...
    
    @classmethod
    def nWire(cls, n: Dim) -> ZX:
      ...
      
    @property
    def n(self) -> Dim:
      ...
      
    @property
    def m(self) -> Dim:
      ...

if proof_tree:   
  class ZXProof(Expr):
    @classmethod
    def refl(cls, zx : ZX) -> ZXProof:
      ...
    
    def trans(self, zx : ZX, name : StringLike) -> ZXProof:
      ...
      
    @classmethod
    def join(cls, zx : ZX, prfl : ZXProof, prfr : ZXProof) -> ZXProof:
      ...
    
    @classmethod
    def dim_proof(cls, prfn : DimProof, prfm : DimProof) -> ZXProof:
      ...
    


In [934]:
if proof_tree:
  dim_edge = relation('dim_edge', Dim, Dim, StringLike)
  zx_edge = relation('zx_edge', ZX, ZX, StringLike)
  dim_path = relation('dim_path', Dim, Dim, DimProof, i64)
  zx_path = relation('zx_path', ZX, ZX, ZXProof, i64)
  goals = relation('goals', ZX, ZX)
  

In [935]:
def matchvar(Type):
  return var('tobematched', Type)


In [936]:
def build_dim_rule(match : any, to : any, name : str, *conditions : FactLike):
  if proof_tree:
    conds = (eq(matchvar(Dim)).to(match),) + conditions
    return rule(*conds).then(dim_edge(match, to, name))
  else:
    condto = (to,) + conditions
    return rewrite(match).to(*condto)

In [937]:
@egraph.register
def dim_rules():
  a = var('a', Dim)
  i1 = var('i1', i64)
  i2 = var('i2', i64)
  yield build_dim_rule(Dim.lit(0) + a, a, 'simpl#add_0_l')
  yield build_dim_rule(Dim.lit(i1) + Dim.lit(i2), Dim.lit(i1 + i2), 'simpl#c-fold')

In [938]:
@egraph.register
def zx_fn_rules():
  a, b, x = vars_("a b x", ZX)
  s = var("s", String)
  n, m, alpha = vars_("n m alpha", Dim)
  one = register_const(1)
  constructors = [
    (a.cast(n, m), n, m, []),
    (a.compose(b), a.n, b.m, [eq(a.m).to(b.n)]),
    (a.stack(b), a.n + b.n, a.m + b.m, []),
    (ZX.Z(n, m, alpha), n, m, []),
    (ZX.X(n, m, alpha), n, m, []),
    (ZX.symbol(n=n, m=m, symbol=s), n, m, []),
    (a.nStack1(n), n, n, [eq(a.n).to(one), eq(a.m).to(one)]),
    (ZX.nWire(n), n, n, []),
  ]
  for constr in constructors:
    # force n and m to be calculated
    yield rule(eq(x).to(constr[0])).then(x.n, x.m)
    conds = tuple(constr[3])
    # define n
    yield rewrite(constr[0].n).to(constr[1], *conds)
    # define m
    yield rewrite(constr[0].m).to(constr[2], *conds)

In [939]:
if dim_errors:
  dim_err = relation('dim_err', ZX)
  @egraph.register
  def dim_err_rule():
    a, b, x = vars_("a b x", ZX)
    yield rule(eq(x).to(a.compose(b)), ne(a.m).to(b.n)).then(dim_err(a.compose(b)))

In [940]:
if proof_tree:
  def proof_rules(edge, path, Type, ProofType):
    x, y, z = vars_("x y z", Type)
    s = var("s", String)
    p = var("p", ProofType)
    d = var("d", i64)
    return [rule(edge(x,y,s)).then(path(x, y, ProofType.refl(y).trans(x, s), 1)),
     rule(edge(x, y, s), path(y, z, p, d)).then(path(x, z, p.trans(x, s), 1 + d))]
  
  def rule_for_bin_op(binop, path, Type, ProofType):
    a, x1, x2, y1, y2 = vars_("a x1 x2 y1 y2", Type)
    p1, p2 = vars_("p1 p2", ProofType)
    d1, d2 = vars_("d1 d2", i64)
    base_eq = eq(a).to(binop(x1, x2))
    path1 = path(x1, y1, p1, d1)
    path2 = path(x2, y2, p2, d2)
    return [
      rule(base_eq, path1, path2) \
        .then(path(a, binop(y2, y2), ProofType.join(a, p1, p2), (1 + (d1 + d2)))),
      rule(base_eq, path1) \
        .then(path(a, binop(y1, x2), ProofType.join(a, p1, ProofType.refl(x2)), 1 + d1)),
      rule(base_eq, path2) \
        .then(path(a, binop(x1, y2), ProofType.join(a, ProofType.refl(x1), p2), 1 + d2)),
    ]
  
  @egraph.register
  def dim_proof_rules():
    for rule in proof_rules(dim_edge, dim_path, Dim, DimProof):
      yield rule
  
  @egraph.register
  def zx_proof_rules():
    for rule in proof_rules(zx_edge, zx_path, ZX, ZXProof):
      yield rule

  @egraph.register
  def dim_binop_rules():
    for rule in rule_for_bin_op(lambda x, y  : x + y, dim_path, Dim, DimProof):
      yield rule
      
  @egraph.register
  def zx_binop_rules():
    binops = [lambda x, y  : x.stack(y), lambda x, y  : x.compose(y)]
    for op in binops:
      for rule in rule_for_bin_op(op, zx_path, ZX, ZXProof):
        yield rule
    
  @egraph.register
  def zx_cast_rules():
    zx = var("zx", ZX)
    n1, n2, m1, m2 = vars_("n1 n2 m1 m2", Dim)
    p1, p2 = vars_("p1 p2", DimProof)
    d1, d2 = vars_("d1 d2", i64)
    path1 = dim_path(n1, m1, p1, d1)
    path2 = dim_path(n2, m2, p2, d2)
    base_eq = eq(matchvar(ZX)).to(zx.cast(n1, n2))    
    yield rule(base_eq, path1, path2).then(zx_path(zx, zx.cast(m1, m2), ZXProof.dim_proof(p1, p2), (1 + (d1 + d2))))
    yield rule(base_eq, path1).then(zx_path(zx, zx.cast(m1, n2), ZXProof.dim_proof(p1, DimProof.refl(n2)), 1 + d1))
    yield rule(base_eq, path2).then(zx_path(zx, zx.cast(n1, m2), ZXProof.dim_proof(DimProof.refl(n1), p2), 1 + d2))

In [941]:
@dataclass(frozen=True)
class EvalDim:
  def __add__(self, other : EvalDim):
    return EvalAdd(self, other)
  
@dataclass(frozen=True)
class EvalSymbol(EvalDim):
  s : str

@dataclass(frozen=True)
class EvalInt(EvalDim):
  i : int
    
@dataclass(frozen=True)
class EvalAdd(EvalDim):
  a : EvalDim
  b : EvalDim

class ZXParam:
  name : str
  n : EvalDim 
  m : EvalDim
  
  def __init__(self, name : str, n : Dim, m : Dim) -> None:
    self.name = name
    self.n = n
    self.m = m
    
  def __eq__(self, other : ZXParam) -> bool:
    return self.name == other.name
  
  def __hash__(self) -> int:
    return self.name.__hash__()
  
  def get_var(self) -> ZX:
    return var(self.name, ZX)
  
class ParamConstrType(Enum):
  n = 0
  m = 1

def flatten_constr_type_to_property(zx : ZX, typ : ParamConstrType):
  return zx.n if typ == ParamConstrType.n else zx.m

In [942]:
def gen_constraints(params : list[ZXParam]) -> Iterator[Fact]:
  dim_to_param : dict[EvalDim, set[tuple[ZXParam, ParamConstrType]]] = defaultdict(set)

  for param in params:
    dim_to_param[param.n].add((param, ParamConstrType.n))
    dim_to_param[param.m].add((param, ParamConstrType.m))
  for dimparams in dim_to_param.values():
    listparams = list(dimparams) # Need ordering
    if len(listparams) > 1:
      base_param = listparams[0]
      for dimparam in listparams[1:]:
        yield eq(flatten_constr_type_to_property(base_param[0].get_var(), base_param[1])) \
              .to(flatten_constr_type_to_property(dimparam[0].get_var(), dimparam[1]))
        
        

In [943]:
def gen_rule(fromExpr : ZX, toExpr : ZX, name : str, *constraints : FactLike):
  x = matchvar(ZX)
  if proof_rules:
    return rule(eq(x).to(fromExpr), *constraints).then(zx_edge(fromExpr, toExpr, name))
  else:
    return rewrite(fromExpr).to(toExpr, *constraints)

In [944]:
def gen_zx_rule(params : list[ZXParam], fromExpr : ZX, toExpr : ZX, name : str):
  constraints = tuple(gen_constraints(params))
  return gen_rule(fromExpr, toExpr, name, *constraints)

In [945]:
@egraph.register
def cast_rules():
  param = ZXParam('zx', EvalSymbol('n'), EvalSymbol('m'))
  n, m = vars_("n m", Dim)
  yield gen_rule(param.get_var().cast(n, m), param.get_var(), 'cast_id',
           eq(param.get_var().n).to(n), eq(param.get_var().m).to(m))

In [947]:
@egraph.register
def gen_compose_assoc():
  params : list[ZXParam] = [ 
            ZXParam('zx0', EvalSymbol('n'), EvalSymbol('m')),
            ZXParam('zx1', EvalSymbol('m'), EvalSymbol('o')),
            ZXParam('zx2', EvalSymbol('o'), EvalSymbol('p'))
          ]
  yield gen_zx_rule(params, 
           params[0].get_var().compose(params[1].get_var().compose(params[2].get_var())), 
           params[0].get_var().compose(params[1].get_var()).compose(params[2].get_var()),
           'ComposeRules.compose_assoc')
  
@egraph.register
def gen_stack_assoc():
  params : list[ZXParam] = [ 
            ZXParam('zx0', EvalSymbol('n0'), EvalSymbol('m0')),
            ZXParam('zx1', EvalSymbol('n1'), EvalSymbol('m1')),
            ZXParam('zx2', EvalSymbol('n2'), EvalSymbol('m2'))
          ]
  yield gen_zx_rule(params, 
           params[0].get_var().stack(params[1].get_var().stack(params[2].get_var())), 
           (params[0].get_var().stack(params[1].get_var()).stack(params[2].get_var()))
           .cast(
             params[0].get_var().n + (params[1].get_var().n + params[2].get_var().n),
             params[0].get_var().m + (params[1].get_var().m + params[2].get_var().n)
            ),
           'StackRules.stack_assoc')
  
  yield gen_zx_rule(params, 
           params[0].get_var().stack(params[1].get_var()).stack(params[2].get_var()),
            (params[0].get_var().stack(params[1].get_var().stack(params[2].get_var())))
           .cast(
             (params[0].get_var().n + params[1].get_var().n) + params[2].get_var().n,
             (params[0].get_var().m + params[1].get_var().m) + params[2].get_var().n
            ),
           'StackRules.stack_assoc_back')
  
@egraph.register
def gen_stack_compose_distr():
  params : list[ZXParam] = [ 
            ZXParam('zx0', EvalSymbol('n0'), EvalSymbol('m0')),
            ZXParam('zx1', EvalSymbol('n1'), EvalSymbol('m1')),
            ZXParam('zx2', EvalSymbol('m0'), EvalSymbol('o0')),
            ZXParam('zx3', EvalSymbol('m1'), EvalSymbol('o1')),
          ]
  yield gen_zx_rule(params,
              (params[0].get_var().stack(params[1].get_var())).compose(params[2].get_var().stack(params[3].get_var())),
              (params[0].get_var().compose(params[2].get_var())).stack(params[1].get_var().compose(params[3].get_var())),
              'stack_compose_distr'
              )

In [None]:
with open('zxgen.egg', 'w') as f:
 f.write(egraph.as_egglog_string if debug else "Only on debug")