In [None]:
import { display } from "tslab";
import { readFileSync } from "fs";

const css = readFileSync("../style.css", "utf8");
display.html(`<style>${css}</style>`);

## Imports and Setup

In [None]:
import { buildParser } from '@lezer/generator';
import { Tree, TreeCursor } from '@lezer/common';
import { LRParser } from '@lezer/lr';
import { RecursiveSet, Tuple } from "recursive-set";

In [None]:
import { 
    AST, NumNode, VarNode, BinaryExpr, CallNode, 
    Operator, cstToAST, ParserConfig, ast2dot 
} from "./AST2Dot";

# Symbolic Differentiation

In this notebook, our goal is to implement **symbolic differentiation**. Concretely, we want to implement a function `diff` that takes one argument:

- The argument `expr` represents an *arithmetic expression*.

Here, an arithmetic expression is any string built from variables and numbers by applying any of the operator symbols `+`, `-`, `*`, `/`, and `**`.
The operator `**` represents exponentiation, i.e., an expression of the form $a \texttt{**} b$ is interpreted as $a^b$.
Furthermore, if $e$ is an expression, then both $\exp(e)$ and $\ln(e)$ are expressions too.

The function call `diff(expr)` will then compute the derivative of `expr` with respect to the variable `x`. For example, the function call
`diff("x * exp(x)")` will compute the output:
`1 * exp(x) + x * exp(x)`

This is because the product rule gives us:
\$ \frac{\mathrm{d}\;}{\mathrm{d}x} \bigl( x \cdot \mathrm{e}^x \bigr) = 1 \cdot \mathrm{e}^x + x \cdot \mathrm{e}^x. \$

This implementation uses **TypeScript** and **RecursiveSet** for value semantics (ensuring that structurally identical expressions are treated as equal). The parsing is handled by the **Lezer** parser generator.

The grammar for the language implemented by this parser is as follows:

$$
\begin{array}{lcl}
 \texttt{expr} & \rightarrow & \;\texttt{expr}\; \texttt{'+'} \; \texttt{product} \\
 & \mid & \;\texttt{expr}\; \texttt{'-'} \; \texttt{product} \\
 & \mid & \;\texttt{product} \\[0.2cm]
 \texttt{product} & \rightarrow & \;\texttt{product}\; \texttt{'*'} \;\texttt{factor} \\
 & \mid & \;\texttt{product}\; \texttt{'/'} \;\texttt{factor} \\
 & \mid & \;\texttt{factor} \\[0.2cm]
 \texttt{factor} & \rightarrow & \texttt{base} \;\texttt{'**'} \; \texttt{factor} \\
 & \mid & \texttt{base} \\[0.2cm]
 \texttt{base} & \rightarrow & \texttt{exp} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \texttt{ln} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \;\texttt{NUMBER} \\
 & \mid & \;\texttt{'x'} 
\end{array}
$$

***

## Grammar Definition (Lezer)

To ensure the parser behaves exactly as our mathematical model predicts, we translate our formal EBNF grammar **nearly 1:1** into Lezer's syntax. This declarative approach allows us to define precedence and associativity directly through the structure of the rules.



### Key Structural Features:

* **Direct Mapping:** Every rule from the formal definition (Expression, Product, Factor, Base) is explicitly represented.
* **Left Recursion (Left-Associativity):** Rules for `expr` and `product` recurse on the left (e.g., `expr { BinaryExpr { expr Plus product } ... }`). This ensures that an expression like $10 - 5 - 2$ is correctly grouped as $(10 - 5) - 2$.
* **Right Recursion (Right-Associativity):** The `factor` rule for the power operator recurses on the right (`base Power factor`). This is crucial for mathematics, as $a^{b^c}$ is conventionally interpreted as $a^{(b^c)}$ (top-down).
* **Precedence via Hierarchy:** By nesting the rules (Expression $\rightarrow$ Product $\rightarrow$ Factor $\rightarrow$ Base), we naturally enforce that multiplication happens before addition, and exponentiation happens before multiplication.

In [None]:
const grammarDefinition = `
    @top Expression { expr }

    expr {
        BinaryExpr { expr Plus product } |
        BinaryExpr { expr Minus product } |
        product
    }

    product {
        BinaryExpr { product Mul factor } |
        BinaryExpr { product Div factor } |
        factor
    }

    factor {
        BinaryExpr { base Power factor } |
        base
    }

    base {
        CallExp { Exp "(" expr ")" } |
        CallLn  { Ln  "(" expr ")" } |
        ParenExpr { "(" expr ")" } |
        Number |
        VarX
    }

    @tokens {
        Number { "0" | $[1-9] $[0-9]* }
        VarX { "x" }
        Exp { "exp" }
        Ln { "ln" }
        Power { "**" }
        Plus { "+" } Minus { "-" }
        Mul { "*" }  Div { "/" }
        space { $[ \t\n\r]+ }
        "(" ")"
    }
    @skip { space }
`;

