diff --git a/README.md b/README.md
index 8929228..c0484b5 100644
--- a/README.md
+++ b/README.md
@@ -196,6 +196,23 @@ console.log(result.nodes); // Prints the array of nodes in the shortest path
console.log(result.weight); // Prints the total weight of the path
```
+# shortestPath(graph, sourceNode, destinationNode, nextWeightFn)
+
+Calculates the weight based on the custom function.
+
+```javascript
+import type { NextWeightFnParams } from '../../types.js';
+function multiplyWeightFunction(wp: NextWeightFnParams): number {
+ if (wp.currentPathWeight === undefined) {
+ return wp.edgeWeight;
+ }
+ return wp.edgeWeight * wp.currentPathWeight;
+}
+var result = shortestPath(graph, 'a', 'c', multiplyWeightFunction);
+console.log(result.nodes); // Prints the array of nodes in the shortest path
+console.log(result.weight); // Prints the total weight of the path
+```
+
diff --git a/src/algorithms/shortestPath/getPath.ts b/src/algorithms/shortestPath/getPath.ts
index 8ce82e9..95d6ef5 100644
--- a/src/algorithms/shortestPath/getPath.ts
+++ b/src/algorithms/shortestPath/getPath.ts
@@ -1,8 +1,19 @@
import type { EdgeWeight, NoInfer } from '../../types.js';
import type { TraversingTracks } from './types.js';
+import type { NextWeightFnParams } from '../../types.js';
import { Graph } from '../../Graph.js';
+/**
+ * Computes edge weight as the sum of all the edges in the path.
+ */
+export function addWeightFunction( wp: NextWeightFnParams): number {
+ if (wp.currentPathWeight === undefined) {
+ return wp.edgeWeight;
+ }
+ return wp.edgeWeight + wp.currentPathWeight;
+}
+
/**
* Assembles the shortest path by traversing the
* predecessor subgraph from destination to source.
@@ -12,22 +23,30 @@ export function getPath(
tracks: TraversingTracks>,
source: NoInfer,
destination: NoInfer,
+ nextWeightFn: (params: NextWeightFnParams) => number = addWeightFunction
): {
nodes: [Node, Node, ...Node[]];
- weight: number;
+ weight: number | undefined;
} {
const { p } = tracks;
const nodeList: Node[] & { weight?: EdgeWeight } = [];
- let totalWeight = 0;
+ let totalWeight : EdgeWeight | undefined = undefined;
let node = destination;
+ let hop = 1;
while (p.has(node)) {
const currentNode = p.get(node)!;
nodeList.push(node);
- totalWeight += graph.getEdgeWeight(currentNode, node);
+ const edgeWeight = graph.getEdgeWeight(currentNode, node)
+ totalWeight = nextWeightFn({
+ edgeWeight, currentPathWeight: totalWeight,
+ hop: hop, graph: graph, path: tracks,
+ previousNode: node, currentNode: currentNode
+ });
node = currentNode;
+ hop++;
}
if (node !== source) {
diff --git a/src/algorithms/shortestPath/shortestPath.spec.ts b/src/algorithms/shortestPath/shortestPath.spec.ts
index c67492b..134eca9 100644
--- a/src/algorithms/shortestPath/shortestPath.spec.ts
+++ b/src/algorithms/shortestPath/shortestPath.spec.ts
@@ -1,8 +1,10 @@
-import { describe, expect, it } from 'vitest';
+import { describe, expect, it, vi } from 'vitest';
import { Graph } from '../../Graph.js';
import { serializeGraph } from '../../utils/serializeGraph.js';
import { shortestPath } from './shortestPath.js';
import { shortestPaths } from './shortestPaths.js';
+import { addWeightFunction } from './getPath.js';
+import { NextWeightFnParams } from '../../types.js';
describe("Dijkstra's Shortest Path Algorithm", function () {
it('Should compute shortest path on a single edge.', function () {
@@ -104,3 +106,98 @@ describe("Dijkstra's Shortest Path Algorithm", function () {
expect(postSerializedGraph.links).toContainEqual({ source: 'f', target: 'c' });
});
});
+
+describe('addWeightFunction', () => {
+ it('should return edgeWeight if currentPathWeight is undefined', () => {
+ const graph = new Graph();
+ const params = {
+ edgeWeight: 5, currentPathWeight: undefined, hop: 1,
+ graph: graph, path: { d: new Map(), p: new Map(), q: new Set() },
+ previousNode: 'a', currentNode: 'b'
+ };
+ expect(addWeightFunction(params)).toBe(5);
+ });
+
+ it('should return the sum of edgeWeight and currentPathWeight', () => {
+ const graph = new Graph()
+ const params = { edgeWeight: 5, currentPathWeight: 10, hop: 1,
+ graph: graph, path: { d: new Map(), p: new Map(), q: new Set() },
+ previousNode: 'a', currentNode: 'b'
+ };
+ expect(addWeightFunction(params)).toBe(15);
+ });
+});
+
+describe('shortestPath with custom weight functions', () => {
+ it('should compute shortest path with default weight function (sum of weights)', () => {
+ const graph = new Graph().addEdge('a', 'b', 1).addEdge('b', 'c', 2);
+ expect(shortestPath(graph, 'a', 'c')).toEqual({
+ nodes: ['a', 'b', 'c'],
+ weight: 3,
+ });
+ });
+
+ it('should compute shortest path with a custom weight function', () => {
+ const customWeightFn = ({ edgeWeight, currentPathWeight, hop }: NextWeightFnParams) => {
+ if (currentPathWeight === undefined) {
+ return edgeWeight;
+ }
+ return currentPathWeight + edgeWeight ** hop;
+ };
+
+ const graph = new Graph().addEdge('a', 'b', 2).addEdge('b', 'c', 3);
+ expect(shortestPath(graph, 'a', 'c', customWeightFn)).toEqual({
+ nodes: ['a', 'b', 'c'],
+ weight: 7,
+ });
+ });
+
+ it('should pass correct parameters to custom weight function for a path with 3 nodes', () => {
+ const customWeightFn = vi.fn(({ edgeWeight, currentPathWeight, hop }: NextWeightFnParams) => {
+ if (currentPathWeight === undefined) {
+ return edgeWeight;
+ }
+ return currentPathWeight + edgeWeight ** hop;
+ });
+
+ const graph = new Graph().addEdge('a', 'b', 1).addEdge('b', 'c', 2);
+ shortestPath(graph, 'a', 'c', customWeightFn);
+
+ expect(customWeightFn).toHaveBeenCalledWith({ edgeWeight: 2, currentPathWeight: undefined, hop: 1,
+ graph: graph, currentNode: 'b', previousNode: 'c',
+ path: {
+ d: new Map([['a', 0], ['b', 1], ['c', 3]]),
+ p: new Map([['b', 'a'], ['c', 'b']]),
+ q: new Set(),
+ },
+ });
+ expect(customWeightFn).toHaveBeenCalledWith({ edgeWeight: 1, currentPathWeight: 2, hop: 2,
+ graph: graph, currentNode: 'a', previousNode: 'b',
+ path: {
+ d: new Map([['a', 0], ['b', 1], ['c', 3]]),
+ p: new Map([['b', 'a'], ['c', 'b']]),
+ q: new Set(),
+ }
+ });
+ });
+
+ it('should compute shortest path with a custom weight function in a graph with multiple paths', () => {
+ const customWeightFn = ({ edgeWeight, currentPathWeight }: NextWeightFnParams) => {
+ if (currentPathWeight === undefined) {
+ return edgeWeight;
+ }
+ return edgeWeight + currentPathWeight;
+ };
+
+ const graph = new Graph()
+ .addEdge('a', 'b', 1)
+ .addEdge('b', 'c', 2)
+ .addEdge('a', 'd', 1)
+ .addEdge('d', 'c', 1);
+
+ expect(shortestPath(graph, 'a', 'c', customWeightFn)).toEqual({
+ nodes: ['a', 'd', 'c'],
+ weight: 2,
+ });
+ });
+});
diff --git a/src/algorithms/shortestPath/shortestPath.ts b/src/algorithms/shortestPath/shortestPath.ts
index 3289dab..ba255ac 100644
--- a/src/algorithms/shortestPath/shortestPath.ts
+++ b/src/algorithms/shortestPath/shortestPath.ts
@@ -1,8 +1,9 @@
import { Graph } from '../../Graph.js';
import { NoInfer } from '../../types.js';
import { dijkstra } from './dijkstra.js';
-import { getPath } from './getPath.js';
+import { getPath, addWeightFunction } from './getPath.js';
import { TraversingTracks } from './types.js';
+import type { NextWeightFnParams } from '../../types.js';
/**
* Dijkstra's Shortest Path Algorithm.
@@ -13,9 +14,10 @@ export function shortestPath(
graph: Graph,
source: NoInfer,
destination: NoInfer,
+ nextWeightFn: (params: NextWeightFnParams) => number = addWeightFunction
): {
nodes: [Node, Node, ...Node[]];
- weight: number;
+ weight: number | undefined;
} {
const tracks: TraversingTracks = {
d: new Map(),
@@ -25,5 +27,5 @@ export function shortestPath(
dijkstra(graph, tracks, source, destination);
- return getPath(graph, tracks, source, destination);
+ return getPath(graph, tracks, source, destination, nextWeightFn);
}
diff --git a/src/algorithms/shortestPath/shortestPaths.ts b/src/algorithms/shortestPath/shortestPaths.ts
index 928b382..57a1281 100644
--- a/src/algorithms/shortestPath/shortestPaths.ts
+++ b/src/algorithms/shortestPath/shortestPaths.ts
@@ -41,7 +41,7 @@ export function shortestPaths(
try {
path = shortestPath(graph, source, destination);
- if (!path.weight || pathWeight < path.weight) break;
+ if (!path.weight || !pathWeight || pathWeight < path.weight) break;
paths.push(path);
} catch (e) {
break;
diff --git a/src/index.ts b/src/index.ts
index 817e98e..a2419ab 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -1,4 +1,4 @@
-export type { Edge, Serialized, SerializedInput, EdgeWeight } from './types.js';
+export type { Edge, Serialized, SerializedInput, EdgeWeight, NextWeightFnParams } from './types.js';
export { Graph } from './Graph.js';
export { CycleError } from './CycleError.js';
diff --git a/src/types.ts b/src/types.ts
index bbce017..37b3c54 100644
--- a/src/types.ts
+++ b/src/types.ts
@@ -1,3 +1,6 @@
+import { TraversingTracks } from './algorithms/shortestPath/types.js';
+import { Graph } from './Graph.js';
+
export type EdgeWeight = number;
export type Edge = {
@@ -18,3 +21,13 @@ export type SerializedInput = {
};
export type NoInfer = [T][T extends any ? 0 : never];
+
+export type NextWeightFnParams = {
+ edgeWeight: EdgeWeight;
+ currentPathWeight: EdgeWeight | undefined;
+ hop: number;
+ graph: Graph;
+ path: TraversingTracks>;
+ previousNode: NoInfer;
+ currentNode: NoInfer;
+};