In [None]:
import { display } from "tslab";
import { readFileSync } from "fs";

const css = readFileSync("../style.css", "utf8");
display.html(`<style>${css}</style>`);

## Imports and Setup

In [None]:
import { buildParser } from '@lezer/generator';
import { Tree, TreeCursor } from '@lezer/common';
import { LRParser } from '@lezer/lr';
import { RecursiveSet, Tuple } from "recursive-set";

# Symbolic Differentiation

In this notebook, our goal is to implement **symbolic differentiation**. Concretely, we want to implement a function `diff` that takes one argument:

- The argument `expr` represents an *arithmetic expression*.

Here, an arithmetic expression is any string built from variables and numbers by applying any of the operator symbols `+`, `-`, `*`, `/`, and `**`.
The operator `**` represents exponentiation, i.e., an expression of the form $a \texttt{**} b$ is interpreted as $a^b$.
Furthermore, if $e$ is an expression, then both $\exp(e)$ and $\ln(e)$ are expressions too.

The function call `diff(expr)` will then compute the derivative of `expr` with respect to the variable `x`. For example, the function call
`diff("x * exp(x)")` will compute the output:
`1 * exp(x) + x * exp(x)`

This is because the product rule gives us:
\$ \frac{\mathrm{d}\;}{\mathrm{d}x} \bigl( x \cdot \mathrm{e}^x \bigr) = 1 \cdot \mathrm{e}^x + x \cdot \mathrm{e}^x. \$

This implementation uses **TypeScript** and **RecursiveSet** for value semantics (ensuring that structurally identical expressions are treated as equal). The parsing is handled by the **Lezer** parser generator.

The grammar for the language implemented by this parser is as follows:

$$
\begin{array}{lcl}
 \texttt{expr} & \rightarrow & \;\texttt{expr}\; \texttt{'+'} \; \texttt{product} \\
 & \mid & \;\texttt{expr}\; \texttt{'-'} \; \texttt{product} \\
 & \mid & \;\texttt{product} \\[0.2cm]
 \texttt{product} & \rightarrow & \;\texttt{product}\; \texttt{'*'} \;\texttt{factor} \\
 & \mid & \;\texttt{product}\; \texttt{'/'} \;\texttt{factor} \\
 & \mid & \;\texttt{factor} \\[0.2cm]
 \texttt{factor} & \rightarrow & \texttt{base} \;\texttt{'**'} \; \texttt{factor} \\
 & \mid & \texttt{base} \\[0.2cm]
 \texttt{base} & \rightarrow & \texttt{exp} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \texttt{ln} \; \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \texttt{'('} \; \texttt{expr} \;\texttt{')'} \\
 & \mid & \;\texttt{NUMBER} \\
 & \mid & \;\texttt{'x'} 
\end{array}
$$

***

## Abstract Syntax Tree (AST) Definition

We define our AST nodes as immutable tuples using `RecursiveSet`. This guarantees **value semantics**: two expressions `x + 1` created independently will be equal by reference if their content is identical.

We define a base class `ExprNode` inheriting from `Tuple`. Since TypeScript requires strict typing, we define a recursive type `ASTElement` to constrain the tuple contents.


