Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 130 additions & 26 deletions src/compute-engine/differential-equation-utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { Expression, IComputeEngine } from './global-types';
import { isFunction, isSymbol, sym } from './boxed-expression/type-guards';
import { rk4 } from './numerics/differential-equations';
import { rk4, rk4System } from './numerics/differential-equations';

export function symbolArg(
engine: IComputeEngine,
Expand Down Expand Up @@ -29,34 +29,65 @@ export function isDerivativeOfDependent(
dependentName: string,
independentName: string
): boolean {
return derivativeOrderOfDependent(expr, dependentName, independentName) === 1;
}

export function derivativeOrderOfDependent(
expr: Expression,
dependentName: string,
independentName: string
): number | undefined {
if (isFunction(expr, 'D')) {
return (
isDependentFunction(expr.op1, dependentName, independentName) &&
isSymbol(expr.op2, independentName)
const variables = expr.ops.slice(1);
if (
variables.length === 0 ||
!variables.every((op) => isSymbol(op, independentName))
)
return undefined;
const innerOrder = derivativeOrderOfDependent(
expr.op1,
dependentName,
independentName
);
if (innerOrder !== undefined) return innerOrder + variables.length;
if (isDependentFunction(expr.op1, dependentName, independentName))
return variables.length;
return undefined;
}

if (isFunction(expr, 'Apply') && isFunction(expr.op1, 'Derivative')) {
return (
isSymbol(expr.op1.op1, dependentName) &&
expr.nops === 2 &&
isSymbol(expr.op2, independentName)
);
if (
!isSymbol(expr.op1.op1, dependentName) ||
expr.nops !== 2 ||
!isSymbol(expr.op2, independentName)
)
return undefined;

const order = expr.op1.op2 === undefined ? 1 : expr.op1.op2.N().re;
return Number.isInteger(order) && order > 0 ? order : undefined;
}

return false;
return undefined;
}

function explicitRhs(
function explicitDerivativeRhs(
equation: Expression,
dependentName: string,
independentName: string
): Expression | undefined {
): { order: number; rhs: Expression } | undefined {
if (!isFunction(equation, 'Equal')) return undefined;
if (isDerivativeOfDependent(equation.op1, dependentName, independentName))
return equation.op2;
if (isDerivativeOfDependent(equation.op2, dependentName, independentName))
return equation.op1;
const lhsOrder = derivativeOrderOfDependent(
equation.op1,
dependentName,
independentName
);
if (lhsOrder !== undefined) return { order: lhsOrder, rhs: equation.op2 };
const rhsOrder = derivativeOrderOfDependent(
equation.op2,
dependentName,
independentName
);
if (rhsOrder !== undefined) return { order: rhsOrder, rhs: equation.op1 };
return undefined;
}

Expand All @@ -77,6 +108,32 @@ function substituteDependentCall(
);
}

function substituteDependentState(
expr: Expression,
dependentName: string,
independentName: string,
stateNames: readonly string[]
): Expression {
if (isDependentFunction(expr, dependentName, independentName))
return expr.engine.symbol(stateNames[0]);

const order = derivativeOrderOfDependent(
expr,
dependentName,
independentName
);
if (order !== undefined && order < stateNames.length)
return expr.engine.symbol(stateNames[order]);

if (!isFunction(expr)) return expr;
return expr.engine._fn(
expr.operator,
expr.ops.map((op) =>
substituteDependentState(op, dependentName, independentName, stateNames)
)
);
}

export function nDSolve(
equation: Expression,
dependent: Expression,
Expand All @@ -92,12 +149,8 @@ export function nDSolve(
const independentName = sym(limits.op1);
if (!independentName) return undefined;

const [x0, x1, y0] = [
limits.op2.N().re,
limits.op3.N().re,
initialValue.N().re,
];
if (![x0, x1, y0].every(Number.isFinite)) return undefined;
const [x0, x1] = [limits.op2.N().re, limits.op3.N().re];
if (![x0, x1].every(Number.isFinite)) return undefined;

const steps = stepsExpr === undefined ? 100 : stepsExpr.N().re;
if (
Expand All @@ -108,12 +161,63 @@ export function nDSolve(
)
return undefined;

const rhs = explicitRhs(equation.structural, dependentName, independentName);
if (!rhs) return undefined;
const rhsInfo = explicitDerivativeRhs(
equation.structural,
dependentName,
independentName
);
if (!rhsInfo) return undefined;

const initialValues = isFunction(initialValue, 'List')
? initialValue.ops.map((op) => op.N().re)
: [initialValue.N().re];
if (
rhsInfo.order !== initialValues.length ||
!initialValues.every(Number.isFinite)
)
return undefined;

if (rhsInfo.order > 1) {
const stateNames = Array.from(
{ length: rhsInfo.order },
(_, i) => `ndsolve${dependentName}state${i}`
);
const compiledRhs = substituteDependentState(
rhsInfo.rhs,
dependentName,
independentName,
stateNames
);
const compiled = ce._compile(compiledRhs, { realOnly: true });
if (!compiled.success) return undefined;
const run = compiled.run as (vars: Record<string, number>) => number;

const samples = rk4System(
(x, y) => {
const vars: Record<string, number> = { [independentName]: x };
stateNames.forEach((name, i) => {
vars[name] = y[i];
});
const highest = run(vars);
if (!Number.isFinite(highest)) return undefined;
return [...y.slice(1), highest];
},
x0,
initialValues,
x1,
{ steps, deadline: ce._deadline }
);
if (!samples) return undefined;

return ce._fn(
'List',
samples.map(([x, y]) => ce._fn('List', [ce.number(x), ce.number(y[0])]))
);
}

const stateName = `ndsolve${dependentName}state`;
const compiledRhs = substituteDependentCall(
rhs,
rhsInfo.rhs,
dependentName,
independentName,
stateName
Expand All @@ -125,7 +229,7 @@ export function nDSolve(
const samples = rk4(
(x, y) => run({ [independentName]: x, [stateName]: y }),
x0,
y0,
initialValues[0],
x1,
{ steps, deadline: ce._deadline }
);
Expand Down
69 changes: 69 additions & 0 deletions src/compute-engine/numerics/differential-equations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export type RK4Options = {
};

export type ODESample = readonly [x: number, y: number];
export type ODEVectorSample = readonly [x: number, y: readonly number[]];

/**
* Fixed-step classical fourth-order Runge-Kutta solver for scalar explicit
Expand Down Expand Up @@ -50,3 +51,71 @@ export function rk4(

return samples;
}

/**
* Fixed-step classical fourth-order Runge-Kutta solver for first-order systems:
* y' = f(x, y), y(x0) = y0.
*/
export function rk4System(
f: (x: number, y: readonly number[]) => readonly number[] | undefined,
x0: number,
y0: readonly number[],
x1: number,
options: RK4Options
): ODEVectorSample[] | undefined {
const steps = Math.trunc(options.steps);
if (
!Number.isFinite(x0) ||
!Number.isFinite(x1) ||
!Number.isInteger(steps) ||
steps <= 0 ||
y0.length === 0 ||
!y0.every(Number.isFinite)
)
return undefined;

const addScaled = (
y: readonly number[],
dy: readonly number[],
scale: number
): number[] => y.map((yi, i) => yi + scale * dy[i]);

const combine = (
y: readonly number[],
k1: readonly number[],
k2: readonly number[],
k3: readonly number[],
k4: readonly number[],
h: number
): number[] =>
y.map((yi, i) => yi + (h / 6) * (k1[i] + 2 * k2[i] + 2 * k3[i] + k4[i]));

const h = (x1 - x0) / steps;
const samples: ODEVectorSample[] = [[x0, [...y0]]];
let x = x0;
let y = [...y0];

for (let i = 0; i < steps; i++) {
if ((i & 0xff) === 0) checkDeadline(options.deadline);

const k1 = f(x, y);
if (!k1 || k1.length !== y.length || !k1.every(Number.isFinite))
return undefined;
const k2 = f(x + h / 2, addScaled(y, k1, h / 2));
if (!k2 || k2.length !== y.length || !k2.every(Number.isFinite))
return undefined;
const k3 = f(x + h / 2, addScaled(y, k2, h / 2));
if (!k3 || k3.length !== y.length || !k3.every(Number.isFinite))
return undefined;
const k4 = f(x + h, addScaled(y, k3, h));
if (!k4 || k4.length !== y.length || !k4.every(Number.isFinite))
return undefined;

y = combine(y, k1, k2, k3, k4, h);
x = i === steps - 1 ? x1 : x + h;
if (!y.every(Number.isFinite)) return undefined;
samples.push([x, [...y]]);
}

return samples;
}
Loading