# Core

> This is a proofs validator to help students and hobbyists do mathematical thinking and problem solving.
>
> It's for when you buy a math book from the local used book store, so you have a piece of software to use to explore the mathematical concepts alongside you that's fun and easy to use for anybody with a programming background. It's to help amateurs satisfy their mathematical curiosity.
>
> It's supposed to validate what you're doing in a "black box" kind of manner, and try to offer you guardrails enough that you can spot your mistakes and feel *reasonably* more confident you know what you're doing.
>
> Not perfectly confident, reasonably.


In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| export
from typing import Callable, List, Tuple, Union
from sympy import Equality, Unequality
from sympy.core.relational import Relational
from sympy.core.basic import Basic
from random import randint, seed
import sympy as sp

from IPython.display import display, Latex, HTML

import typing
import inspect

## Helper functions

In [None]:
#| export
# Basic types
Var = sp.Symbol
Const = sp.Number
Func = Callable[[Var], sp.Expr]
Goal = Union[Equality, Unequality, Relational]
Expression = Union[sp.Expr, Basic]

# TODO - build in a repr override to use latex

# Helper functions to create a more user-friendly interface
def variable(name: str) -> Var:
    return sp.Symbol(name)

def constant(value: Union[int, float]) -> Const:
    return sp.Number(value)

def equation(expr: str) -> sp.Expr:
    return sp.sympify(expr, evaluate=False)

def equals(lhs: sp.Expr, rhs: sp.Expr) -> Equality:
    return sp.Eq(lhs, rhs, evaluate=False)

def not_equals(lhs: Expression, rhs: Expression) -> Unequality:
    return sp.Ne(lhs, rhs, evaluate=False)

## Core functionality 

In [None]:
#| export
def make_examples(domain: str, #Domain of the example equation
                  N: int, #Number of examples
                  equation: str #Equation to generate examples for
                  ) -> List[Tuple[sp.Expr, sp.Expr]]: #List of input-output pairs
    """For a given domain and equation, select N examples and generate a list of N input-output pairs.
    Currently, the domain can be either 'real' or 'integer', and one variable is assumed."""
    examples = []
    if domain == 'real':
        # select N random real numbers
        for _ in range(N):
            x = sp.Symbol('x')
            x_val = randint(-100, 100)
            y_val = sp.sympify(equation).subs(x, x_val)
            examples.append((x_val, y_val))
    elif domain == 'integer':
        # select N random integers
        for _ in range(N):
            x = sp.Symbol('x', integer=True)
            x_val = randint(-100, 100)
            y_val = sp.sympify(equation).subs(x, x_val)
            examples.append((x_val, y_val))
    else:
        raise ValueError(f"Domain {domain} not supported.")
    return examples

In [None]:
#| hide
DocmentTbl(make_examples)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| domain | str | Domain of the example equation |
| N | int | Number of examples |
| equation | str | Equation to generate examples for |
| **Returns** | **typing.List[typing.Tuple[sympy.core.expr.Expr, sympy.core.expr.Expr]]** | **List of input-output pairs** |

In [None]:
examples = make_examples('real', 10, 'x**2')
print(examples)

[(93, 8649), (-49, 2401), (25, 625), (-100, 10000), (56, 3136), (-2, 4), (94, 8836), (32, 1024), (5, 25), (32, 1024)]


In [None]:
assert len(examples) == 10

In [None]:
test_fail(lambda: make_examples('complex', 10, 'x**2'), contains="Domain complex not supported.")

In [None]:
#| export
def prove(goal: Goal, #Goal to prove
          proof_func: Callable[..., Goal], #Proof function
          *args #Arguments to proof function
          ) -> bool: #True if proof succeeds, False otherwise
    """Prove a goal using a proof function and arguments.
    The proof function should take the goal as the last argument and return the derived result.
    The goal is proved if the derived result matches the goal."""
    try:
        derived_result = proof_func(*args)
        if goal == derived_result:
            display(Latex(f"$$\\text \\quad {sp.latex(goal)} \\quad Q.E.D.$$"))
            return True
        else:
            raise Exception(f"Derived result {derived_result} does not match goal {goal}")
    except Exception as e:
        print(f"Proof failed: {str(e)}")
        print("Check your assumptions and proof function for errors.")
        return False

In [None]:
#| hide
DocmentTbl(prove)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| goal | typing.Union[sympy.core.relational.Equality, sympy.core.relational.Unequality, sympy.core.relational.Relational] | Goal to prove |
| proof_func | typing.Callable[..., typing.Union[sympy.core.relational.Equality, sympy.core.relational.Unequality, sympy.core.relational.Relational]] | Proof function |
| args |  |  |
| **Returns** | **bool** | **True if proof succeeds, False otherwise** |

In [None]:
test_eq(prove(equals(variable('x'), constant(2)), equals, variable('x'), constant(2)), True)

<IPython.core.display.Latex object>

In [None]:
result = prove(equals(variable('x'), constant(2)), equals, variable('x'), constant(3))
test_ne(result, True)

Proof failed: Derived result Eq(x, 3) does not match goal Eq(x, 2)
Check your assumptions and proof function for errors.


