# Symbolic Differentiation 

In this notebook our goal is to implement *symbolic differentiation*.  Concretely, we want do implement a function `diff` that takes one argument:
  - The argument `expr` represents an *arithmetic expression*.
    Here an arithmetic expression is any string that is build from variable and numbers by application
    of any of the operator symbols "`+`", "`-`", "`*`", "`/`", and "`**`".
    The operator "`**`" represents exponentiation, i.e. an expression of the form 
    $a \texttt{**} b$ is interpreted as $a^b$.          
    Furthermore, if $e$ is an expression, then both $\exp(e)$ and $\ln(e)$ are expressions too.

The function call `diff(expr)` will then compute the derivative of `expr` with respect to the variable `x`.  For example, the function call 
`diff("x * exp(x)")` will compute the output
`1 * exp(x) + x * exp(x)` because we have:
$$ \frac{\mathrm{d}\;}{\mathrm{d}x} \bigl( x \cdot \mathrm{e}^x \bigr) = 1 \cdot \mathrm{e}^x + x \cdot \mathrm{e}^x. $$

This file shows the implementation of a program that can perform *symbolic differentiation* using `Lezer`.  The grammar for the language implemented by this parser is as follows:
$$
\begin{array}{lcl}
  \texttt{expr}    & \rightarrow & \;\texttt{expr}\; \texttt{'+'} \; \texttt{product}  \\
                   & \mid        & \;\texttt{expr}\; \texttt{'-'} \; \texttt{product}  \\
                   & \mid        & \;\texttt{product}                                  \\[0.2cm]
  \texttt{product} & \rightarrow & \;\texttt{product}\; \texttt{'*'} \;\texttt{factor} \\
                   & \mid        & \;\texttt{product}\; \texttt{'/'} \;\texttt{factor} \\
                   & \mid        & \;\texttt{factor}                                   \\[0.2cm]
  \texttt{factor}  & \rightarrow & \texttt{base} \;\texttt{'**'} \; \texttt{factor}    \\
                   & \mid        & \texttt{base}                                       \\[0.2cm]
  \texttt{base}    & \rightarrow & \texttt{'exp'} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'}     \\
                   & \mid        & \texttt{'ln'} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'}      \\
                   & \mid        & \texttt{'('} \; \texttt{expr} \;\texttt{')'}                     \\
                   & \mid        & \;\texttt{NUMBER}                                                \\
                   & \mid        & \;\texttt{'x'}                               
\end{array}
$$

In [None]:
import { buildParser } from '@lezer/generator';
import { TreeCursor } from '@lezer/common';

## Grammar Definition (Lezer)

To ensure the parser behaves exactly as our mathematical model predicts, we translate our formal EBNF grammar **nearly 1:1** into Lezer's syntax. This declarative approach allows us to define precedence and associativity directly through the structure of the rules.



### Key Structural Features:

* **Direct Mapping:** Every rule from the formal definition (Expression, Product, Factor, Base) is explicitly represented.
* **Left Recursion (Left-Associativity):** Rules for `expr` and `product` recurse on the left (e.g., `expr { BinaryExpr { expr Plus product } ... }`). This ensures that an expression like $10 - 5 - 2$ is correctly grouped as $(10 - 5) - 2$.
* **Right Recursion (Right-Associativity):** The `factor` rule for the power operator recurses on the right (`base Power factor`). This is crucial for mathematics, as $a^{b^c}$ is conventionally interpreted as $a^{(b^c)}$ (top-down).
* **Precedence via Hierarchy:** By nesting the rules (Expression $\rightarrow$ Product $\rightarrow$ Factor $\rightarrow$ Base), we naturally enforce that multiplication happens before addition, and exponentiation happens before multiplication.

In [None]:
const grammarDefinition = `
    @top Expression { Expr }

    @tokens {
        Number { "0" | $[1-9] $[0-9]* }
        "x" "exp" "ln" "**"
        "+" "-" "*" "/"
        space { $[ \t\n\r]+ }
        "(" ")"
    }

    Expr {
        Expr "+" Product |
        Expr "-" Product |
        Product
    }

    Product {
        Product "*" Factor |
        Product "/" Factor |
        Factor
    }

    Factor {
        Base "**" Factor |
        Base
    }

    Base {
        "exp" "(" Expr ")" |
        "ln"  "(" Expr ")" |
        "(" Expr ")"       |
        Number             |
        "x"
    }

    @skip { space }
`;

