diff --git a/packages/traverse/__test__/traverse.test.ts b/packages/traverse/__test__/traverse.test.ts index a5fb5025..1768c61e 100644 --- a/packages/traverse/__test__/traverse.test.ts +++ b/packages/traverse/__test__/traverse.test.ts @@ -1,5 +1,5 @@ -import { visit, walk, NodePath } from '../src'; import type { Visitor, Walker } from '../src'; +import { NodePath,visit, walk } from '../src'; describe('traverse', () => { it('should visit SelectStmt nodes with new walk API', () => { @@ -122,13 +122,13 @@ describe('traverse', () => { const visitedNodes: string[] = []; const visitor: Visitor = { - A_Expr: (path: NodePath) => { + A_Expr: (_path: NodePath) => { visitedNodes.push('A_Expr'); }, - ColumnRef: (path: NodePath) => { + ColumnRef: (_path: NodePath) => { visitedNodes.push('ColumnRef'); }, - A_Const: (path: NodePath) => { + A_Const: (_path: NodePath) => { visitedNodes.push('A_Const'); } }; @@ -390,4 +390,100 @@ describe('traverse', () => { expect(targetListVisited[0].ctx.path).toEqual(['targetList', 0]); expect(targetListVisited[1].ctx.path).toEqual(['targetList', 1]); }); + + it('should traverse WithClause nodes', () => { + const visitedNodes: string[] = []; + + const walker: Walker = (path: NodePath) => { + visitedNodes.push(path.tag); + }; + + const ast = { + SelectStmt: { + withClause: { + WithClause: { + ctes: [ + { + CommonTableExpr: { + ctename: 'cte1', + ctequery: { + SelectStmt: { + targetList: [] as any[], + limitOption: 'LIMIT_OPTION_DEFAULT', + op: 'SETOP_NONE' + } + } + } + } + ] + } + }, + targetList: [] as any[], + limitOption: 'LIMIT_OPTION_DEFAULT', + op: 'SETOP_NONE' + } + }; + + walk(ast, walker); + + expect(visitedNodes).toContain('SelectStmt'); + expect(visitedNodes).toContain('WithClause'); + expect(visitedNodes).toContain('CommonTableExpr'); + expect(visitedNodes.filter(n => n === 'SelectStmt')).toHaveLength(2); + }); + + it('should traverse union larg and rarg nodes', () => { + const visitedNodes: string[] = []; + + const walker: Walker = (path: NodePath) => { + visitedNodes.push(path.tag); + }; + + const ast = { + SelectStmt: { + larg: { + SelectStmt: { + targetList: [ + { + ResTarget: { + val: { + A_Const: { + ival: { Integer: { ival: 1 } } + } + } + } + } + ], + limitOption: 'LIMIT_OPTION_DEFAULT', + op: 'SETOP_NONE' + } + }, + op: 'SETOP_UNION', + rarg: { + SelectStmt: { + targetList: [ + { + ResTarget: { + val: { + A_Const: { + ival: { Integer: { ival: 2 } } + } + } + } + } + ], + limitOption: 'LIMIT_OPTION_DEFAULT', + op: 'SETOP_NONE' + } + } + } + }; + + walk(ast, walker); + + expect(visitedNodes).toContain('SelectStmt'); + expect(visitedNodes.filter(n => n === 'SelectStmt')).toHaveLength(3); + expect(visitedNodes).toContain('ResTarget'); + expect(visitedNodes).toContain('A_Const'); + }); }); diff --git a/packages/traverse/scripts/pg-proto-parser.ts b/packages/traverse/scripts/pg-proto-parser.ts index a98d1386..6b87297c 100644 --- a/packages/traverse/scripts/pg-proto-parser.ts +++ b/packages/traverse/scripts/pg-proto-parser.ts @@ -1,5 +1,5 @@ +import { join,resolve } from 'path'; import { PgProtoParser, PgProtoParserOptions } from 'pg-proto-parser'; -import { resolve, join } from 'path'; const versions = ['17']; const baseDir = resolve(join(__dirname, '../../../__fixtures__/proto')); diff --git a/packages/traverse/src/index.ts b/packages/traverse/src/index.ts index 127dc14a..f2ef5db1 100644 --- a/packages/traverse/src/index.ts +++ b/packages/traverse/src/index.ts @@ -1,2 +1,2 @@ -export { walk, visit, NodePath } from './traverse'; -export type { Visitor, VisitorContext, Walker, NodeTag } from './traverse'; +export type { NodeTag,Visitor, VisitorContext, Walker } from './traverse'; +export { NodePath,visit, walk } from './traverse'; diff --git a/packages/traverse/src/traverse.ts b/packages/traverse/src/traverse.ts index 9f6e7db4..f95ee04d 100644 --- a/packages/traverse/src/traverse.ts +++ b/packages/traverse/src/traverse.ts @@ -1,5 +1,6 @@ import type { Node } from '@pgsql/types'; + import type { NodeSpec } from './17/runtime-schema'; import { runtimeSchema } from './17/runtime-schema'; @@ -47,10 +48,10 @@ export function walk( const actualCallback: Walker = typeof callback === 'function' ? callback : (path: NodePath) => { - const visitor = callback as Visitor; - const visitFn = visitor[path.tag]; - return visitFn ? visitFn(path) : undefined; - }; + const visitor = callback as Visitor; + const visitFn = visitor[path.tag]; + return visitFn ? visitFn(path) : undefined; + }; if (Array.isArray(root)) { root.forEach((node, index) => { @@ -70,7 +71,9 @@ export function walk( const nodeSpec = schemaMap.get(tag); if (nodeSpec) { for (const field of nodeSpec.fields) { - if (field.type === 'Node' && nodeData[field.name] != null) { + // Check if field type is 'Node' or any other node type (e.g., 'WithClause', 'SelectStmt', etc.) + const isNodeType = field.type === 'Node' || schemaMap.has(field.type); + if (isNodeType && nodeData[field.name] != null) { const value = nodeData[field.name]; if (field.isArray && Array.isArray(value)) { value.forEach((item, index) => {