const parser: LRParser = buildParser(grammarDefinition);

## Transformation to AST

The parser generates a Concrete Syntax Tree (CST). To transform this into our strictly typed AST, we utilize the generic `cstToAST` mapper. 

Since our Lezer grammar is **strictly recursive**, the structure of the parse tree already matches the intended hierarchy of our AST. We can therefore define the transformation rules declaratively without the need for manual loops or complex folding logic:

- **Leaf Transformation:** Terminal tokens like `Number` or `VarX` are directly mapped to `NumNode` or `VarNode`.
- **Recursive Binary Mapping:** Because the grammar is recursive, every `BinaryExpr` node consists of exactly three children: `[Left, Operator, Right]`. We simply destructure this array to create our AST nodes. The grammar’s recursion direction (left for `+`, right for `**`) automatically ensures the correct associativity.
- **Function Calls & Scoping:** Grammar nodes like `CallExp` or `ParenExpr` are used to unwrap nested expressions or map them to specific `CallNode` instances (e.g., for "exp" or "ln").
- **Error Handling:** We treat the syntax error node `⚠` as a regular rule that throws a meaningful exception, ensuring that invalid mathematical expressions are caught during the transformation phase.

In [None]:
const mathConfig: ParserConfig = {
    ignore: new Set(["(", ")"]),
    
    rules: {
        "Number": (_, text) => new NumNode(parseInt(text)),
        "VarX": () => new VarNode("x"),
        "Plus": () => new VarNode("+"), "Minus": () => new VarNode("-"),
        "Mul": () => new VarNode("*"),  "Div": () => new VarNode("/"),
        "Power": () => new VarNode("**"),
        "Exp": () => new VarNode("exp"),
        "Ln":  () => new VarNode("ln"),
        "BinaryExpr": ([left, op, right]) => {
            if (op instanceof VarNode) {
                return new BinaryExpr(left, op.name, right);
            }
            throw new Error("Expected operator in BinaryExpr");
        },

        "CallExp": ([_func, arg]) => new CallNode("exp", [arg]),
        "CallLn":  ([_func, arg]) => new CallNode("ln", [arg]),
        "ParenExpr": ([inner]) => inner,
        "⚠": (_, text) => { throw new Error(`Syntax Error: Unexpected '${text}'`); }
    }
};

function parseMath(input: string): AST {
    const tree = parser.parse(input);
    return cstToAST(tree.cursor(), input, mathConfig);
}

## Symbolic Differentiation

Now we implement the core function `diff` that takes an expression containing the variable `x` and computes its derivative.

### AST Construction Helpers

To keep the `diff` function readable and reduce the verbosity of the `new` keyword, we define several factory functions. These helpers allow us to construct the Abstract Syntax Tree (AST) using a notation that closely resembles standard arithmetic.

In [None]:
const Num = (n: number) => new NumNode(n);
const Add = (a: AST, b: AST) => new BinaryExpr(a, "+", b);
const Sub = (a: AST, b: AST) => new BinaryExpr(a, "-", b);
const Mul = (a: AST, b: AST) => new BinaryExpr(a, "*", b);
const Div = (a: AST, b: AST) => new BinaryExpr(a, "/", b);
const Ln  = (a: AST) => new CallNode("ln", [a]);
const Exp = (a: AST) => new CallNode("exp", [a]);

### Implementation Logic

The logic is implemented via pattern matching on the AST node types and their operators. We apply standard differentiation rules recursively:

1. **Sum & Difference Rules:** The `+` and `-` cases implement the linearity of the derivative:
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x}\bigl(f(x) \pm g(x)\bigr) = \frac{\mathrm{d}\;}{\mathrm{d}x} f(x) \pm \frac{\mathrm{d}\;}{\mathrm{d}x} g(x) $$

2. **Product Rule:** When the node is a product (`*`), we apply the [Product Rule](https://en.wikipedia.org/wiki/Product_rule):
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x}\bigl(f(x) \cdot g(x)\bigr) = \left(\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)\right)\cdot g(x) + f(x) \cdot \left(\frac{\mathrm{d}\;}{\mathrm{d}x} g(x)\right) $$

3. **Quotient Rule:** For division (`/`), the implementation follows the [Quotient Rule](https://en.wikipedia.org/wiki/Quotient_rule):
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x}\left(\frac{f(x)}{g(x)}\right) = \frac{\displaystyle\left(\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)\right)\cdot g(x) - f(x) \cdot \left(\frac{\mathrm{d}\;}{\mathrm{d}x} g(x)\right)}{g(x) \cdot g(x)} $$

4. **Power Rule (General):** To differentiate an expression of the form $f(x)^{g(x)}$, we first rewrite it using the identity:
   $$ f(x)^{g(x)} = \exp\bigl(\ln\bigl(f(x)^{g(x)}\bigr)\bigr) = \exp\bigl(g(x) \cdot \ln(f(x))\bigr) $$
   Then, we recursively call `diff` on this rewritten expression. This works because our function can already handle `exp` and `ln`.

