diff --git a/src/codegen/infrastructure/array-allocator.ts b/src/codegen/infrastructure/array-allocator.ts index d41cc9c89..0cf1ac668 100644 --- a/src/codegen/infrastructure/array-allocator.ts +++ b/src/codegen/infrastructure/array-allocator.ts @@ -10,6 +10,7 @@ import { MemberAccessNode, VariableNode, MethodCallNode, + CallNode, SourceLocation, } from "../../ast/types.js"; import { InterfaceAllocator } from "./interface-allocator.js"; @@ -437,6 +438,26 @@ export class ArrayAllocator { return null; } + if (idxObjBase.type === "call") { + const callExpr = indexExpr.object as CallNode; + const ast = this.ctx.getAst(); + if (ast && callExpr.name) { + const funcs = ast.functions || []; + for (let i = 0; i < funcs.length; i++) { + const fn = funcs[i]; + if (fn.name === callExpr.name && fn.returnType) { + const rt = fn.returnType; + if (rt.endsWith("[]")) { + const elementType = rt.slice(0, -2).trim(); + return this.interfaceAlloc.getTypeInfoForElementType(elementType); + } + break; + } + } + } + return null; + } + if (idxObjBase.type !== "member_access") return null; const memberAccess = indexExpr.object as MemberAccessNode; diff --git a/src/codegen/infrastructure/variable-allocator.ts b/src/codegen/infrastructure/variable-allocator.ts index 4c867fd4c..2d337baef 100644 --- a/src/codegen/infrastructure/variable-allocator.ts +++ b/src/codegen/infrastructure/variable-allocator.ts @@ -174,6 +174,7 @@ export interface VariableAllocatorContext { setCurrentDeclaredInterfaceType(type: string | undefined): void; getCurrentDeclaredInterfaceType(): string | undefined; getCurrentClassName(): string | null; + getMethodReturnType(className: string, methodName: string): string | null; getParameterTypeFromAST(paramName: string): string | null; resolveImportAlias(localName: string): string; typeResolverGetInterface(name: string): InterfaceDeclaration | null; @@ -1223,33 +1224,70 @@ export class VariableAllocator { if (exprBase.type !== "member_access") return null; const memberExpr = expr as MemberAccessNode; const objBase = memberExpr.object as ExprBase; - if (objBase.type !== "variable") return null; - const varName = (memberExpr.object as VariableNode).name; - if (!varName) return null; let objectInterfaceType: string | null = null; - const ifaceType = this.ctx.symbolTable.getInterfaceType(varName); - if (ifaceType) { - objectInterfaceType = ifaceType; - } else { - const objMeta = this.ctx.symbolTable.getObjectMetadata(varName); - if (objMeta && objMeta.tsTypes) { - if (!objMeta.keys || !memberExpr.property) return null; - const keyIdx = objMeta.keys.indexOf(memberExpr.property); - if (keyIdx >= 0 && objMeta.tsTypes) { - const propType = objMeta.tsTypes[keyIdx]; - if ( - propType && - !propType.endsWith("[]") && - propType !== "string" && - propType !== "number" && - propType !== "boolean" - ) { - const iface = this.getInterface(propType); - if (iface) return propType; + if (objBase.type === "method_call") { + const mc = memberExpr.object as MethodCallNode; + const mcObjBase = mc.object as ExprBase; + let mcClassName: string | null = null; + if (mcObjBase.type === "variable") { + const mcVar = mc.object as VariableNode; + const concrete = this.ctx.symbolTable.getConcreteClass(mcVar.name); + if (concrete) mcClassName = concrete; + else if (this.ctx.symbolTable.isClass(mcVar.name)) { + const ci = this.ctx.symbolTable.getClassInfo(mcVar.name); + if (ci) mcClassName = ci.className; + } + } else if (mcObjBase.type === "this") { + mcClassName = this.ctx.getCurrentClassName(); + } + if (mcClassName) { + const rt = this.ctx.getMethodReturnType(mcClassName, mc.method); + if (rt && !rt.endsWith("[]")) { + objectInterfaceType = stripNullable(rt); + } + } + } else if (objBase.type === "call") { + const ce = memberExpr.object as CallNode; + const ast = this.ctx.getAst(); + if (ast && ce.name) { + const funcs = ast.functions || []; + for (let i = 0; i < funcs.length; i++) { + const fn = funcs[i]; + if (fn.name === ce.name && fn.returnType && !fn.returnType.endsWith("[]")) { + objectInterfaceType = stripNullable(fn.returnType); + break; } } - return null; } + } else if (objBase.type === "variable") { + const varName = (memberExpr.object as VariableNode).name; + if (!varName) return null; + const ifaceType = this.ctx.symbolTable.getInterfaceType(varName); + if (ifaceType) { + objectInterfaceType = ifaceType; + } else { + const objMeta = this.ctx.symbolTable.getObjectMetadata(varName); + if (objMeta && objMeta.tsTypes) { + if (!objMeta.keys || !memberExpr.property) return null; + const keyIdx = objMeta.keys.indexOf(memberExpr.property); + if (keyIdx >= 0 && objMeta.tsTypes) { + const propType = objMeta.tsTypes[keyIdx]; + if ( + propType && + !propType.endsWith("[]") && + propType !== "string" && + propType !== "number" && + propType !== "boolean" + ) { + const iface = this.getInterface(propType); + if (iface) return propType; + } + } + return null; + } + } + } else { + return null; } if (!objectInterfaceType) return null; const objectInterface = this.getInterface(objectInterfaceType); diff --git a/tests/fixtures/interfaces/fn-returning-array-indexed.ts b/tests/fixtures/interfaces/fn-returning-array-indexed.ts new file mode 100644 index 000000000..83a695880 --- /dev/null +++ b/tests/fixtures/interfaces/fn-returning-array-indexed.ts @@ -0,0 +1,15 @@ +interface P { + x: number; + y: string; +} +function make(): P[] { + return [ + { x: 1, y: "a" }, + { x: 2, y: "b" }, + ]; +} +function main(): void { + const last = make()[1]; + if (last.x === 2 && last.y === "b") console.log("TEST_PASSED"); +} +main(); diff --git a/tests/fixtures/interfaces/method-return-inner-interface.ts b/tests/fixtures/interfaces/method-return-inner-interface.ts new file mode 100644 index 000000000..1eef9b401 --- /dev/null +++ b/tests/fixtures/interfaces/method-return-inner-interface.ts @@ -0,0 +1,20 @@ +interface Inner { + v: number; +} +interface Outer { + inner: Inner; + name: string; +} +class S { + build(): Outer { + return { inner: { v: 9 }, name: "x" }; + } +} +function main(): void { + const s = new S(); + const a = s.build().inner; + if (a.v === 9) { + console.log("TEST_PASSED"); + } +} +main();