In [None]:
import { display } from "tslab";
import { readFileSync } from "fs";

const css = readFileSync("../style.css", "utf8");
display.html(`<style>${css}</style>`);

## Imports and Setup

In [None]:
import { buildParser } from '@lezer/generator';
import { Tree, TreeCursor } from '@lezer/common';
import { LRParser } from '@lezer/lr';
import { RecursiveSet, Tuple } from "recursive-set";

In [None]:
import { 
    AST, NumNode, VarNode, BinaryExpr, CallNode, 
    cstToAST, ParserConfig, ast2dot 
} from "./AST2Dot";

# Symbolic Differentiation

In this notebook, our goal is to implement **symbolic differentiation**. Concretely, we want to implement a function `diff` that takes one argument:

- The argument `expr` represents an *arithmetic expression*.

Here, an arithmetic expression is any string built from variables and numbers by applying any of the operator symbols `+`, `-`, `*`, `/`, and `**`.
The operator `**` represents exponentiation, i.e., an expression of the form $a \texttt{**} b$ is interpreted as $a^b$.
Furthermore, if $e$ is an expression, then both $\exp(e)$ and $\ln(e)$ are expressions too.

The function call `diff(expr)` will then compute the derivative of `expr` with respect to the variable `x`. For example, the function call
`diff("x * exp(x)")` will compute the output:
`1 * exp(x) + x * exp(x)`

This is because the product rule gives us:
\$ \frac{\mathrm{d}\;}{\mathrm{d}x} \bigl( x \cdot \mathrm{e}^x \bigr) = 1 \cdot \mathrm{e}^x + x \cdot \mathrm{e}^x. \$

This implementation uses **TypeScript** and **RecursiveSet** for value semantics (ensuring that structurally identical expressions are treated as equal). The parsing is handled by the **Lezer** parser generator.

The grammar for the language implemented by this parser is as follows:

$$
\begin{array}{lcl}
 \texttt{expr} & \rightarrow & \;\texttt{expr}\; \texttt{'+'} \; \texttt{product} \\
 & \mid & \;\texttt{expr}\; \texttt{'-'} \; \texttt{product} \\
 & \mid & \;\texttt{product} \\[0.2cm]
 \texttt{product} & \rightarrow & \;\texttt{product}\; \texttt{'*'} \;\texttt{factor} \\
 & \mid & \;\texttt{product}\; \texttt{'/'} \;\texttt{factor} \\
 & \mid & \;\texttt{factor} \\[0.2cm]
 \texttt{factor} & \rightarrow & \texttt{base} \;\texttt{'**'} \; \texttt{factor} \\
 & \mid & \texttt{base} \\[0.2cm]
 \texttt{base} & \rightarrow & \texttt{exp} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \texttt{ln} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \;\texttt{NUMBER} \\
 & \mid & \;\texttt{'x'} 
\end{array}
$$

***

## Grammar Definition

### Specification of the Parser

We use **Lezer** to define the grammar. The grammar handles operator precedence (multiplication binds tighter than addition) and associativity (left-associative for `+`, `*`, right-associative for `**`).

In [None]:
const grammarDefinition = `
    @top Expression { expr }

    expr    { Product ((Plus | Minus) Product)* }
    Product { Factor ((Mul | Div) Factor)* }
    Factor  { PowerExpr { Base (Power Factor)? } }

    Base {
        CallExp { Exp "(" expr ")" } |
        CallLn  { Ln "(" expr ")" } |
        ParenExpr { "(" expr ")" } |
        Number |
        VarX
    }

    @tokens {
        Number { "0" | $[1-9] $[0-9]* }
        VarX { "x" }
        Exp { "exp" }
        Ln { "ln" }
        Power { "**" }
        Plus { "+" } Minus { "-" }
        Mul { "*" }  Div { "/" }
        space { $[ \t\n\r]+ }
        "(" ")"
    }
    @skip { space }
`;

const parser: LRParser = buildParser(grammarDefinition);

## Transformation to AST

The parser generates a Concrete Syntax Tree (CST). To transform this into our strictly typed AST, we utilize the generic `cstToAST` mapper (imported from `AST2Dot`).

Instead of writing a manual traversal function with a large switch-statement, we define the transformation rules declaratively in a `ParserConfig` object:

1.  **Leaf Transformation:** Tokens like `Number` or `VarX` are directly mapped to `NumNode` or `VarNode`.
2.  **Operator Folding (`reduceBinary`):** Lezer parses chains of same-precedence operators (e.g., `1 + 2 + 3`) as a flat list of children: `[Expr, "+", Expr, "+", Expr]`. The `reduceBinary` helper function iterates through this list and folds it into a left-associative tree structure: `((1 + 2) + 3)`.
3.  **Function Calls:** Specific grammar nodes like `CallExp` are mapped to `CallNode` instances (e.g., with function name "exp").

