In [1]:
import os
os.chdir('..')

In [2]:
from kernel.type import TFun, NatType
from kernel.term import Var, Nat, Eq
from kernel import term_ord
from kernel.proofterm import refl
from kernel.macro import Macro
from kernel.theory import check_proof, register_macro
from kernel.proofterm import ProofTerm
from kernel.report import ProofReport
from logic import basic
from logic.conv import Conv, rewr_conv, arg_conv, arg1_conv, binop_conv, top_conv
from data import nat
from syntax.settings import settings

basic.load_theory('nat')
settings.unicode = True

## Normalization of polynomials

In this section, we give an extended example of conversions. Our goal is to normalize an expression on natural numbers consisting of addition, multiplication, and constant powers. This is done by "expanding" the expression into a polynomial and cancelling terms. For example, the expression

$$ (a + b) (a + 2b) $$

is normalized to

$$ a^2 + 2b^2 + 3ab. $$

The basic structure of the procedure is as follows: to normalize an expression with addition or multiplication at the root, we first normalize the two sides. Then, we only need to consider the problem of how to add or multiply two normalized expressions. This reduces the number of cases that we need to consider significantly.

We first define the concept of a "normal form". The requirements for this concept is: any expression can be rewritten into a normal form. Moreover, if two expressions can be made equal by applying the standard rules of addition and multiplication (commutativity, associativity, and distributivity), then the two expressions should have the same normal form.

The normal form is defined as follows: an atomic term is a term that is not a constant, nor an addition, multiplication, or constant power. Examples of atomic terms include $x$, $f(x)$, $x^y$, etc. We fix an ordering on the atomic terms. The specific ordering is not important as long as it is consistent. We use the function `term_ord.fast_compare` to compare two terms. The `fast_compare` function first compares the size of the two terms (according to the abstract syntax tree), then according to some lexicographic order. It returns one of $-1$, $0$, or $1$. For example:

In [3]:
x = Var('x', NatType)
y = Var('y', NatType)
f = Var('f', TFun(NatType, NatType))
print(term_ord.fast_compare(x, y))
print(term_ord.fast_compare(y, f(x)))

-1
-1


A monomial is an expression of the form $ca_1^{e_1}a_2^{e_2}\cdots a_k^{e_k}$, where $c$ is a natural number greater than $0$, each $a_i$ is an atomic term, the $a_i$'s are in sorted order (in particular they are distinct), and each $e_i$ is a constant natural number greater than $0$. We call $c$ the coefficient of the monomial, and $a_1^{e_1}a_2^{e_2}\cdots a_k^{e_k}$ to be the body of the monomial (considered to be $1$ when $k=0$).

A polynomial is a sum of monomials $m_1+m_2+\cdots+m_l$, where $m_i$ are in sorted order according to their body (in particular they are distinct). The case $l=0$ corresponds to the zero expression.

The overall goal of this section is then to design a conversion rewritting any expression to a polynomial as defined above.

We begin with the task of multiplying two monomials (in normal form). In particular, how to make sure the atoms are in sorted order. First, we consider the case without coefficients and exponents. Hence, our goal is to normalize an expression like $(a\cdot b)\cdot (a\cdot c)$ into $a\cdot a\cdot b\cdot c$. For this, we first implement multiplying a monomial by an atom.

In [4]:
class swap_mult_r(Conv):
    """Rewrite (a * b) * c to (a * c) * b."""
    def get_proof_term(self, t):
        return refl(t).on_rhs(
            rewr_conv('mult_assoc'),  # a * (b * c)
            arg_conv(rewr_conv('mult_comm')),  # a * (c * b)
            rewr_conv('mult_assoc', sym=True))  # (a * c) * b

In [5]:
class norm_mult_atom(Conv):
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg1.is_times():  # t is the form (a * b) * c
            cp = term_ord.fast_compare(t.arg1.arg, t.arg)  # compare b with c
            if cp > 0:  # if b > c, need to swap b with c
                return pt.on_rhs(
                    swap_mult_r(),   # (a * c) * b
                    arg1_conv(self))   # possibly move c further inside a
            else:  # if b <= c, atoms already ordered since we assume b is ordered.
                return pt
        else:  # t is of the form a * b
            cp = term_ord.fast_compare(t.arg1, t.arg)
            if cp > 0:  # if a > b, need to swap a and b
                return pt.on_rhs(rewr_conv('mult_comm'))
            else:
                return pt