const parser = buildParser(grammarDefinition);

## Transformation to AST

The parser generates a Concrete Syntax Tree (CST). To transform this into our strictly typed AST, we utilize the generic `cstToAST` mapper. 

Since our Lezer grammar is **strictly recursive**, the structure of the parse tree already matches the intended hierarchy of our AST. We can therefore define the transformation rules declaratively without the need for manual loops or complex folding logic:

- **Leaf Transformation:** Terminal tokens like `Number` or `VarX` are directly mapped to `NumNode` or `VarNode`.
- **Recursive Binary Mapping:** Because the grammar is recursive, every `BinaryExpr` node consists of exactly three children: `[Left, Operator, Right]`. We simply destructure this array to create our AST nodes. The grammar’s recursion direction (left for `+`, right for `**`) automatically ensures the correct associativity.
- **Function Calls & Scoping:** Grammar nodes like `CallExp` or `ParenExpr` are used to unwrap nested expressions or map them to specific `CallNode` instances (e.g., for "exp" or "ln").
- **Error Handling:** We treat the syntax error node `⚠` as a regular rule that throws a meaningful exception, ensuring that invalid mathematical expressions are caught during the transformation phase.

In [None]:
const operators = ["+", "-", "*", "/", "**", "exp", "ln"];
const listVars: string[] = [];

In [None]:
import { AST, cst2ast } from "./CST2AST"

In [None]:
const input : string = "ln(exp(x)) / exp(x/x)"
const tree = parser.parse(input);
const ast = cst2ast(tree.cursor(), input, operators, listVars);
console.dir(ast, {depth: null});

## Symbolic Differentiation

Now that we have a structured AST, we can implement symbolic differentiation. We define a function `diff(e)` that takes an AST expression representing $f(x)$ and returns a new AST representing $f'(x)$.

### Implementation Logic

The logic is implemented via pattern matching on the AST node types and their operators. We apply standard differentiation rules recursively:

1. **Sum & Difference Rules:** The `+` and `-` cases implement the linearity of the derivative:
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x}\bigl(f(x) \pm g(x)\bigr) = \frac{\mathrm{d}\;}{\mathrm{d}x} f(x) \pm \frac{\mathrm{d}\;}{\mathrm{d}x} g(x) $$

2. **Product Rule:** When the node is a product (`*`), we apply the [Product Rule](https://en.wikipedia.org/wiki/Product_rule):
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x}\bigl(f(x) \cdot g(x)\bigr) = \left(\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)\right)\cdot g(x) + f(x) \cdot \left(\frac{\mathrm{d}\;}{\mathrm{d}x} g(x)\right) $$

3. **Quotient Rule:** For division (`/`), the implementation follows the [Quotient Rule](https://en.wikipedia.org/wiki/Quotient_rule):
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x}\left(\frac{f(x)}{g(x)}\right) = \frac{\displaystyle\left(\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)\right)\cdot g(x) - f(x) \cdot \left(\frac{\mathrm{d}\;}{\mathrm{d}x} g(x)\right)}{g(x) \cdot g(x)} $$

4. **Power Rule (General):** To differentiate an expression of the form $f(x)^{g(x)}$, we first rewrite it using the identity:
   $$ f(x)^{g(x)} = \exp\bigl(\ln\bigl(f(x)^{g(x)}\bigr)\bigr) = \exp\bigl(g(x) \cdot \ln(f(x))\bigr) $$
   Then, we recursively call `diff` on this rewritten expression. This works because our function can already handle `exp` and `ln`.

5. **Logarithm (Chain Rule):** For `ln(f(x))`, the derivative is computed as:
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x} \ln\bigl(f(x)\bigr) = \frac{\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)}{f(x)} $$

6. **Exponential (Chain Rule):** For `exp(f(x))`, the derivative is computed as:
   
   $$\frac{\mathrm{d}\;}{\mathrm{d}x} \exp\bigl(f(x)\bigr) = \left(\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)\right) \cdot \exp\bigl(f(x)\bigr)$$