In [None]:
function reduceBinary(children: AST[], _text: string): AST {
    let left = children[0];
    for (let i = 1; i < children.length; i += 2) {
        const opNode = children[i];
        if (!(opNode instanceof VarNode)) throw new Error("Expected Operator");
        left = new BinaryExpr(left, opNode.name, children[i + 1]);
    }
    return left;
}

const mathConfig: ParserConfig = {
    ignore: new Set(["(", ")"]),
    rules: {
        "Number": (_, text) => new NumNode(parseInt(text)),
        "VarX": () => new VarNode("x"),
        "Plus": () => new VarNode("+"), "Minus": () => new VarNode("-"),
        "Mul": () => new VarNode("*"),  "Div": () => new VarNode("/"),
        "Power": () => new VarNode("**"),
        "Exp": () => new VarNode("exp"),
        "Ln":  () => new VarNode("ln"),

        "Expression": (children) => children[0],
        "expr": reduceBinary,
        "Product": reduceBinary,
        
        "PowerExpr": (children) => {
            if (children.length === 3) return new BinaryExpr(children[0], "**", children[2]);
            return children[0];
        },
        "CallExp": (children) => new CallNode("exp", [children[1]]), 
        "CallLn": (children) => new CallNode("ln", [children[1]]),
        
        "ParenExpr": (children) => children[0]
    }
};

function parseMath(input: string): AST {
    const tree = parser.parse(input);
    return cstToAST(tree.cursor(), input, mathConfig);
}

## Symbolic Differentiation

Now we implement the core function `diff` that takes an expression containing the variable `x` and computes its derivative.

The logic is implemented via pattern matching on the tuple structure. We apply standard differentiation rules:

- **Sum Rule:** $\frac{d}{dx}(u + v) = u' + v'$
- **Product Rule:** $\frac{d}{dx}(u \cdot v) = u'v + uv'$
- **Quotient Rule:** $\frac{d}{dx}\left(\frac{u}{v}\right) = \frac{u'v - uv'}{v^2}$
- **Chain Rule:** Used implicitly for functions like $\exp(u)$ and $\ln(u)$.
- **Power Rule:**
    - For $x^n$ (variable base, constant exponent), we use $n \cdot x^{n-1}$.
    - For general cases $u^v$, we rewrite it as $\exp(v \cdot \ln(u))$ and differentiate using the chain rule.


In [None]:
function diff(e: AST): AST {
    if (e instanceof NumNode) {
        return new NumNode(0);
    }
    if (e instanceof VarNode) {
        return e.name === "x" ? new NumNode(1) : new NumNode(0);
    }
    if (e instanceof BinaryExpr) {
        const isNum = (n: AST, v: number) => n instanceof NumNode && n.value === v;
        const isZero = (n: AST) => isNum(n, 0);
        const isOne = (n: AST) => isNum(n, 1);
        
        switch (e.op) {
            case "+": return new BinaryExpr(diff(e.left), "+", diff(e.right));
            case "-": return new BinaryExpr(diff(e.left), "-", diff(e.right));
            case "*": 
                return new BinaryExpr(
                    new BinaryExpr(e.left, "*", diff(e.right)),
                    "+",
                    new BinaryExpr(diff(e.left), "*", e.right)
                );
            
            case "/":
                return new BinaryExpr(
                    new BinaryExpr(
                        new BinaryExpr(diff(e.left), "*", e.right),
                        "-",
                        new BinaryExpr(e.left, "*", diff(e.right))
                    ),
                    "/",
                    new BinaryExpr(e.right, "*", e.right)
                );

            case "**":
                const base = e.left;
                const exp = e.right;
                if (exp instanceof NumNode) {
                    const n = exp.value;
                    if (n === 0) return new NumNode(0);
                    if (n === 1) return diff(base);
                    const newExp = new NumNode(n - 1);
                    const powerRule = new BinaryExpr(new NumNode(n), "*", new BinaryExpr(base, "**", newExp));
                    return new BinaryExpr(powerRule, "*", diff(base));
                }
                const g = new BinaryExpr(exp, "*", new CallNode("ln", [base]));
                return diff(new CallNode("exp", [g]));
        }
    }
    if (e instanceof CallNode) {
        const arg = e.args[0];
        const argPrime = diff(arg);
        if (e.fn === "exp") return new BinaryExpr(e, "*", argPrime);
        if (e.fn === "ln") return new BinaryExpr(argPrime, "/", arg);
    }
    throw new Error(`Diff not implemented for node: ${e.constructor.name}`);
}

## Tests

We verify the implementation with several test cases, including basic polynomials, exponential functions, and compositions that require the chain rule.

In [None]:
function test(s: string) {
    console.log(`In:  ${s}`);
    try {
        const d = diff(parseMath(s));
        console.log(`Out: ${d}`);
    } catch (e) {
        console.log(`Err: ${e instanceof Error ? e.message : e}`);
    }
    console.log("-".repeat(20));
}

In [None]:
test("x ** x")

In [None]:
test("x * ln(x) / exp(x/x)")