We test this function on some examples:

In [6]:
def test_conv(cv, ts):
    for t in ts:
        print(refl(t).on_rhs(cv).prop)

test_conv(norm_mult_atom(), [
    (x * y * f(x)) * x,
    (x * y * f(x)) * y,
    y * x
])

x * y * f x * x = x * x * y * f x
x * y * f x * y = x * y * y * f x
y * x = x * y


The general case is then quite simple:

In [7]:
class norm_mult_monomial(Conv):
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg.is_times():  # t is of form a * (b * c)
            return pt.on_rhs(
                rewr_conv('mult_assoc', sym=True),  # (a * b) * c
                arg1_conv(self),  # merge terms in b into a
                norm_mult_atom())  # merge c into a * b
        else:
            return pt.on_rhs(norm_mult_atom())

This can be tested as follows:

In [8]:
test_conv(norm_mult_monomial(), [
    (x * y * f(x)) * (x * y * f(x)),
    (2 * x) * (3 * y)
])

x * y * f x * (x * y * f x) = x * x * y * y * f x * f x
2 * x * (3 * y) = 2 * x * y * 3


There are two aspects in which the algorithm still need to be improved. First, we need to collect (and multiply together) coefficients at the front. Second, we need to combine equal atoms into powers (e.g. $x\cdot x$ to $x^2$). We first implement a function to compare two terms of form $a_i^{e_i}$ and $a_j^{e_j}$ by their base:

In [9]:
def compare_atom(t1, t2):
    """Assume t1 and t2 are in the form a_i^{e_i} and a_j^{e_j},
    compare a_i with a_j."""
    return term_ord.fast_compare(t1.arg1, t2.arg1)

Next, we re-implement `norm_mult_atom` to take a product of $a_i^{e_i}$ instead if $a_i$'s, followed by a conversion `norm_mult_monomial_wo_coeff`, which assumes input in the form $(a_1^{e_1}\cdots a_k^{e_k})\cdot (b_1^{f_1}\cdots b_l^{f_l})$, where $k>0$ and $l>0$.

In [10]:
class norm_mult_atom(Conv):
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg1.is_times():  # t is of form (a * b) * c
            cp = compare_atom(t.arg1.arg, t.arg)  # compare b with c by their base
            if cp > 0:  # if b > c, need to swap b with c
                return pt.on_rhs(
                    swap_mult_r(),  # (a * c) * b
                    arg1_conv(self))   # possibly move c further inside a
            elif cp == 0:  # if b and c have the same base, combine the exponents
                return pt.on_rhs(
                    rewr_conv('mult_assoc'),  # a * (b^e1 * b^e2)
                    arg_conv(rewr_conv('nat_power_add', sym=True)),  # a * (b^(e1 + e2))
                    arg_conv(arg_conv(nat.nat_conv())))  # evaluate e1 + e2
            else:  # if b < c, atoms already ordered since we assume b is ordered.
                return pt
        else:  # t is of the form a * b
            cp = compare_atom(t.arg1, t.arg)  # compare a with b by their base
            if cp > 0:  # if a > b, need to swap a and b
                return pt.on_rhs(rewr_conv('mult_comm'))
            elif cp == 0:  # if a and b have the same base, combine the exponents
                return pt.on_rhs(
                    rewr_conv('nat_power_add', sym=True),
                    arg_conv(nat.nat_conv()))
            else:
                return pt

class norm_mult_monomial_wo_coeff(Conv):
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg.is_times():  # t is of form a * (b * c)
            return pt.on_rhs(
                rewr_conv('mult_assoc', sym=True),  # (a * b) * c
                arg1_conv(self),  # merge terms in b into a
                norm_mult_atom())  # merge c into a * b
        else:
            return pt.on_rhs(norm_mult_atom())

We now test this conversion on some examples:

In [11]:
x = Var('x', NatType)
y = Var('y', NatType)
z = Var('z', NatType)