In [None]:
type ASTElement = string | number | ExprNode;
interface ITupleAccess {
    get(index: number): ASTElement;
    length: number;
}
abstract class ExprNode extends Tuple<ASTElement[]> {
    abstract precedence(): number;
    protected getItem(index: number): ASTElement {
        return (this as unknown as ITupleAccess).get(index);
    }
}
class Num extends ExprNode {
    constructor(val: number) {
        super("num", val);
    }
    get value(): number {
        return this.getItem(1) as number;
    }
    precedence() {
        return 4;
    }
    toString() {
        return String(this.value);
    }
}
class VarX extends ExprNode {
    constructor() {
        super("x");
    }
    precedence() {
        return 4;
    }
    toString() {
        return "x";
    }
}
class ExpFunc extends ExprNode {
    constructor(arg: ExprNode) {
        super("exp", arg);
    }
    get arg(): ExprNode {
        return this.getItem(1) as ExprNode;
    }
    precedence() {
        return 4;
    }
    toString() {
        return `exp(${this.arg})`;
    }
}
class LnFunc extends ExprNode {
    constructor(arg: ExprNode) {
        super("ln", arg);
    }
    get arg(): ExprNode {
        return this.getItem(1) as ExprNode;
    }
    precedence() {
        return 4;
    }
    toString() {
        return `ln(${this.arg})`;
    }
}
abstract class BinOp extends ExprNode {
    constructor(op: string, left: ExprNode, right: ExprNode) {
        super(op, left, right);
    }
    get op(): string {
        return this.getItem(0) as string;
    }
    get left(): ExprNode {
        return this.getItem(1) as ExprNode;
    }
    get right(): ExprNode {
        return this.getItem(2) as ExprNode;
    }
    toString() {
        const p = this.precedence();
        let l = this.left.toString();
        let r = this.right.toString();

        if (this.left.precedence() < p) l = `(${l})`;

        if (this.op === "**") {
            if (this.right.precedence() < p) r = `(${r})`;
        } else {
            if (this.right.precedence() <= p) r = `(${r})`;
        }
        return `${l} ${this.op} ${r}`;
    }
}
class Add extends BinOp {
    constructor(l: ExprNode, r: ExprNode) {
        super("+", l, r);
    }
    precedence() {
        return 1;
    }
}
class Sub extends BinOp {
    constructor(l: ExprNode, r: ExprNode) {
        super("-", l, r);
    }
    precedence() {
        return 1;
    }
}
class Mul extends BinOp {
    constructor(l: ExprNode, r: ExprNode) {
        super("*", l, r);
    }
    precedence() {
        return 2;
    }
}
class Div extends BinOp {
    constructor(l: ExprNode, r: ExprNode) {
        super("/", l, r);
    }
    precedence() {
        return 2;
    }
}
class Pow extends BinOp {
    constructor(l: ExprNode, r: ExprNode) {
        super("**", l, r);
    }
    precedence() {
        return 3;
    }
}

## Grammar Definition

### Specification of the Parser

We use **Lezer** to define the grammar. The grammar handles operator precedence (multiplication binds tighter than addition) and associativity (left-associative for `+`, `*`, right-associative for `**`).

In [None]:
const grammarDefinition = `
@top Expression { expr }
expr { Product ((Plus | Minus) Product)* }
Product { Factor ((Mul | Div) Factor)* }
Factor { PowerExpr { Base (Power Factor)? } }
Base {
  ExpFunc { Exp "(" expr ")" } |
  LnFunc { Ln "(" expr ")" } |
  ParenExpr { "(" expr ")" } |
  Number |
  VarX
}
@tokens {
  Number { "0" | $[1-9] $[0-9]* }
  VarX { "x" }
  Exp { "exp" }
  Ln { "ln" }
  Power { "**" }
  Plus { "+" }
  Minus { "-" }
  Mul { "*" }
  Div { "/" }
  space { $[ \t\n\r]+ }
  "(" ")"
}
@skip { space }
`;

const parser: LRParser = buildParser(grammarDefinition);

## Transformation to AST

The parser generates a Concrete Syntax Tree (CST). We traverse this tree using a cursor to build our immutable `ExprNode` AST. This step converts the raw syntax tokens into our structured tuple objects.

In [None]:
function lezerToAST(cursor: TreeCursor, source: string): ExprNode {
    const name = cursor.name;
    const text = source.slice(cursor.from, cursor.to);

    if (name === "âš ") throw new Error(`Syntax Error: ${text}`);

    switch (name) {
        case "Expression":
            cursor.firstChild();
            const res = lezerToAST(cursor, source);
            cursor.parent();
            return res;
        case "expr":
            return parseLeftAssociative(cursor, source, ["+", "-"]);
        case "Product":
            return parseLeftAssociative(cursor, source, ["*", "/"]);
        case "PowerExpr": {
            cursor.firstChild();
            const base = lezerToAST(cursor, source);
            if (cursor.nextSibling()) {
                cursor.nextSibling();
                const exp = lezerToAST(cursor, source);
                cursor.parent();
                return new Pow(base, exp);
            }
            cursor.parent();
            return base;
        }
        case "ExpFunc": {
            cursor.firstChild();
            cursor.nextSibling();
            cursor.nextSibling();
            const arg = lezerToAST(cursor, source);
            cursor.parent();
            return new ExpFunc(arg);
        }
        case "LnFunc": {
            cursor.firstChild();
            cursor.nextSibling();
            cursor.nextSibling();
            const arg = lezerToAST(cursor, source);
            cursor.parent();
            return new LnFunc(arg);
        }
        case "ParenExpr": {
            cursor.firstChild();
            cursor.nextSibling();
            const inner = lezerToAST(cursor, source);
            cursor.parent();
            return inner;
        }
        case "Number":
            return new Num(parseInt(text));
        case "VarX":
            return new VarX();
        default:
            if (cursor.firstChild()) {
                const child = lezerToAST(cursor, source);
                cursor.parent();
                return child;
            }
            throw new Error(`Unknown Node: ${name}`);
    }
}