In [None]:
#| export
def print_proof(proof: Callable[..., Goal], # the proof function
                 *args # the arguments to the proof function
                 ) -> None: # no return value
    """Print a proof step by step. Mostly used when defining a proof evaluation function.
    The proof function should take the goal as the last argument and return the derived result.
    The goal is proved if the derived result matches the goal.
    Comments do not support latex formatting, but the rest of the proof does."""
    # add all of the arguments to the local namespace with their existing names that are passed in.
    args_list = lambda args: [_arg for _arg in args]
    _printed = ['_formatted_comment','_i', '_arg', 'line', 'var_name','args_list', '_printed', 'var_value', 'proof', 'args', 'kwargs', 'hints']
    #build latex string progressively and render at the end
    for _i, _arg in enumerate(args_list(args)):
        exec(f"{_arg} = args[{_i}]", globals(), locals())
        _printed.append(str(_arg))
    for _i, line in enumerate(inspect.getsourcelines(proof)[0]):
        #print each varaible only once, and add opt outs
        line = line.strip()
        if line.startswith('#'):
            # Nicely formatted representation of the comment
            _formatted_comment = line[1:].strip()
            display(Latex(f"$${sp.latex(_formatted_comment)}$$"))
        if line.startswith('def'):
            continue
        if line.startswith('@'):
            continue
        if line.startswith('return'):
            line = line[7:]
        if line == '':
            continue
        exec(line, globals(), locals())
        # print(f"\nProof state after line {i+1}: {line}")
        for var_name, var_value in locals().items():
            if var_name in _printed:
                continue
            else:
                _printed.append(var_name)
            # if isinstance(var_value, (Var, Const, Func)):
            display(Latex(f"$${sp.latex(var_value)}$$"))

In [None]:
#| hide
DocmentTbl(print_proof)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| proof | typing.Callable[..., typing.Union[sympy.core.relational.Equality, sympy.core.relational.Unequality, sympy.core.relational.Relational]] | the proof function |
| args |  |  |
| **Returns** | **None** | **no return value** |

In [None]:
def print_proof_example():
    # This is a test function to show how print_proof works
    x = variable('x')
    # again
    y = variable('y')
    return equals(x, y)

latex = Latex(print_proof(print_proof_example))

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [None]:
#| export
def contradiction_proof(proof: Callable[..., Goal] # the proof function
                        ) -> Callable[..., Unequality]: # the wrapped contradiction proof function to evaluate
    """Wrap a proof function to prove a contradiction.
    The proof function should take the goal as the last argument and return the derived result.
    The goal is proved if the derived result matches the goal."""
    def wrapper(*args, **kwargs):
        hints = typing.get_type_hints(proof)
        if hints != {}:
            if hints.get('return') != Unequality:
                print(hints.get('return'))
                raise TypeError("Proof function must return Unequality")
        else:
            try:
                print_proof(proof, *args, **kwargs)
            except Exception as e:
                print(f"Error in proof function: {str(e)}")
                raise
            # there might be something weird around result that would prevent it from being printed if it's in the proof. check that later
            result = proof(*args, **kwargs)
            if not isinstance(result, Unequality):
                raise TypeError("Proof function must return Unequality")
            return result
    return wrapper

In [None]:
#| hide
DocmentTbl(contradiction_proof)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| proof | typing.Callable[..., typing.Union[sympy.core.relational.Equality, sympy.core.relational.Unequality, sympy.core.relational.Relational]] | the proof function |
| **Returns** | **typing.Callable[..., sympy.core.relational.Unequality]** | **the contradiction proof function to evaluate** |

In [None]:
#| export
def direct_proof(proof: Callable[..., Goal] # the proof function
                        ) -> Callable[..., Equality]: # the wrapped proof function to evaluate
    """Wrap a proof function to prove a direct proof.
    The proof function should take the goal as the last argument and return the derived result.
    The goal is proved if the derived result matches the goal."""
    def wrapper(*args, **kwargs):
        hints = typing.get_type_hints(proof)
        if hints != {}:
            if hints.get('return') != Equality:
                print(hints.get('return'))
                raise TypeError("Proof function must return Equality")
        else:
            try:
                print_proof(proof, *args, **kwargs)
            except Exception as e:
                print(f"Error in proof function: {str(e)}")
                raise
            result = proof(*args, **kwargs)
            if not isinstance(result, Equality):
                raise TypeError("Proof function must return Equality")
            return result
    return wrapper

In [None]:
#| hide
DocmentTbl(direct_proof)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| proof | typing.Callable[..., typing.Union[sympy.core.relational.Equality, sympy.core.relational.Unequality, sympy.core.relational.Relational]] | the proof function |
| **Returns** | **typing.Callable[..., sympy.core.relational.Equality]** | **the wrapped proof function to evaluate** |

Here's how to use a contradiction goal.

In [None]:
# Start by defining your domain
arbitrary_x = variable("x")
expression = arbitrary_x + 1

# Then define your goal
contradiction_goal = not_equals(expression, arbitrary_x)

@contradiction_proof
def proof_of_x_plus_one(x):
    # Given x, Assume x + 1 = x is true for arbitrary_x
    assumed_eq = equals(x + 1, x)

    # Calculate x + 1
    next = x + 1

    # Observing x + 1 != x, we have reached a contradiction
    return not_equals(next, assumed_eq.rhs)

In [None]:
assert prove(contradiction_goal, proof_of_x_plus_one, arbitrary_x)
# assert added for nbdev test

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
