$\newcommand{\To}{\Rightarrow}$
$\newcommand{\false}{\mathrm{false}}$

In [1]:
import os, sys
sys.path.append(os.path.split(os.getcwd())[0])

In [2]:
from kernel.type import TFun
from kernel.term import Term, Var
from logic import basic
from logic import matcher
from logic.proofterm import ProofTerm
from logic.conv import Conv, ConvException
from logic.nat import natT, plus, times, to_binary, add_conv
from syntax import printer

thy = basic.load_theory('nat')

## Conversions

With the macro system in place, we begin in earnest the study of proof automation. We start with rewriting: using equalities to transform a term to an equivalent term. The central concept in automation of rewriting is *conversions*. A conversion is a function taking a term $t$ as input, and returns a theorem of the form $t = t'$. In some sense, conversions can be viewed as macros, where the argument is a single term and there are no input sequents. However, we single out the concept of conversions because it has nice composition properties.

In Python, a conversion is created by inheriting from the class `Conv`. A conversion class needs to implement a `get_proof_term` function, which takes as inputs the current theory and a term, and (if the inputs are valid) returns a proof term rewriting the input term.

We consider the most basic example: rewriting using a theorem. We implement this as a class `rewr_conv_test`. Again, this is named to avoid conflicts with the actual `rewr_conv` class. The constructor for `rewr_conv_test` takes the name of the theorem, and an optional argument `sym` specifying whether the rewriting is performed left-to-right (`sym = False`) or right-to-left (`sym = True`).

In [3]:
class rewr_conv_test(Conv):
    def __init__(self, th_name, sym=False):
        self.th_name = th_name
        self.sym = sym
        
    def get_proof_term(self, thy, t):
        pt = ProofTerm.theorem(thy, self.th_name)
        if self.sym:
            pt = ProofTerm.symmetric(pt)
        try:
            tyinst, inst = matcher.first_order_match(pt.prop.lhs, t)
        except matcher.MatchException:
            raise ConvException
        if tyinst:
            pt = ProofTerm.subst_type(tyinst, pt)
        if inst:
            pt = ProofTerm.substitution(inst, pt)
        return pt

If the matching fails, the conversion raises `ConvException`. This is a standard exception that signals that the conversion is unable to act on the input. We test this conversion on a simple example. First, we create a conversion using theorem `nat_assoc`:

In [4]:
print(printer.print_thm(thy, thy.get_theorem('add_assoc')))
assoc_cv = rewr_conv_test('add_assoc')

|- x + y + z = x + (y + z)


We apply this conversion to a new term:

In [5]:
a = Var("a", natT)
t = plus(plus(a, to_binary(2)), to_binary(3))
print("t:", printer.print_term(thy, t))
pt = assoc_cv.get_proof_term(thy, t)
print("th:", printer.print_thm(thy, pt.th, unicode=True))

t: a + 2 + 3
th: ⊢ a + 2 + 3 = a + (2 + 3)


The resulting proof term can be verified as before.

In [6]:
prf = pt.export()
thy.check_proof(prf)
print(printer.print_proof(thy, prf, unicode=True))

0: ⊢ x + y + z = x + (y + z) by theorem add_assoc
1: ⊢ a + 2 + 3 = a + (2 + 3) by substitution {x: a, y: 2, z: 3} from 0


## Composition of conversions

An important feature of conversions is that they can be composed and modified by functions (in fact constructors of classes) that take one or more conversions as input, and return a new conversion. As a first example, we define functions for creating conversions that act on parts of a term:

In [7]:
class arg_conv_test(Conv):
    def __init__(self, cv):
        self.cv = cv
        
    def get_proof_term(self, thy, t):
        pt = self.cv.get_proof_term(thy, t.arg)
        return ProofTerm.combination(ProofTerm.reflexive(t.fun), pt)

Calling `arg_conv_test(cv)` creates a new conversion that applies `cv` to the argument of a term. Let's test this on a simple example (recall that the argument of a binary operation is the argument on the right).

In [8]:
cv = arg_conv_test(rewr_conv_test('add_comm'))

b = Var("b", natT)
t = plus(a, plus(b, to_binary(2)))
print("t:", printer.print_term(thy, t))
pt = cv.get_proof_term(thy, t)
print("th:", printer.print_thm(thy, pt.th, unicode=True))

t: a + (b + 2)
th: ⊢ a + (b + 2) = a + (2 + b)


Likewise, we can define a conversion combinator that applies the input conversion to the function part of a term:

In [9]:
class fun_conv_test(Conv):
    def __init__(self, cv):
        self.cv = cv
        
    def get_proof_term(self, thy, t):
        pt = self.cv.get_proof_term(thy, t.fun)
        return ProofTerm.combination(pt, ProofTerm.reflexive(t.arg))

With these, we can directly define the conversion combinator for applying a conversion to the left side of a binary operator:

