In [None]:
import { display } from "tslab";
import { readFileSync } from "fs";

const css = readFileSync("../style.css", "utf8");
display.html(`<style>${css}</style>`);

## Imports and Setup

In [None]:
import {
  createToken,
  Lexer,
  CstParser,
  IToken,
  ILexingResult,
  TokenType,
  CstNode,
} from "chevrotain";

# Symbolic Differentiation

In this notebook the goal is to implement *symbolic differentiation* in TypeScript.  
Concretely, we implement a function `diffString` that takes one argument:

- The argument `expr` is an *arithmetic expression* in the variable `x`.

An arithmetic expression is any string built from variables and numbers by using the operators `+`, `-`, `*`, `/`, and `**`.  
The operator `**` represents exponentiation, i.e. an expression of the form $a \texttt{**} b$ is interpreted as $a^b$.  
Furthermore, if \(e\) is an expression, then both $\exp(e)$ and $\ln(e)$ are expressions as well.

The function call `diffString(expr)` computes the derivative of `expr` with respect to the variable `x`.  
For example, the call

`diffString("x * exp(x)")`

yields the output

`1*exp(x) + x*exp(x)`

because

$$ \frac{\mathrm{d}\;}{\mathrm{d}x} \bigl( x \cdot \mathrm{e}^x \bigr) = 1 \cdot \mathrm{e}^x + x \cdot \mathrm{e}^x. $$

This notebook shows the implementation of a program that performs *symbolic differentiation* using Chevrotain as the parser.  
The grammar for the language implemented by this parser is:

\[
\begin{array}{lcl}
  \mathrm{expr}    & \rightarrow & \;\mathrm{expr}\; \texttt{'+'} \; \mathrm{product}  \\
                   & \mid        & \;\mathrm{expr}\; \texttt{'-'} \; \mathrm{product}  \\
                   & \mid        & \;\mathrm{product}                                  \\[0.2cm]
  \mathrm{product} & \rightarrow & \;\mathrm{product}\; \texttt{'*'} \;\mathrm{factor} \\
                   & \mid        & \;\mathrm{product}\; \texttt{'/'} \;\mathrm{factor} \\
                   & \mid        & \;\mathrm{factor}                                   \\[0.2cm]
  \mathrm{factor}  & \rightarrow & \mathrm{base} \;\texttt{'**'} \; \mathrm{factor}    \\
                   & \mid        & \mathrm{base}                                       \\[0.2cm]
  \mathrm{base}    & \rightarrow & \texttt{exp} \; \texttt{'('} \; \mathrm{expr} \;\texttt{')'}     \\
                   & \mid        & \texttt{ln} \; \texttt{'('} \; \mathrm{expr} \;\texttt{')'}      \\
                   & \mid        & \texttt{'('} \; \mathrm{expr} \;\texttt{')'}                     \\
                   & \mid        & \;\texttt{NUMBER}                                                \\
                   & \mid        & \;\texttt{'x'}
\end{array}
\]

## Specification of the Scanner

We define the tokens for numbers, functions (`exp`, `ln`), power `**`, arithmetic operators, parentheses, variable `x`, and whitespace.
Whitespace is skipped by the lexer and does not appear in the token stream.

The token `NUMBER` specifies a *natural number*.

In [None]:
const NumberTok: TokenType = createToken({
  name: "NUMBER",
  pattern: /0|[1-9][0-9]*/,
});

In [None]:
const ExpTok: TokenType = createToken({ name: "EXP", pattern: /exp/ });

In [None]:
const LnTok: TokenType = createToken({ name: "LN", pattern: /ln/ });

Below, we need to escape the meta charater `*`.

In [None]:
const PowerTok: TokenType = createToken({ name: "POWER", pattern: /\*\*/ });

