In [1]:
%%HTML
<style>
.container { width: 100% }
</style>

# Generating Abstract Syntax Trees

Our grammar is stored in the file `Differentiator.g4`.  The grammar describes arithmetical expression that contain variables.
The function symbols `ln` (natural logarithm) and `exp` (exponential) are supported.

In [2]:
!cat Differentiator.g4

grammar Differentiator;

expr returns [result]
    : e=expr '+' p=product {$result = ('+', $e.result, $p.result)}
    | e=expr '-' p=product {$result = ('-', $e.result, $p.result)}
    | product          {$result = $product.result                }    
    ;

product returns [result]
    : p=product '*' f=factor {$result = ('*', $p.result, $f.result)}
    | p=product '/' f=factor {$result = ('/', $p.result, $f.result)}
    | f=factor               {$result = $f.result                  }
    ;

factor returns [result]
    : '(' expr ')'       {$result = $expr.result;        }
    | 'exp' '(' expr ')' {$result = ('exp', $expr.result)}
    | 'ln'  '(' expr ')' {$result = ('ln' , $expr.result)}
    | VAR                {$result = $VAR.text            }
    | NUM                {$result = int($NUM.text)       }
    ;

VAR : [a-zA-Z][a-zA-Z0-9]*;
NUM : '0'|[1-9][0-9]*;
WS  : [ \t\n\r] -> skip; 


We start by generating both scanner and parser.  

In [3]:
!antlr4 -Dlanguage=Python3 Differentiator.g4

The files `CalculatorLexer.py` and `CalculatorParser.py` contain the generated scanner and parser, respectively.  We have to import these files.  Furthermore, the runtime of 
<span style="font-variant:small-caps;">Antlr</span>
needs to be imported.

In [4]:
from DifferentiatorLexer  import DifferentiatorLexer
from DifferentiatorParser import DifferentiatorParser
import antlr4

The function `main` prompts for an expression that is then parsed and differentiated with respect to the variable `x`. 

In [5]:
def main():
    parser        = DifferentiatorParser(None)
    parser.Values = {}
    line          = input('> ')
    while line != '':
        input_stream  = antlr4.InputStream(line)
        lexer         = DifferentiatorLexer(input_stream)
        token_stream  = antlr4.CommonTokenStream(lexer)
        parser.setInputStream(token_stream)
        term = parser.expr()
        d = diff(term.result)
        print(toString(d))
        line = input('> ')

The function `diff` takes the parse tree `e` of an arithmetic expression and differentiate this expressions e with respect to the variable `x`. 

In [6]:
def diff(e):
    "differentiate the expressions e with respect to the variable x"
    if e[0] == '+':
        f , g  = e[1:]
        fs, gs = diff(f), diff(g)
        return ('+', fs, gs)
    if e[0] == '-':
        f , g  = e[1:]
        fs, gs = diff(f), diff(g)
        return ('-', fs, gs)
    if e[0] == '*':
        f , g  = e[1:]
        fs, gs = diff(f), diff(g)
        return ('+', ('*', fs, g), ('*', f, gs))
    if e[0] == '/':
        f , g  = e[1:]
        fs, gs = diff(f), diff(g)
        return ('/', ('-', ('*', fs, g), ('*', f, gs)), ('*', g, g))
    if e[0] == 'ln':
        f  = e[1]
        fs = diff(f) 
        return ('/', fs, f)
    if e[0] == 'exp':
        f  = e[1]
        fs = diff(f) 
        return ('*', fs, e)
    if e == 'x':
        return '1'
    return 0

The function `toString` takes an arithmetical expression that is represented as a nested tuple and converts it into a string.

In [7]:
def toString(e):
    if e[0] == '+':
        f, g = e[1:]
        return toString(f) + ' + ' + toString(g)
    if e[0] == '-':
        f, g = e[1:]
        return toString(f) + ' - (' + toString(g) + ')'
    if e[0] == '*':
        f, g = e[1:]
        return parenString(f) + ' * ' + parenString(g)
    if e[0] == '/':
        f, g = e[1:]
        return parenString(f) + ' / (' + toString(g) + ')'
    if e[0] == 'ln':
        return 'ln(' + toString(e[1]) + ')'
    if e[0] == 'exp':
        return 'exp(' + toString(e[1]) + ')'
    return str(e)

Convert `e` into a string that is parenthesized if necessary.

In [8]:
def parenString(e):
    if e[0] in ['+', '-']:
        return '(' + toString(e) + ')'
    else:
        return toString(e)    

In [9]:
main()

> x * ln(exp(x))
1 * ln(exp(x)) + x * 1 * exp(x) / (exp(x))
> ln(exp(x * x))
(1 * x + x * 1) * exp(x * x) / (exp(x * x))
> 