5. **Logarithm (Chain Rule):** For `ln(f(x))`, the derivative is computed as:
   $$ \frac{\mathrm{d}\;}{\mathrm{d}x} \ln\bigl(f(x)\bigr) = \frac{\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)}{f(x)} $$

6. **Exponential (Chain Rule):** For `exp(f(x))`, the derivative is computed as:
   
   $$\frac{\mathrm{d}\;}{\mathrm{d}x} \exp\bigl(f(x)\bigr) = \left(\frac{\mathrm{d}\;}{\mathrm{d}x} f(x)\right) \cdot \exp\bigl(f(x)\bigr)$$

7. **Base Cases:**
   - **Variable `x`:** Returns `1` (since $\frac{\mathrm{d}x}{\mathrm{d}x} = 1$).
   - **Numbers & Other Variables:** Assumed to be constants relative to $x$, returning `0`.
   
8. **Error Handling:** Finally, the function throws an error if it encounters a node type that is not handled above. Although our parser currently only generates valid arithmetic expressions, the `AST` type definition includes nodes from other contexts (like `AssignNode` from the calculator). This fallback ensures type safety and helps debug invalid manual AST constructions.

In [None]:
function diff(e: AST): AST {
    if (e instanceof BinaryExpr) {
        const { left: f, right: g } = e; 
        const df = diff(f);
        const dg = diff(g);
        switch (e.op) {
            case "+": return Add(df, dg);
            case "-": return Sub(df, dg);
            case "*": return Add(Mul(df, g), Mul(f, dg));
            case "/": return Div(Sub(Mul(df, g), Mul(f, dg)), Mul(g, g));
            case "**": return diff(Exp(Mul(g, Ln(f))));
        }
    }
    if (e instanceof CallNode) {
        const f = e.args[0];
        const df = diff(f);
        if (e.fn === "ln")  return Div(df, f);
        if (e.fn === "exp") return Mul(df, e);
    }
    if (e instanceof NumNode) return Num(0);
    if (e instanceof VarNode) return e.name === "x" ? Num(1) : Num(0);
    throw new Error(`Diff not implemented for node: ${e.constructor.name}`);
}

## Pretty Printing

The default `toString()` method of our AST nodes is primarily designed for debugging the tree structure. To present the results of our differentiation in a human-readable format, we implement a `prettyPrint` function.

The main challenge in converting an AST back to a string is **minimizing parentheses**. We want to output `x * y + z` instead of `(x * y) + z`, but we must keep them in `x * (y + z)`.

### How it works:

1.  **Precedence Table:** We define a numerical priority for each operator (e.g., `**` is higher than `*`).
2.  **Recursive Wrapping:** The `wrap` helper function decides whether a child node needs parentheses based on two criteria:
    * **Precedence:** If the child's operator has a lower priority than the parent's operator (e.g., an addition inside a multiplication), we **must** wrap it.
    * **Associativity:** 
        * Operators like `-` and `/` are **left-associative**. If the same precedence appears on the right side, we wrap it (e.g., `a - (b - c)`).
        * The power operator `**` is **right-associative**. If the same precedence appears on the left side, we wrap it (e.g., `(a ** b) ** c`).

In [None]:
function prettyPrint(node: AST): string {
    if (node instanceof NumNode) return node.value.toString();
    if (node instanceof VarNode) return node.name;
    
    if (node instanceof CallNode) {
        const args = node.args.map(prettyPrint).join(", ");
        return `${node.fn}(${args})`;
    }
    
    if (node instanceof BinaryExpr) {
        const prec: Record<string, number> = { 
            "+": 1, "-": 1, 
            "*": 2, "/": 2, 
            "**": 3 
        };
        const opPrec = prec[node.op] ?? 0;
        const wrap = (child: AST, isRight: boolean) => {
            let childPrec = 4;
            if (child instanceof BinaryExpr) childPrec = prec[child.op] ?? 0;
            if (childPrec < opPrec) return `(${prettyPrint(child)})`;
            if (childPrec === opPrec) {
                if ((node.op === "-" || node.op === "/") && isRight) return `(${prettyPrint(child)})`;
                if (node.op === "**" && !isRight) return `(${prettyPrint(child)})`;
            }
            return prettyPrint(child);
        };
        return `${wrap(node.left, false)} ${node.op} ${wrap(node.right, true)}`;
    }
    return node.toString();
}

## Tests

We verify the implementation with several test cases, including basic polynomials, exponential functions, and compositions that require the chain rule.

In [None]:
function test(s: string) {
    try {
        const ast = parseMath(s);
        const d = diff(ast);
        console.log(`d/dx ${s} = ${prettyPrint(d)}`);
    } catch (e) {
        console.log(`Error processing "${s}": ${e instanceof Error ? e.message : e}`);
    }
}

In [None]:
test("x ** x")

In [None]:
test("ln(exp(x)) / exp(x/x)")