test_conv(norm_mult_monomial_wo_coeff(), [
    (x ** 1 * y ** 2) * (x ** 2 * y ** 1),
    (y ** 2) * (x ** 3 * z ** 1)
])

x ^ (1::nat) * y ^ (2::nat) * (x ^ (2::nat) * y ^ (1::nat)) = x ^ (3::nat) * y ^ (3::nat)
y ^ (2::nat) * (x ^ (3::nat) * z ^ (1::nat)) = x ^ (3::nat) * y ^ (2::nat) * z ^ (1::nat)


Next, we implement the version of norm_mult_monomial with coefficients. This conversion assumes the input is in the form $(c\cdot a_1^{e_1}\cdots a_k^{e_k})\cdot (d\cdot b_1^{f_1}\cdots b_l^{f_l})$. Here we consider the case where $k=0$ or $l=0$.

In [12]:
class norm_mult_monomial(Conv):
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg1.is_number() and t.arg.is_number():  # c * d
            return pt.on_rhs(nat.nat_conv())
        elif t.arg1.is_number() and not t.arg.is_number():  # c * (d * body)
            return pt.on_rhs(
                rewr_conv('mult_assoc', sym=True),  # (c * d) * body
                arg1_conv(nat.nat_conv()))  # evaluate c * d
        elif not t.arg1.is_number() and t.arg.is_number():  # (c * body) * d
            return pt.on_rhs(rewr_conv('mult_comm'), self)  # d * (c * body)
        else:  # (c * body1) * (d * body2)
            return pt.on_rhs(
                rewr_conv('mult_assoc', sym=True),  # ((c * body1) * d) * body2
                arg1_conv(swap_mult_r()),  # ((c * d) * body1) * body2
                arg1_conv(arg1_conv(nat.nat_conv())),  # evaluate c * d
                rewr_conv('mult_assoc'),  # cd * (body1 * body2)
                arg_conv(norm_mult_monomial_wo_coeff()))

We test this conversion on some examples:

In [13]:
test_conv(norm_mult_monomial(), [
    (3 * (x ** 1 * y ** 2)) * (2 * (x ** 2 * y ** 1)),
    (1 * y ** 2) * (1 * (x ** 3 * z ** 1)),
    3 * Nat(5),
    (3 * (x ** 2)) * 5,
    3 * (5 * (x ** 2))
])

3 * (x ^ (1::nat) * y ^ (2::nat)) * (2 * (x ^ (2::nat) * y ^ (1::nat))) = 6 * (x ^ (3::nat) * y ^ (3::nat))
1 * y ^ (2::nat) * (1 * (x ^ (3::nat) * z ^ (1::nat))) = 1 * (x ^ (3::nat) * y ^ (2::nat) * z ^ (1::nat))
(3::nat) * 5 = 15
3 * x ^ (2::nat) * 5 = 15 * x ^ (2::nat)
3 * (5 * x ^ (2::nat)) = 15 * x ^ (2::nat)


Next, we consider the problem of adding two polynomials. The idea is the same as before: we sort the monomials by their body, and combine monomials with the same body. We first define the function comparing two monomials by their body.

In [14]:
def compare_monomial(t1, t2):
    """Assume t1 and t2 are in the form c1 * body1 and c2 * body2,
    compare body1 with body2."""
    if t1.is_number() and t2.is_number():
        return 0
    if t1.is_number() and not t2.is_number():
        return -1
    if not t1.is_number() and t2.is_number():
        return 1
    else:
        return term_ord.fast_compare(t1.arg, t2.arg)

Like `swap_mult_r`, we define swapping an addition:

In [15]:
class swap_add_r(Conv):
    """Rewrite (a + b) + c to (a + c) + b."""
    def get_proof_term(self, t):
        return refl(t).on_rhs(
            rewr_conv('add_assoc'),  # a + (b + c)
            arg_conv(rewr_conv('add_comm')),  # a + (c + b)
            rewr_conv('add_assoc', sym=True))  # (a + c) + b

Next, we define adding a sum of monomials with a monomial. Each monomial is assumed to be in the form $c\cdot \mathit{body}$. Note however that the body may be empty.