function parseLeftAssociative(
    cursor: TreeCursor,
    source: string,
    ops: string[],
): ExprNode {
    const operands: ExprNode[] = [];
    const operators: string[] = [];
    if (cursor.firstChild()) {
        do {
            const t = source.slice(cursor.from, cursor.to);
            if (ops.includes(t)) operators.push(t);
            else operands.push(lezerToAST(cursor, source));
        } while (cursor.nextSibling());
        cursor.parent();
    }
    let result = operands[0];
    for (let i = 0; i < operators.length; i++) {
        const op = operators[i];
        const r = operands[i + 1];
        if (op === "+") result = new Add(result, r);
        else if (op === "-") result = new Sub(result, r);
        else if (op === "*") result = new Mul(result, r);
        else if (op === "/") result = new Div(result, r);
    }
    return result;
}

function parseExpr(s: string): ExprNode {
    return lezerToAST(parser.parse(s).cursor(), s);
}

## Symbolic Differentiation

Now we implement the core function `diff` that takes an expression containing the variable `x` and computes its derivative.

The logic is implemented via pattern matching on the tuple structure. We apply standard differentiation rules:

- **Sum Rule:** $\frac{d}{dx}(u + v) = u' + v'$
- **Product Rule:** $\frac{d}{dx}(u \cdot v) = u'v + uv'$
- **Quotient Rule:** $\frac{d}{dx}\left(\frac{u}{v}\right) = \frac{u'v - uv'}{v^2}$
- **Chain Rule:** Used implicitly for functions like $\exp(u)$ and $\ln(u)$.
- **Power Rule:**
    - For $x^n$ (variable base, constant exponent), we use $n \cdot x^{n-1}$.
    - For general cases $u^v$, we rewrite it as $\exp(v \cdot \ln(u))$ and differentiate using the chain rule.


In [None]:
function diff(e: ExprNode): ExprNode {
    const tuple = e as unknown as ITupleAccess;
    const tag = tuple.get(0) as string;
    switch (tag) {
        case "num":
            return new Num(0);
        case "x":
            return new Num(1);

        case "+": {
            const n = e as Add;
            return new Add(diff(n.left), diff(n.right));
        }
        case "-": {
            const n = e as Sub;
            return new Sub(diff(n.left), diff(n.right));
        }
        case "*": {
            const n = e as Mul;
            return new Add(
                new Mul(diff(n.left), n.right),
                new Mul(n.left, diff(n.right)),
            );
        }
        case "/": {
            const n = e as Div;
            return new Div(
                new Sub(
                    new Mul(diff(n.left), n.right),
                    new Mul(n.left, diff(n.right)),
                ),
                new Mul(n.right, n.right),
            );
        }
        case "ln": {
            const n = e as LnFunc;
            return new Div(diff(n.arg), n.arg);
        }
        case "exp": {
            const n = e as ExpFunc;
            return new Mul(diff(n.arg), n);
        }
        case "**": {
            const n = e as Pow;
            const base = n.left;
            const exp = n.right;
            const baseTag = (base as unknown as ITupleAccess).get(0) as string;
            const expTag = (exp as unknown as ITupleAccess).get(0) as string;
            if (baseTag === "x" && expTag === "num") {
                const val = (exp as Num).value;
                if (val === 0) return new Num(0);
                if (val === 1) return new Num(1);
                return new Mul(new Num(val), new Pow(base, new Num(val - 1)));
            }
            if (baseTag === "num" && expTag === "x") {
                return new Mul(n, new LnFunc(base));
            }
            return diff(new ExpFunc(new Mul(exp, new LnFunc(base))));
        }
        default:
            throw new Error(`Diff not implemented for ${tag}`);
    }
}

## Tests

We verify the implementation with several test cases, including basic polynomials, exponential functions, and compositions that require the chain rule.

In [None]:
function test(s: string) {
    console.log(`In:  ${s}`);
    try {
        const d = diff(parseExpr(s));
        console.log(`Out: ${d}`);
    } catch (e) {
        console.log(`Err: ${e instanceof Error ? e.message : e}`);
    }
    console.log("-".repeat(20));
}

In [None]:
test("x ** x")

In [None]:
test("x * ln(x) / exp(x/x)")