In [10]:
def arg1_conv_test(cv):
    return fun_conv_test(arg_conv_test(cv))

Note carefully the order of application, which can be tricky on a first sight. We can test this function as follows:

In [11]:
cv = arg1_conv_test(rewr_conv_test('add_comm'))

t = plus(plus(a, b), to_binary(2))
print("t:", printer.print_term(thy, t))
pt = cv.get_proof_term(thy, t)
print("th:", printer.print_thm(thy, pt.th, unicode=True))

t: a + b + 2
th: ⊢ a + b + 2 = b + a + 2


Another way to combine conversions is to apply them in sequence. This is defined by the following function, which takes a list of conversions, and return a conversion that applies elements in the list in sequence. In the code below, the method `is_reflexive` checks whether a theorem is of the form $t = t$. This is used to simplify the resulting proof term as much as possible.

In [12]:
class every_conv_test(Conv):
    def __init__(self, *cvs):
        self.cvs = cvs
        
    def get_proof_term(self, thy, t):
        pt = ProofTerm.reflexive(t)
        for cv in self.cvs:
            pt2 = cv.get_proof_term(thy, pt.prop.rhs)
            if not pt2.th.is_reflexive():
                if pt.th.is_reflexive():
                    pt = pt2
                else:
                    pt = ProofTerm.transitive(pt, pt2)
        return pt

We now test this on the following example: given a term of the form $(a + b) + c$, rewrite it to $(a + c) + b$. This operation, called swap, is useful in many normalization procedures. This can be done by a chain of equalities:

$$(a + b) + c = a + (b + c) = a + (c + b) = (a + c) + b.$$

In words, we first rewrite using `add_assoc`, then rewrite on the argument using `add_comm`, and finally rewrite using `add_assoc` in the right-to-left direction. This can be implemented as follows:

In [13]:
swap_cv = every_conv_test(
    rewr_conv_test('add_assoc'),
    arg_conv_test(rewr_conv_test('add_comm')),
    rewr_conv_test('add_assoc', sym=True))

t = plus(plus(a, b), to_binary(2))
print("t:", printer.print_term(thy, t))
pt = swap_cv.get_proof_term(thy, t)
print("th:", printer.print_thm(thy, pt.th, unicode=True))

t: a + b + 2
th: ⊢ a + b + 2 = a + 2 + b


Let's check and print the proof:

In [14]:
prf = pt.export()
thy.check_proof(prf)
print(printer.print_proof(thy, prf, unicode=True))

0: ⊢ x + y + z = x + (y + z) by theorem add_assoc
1: ⊢ a + b + 2 = a + (b + 2) by substitution {x: a, y: b, z: 2} from 0
2: ⊢ plus a = plus a by reflexive plus a
3: ⊢ x + y = y + x by theorem add_comm
4: ⊢ b + 2 = 2 + b by substitution {x: b, y: 2} from 3
5: ⊢ a + (b + 2) = a + (2 + b) by combination from 2, 4
6: ⊢ a + b + 2 = a + (2 + b) by transitive from 1, 5
7: ⊢ x + (y + z) = x + y + z by symmetric from 0
8: ⊢ a + (2 + b) = a + 2 + b by substitution {x: a, y: 2, z: b} from 7
9: ⊢ a + b + 2 = a + 2 + b by transitive from 6, 8


## Rewriting on subterms

One task that we frequently encounter is using some equality to rewrite all subterms of a term. For example, suppose we obtained $f(a)=b$ in a proof, we wish to use it to rewrite $g(f(a))+f(a)$ to $g(b)+b$. This requires a recursive search on the structure of the term, performing the rewrite whenever possible.

First, we define a new conversion that simply replaces the left side of an equality by the right side, without performing matching. If the input term does not agree with the left side, it raises `ConvException`.

In [15]:
class replace_conv(Conv):
    def __init__(self, pt):
        self.pt = pt
        
    def get_proof_term(self, thy, t):
        if t == self.pt.prop.lhs:
            return self.pt
        else:
            raise ConvException

The conversion combinator `try_conv_test` attempts to apply a conversion. On failure, it returns the trivial equality:

In [16]:
class try_conv_test(Conv):
    def __init__(self, cv):
        self.cv = cv
        
    def get_proof_term(self, thy, t):
        try:
            return self.cv.get_proof_term(thy, t)
        except ConvException:
            return ProofTerm.reflexive(t)

We test this conversion on a simple example:

In [17]:
f = Var("f", TFun(natT, natT))
eq_pt = ProofTerm.assume(Term.mk_equals(f(a), b))
cv = replace_conv(eq_pt)

pt1 = cv.get_proof_term(thy, f(a))
print(printer.print_thm(thy, pt1.th, unicode=True))
pt2 = try_conv_test(cv).get_proof_term(thy, f(b))
print(printer.print_thm(thy, pt2.th, unicode=True))