In [16]:
class norm_add_monomial(Conv):
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg1.is_plus():  # (a + b) + c
            cp = compare_monomial(t.arg1.arg, t.arg)  # compare b with c
            if cp > 0:  # if b > c, need to swap b with c
                return pt.on_rhs(
                    swap_add_r(),  # (a + c) + b
                    arg1_conv(self))  # possibly move c further into a
            elif cp == 0:  # if b and c have the same body, combine coefficients
                return pt.on_rhs(
                    rewr_conv('add_assoc'),  # a + (c1 * b + c2 * b)
                    arg_conv(rewr_conv('distrib_r', sym=True)), # a + (c1 + c2) * b
                    arg_conv(arg1_conv(nat.nat_conv())))  # evaluate c1 + c2
            else:  # if b < c, monomials are already sorted
                return pt
        else:  # a + b
            cp = compare_monomial(t.arg1, t.arg)  # compare a with b
            if cp > 0:  # if a > b, need to swap a with b
                return pt.on_rhs(rewr_conv('add_comm'))
            elif cp == 0:  # if b and c have the same body, combine coefficients
                if t.arg.is_number():
                    return pt.on_rhs(nat.nat_conv())
                else:
                    return pt.on_rhs(
                        rewr_conv('distrib_r', sym=True),
                        arg1_conv(nat.nat_conv()))
            else:
                return pt

Let's briefly test this conversion:

In [17]:
test_conv(norm_add_monomial(), [
    1 * y ** 1 + 1 * x ** 1,
    1 * x ** 1 + 1 * x ** 1,
    (1 * x ** 1 + 2 * y ** 1) + 2 * x ** 1,
    (1 * x ** 1 + 2 * y ** 1) + 2 * y ** 1,
    (1 * x ** 1 + 2 * y ** 1) + 3 * z ** 1,
    Nat(1) + 2,
    (1 + 1 * x ** 1) + 2,
])

1 * y ^ (1::nat) + 1 * x ^ (1::nat) = 1 * x ^ (1::nat) + 1 * y ^ (1::nat)
1 * x ^ (1::nat) + 1 * x ^ (1::nat) = 2 * x ^ (1::nat)
1 * x ^ (1::nat) + 2 * y ^ (1::nat) + 2 * x ^ (1::nat) = 3 * x ^ (1::nat) + 2 * y ^ (1::nat)
1 * x ^ (1::nat) + 2 * y ^ (1::nat) + 2 * y ^ (1::nat) = 1 * x ^ (1::nat) + 4 * y ^ (1::nat)
1 * x ^ (1::nat) + 2 * y ^ (1::nat) + 3 * z ^ (1::nat) = 1 * x ^ (1::nat) + 2 * y ^ (1::nat) + 3 * z ^ (1::nat)
(1::nat) + 2 = 3
1 + 1 * x ^ (1::nat) + 2 = 3 + 1 * x ^ (1::nat)


So far so good. Now we implement the conversion adding two polynomials. Note we need now need to consider the case where either polynomial may be zero.

In [18]:
class norm_add_polynomial(Conv):        
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg1.is_zero():
            return pt.on_rhs(rewr_conv('nat_plus_def_1'))
        elif t.arg.is_zero():
            return pt.on_rhs(rewr_conv('add_0_right'))
        elif t.arg.is_plus():  # t is of form a + (b + c)
            return pt.on_rhs(
                rewr_conv('add_assoc', sym=True),  # (a + b) + c
                arg1_conv(self),  # merge terms in b into a
                norm_add_monomial())  # merge c into a * b
        else:
            return pt.on_rhs(norm_add_monomial())

In [19]:
test_conv(norm_add_polynomial(), [
    (1 * x ** 1 + 1 * y ** 1) + (2 * x ** 1 + 3 * y ** 1),
    (1 * x ** 1 + 1 * y ** 1) + (2 * x ** 1 + 3 * z ** 1),
    0 + 2 * x,
    2 * y + 0,
    (1 + 1 * x ** 1) + (2 + 1 * y ** 1),
])