7. **Base Cases:**
   - **Variable `x`:** Returns `1` (since $\frac{\mathrm{d}x}{\mathrm{d}x} = 1$).
   - **Numbers & Other Variables:** Assumed to be constants relative to $x$, returning `0`.
   
8. **Error Handling:** Finally, the function throws an error if it encounters a node type that is not handled above. Although our parser currently only generates valid arithmetic expressions, the `AST` type definition includes nodes from other contexts (like `AssignNode` from the calculator). This fallback ensures type safety and helps debug invalid manual AST constructions.

In [None]:
function diff(node: AST): AST {
    if (typeof node === 'number') return 0;
    if (typeof node === 'string') return node === "x" ? 1 : 0;

    if (Array.isArray(node)) {
        const op = node[0];
        const f  = node[1];
        const g  = node[2];

        if (f === undefined) throw new Error(`Invalid AST: Missing operand for ${op}`);
        const df = diff(f);
        
        if (op === "ln")  return ["/", df, f];          // f' / f
        if (op === "exp") return ["*", df, ["exp", f]]; // f' * exp(f)

        if (g === undefined) throw new Error(`Invalid AST: Missing second operand for ${op}`);
        const dg = diff(g);
        switch (op) {
            case "+": return ["+", df, dg];                                         // f' + g'
            case "-": return ["-", df, dg];                                         // f' - g'
            case "*": return ["+", ["*", df, g], ["*", f, dg]];                     // f'g + fg'
            case "/": return ["/", ["-", ["*", df, g], ["*", f, dg]], ["*", g, g]]; // (f'g - fg') / g^2
            case "**": return diff(["exp", ["*", g, ["ln", f]]]);                   // d/dx (exp(y * ln(x)))
        }
    }
    throw new Error(`Diff not implemented for node: ${JSON.stringify(node)}`);
}

## Pretty Printing

To present the results in a human-readable mathematical format, we convert the AST back into a string using a `prettyPrint` function.

We utilize a **`Map`** to store operator precedence.

### Logic
1.  **Precedence:** We assign a priority level to every operator (e.g., `*` > `+`). If a child node has lower precedence than its parent, it must be wrapped (e.g., `(a + b) * c`).
2.  **Associativity:** When precedence is equal, the position matters:
    * **Left-associative** operators (`-`, `/`) require parentheses on the **right** (e.g., `a - (b - c)`).
    * **Right-associative** operators (`**`) require parentheses on the **left** (e.g., `(a ** b) ** c`).


In [None]:
const PRECEDENCE = new Map<string, number>([["+", 1], ["-", 1], ["*", 2], ["/", 2], ["**", 3]]);

function prettyPrint(node: AST): string {
    // 1. Atoms (Numbers & Variables)
    if (!Array.isArray(node)) return String(node);
    const [op, l, r] = node;
    // 2. Functions (exp, ln)
    if (["exp", "ln"].includes(op)) return `${op}(${prettyPrint(l)})`;
    // 3. Binary Operators
    if (r !== undefined) {
        const p    = PRECEDENCE.get(op) ?? 0; // Default to 0 (weakest) if operator not found
        const wrap = (child: AST, isRight: boolean) => {
            let cp = 4; // Default to 4 (strongest) for atoms/functions, or lookup operator precedence
            if (Array.isArray(child)) cp = PRECEDENCE.get(child[0]) ?? 4;
            const  needParens = cp < p || (cp === p && ( isRight ? ["-", "/"].includes(op) : op === "**" ));
            return needParens ? `(${prettyPrint(child)})` : prettyPrint(child);
        };
        return `${wrap(l, false)} ${op} ${wrap(r, true)}`;
    }
    return String(node);
}

## Tests

We verify the implementation with several test cases, including basic polynomials, exponential functions, and compositions that require the chain rule.

In [None]:
function test(input: string) {
    try {
        console.log(`\nInput:  ${input}`);
        const tree = parser.parse(input);
        const ast = cst2ast(tree.cursor(), input, operators, listVars);
        const derived = diff(ast);
        console.log(`d/dx :  ${prettyPrint(derived)}`);
    } catch (e) {
        console.error(`Error: ${(e as Error).message}`);
    }
}

In [None]:
test("x ** x")

In [None]:
test("ln(exp(x)) / exp(x/x)")