f a = b ⊢ f a = b
⊢ f b = f b


The conversion combinator `sub_conv_test` applies a conversion on all subterms of the term. Currently we only consider the combination case.

In [18]:
class sub_conv_test(Conv):
    def __init__(self, cv):
        self.cv = cv
        
    def get_proof_term(self, thy, t):
        if t.is_comb():
            return every_conv_test(
                fun_conv_test(self.cv),
                arg_conv_test(self.cv)).get_proof_term(thy, t)
        else:
            return ProofTerm.reflexive(t)

Now, we implement the conversion combinator `top_conv_test` that tries to apply a conversion on all subterms of a term. This corresponds to the `top_conv` combinator in the actual library. The name `top_conv` will be explained later.

In [19]:
class top_conv_test(Conv):
    def __init__(self, cv):
        self.cv = cv
        
    def get_proof_term(self, thy, t):
        return every_conv_test(
            try_conv_test(self.cv),
            sub_conv_test(self)).get_proof_term(thy, t)

The code can be explained as follows: to apply `cv` to all subterms of a term, we first try to apply it to the term itself. If the term is a combination, then we recursively apply the conversion to all subterms of the function and argument of the combination.

We now test this function:

In [20]:
cv = top_conv_test(replace_conv(eq_pt))

g = Var("g", TFun(natT, natT))
t = plus(g(f(a)), f(a))
print("t:", printer.print_term(thy, t))
pt = cv.get_proof_term(thy, t)
print("th:", printer.print_thm(thy, pt.th, unicode=True))

t: g (f a) + f a
th: f a = b ⊢ g (f a) + f a = g b + b


There is one subtlety in the implementation of `top_conv_test`: applying the conversion on subterms of $t$ comes *after* applying the conversion on $t$ itself. This means if applying the conversion on $t$ results in a new term where the conversion can still act on some of its subterms, these actions will be performed. We will take advantage of this feature frequently in later sections. For now, we give a simple example illustrating this in action.

Consider rewriting using the distributivity theorem:

In [21]:
distrib_th = thy.get_theorem('distrib_l')
print(printer.print_thm(thy, distrib_th, unicode=True))
distrib_cv = rewr_conv_test("distrib_l")

⊢ x * (y + z) = x * y + x * z


Suppose we want to use this identity to rewrite $a\cdot ((b + c) + d)$ (we inserted parentheses that is usually omitted for clarity). After applying `distrib_cv` to the whole term, we get $a\cdot (b + c) + a\cdot d$. Note the first argument of this term can still be rewritten using the identity, resulting in $(a\cdot b + a\cdot c) + a\cdot d$. This means `top_conv_test` can apply this rewrite in one step:

In [22]:
c = Var("c", natT)
d = Var("d", natT)
t = times(a, plus(plus(b,c),d))
print("t:", printer.print_term(thy, t))
pt = top_conv_test(try_conv_test(distrib_cv)).get_proof_term(thy, t)
print("th:", printer.print_thm(thy, pt.th))

t: a * (b + c + d)
th: |- a * (b + c + d) = a * b + a * c + a * d


However, there are also times when it is better to rewrite on subterms first, then on the whole term. We give an example using `add_conv`. This conversion takes a term of the form $a + b$, where $a$ and $b$ are both constant natural numbers, and evaluates the arithmetic operation.

In [23]:
pt = add_conv().get_proof_term(thy, plus(to_binary(5), to_binary(7)))
print(printer.print_thm(thy, pt.th))

|- 5 + 7 = 12


The conversion is unable to do anything when either side is not a constant:

In [24]:
add_conv().get_proof_term(thy, plus(to_binary(5), plus(to_binary(7), to_binary(3))))

ConvException: 

Now, suppose we wish to evaluate an expression like $5 + (7 + 3)$, we need to first evaluate the subterms, then the term itself. This is opposite to the order performed by `top_conv_test`. We can implement this order as follows:

In [25]:
class bottom_conv_test(Conv):
    def __init__(self, cv):
        self.cv = cv
        
    def get_proof_term(self, thy, t):
        return every_conv_test(
            sub_conv_test(self),
            try_conv_test(self.cv)).get_proof_term(thy, t)

In [26]:
t = plus(to_binary(5), plus(to_binary(7), to_binary(3)))
pt = bottom_conv_test(try_conv_test(add_conv())).get_proof_term(thy, t)
print(printer.print_thm(thy, pt.th, unicode=True))

⊢ 5 + (7 + 3) = 15


The conversions `top_conv_test` and `bottom_conv_test` perform similar functionality but in different ways. Both are useful in different situations. Their names are explained by the fact that `top_conv_test` performs rewriting "top-down", while `bottom_conv_test` performs rewriting "bottom-up".

We have now studied implementations of basic conversions. In practice, all of these conversions are already in the library. In the next section, we show how to use the existing API for programming with conversions.