In [None]:
const Plus: TokenType = createToken({ name: "Plus", pattern: /\+/ });
const Minus: TokenType = createToken({ name: "Minus", pattern: /-/ });
const Mul: TokenType = createToken({ name: "Mul", pattern: /\*/ });
const Div: TokenType = createToken({ name: "Div", pattern: /\// });
const LParen: TokenType = createToken({ name: "LParen", pattern: /\(/ });
const RParen: TokenType = createToken({ name: "RParen", pattern: /\)/ });
const XTok: TokenType = createToken({ name: "X", pattern: /x/ });

Blanks and tabulators are ignored.

In [None]:
const WhiteSpace: TokenType = createToken({
  name: "WhiteSpace",
  pattern: /[ \t\n\r]+/,
  group: Lexer.SKIPPED,
});

In [None]:
const allTokens: TokenType[] = [
  WhiteSpace,
  NumberTok,
  ExpTok,
  LnTok,
  PowerTok,
  Plus,
  Minus,
  Mul,
  Div,
  LParen,
  RParen,
  XTok,
];

We generate the lexer.

In [None]:
const DifferLexer: Lexer = new Lexer(allTokens);

For debugging, we provide a helper `tokenizeDiff` that shows the token images.

In [None]:
function tokenizeDiff(input: string): string[] {
  const lexRes: ILexingResult = DifferLexer.tokenize(input);
  if (lexRes.errors.length > 0) {
    throw new Error(`Lexing error: ${lexRes.errors[0].message}`);
  }
  return lexRes.tokens.map((t: IToken): string => t.image);
}

Example:

In [None]:
console.log(tokenizeDiff("ln(x ** x) + exp(x * x)"));

## Specification of the Parser

We translate the grammar into Chevrotain rules.  
The start rule is `expr`.

- numbers: constants,
- `"x"`: the variable,
- unary operations: `["ln", e]`, `["exp", e]`,
- binary operations: `["+", f, g]`, `["*", f, g]`, `["**", f, g]`, etc.

In [None]:
type BinaryOp = "+" | "-" | "*" | "/" | "**";
type UnaryOp = "ln" | "exp";

type ExprNode =
  | number
  | "x"
  | [UnaryOp, ExprNode]
  | [BinaryOp, ExprNode, ExprNode];

class DifferParser extends CstParser {
  public expr!: (idx?: number) => CstNode;
  public product!: (idx?: number) => CstNode;
  public factor!: (idx?: number) => CstNode;
  public base!: (idx?: number) => CstNode;

  constructor() {
    super(allTokens, { maxLookahead: 2 });
    const $ = this;

    // expr : product (('+' | '-') product)*
    $.RULE("expr", () => {
      $.SUBRULE($.product);
      $.MANY(() => {
        $.OR([
          {
            ALT: () => {
              $.CONSUME(Plus);
              $.SUBRULE2($.product);
            },
          },
          {
            ALT: () => {
              $.CONSUME(Minus);
              $.SUBRULE3($.product);
            },
          },
        ]);
      });
    });

    // product : factor (('*' | '/') factor)*
    $.RULE("product", () => {
      $.SUBRULE($.factor);
      $.MANY(() => {
        $.OR([
          {
            ALT: () => {
              $.CONSUME(Mul);
              $.SUBRULE2($.factor);
            },
          },
          {
            ALT: () => {
              $.CONSUME(Div);
              $.SUBRULE3($.factor);
            },
          },
        ]);
      });
    });

    // factor : base '**' factor | base
    $.RULE("factor", () => {
      $.SUBRULE($.base);
      $.OPTION(() => {
        $.CONSUME(PowerTok);
        $.SUBRULE2($.factor);
      });
    });

    // base : exp '(' expr ')' | ln '(' expr ')' | '(' expr ')' | NUMBER | 'x'
    $.RULE("base", () => {
      $.OR([
        {
          ALT: () => {
            $.CONSUME(ExpTok);
            $.CONSUME(LParen);
            $.SUBRULE($.expr);
            $.CONSUME(RParen);
          },
        },
        {
          ALT: () => {
            $.CONSUME(LnTok);
            $.CONSUME2(LParen);
            $.SUBRULE2($.expr);
            $.CONSUME2(RParen);
          },
        },
        {
          ALT: () => {
            $.CONSUME3(LParen);
            $.SUBRULE3($.expr);
            $.CONSUME3(RParen);
          },
        },
        { ALT: () => $.CONSUME(NumberTok) },
        { ALT: () => $.CONSUME(XTok) },
      ]);
    });

    this.performSelfAnalysis();
  }
}

const parser: DifferParser = new DifferParser();
const BaseCstVisitor = parser.getBaseCstVisitorConstructor();

## CST to AST: The Visitor

We now implement the visitor that transforms the CST into our `ExprNode` representation.


In [None]:
class ToASTVisitor extends BaseCstVisitor {
  constructor() {
    super();
    (this as unknown as { validateVisitor(): void }).validateVisitor();
  }

  // expr : product (('+' | '-') product)*
  public expr(ctx: {
    product: CstNode[];
    Plus?: IToken[];
    Minus?: IToken[];
  }): ExprNode {
    let node: ExprNode = this.visit(ctx.product[0]) as ExprNode;

    if (ctx.product.length > 1) {
      // FIX: Sort additive operators by position to preserve order
      const plusTokens = ctx.Plus || [];
      const minusTokens = ctx.Minus || [];
      const allOps = [...plusTokens, ...minusTokens].sort(
        (a, b) => a.startOffset - b.startOffset
      );

      for (let i = 1; i < ctx.product.length; i++) {
        const rightNode = this.visit(ctx.product[i]) as ExprNode;
        const operator = allOps[i - 1];

        if (operator.tokenType.name === "Plus") {
          node = ["+", node, rightNode];
        } else {
          node = ["-", node, rightNode];
        }
      }
    }

    return node;
  }

  // product : factor (('*' | '/') factor)*
  public product(ctx: {
    factor: CstNode[];
    Mul?: IToken[];
    Div?: IToken[];
  }): ExprNode {
    let node: ExprNode = this.visit(ctx.factor[0]) as ExprNode;

    if (ctx.factor.length > 1) {
      // FIX: Sort multiplicative operators by position
      const mulTokens = ctx.Mul || [];
      const divTokens = ctx.Div || [];
      const allOps = [...mulTokens, ...divTokens].sort(
        (a, b) => a.startOffset - b.startOffset
      );

      for (let i = 1; i < ctx.factor.length; i++) {
        const rightNode = this.visit(ctx.factor[i]) as ExprNode;
        const operator = allOps[i - 1];

        if (operator.tokenType.name === "Mul") {
          node = ["*", node, rightNode];
        } else {
          node = ["/", node, rightNode];
        }
      }
    }

    return node;
  }

  // factor : base '**' factor | base
  public factor(ctx: {
    base: CstNode[];
    POWER?: IToken[];
    factor?: CstNode[];
  }): ExprNode {
    const baseNode: ExprNode = this.visit(ctx.base[0]) as ExprNode;
    // Right-associative logic is handled correctly by recursion here
    if (ctx.POWER && ctx.factor) {
      const right: ExprNode = this.visit(ctx.factor[0]) as ExprNode;
      return ["**", baseNode, right];
    }
    return baseNode;
  }

  // base : exp '(' expr ')' | ln '(' expr ')' | '(' expr ')' | NUMBER | 'x'
  public base(ctx: {
    EXP?: IToken[];
    LN?: IToken[];
    expr?: CstNode[];
    LParen?: IToken[];
    NUMBER?: IToken[];
    X?: IToken[];
  }): ExprNode {
    if (ctx.EXP) {
      return ["exp", this.visit(ctx.expr![0]) as ExprNode];
    }
    if (ctx.LN) {
      return ["ln", this.visit(ctx.expr![0]) as ExprNode];
    }
    // Parentheses case: needs expr AND LParen
    if (ctx.expr && ctx.LParen) {
      return this.visit(ctx.expr[0]) as ExprNode;
    }
    if (ctx.NUMBER) {
      return parseInt(ctx.NUMBER[0].image, 10);
    }
    if (ctx.X) {
      return "x";
    }
    throw new Error("Unexpected base rule");
  }
}

const toAST: ToASTVisitor = new ToASTVisitor();

## Parsing Function

We combine lexer, parser, and visitor into a single `parseExpr` function.

In [None]:
function parseExpr(s: string): ExprNode {
  const lexRes: ILexingResult = DifferLexer.tokenize(s);
  if (lexRes.errors.length > 0) {
    throw new Error(`Lexing error: ${lexRes.errors[0].message}`);
  }

  parser.input = lexRes.tokens;
  const cst: CstNode = parser.expr();

  if (parser.errors.length > 0) {
    throw new Error(`Parse error: ${parser.errors[0].message}`);
  }

  return toAST.visit(cst) as ExprNode;
}

Example AST:

In [None]:
console.log(parseExpr("ln(x ** x) + exp(x * x)"));

## The `diff` Function

We now implement symbolic differentiation on the AST.

In [None]:
function diff(e: ExprNode): ExprNode {
  if (typeof e === "number") return 0;
  if (e === "x") return 1;
  if (typeof e === "string") return 0;

  const [op, f, g] = e as [string, ExprNode, ExprNode];

  switch (op) {
    case "+":
      return ["+", diff(f), diff(g)];
    case "-":
      return ["-", diff(f), diff(g)];
    case "*":
      return ["+", ["*", diff(f), g], ["*", f, diff(g)]];
    case "/":
      return [
        "/",
        ["-", ["*", diff(f), g], ["*", f, diff(g)]],
        ["*", g, g],
      ];
    case "**":
      // f(x)**g(x) = exp(g(x)*ln(f(x))) â†’ differentiate that
      return diff(["exp", ["*", g, ["ln", f]]]);
    case "ln":
      return ["/", diff(f), f];
    case "exp":
      return ["*", diff(f), e];
  }
  return 0;
}

## Pretty Printing

We turn an `ExprNode` back into a readable string while respecting operator precedence.

In [None]:
function precedence(op: string): number {
  const Precedences: Record<string, number> = {
    "+": 1,
    "-": 1,
    "*": 2,
    "/": 2,
    "**": 3,
  };
  return Precedences[op] ?? 4;
}

function precedenceOp(expr: ExprNode): number {
  if (Array.isArray(expr)) return precedence(expr[0]);
  return 4;
}

function toStringExpr(e: ExprNode): string {
  if (typeof e === "number" || typeof e === "string") {
    return String(e);
  }

  if (Array.isArray(e) && e.length === 2) {
    return e[0] + "(" + toStringExpr(e[1]) + ")";
  }

  if (e[0] === "+") {
    return toStringExpr(e[1]) + " + " + toStringExpr(e[2]);
  }

  if (e[0] === "-") {
    const lhs: string = toStringExpr(e[1]);
    const rhs: string =
      precedenceOp(e[2]) === 1
        ? "(" + toStringExpr(e[2]) + ")"
        : toStringExpr(e[2]);
    return lhs + " - " + rhs;
  }

  if (e[0] === "*") {
    const lhs: string =
      precedenceOp(e[1]) === 1
        ? "(" + toStringExpr(e[1]) + ")"
        : toStringExpr(e[1]);
    const rhs: string =
      precedenceOp(e[2]) === 1
        ? "(" + toStringExpr(e[2]) + ")"
        : toStringExpr(e[2]);
    return lhs + "*" + rhs;
  }

  if (e[0] === "/") {
    const lhs: string =
      precedenceOp(e[1]) === 1
        ? "(" + toStringExpr(e[1]) + ")"
        : toStringExpr(e[1]);
    const rhs: string =
      precedenceOp(e[2]) <= 2
        ? "(" + toStringExpr(e[2]) + ")"
        : toStringExpr(e[2]);
    return lhs + "/" + rhs;
  }

  if (e[0] === "**") {
    const lhs: string =
      precedenceOp(e[1]) <= 3
        ? "(" + toStringExpr(e[1]) + ")"
        : toStringExpr(e[1]);
    const rhs: string =
      precedenceOp(e[2]) <= 2
        ? "(" + toStringExpr(e[2]) + ")"
        : toStringExpr(e[2]);
    return lhs + "**" + rhs;
  }

  return "?";
}

## Top-Level Interface and Tests

We expose a function `diffString` that works directly with strings, and a `test` helper.

In [None]:
function diffString(s: string): string {
  const t: ExprNode = parseExpr(s);
  const d: ExprNode = diff(t);
  return toStringExpr(d);
}

function test(s: string): void {
  console.log(`d/dx ${s} = ${diffString(s)}`);
}

Example tests:

In [None]:
test("x ** x");

In [None]:
test("x * ln(x) / exp(x/x)");