1 * x ^ (1::nat) + 1 * y ^ (1::nat) + (2 * x ^ (1::nat) + 3 * y ^ (1::nat)) = 3 * x ^ (1::nat) + 4 * y ^ (1::nat)
1 * x ^ (1::nat) + 1 * y ^ (1::nat) + (2 * x ^ (1::nat) + 3 * z ^ (1::nat)) = 3 * x ^ (1::nat) + 1 * y ^ (1::nat) + 3 * z ^ (1::nat)
0 + 2 * x = 2 * x
2 * y + 0 = 2 * y
1 + 1 * x ^ (1::nat) + (2 + 1 * y ^ (1::nat)) = 3 + 1 * x ^ (1::nat) + 1 * y ^ (1::nat)


To multiply two polynomials, we use distribution rule on left and right, reducing to the problem of multiplying monomials and adding polynomials. Note the special case where the given polynomial is zero.

In [20]:
class norm_mult_poly_monomial(Conv):
    """Multiply a polynomial a_1 + ... + a_n with a monomial c."""
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg1.is_plus():  # (a + b) * c
            return pt.on_rhs(
                rewr_conv('distrib_r'),  # a * c + b * c
                arg1_conv(self),  # process a * c
                arg_conv(norm_mult_monomial()), # process b * c
                norm_add_polynomial())  # add the results
        else:
            return pt.on_rhs(norm_mult_monomial())
        
class norm_mult_polynomials(Conv):
    """Multiply two polynomials."""
    def get_proof_term(self, t):
        pt = refl(t)
        if t.arg1.is_zero():
            return pt.on_rhs(rewr_conv('nat_times_def_1'))
        elif t.arg.is_zero():
            return pt.on_rhs(rewr_conv('mult_0_right'))
        elif t.arg.is_plus():  # a * (b + c)
            return pt.on_rhs(
                rewr_conv('distrib_l'), # a * b + a * c
                arg1_conv(self),  # process a * b
                arg_conv(norm_mult_poly_monomial()),  # process a * c
                norm_add_polynomial())
        else:
            return pt.on_rhs(norm_mult_poly_monomial())

This can be tested as follows:

In [21]:
test_conv(norm_mult_polynomials(), [
    (1 * x ** 1 + 1 * y ** 1) * (1 * x ** 1 + 1 * y ** 1),
    (1 * x ** 1) * (2 * x ** 2 + 2 * y ** 2),
    (1 * x ** 2 + 2 * y ** 1) * (3 * x ** 2),
    0 * (1 * x ** 1),
    (1 * x ** 1) * 0,
])

(1 * x ^ (1::nat) + 1 * y ^ (1::nat)) * (1 * x ^ (1::nat) + 1 * y ^ (1::nat)) = 1 * x ^ (2::nat) + 1 * y ^ (2::nat) + 2 * (x ^ (1::nat) * y ^ (1::nat))
1 * x ^ (1::nat) * (2 * x ^ (2::nat) + 2 * y ^ (2::nat)) = 2 * x ^ (3::nat) + 2 * (x ^ (1::nat) * y ^ (2::nat))
(1 * x ^ (2::nat) + 2 * y ^ (1::nat)) * (3 * x ^ (2::nat)) = 3 * x ^ (4::nat) + 6 * (x ^ (2::nat) * y ^ (1::nat))
0 * (1 * x ^ (1::nat)) = 0
1 * x ^ (1::nat) * 0 = 0


We now define the full simplification function. The conversion always writes its input into normal form as defined above. In particular, it will write an atom $x$ into $1\cdot x^1$, which is more complicated for humans but more useful for the procedures above. We will show how to fix this problem later.

In [22]:
class norm_full(Conv):
    def get_proof_term(self, t):
        pt = refl(t)
        if t.is_plus():
            return pt.on_rhs(binop_conv(self), norm_add_polynomial())
        elif t.is_times():
            return pt.on_rhs(binop_conv(self), norm_mult_polynomials())
        elif t.is_number():
            return pt
        elif t.is_nat_power() and t.arg.is_number():  # rewrite x ^ n to 1 * x ^ n
            return pt.on_rhs(rewr_conv('mult_1_left', sym=True))
        else:  # rewrite x to 1 * x ^ 1
            return pt.on_rhs(
                rewr_conv('nat_power_1', sym=True),
                rewr_conv('mult_1_left', sym=True))

In [23]:
test_conv(norm_full(), [
    (x + y) * (x + y),
    0 * (x + y),
    (x + y) * 0,
    (3 * x) * (2 * y) * (x + y),
    0 + x + y,
    1 + x + 2 + y,
])

(x + y) * (x + y) = 1 * x ^ (2::nat) + 1 * y ^ (2::nat) + 2 * (x ^ (1::nat) * y ^ (1::nat))
0 * (x + y) = 0
(x + y) * 0 = 0
3 * x * (2 * y) * (x + y) = 6 * (x ^ (1::nat) * y ^ (2::nat)) + 6 * (x ^ (2::nat) * y ^ (1::nat))
0 + x + y = 1 * x ^ (1::nat) + 1 * y ^ (1::nat)
1 + x + 2 + y = 3 + 1 * x ^ (1::nat) + 1 * y ^ (1::nat)


## Deciding equality

One of the major uses of normalization is to prove that two terms are equal. This can be implemented as a macro as follows. It is registered as `my_nat_norm` to avoid conflict with the existing `nat_norm` macro.

In [24]:
@register_macro('my_nat_norm')
class my_nat_norm_macro(Macro):
    def __init__(self):
        self.level = 1
        self.limit = 'nat_power_add'
        
    def get_proof_term(self, goal, prevs):
        assert goal.is_equals(), 'nat_norm: goal is not an equality'
        
        # Obtain the normalization of the two sides
        pt1 = refl(goal.lhs).on_rhs(norm_full())
        pt2 = refl(goal.rhs).on_rhs(norm_full())
        
        assert pt1.rhs == pt2.rhs, 'nat_norm: normalizations are not equal.'
        return pt1.transitive(pt2.symmetric())

We can construct a proof term using the macro as follows:

In [25]:
pt = ProofTerm('my_nat_norm', Eq((x + y) * (x + y), x ** 2 + 2 * x * y + y ** 2))
print(pt)

ProofTerm(⊢ (x + y) * (x + y) = x ^ (2::nat) + 2 * x * y + y ^ (2::nat))


If we just export the proof term, we see it has only one line:

In [26]:
print(pt.export())

0: ⊢ (x + y) * (x + y) = x ^ (2::nat) + 2 * x * y + y ^ (2::nat) by my_nat_norm (x + y) * (x + y) = x ^ (2::nat) + 2 * x * y + y ^ (2::nat)


However, checking the proof with report on shows it consists of many more steps.

In [27]:
rpt = ProofReport()
check_proof(pt.export(), rpt=rpt)
print(rpt)

Steps: 209
  Theorems:  13
  Primitive: 196
  Macro:     0
Theorems applied: nat_of_nat_def, add_assoc, nat_power_1, nat_power_add, add_1_left, mult_1_right, mult_1_left, add_comm, mult_comm, distrib_r, one_Suc, distrib_l, mult_assoc
Macros evaluated: 
Macros expanded: my_nat_norm
Gaps: []


## Simplification

One aspect that is still unsatisfactory is that the result of normalization is sometimes more complex than needed. For example, patterns like $1\cdot x^1$ appears in the expression, which should be simplified to $x$. This can be done quite simply, by applying `top_conv` on the two simplifications on the result:

In [28]:
class simp_full(Conv):
    def get_proof_term(self, t):
        return refl(t).on_rhs(
            norm_full(),
            top_conv(rewr_conv('mult_1_left')),
            top_conv(rewr_conv('nat_power_1')))

This conversion does something quite close to "simplification" by expanding the polynomial:

In [29]:
test_conv(simp_full(), [
    (x + y) * (x + y),
    (x + 2 * y) * (2 * x + y),
    (1 + 2 * x + 2) * (2 + y + 1),
])

(x + y) * (x + y) = x ^ (2::nat) + y ^ (2::nat) + 2 * (x * y)
(x + 2 * y) * (2 * x + y) = 2 * x ^ (2::nat) + 2 * y ^ (2::nat) + 5 * (x * y)
(1 + 2 * x + 2) * (2 + y + 1) = 9 + 6 * x + 3 * y + 2 * (x * y)
