From 2d41577832d20f5858dd24c880ed15436aad8089 Mon Sep 17 00:00:00 2001 From: cs01 Date: Mon, 13 Apr 2026 09:20:37 -0700 Subject: [PATCH] perf: i64 abi specialization for integer-pure numeric functions, 1.5x fib --- src/ast/types.ts | 3 + src/codegen/expressions/arrow-functions.ts | 9 +- src/codegen/expressions/calls.ts | 22 + .../infrastructure/function-generator.ts | 23 +- .../int-specialization-detector.ts | 656 ++++++++++++++++++ src/codegen/llvm-generator.ts | 14 + tests/compiler.test.ts | 4 +- 7 files changed, 725 insertions(+), 6 deletions(-) create mode 100644 src/codegen/infrastructure/int-specialization-detector.ts diff --git a/src/ast/types.ts b/src/ast/types.ts index e89b8c7d..239f55c1 100644 --- a/src/ast/types.ts +++ b/src/ast/types.ts @@ -415,6 +415,9 @@ export interface FunctionNode { // When true, codegen emits LLVM `declare` instead of `define`, no _cs_ prefix declare?: boolean; typeParameters?: string[]; + // Marked true by int-specialization pass when params and return are + // all integer-valued. Triggers i64 ABI codegen instead of double. + intSpecialized?: boolean; } export interface ClassMethod { diff --git a/src/codegen/expressions/arrow-functions.ts b/src/codegen/expressions/arrow-functions.ts index d270a148..4edd43a7 100644 --- a/src/codegen/expressions/arrow-functions.ts +++ b/src/codegen/expressions/arrow-functions.ts @@ -162,7 +162,8 @@ export class ArrowFunctionExpressionGenerator extends BaseGenerator { } // All FunctionNode fields must be present so the native compiler allocates - // the full struct size — closureInfo is the 11th field (after declare). + // the full struct size — closureInfo is the last field after the + // FunctionNode prefix. const liftedFunc: LiftedFunction = { name: funcName, params: funcParams, @@ -174,6 +175,7 @@ export class ArrowFunctionExpressionGenerator extends BaseGenerator { loc: undefined, declare: false, typeParameters: undefined, + intSpecialized: false, closureInfo, }; @@ -231,8 +233,8 @@ export class ArrowFunctionExpressionGenerator extends BaseGenerator { } if (funcResult) { // Type assertion must include ALL fields from FunctionNode + closureInfo - // in exact struct order. LiftedFunction extends FunctionNode (10 fields), - // so closureInfo is at index 10. Omitting middle fields causes GEP to + // in exact struct order. LiftedFunction extends FunctionNode, so + // closureInfo is the final field. Omitting middle fields causes GEP to // read the wrong offset in native code. const func = funcResult as { name: string; @@ -245,6 +247,7 @@ export class ArrowFunctionExpressionGenerator extends BaseGenerator { loc: SourceLocation; declare: boolean; typeParameters: string[]; + intSpecialized: boolean; closureInfo: ClosureInfo; }; return func.closureInfo; diff --git a/src/codegen/expressions/calls.ts b/src/codegen/expressions/calls.ts index 9c2166e6..080fd51a 100644 --- a/src/codegen/expressions/calls.ts +++ b/src/codegen/expressions/calls.ts @@ -904,6 +904,15 @@ export class CallExpressionGenerator { ), ); } + // Integer-specialized callee: every double param/return becomes i64. + // The existing FFI coercion paths in this loop already handle paramType + // === "i64" (fptosi from double, or pass-through from i64). + if (func.intSpecialized) { + for (let pi = 0; pi < paramTypes.length; pi++) { + if (paramTypes[pi] === "double") paramTypes[pi] = "i64"; + } + if (returnType === "double") returnType = "i64"; + } } else { const funcNode = this.getFunctionFromAST(expr.name); if (funcNode) { @@ -936,6 +945,12 @@ export class CallExpressionGenerator { ); } } + if (funcNode.intSpecialized) { + for (let pi = 0; pi < paramTypes.length; pi++) { + if (paramTypes[pi] === "double") paramTypes[pi] = "i64"; + } + if (returnType === "double") returnType = "i64"; + } } } @@ -1022,6 +1037,13 @@ export class CallExpressionGenerator { return coerced; } if (returnType === "i64") { + // Integer-specialized callees keep their result as native i64 so that + // surrounding integer arithmetic stays in the i64 lane (no fadd round-trip). + // Other i64-returning extern calls still get coerced to double. + if (func && func.intSpecialized) { + this.ctx.setVariableType(temp, "i64"); + return temp; + } const coerced = this.ctx.nextTemp(); this.ctx.emit(`${coerced} = sitofp i64 ${temp} to double`); return coerced; diff --git a/src/codegen/infrastructure/function-generator.ts b/src/codegen/infrastructure/function-generator.ts index cee013e3..b86bd903 100644 --- a/src/codegen/infrastructure/function-generator.ts +++ b/src/codegen/infrastructure/function-generator.ts @@ -266,6 +266,22 @@ export class FunctionGenerator { const liftedFunc = func as LiftedFunction; const closureInfo = liftedFunc.closureInfo; const hasClosure = closureInfo ? closureInfo.captures.length > 0 : false; + + // Integer-specialized functions (detected by markIntSpecializedFunctions) + // use the i64 ABI: every numeric param becomes i64 instead of double, and + // the return type becomes i64 instead of double. Closures are excluded by + // construction (the detector only sees ast.functions, not lifted lambdas) + // but we double-check here. + const intSpecialized = func.intSpecialized && !hasClosure ? true : false; + if (intSpecialized) { + for (let i = 0; i < paramLLVMTypes.length; i++) { + if (paramLLVMTypes[i] === "double") paramLLVMTypes[i] = "i64"; + } + if (returnType === "double") { + returnType = "i64"; + this.ctx.setCurrentFunctionReturnType("i64"); + } + } const captures = closureInfo ? closureInfo.captures : null; let hasOptionalParams = false; if (func.parameters) { @@ -310,7 +326,8 @@ export class FunctionGenerator { const bodyStmts = func.body ? func.body.statements : []; const numericParamNames: string[] = []; for (let i = 0; i < funcParams.length; i++) { - if (paramLLVMTypes[i] === "double") { + // Numeric params include both default-double params and intSpec'd i64 params. + if (paramLLVMTypes[i] === "double" || paramLLVMTypes[i] === "i64") { numericParamNames.push(funcParams[i]); } } @@ -630,11 +647,15 @@ export class FunctionGenerator { break; } } + // intSpecialized: the function ABI passes i64 directly, so skip fptosi. + const paramAbiIsI64 = llvmType === "i64"; if (paramIsI64) { this.ctx.defineVariable(paramName, allocaReg, "i64", SymbolKind_Number, "local"); this.ctx.emit(`${allocaReg} = alloca i64`); if (isOptional && hasOptionalParams) { this.generateOptionalParamInitI64(i, allocaReg, paramInfo!, funcParams); + } else if (paramAbiIsI64) { + this.ctx.emit(`store i64 %arg${i}, i64* ${allocaReg}`); } else { const i64Val = this.ctx.nextTemp(); this.ctx.emit(`${i64Val} = fptosi double %arg${i} to i64`); diff --git a/src/codegen/infrastructure/int-specialization-detector.ts b/src/codegen/infrastructure/int-specialization-detector.ts new file mode 100644 index 00000000..42b6fef7 --- /dev/null +++ b/src/codegen/infrastructure/int-specialization-detector.ts @@ -0,0 +1,656 @@ +// Detects functions that can be specialized to a pure-i64 ABI instead of +// the default double ABI. Eligible functions: +// - All numeric params are integer-valued throughout the body +// - Return type is `number` (or unspecified) and every return value is an +// integer-shaped expression +// - No closures, async, optional/default params, or non-numeric params +// - Body has no try/throw/await/for-of/switch +// - Body calls only itself (no foreign function or method calls) +// +// When marked `intSpecialized`, the function-generator emits an i64 signature +// (skipping the entry fptosi), and the call-site lowering passes/returns i64 +// directly. All call sites already understand i64 paramTypes via the existing +// FFI coercion paths in calls.ts. +// +// IMPORTANT: this file runs under both the node and native compilers. To stay +// self-hosting safe we (a) only cast AST nodes to their canonical types from +// `src/ast/types.ts` (never to inline subset shapes — see CLAUDE.md rule #5), +// and (b) avoid `for...of`, `Set`, `Map`, etc. + +import type { + Statement, + Expression, + FunctionNode, + AST, + BinaryNode, + UnaryNode, + VariableNode, + NumberNode, + CallNode, + MethodCallNode, + NewNode, + MemberAccessNode, + IndexAccessNode, + ArrayNode, + ObjectNode, + MapNode, + SetNode, + TemplateLiteralNode, + ConditionalExpressionNode, + AwaitExpressionNode, + MemberAccessAssignmentNode, + IndexAccessAssignmentNode, + TypeAssertionNode, + SpreadElementNode, + ArrowFunctionNode, + ReturnStatement, + VariableDeclaration, + AssignmentStatement, + IfStatement, + WhileStatement, + DoWhileStatement, + ForStatement, + ForOfStatement, + ThrowStatement, + TryStatement, + SwitchStatement, + BlockStatement, +} from "../../ast/types.js"; +import { findI64EligibleVariables } from "./integer-analysis.js"; + +// ---------------------------------------------------------------------------- +// Statement walkers — keep all logic in terms of the canonical AST types. +// ---------------------------------------------------------------------------- + +function bodyHasDisqualifyingStmt(stmts: Statement[]): boolean { + for (let i = 0; i < stmts.length; i++) { + const s = stmts[i]; + const t = s.type; + if (t === "try" || t === "throw" || t === "await" || t === "for_of" || t === "switch") { + return true; + } + if (t === "if") { + const ifS = s as IfStatement; + if (bodyHasDisqualifyingStmt(ifS.thenBlock.statements)) return true; + if (ifS.elseBlock && bodyHasDisqualifyingStmt(ifS.elseBlock.statements)) return true; + } else if (t === "while" || t === "do_while") { + const w = s as WhileStatement; + if (bodyHasDisqualifyingStmt(w.body.statements)) return true; + } else if (t === "for") { + const f = s as ForStatement; + if (bodyHasDisqualifyingStmt(f.body.statements)) return true; + } + } + return false; +} + +// Returns true if the expression contains a call to anything other than +// `selfName`, including method calls or any non-self function call. Self +// recursive calls are allowed; their args are walked recursively. +function exprHasForeignInvocation(e: Expression, selfName: string): boolean { + const t = e.type; + if (t === "call") { + const c = e as CallNode; + if (c.name !== selfName) return true; + if (c.args) { + for (let i = 0; i < c.args.length; i++) { + if (exprHasForeignInvocation(c.args[i], selfName)) return true; + } + } + return false; + } + if (t === "method_call") { + return true; + } + if (t === "binary") { + const b = e as BinaryNode; + return ( + exprHasForeignInvocation(b.left, selfName) || exprHasForeignInvocation(b.right, selfName) + ); + } + if (t === "unary") { + const u = e as UnaryNode; + return exprHasForeignInvocation(u.operand, selfName); + } + return false; +} + +function stmtsHaveForeignInvocation(stmts: Statement[], selfName: string): boolean { + for (let i = 0; i < stmts.length; i++) { + const s = stmts[i]; + const t = s.type; + if (t === "return") { + const r = s as ReturnStatement; + if (r.value && exprHasForeignInvocation(r.value, selfName)) return true; + } else if (t === "variable_declaration") { + const vd = s as VariableDeclaration; + if (vd.value && exprHasForeignInvocation(vd.value, selfName)) return true; + } else if (t === "assignment") { + const a = s as AssignmentStatement; + if (a.value && exprHasForeignInvocation(a.value, selfName)) return true; + } else if (t === "if") { + const ifS = s as IfStatement; + if (stmtsHaveForeignInvocation(ifS.thenBlock.statements, selfName)) return true; + if (ifS.elseBlock && stmtsHaveForeignInvocation(ifS.elseBlock.statements, selfName)) + return true; + } else if (t === "while" || t === "do_while") { + const w = s as WhileStatement; + if (stmtsHaveForeignInvocation(w.body.statements, selfName)) return true; + } else if (t === "for") { + const f = s as ForStatement; + if (stmtsHaveForeignInvocation(f.body.statements, selfName)) return true; + } + } + return false; +} + +// Returns true if every reachable return statement has an integer-shaped value. +// `eligibleNames` is the result of findI64EligibleVariables — i.e. locals/params +// that have already been proven to never receive a non-integer value. +function isIntegerShapedExpr(e: Expression, eligibleNames: string[], selfName: string): boolean { + const t = e.type; + if (t === "number") { + return (e as NumberNode).value % 1 === 0; + } + if (t === "variable") { + const name = (e as VariableNode).name; + for (let i = 0; i < eligibleNames.length; i++) { + if (eligibleNames[i] === name) return true; + } + return false; + } + if (t === "binary") { + const b = e as BinaryNode; + const op = b.op; + if ( + op === "+" || + op === "-" || + op === "*" || + op === "%" || + op === "&" || + op === "|" || + op === "^" || + op === "<<" || + op === ">>" || + op === ">>>" + ) { + return ( + isIntegerShapedExpr(b.left, eligibleNames, selfName) && + isIntegerShapedExpr(b.right, eligibleNames, selfName) + ); + } + return false; + } + if (t === "unary") { + const u = e as UnaryNode; + if (u.op === "-" || u.op === "+" || u.op === "~") { + return isIntegerShapedExpr(u.operand, eligibleNames, selfName); + } + return false; + } + if (t === "call") { + const c = e as CallNode; + if (c.name !== selfName) return false; + if (!c.args) return true; + for (let i = 0; i < c.args.length; i++) { + if (!isIntegerShapedExpr(c.args[i], eligibleNames, selfName)) return false; + } + return true; + } + return false; +} + +function collectReturnExprs(stmts: Statement[], out: Expression[]): boolean { + for (let i = 0; i < stmts.length; i++) { + const s = stmts[i]; + const t = s.type; + if (t === "return") { + const r = s as ReturnStatement; + if (!r.value) return false; + out.push(r.value); + } else if (t === "if") { + const ifS = s as IfStatement; + if (!collectReturnExprs(ifS.thenBlock.statements, out)) return false; + if (ifS.elseBlock && !collectReturnExprs(ifS.elseBlock.statements, out)) return false; + } else if (t === "while" || t === "do_while") { + const w = s as WhileStatement; + if (!collectReturnExprs(w.body.statements, out)) return false; + } else if (t === "for") { + const f = s as ForStatement; + if (!collectReturnExprs(f.body.statements, out)) return false; + } + } + return true; +} + +// ---------------------------------------------------------------------------- +// Eligibility check. +// ---------------------------------------------------------------------------- + +function isEligible(func: FunctionNode): boolean { + if (func.async) return false; + if (func.declare) return false; + if (!func.params || func.params.length === 0) return false; + if (!func.body || !func.body.statements) return false; + + // Reject any non-`number` declared param type. + const paramTypes = func.paramTypes || []; + for (let i = 0; i < func.params.length; i++) { + const pt = paramTypes[i]; + if (pt && pt !== "number" && pt !== "") return false; + } + if (func.parameters) { + for (let i = 0; i < func.parameters.length; i++) { + const p = func.parameters[i]; + if (!p) continue; + if (p.optional || p.defaultValue) return false; + if (p.type && p.type !== "number" && p.type !== "") return false; + } + } + + // Reject any non-`number` return type. + const rt = func.returnType || ""; + if (rt !== "" && rt !== "number") return false; + + // Reject statements we don't want to reason about. + if (bodyHasDisqualifyingStmt(func.body.statements)) return false; + + // Reject any method call or non-self function call. + if (stmtsHaveForeignInvocation(func.body.statements, func.name)) return false; + + // Run the existing per-variable integer analyzer; every param must come + // back as eligible. + const eligible = findI64EligibleVariables(func.body.statements, func.params); + if (eligible.length < func.params.length) return false; + for (let i = 0; i < func.params.length; i++) { + let found = false; + for (let j = 0; j < eligible.length; j++) { + if (eligible[j] === func.params[i]) { + found = true; + break; + } + } + if (!found) return false; + } + + // Every return statement must produce an integer-shaped expression. + const returns: Expression[] = []; + if (!collectReturnExprs(func.body.statements, returns)) return false; + if (returns.length === 0) return false; + for (let i = 0; i < returns.length; i++) { + if (!isIntegerShapedExpr(returns[i], eligible, func.name)) return false; + } + + return true; +} + +// ---------------------------------------------------------------------------- +// Escape analysis — a function is "escaped" (ineligible for i64 ABI +// specialization) if its name appears as a first-class value anywhere in the +// program: as a callback arg, assigned to a local, returned, stored in an +// array/object literal, captured by a closure, etc. Such uses go through the +// canonical double-double function-pointer contract, so we must not mutate +// that signature. The ONLY non-escaping use is a direct call: `foo(a, b)` — +// where `foo` is `CallNode.name` (a string field, not a VariableNode). +// +// We collect the set of VariableNode names referenced anywhere in expression +// position, then intersect with top-level function names. The result is +// conservative: if a local variable shadows a function name, we treat it as +// escaped too, which just means we fail to specialize in a corner case. +// ---------------------------------------------------------------------------- + +function collectEscapedVarRefsExpr(e: Expression, out: string[]): void { + const t = e.type; + if (t === "variable") { + out.push((e as VariableNode).name); + return; + } + if (t === "call") { + const c = e as CallNode; + // c.name is a string (direct call target) — NOT a value reference. + if (c.args) { + for (let i = 0; i < c.args.length; i++) collectEscapedVarRefsExpr(c.args[i], out); + } + return; + } + if (t === "method_call") { + const mc = e as MethodCallNode; + // The method name is a string, but chad falls back to calling a + // top-level function if no class method matches. That fallback path + // uses the canonical double ABI, so if `mc.method` names a top-level + // function, we must not specialize that function. Treat the method + // name as a virtual escape reference. False positives are harmless — + // they just prevent specialization when the receiver is actually a + // class with its own method of the same name. + out.push(mc.method); + if (mc.object) collectEscapedVarRefsExpr(mc.object, out); + if (mc.args) { + for (let i = 0; i < mc.args.length; i++) collectEscapedVarRefsExpr(mc.args[i], out); + } + return; + } + if (t === "new") { + const nn = e as NewNode; + if (nn.args) { + for (let i = 0; i < nn.args.length; i++) collectEscapedVarRefsExpr(nn.args[i], out); + } + return; + } + if (t === "binary") { + const b = e as BinaryNode; + collectEscapedVarRefsExpr(b.left, out); + collectEscapedVarRefsExpr(b.right, out); + return; + } + if (t === "unary") { + const u = e as UnaryNode; + collectEscapedVarRefsExpr(u.operand, out); + return; + } + if (t === "member_access") { + const ma = e as MemberAccessNode; + collectEscapedVarRefsExpr(ma.object, out); + return; + } + if (t === "index_access") { + const ia = e as IndexAccessNode; + collectEscapedVarRefsExpr(ia.object, out); + collectEscapedVarRefsExpr(ia.index, out); + return; + } + if (t === "array") { + const a = e as ArrayNode; + if (a.elements) { + for (let i = 0; i < a.elements.length; i++) collectEscapedVarRefsExpr(a.elements[i], out); + } + return; + } + if (t === "object") { + const o = e as ObjectNode; + if (o.properties) { + for (let i = 0; i < o.properties.length; i++) { + collectEscapedVarRefsExpr(o.properties[i].value, out); + } + } + return; + } + if (t === "map") { + const m = e as MapNode; + if (m.entries) { + for (let i = 0; i < m.entries.length; i++) { + collectEscapedVarRefsExpr(m.entries[i].key, out); + collectEscapedVarRefsExpr(m.entries[i].value, out); + } + } + return; + } + if (t === "set") { + const s = e as SetNode; + if (s.values) { + for (let i = 0; i < s.values.length; i++) collectEscapedVarRefsExpr(s.values[i], out); + } + return; + } + if (t === "template_literal") { + const tl = e as TemplateLiteralNode; + if (tl.parts) { + for (let i = 0; i < tl.parts.length; i++) { + const p = tl.parts[i]; + if (typeof p !== "string") collectEscapedVarRefsExpr(p as Expression, out); + } + } + return; + } + if (t === "conditional") { + const ce = e as ConditionalExpressionNode; + collectEscapedVarRefsExpr(ce.condition, out); + collectEscapedVarRefsExpr(ce.consequent, out); + collectEscapedVarRefsExpr(ce.alternate, out); + return; + } + if (t === "await") { + const aw = e as AwaitExpressionNode; + collectEscapedVarRefsExpr(aw.argument, out); + return; + } + if (t === "member_access_assignment") { + const maa = e as MemberAccessAssignmentNode; + collectEscapedVarRefsExpr(maa.object, out); + collectEscapedVarRefsExpr(maa.value, out); + return; + } + if (t === "index_access_assignment") { + const iaa = e as IndexAccessAssignmentNode; + collectEscapedVarRefsExpr(iaa.object, out); + collectEscapedVarRefsExpr(iaa.index, out); + collectEscapedVarRefsExpr(iaa.value, out); + return; + } + if (t === "type_assertion") { + const ta = e as TypeAssertionNode; + collectEscapedVarRefsExpr(ta.expression, out); + return; + } + if (t === "spread_element") { + const sp = e as SpreadElementNode; + collectEscapedVarRefsExpr(sp.argument, out); + return; + } + if (t === "arrow_function") { + const af = e as ArrowFunctionNode; + const body = af.body; + if (body && (body as BlockStatement).type === "block") { + collectEscapedVarRefsStmts((body as BlockStatement).statements, out); + } else if (body) { + collectEscapedVarRefsExpr(body as Expression, out); + } + return; + } + // Leaves: number, string, boolean, null, undefined, regex, this, super — no children. +} + +function collectEscapedVarRefsStmts(stmts: Statement[], out: string[]): void { + for (let i = 0; i < stmts.length; i++) { + const s = stmts[i]; + const t = s.type; + if (t === "variable_declaration") { + const vd = s as VariableDeclaration; + if (vd.value) collectEscapedVarRefsExpr(vd.value, out); + } else if (t === "assignment") { + const a = s as AssignmentStatement; + if (a.value) collectEscapedVarRefsExpr(a.value, out); + } else if (t === "return") { + const r = s as ReturnStatement; + if (r.value) collectEscapedVarRefsExpr(r.value, out); + } else if (t === "if") { + const ifS = s as IfStatement; + collectEscapedVarRefsExpr(ifS.condition, out); + collectEscapedVarRefsStmts(ifS.thenBlock.statements, out); + if (ifS.elseBlock) collectEscapedVarRefsStmts(ifS.elseBlock.statements, out); + } else if (t === "while") { + const w = s as WhileStatement; + collectEscapedVarRefsExpr(w.condition, out); + collectEscapedVarRefsStmts(w.body.statements, out); + } else if (t === "do_while") { + const dw = s as DoWhileStatement; + collectEscapedVarRefsExpr(dw.condition, out); + collectEscapedVarRefsStmts(dw.body.statements, out); + } else if (t === "for") { + const f = s as ForStatement; + if (f.init) { + if ((f.init as VariableDeclaration).type === "variable_declaration") { + const vd2 = f.init as VariableDeclaration; + if (vd2.value) collectEscapedVarRefsExpr(vd2.value, out); + } else { + const as2 = f.init as AssignmentStatement; + if (as2.value) collectEscapedVarRefsExpr(as2.value, out); + } + } + if (f.condition) collectEscapedVarRefsExpr(f.condition, out); + if (f.update) { + const upType = (f.update as { type: string }).type; + if (upType === "assignment") { + const asu = f.update as AssignmentStatement; + if (asu.value) collectEscapedVarRefsExpr(asu.value, out); + } else { + collectEscapedVarRefsExpr(f.update as Expression, out); + } + } + collectEscapedVarRefsStmts(f.body.statements, out); + } else if (t === "for_of") { + const fo = s as ForOfStatement; + collectEscapedVarRefsExpr(fo.iterable, out); + collectEscapedVarRefsStmts(fo.body.statements, out); + } else if (t === "throw") { + const th = s as ThrowStatement; + if (th.argument) collectEscapedVarRefsExpr(th.argument, out); + } else if (t === "try") { + const tr = s as TryStatement; + collectEscapedVarRefsStmts(tr.tryBlock.statements, out); + if (tr.catchBody) collectEscapedVarRefsStmts(tr.catchBody.statements, out); + if (tr.finallyBlock) collectEscapedVarRefsStmts(tr.finallyBlock.statements, out); + } else if (t === "switch") { + const sw = s as SwitchStatement; + collectEscapedVarRefsExpr(sw.discriminant, out); + if (sw.cases) { + for (let j = 0; j < sw.cases.length; j++) { + const cs = sw.cases[j]; + if (cs.test) collectEscapedVarRefsExpr(cs.test, out); + if (cs.consequent) collectEscapedVarRefsStmts(cs.consequent, out); + } + } + } else if (t === "block") { + const bl = s as BlockStatement; + collectEscapedVarRefsStmts(bl.statements, out); + } else { + // Leftover case: a bare expression used as a statement (e.g. a call). + collectEscapedVarRefsExpr(s as Expression, out); + } + } +} + +function collectEscapedFunctionNames(ast: AST): string[] { + const refs: string[] = []; + + const funcs = ast.functions; + if (funcs) { + for (let i = 0; i < funcs.length; i++) { + const f = funcs[i] as FunctionNode; + if (!f || !f.body || !f.body.statements) continue; + collectEscapedVarRefsStmts(f.body.statements, refs); + } + } + + const classes = ast.classes; + if (classes) { + for (let i = 0; i < classes.length; i++) { + const cls = classes[i]; + if (!cls || !cls.methods) continue; + for (let j = 0; j < cls.methods.length; j++) { + const m = cls.methods[j]; + if (m && m.body && m.body.statements) { + collectEscapedVarRefsStmts(m.body.statements, refs); + } + } + } + } + + if (ast.topLevelStatements) { + for (let i = 0; i < ast.topLevelStatements.length; i++) { + const s = ast.topLevelStatements[i]; + if (!s) continue; + const t = s.type; + if (t === "variable_declaration") { + const vd = s as VariableDeclaration; + if (vd.value) collectEscapedVarRefsExpr(vd.value, refs); + } else if (t === "assignment") { + const a = s as AssignmentStatement; + if (a.value) collectEscapedVarRefsExpr(a.value, refs); + } + } + } + + if (ast.topLevelExpressions) { + for (let i = 0; i < ast.topLevelExpressions.length; i++) { + const e = ast.topLevelExpressions[i]; + if (e) collectEscapedVarRefsExpr(e as Expression, refs); + } + } + + if (ast.topLevelItems) { + for (let i = 0; i < ast.topLevelItems.length; i++) { + const it = ast.topLevelItems[i]; + if (!it) continue; + // Dispatch on shape — avoids needing to import every union type. + const t = (it as { type: string }).type; + if (t === "variable_declaration") { + const vd = it as VariableDeclaration; + if (vd.value) collectEscapedVarRefsExpr(vd.value, refs); + } else if (t === "assignment") { + const a = it as AssignmentStatement; + if (a.value) collectEscapedVarRefsExpr(a.value, refs); + } else if ( + t === "if" || + t === "while" || + t === "do_while" || + t === "for" || + t === "for_of" || + t === "try" || + t === "throw" || + t === "block" || + t === "switch" || + t === "return" + ) { + // Statements with nested expressions — reuse the stmt walker. + collectEscapedVarRefsStmts([it as Statement], refs); + } else { + // Expression as a statement (call, new, method_call, await, ...). + collectEscapedVarRefsExpr(it as Expression, refs); + } + } + } + + return refs; +} + +export function markIntSpecializedFunctions(ast: AST): void { + const funcs = ast.functions; + if (!funcs) return; + + // Build the set of function names referenced as first-class values. + const escapedRefs = collectEscapedFunctionNames(ast); + const funcNames: string[] = []; + for (let i = 0; i < funcs.length; i++) { + const f = funcs[i] as FunctionNode; + if (f && f.name) funcNames.push(f.name); + } + const escapedFuncNames: string[] = []; + for (let i = 0; i < funcNames.length; i++) { + const name = funcNames[i]; + for (let j = 0; j < escapedRefs.length; j++) { + if (escapedRefs[j] === name) { + escapedFuncNames.push(name); + break; + } + } + } + + for (let i = 0; i < funcs.length; i++) { + const f = funcs[i] as FunctionNode; + if (!f) continue; + + // Reject if the function is ever referenced as a value. + let escaped = false; + for (let j = 0; j < escapedFuncNames.length; j++) { + if (escapedFuncNames[j] === f.name) { + escaped = true; + break; + } + } + if (escaped) continue; + + if (isEligible(f)) { + f.intSpecialized = true; + } + } +} diff --git a/src/codegen/llvm-generator.ts b/src/codegen/llvm-generator.ts index 1addd3a6..d0072241 100644 --- a/src/codegen/llvm-generator.ts +++ b/src/codegen/llvm-generator.ts @@ -152,6 +152,7 @@ import { JsonObjectMeta } from "./expressions/access/member.js"; import type { TargetInfo } from "../target-types.js"; import { checkClosureMutations } from "../semantic/closure-mutation-checker.js"; import { checkUnionTypes } from "../semantic/union-type-checker.js"; +import { markIntSpecializedFunctions } from "./infrastructure/int-specialization-detector.js"; import { checkTypeAssertions } from "../semantic/type-assertion-checker.js"; import { checkUninitializedFields } from "../semantic/uninitialized-field-checker.js"; import { analyzeEscapes } from "../semantic/escape-analysis.js"; @@ -2862,6 +2863,7 @@ export class LLVMGenerator extends BaseGenerator implements IGeneratorContext { checkArgumentCounts(this.ast, this.sourceCode); checkAsyncAwait(this.ast, this.sourceCode); this.stackEligibleVars = analyzeEscapes(this.ast); + markIntSpecializedFunctions(this.ast); const irParts: string[] = []; @@ -3839,6 +3841,18 @@ export class LLVMGenerator extends BaseGenerator implements IGeneratorContext { } else if (valueType === "i8*" || lastValue === "null") { lastValue = "0.0"; } + } else if (this.currentFunctionReturnType === "i64") { + // Integer-specialized function: every return must end up i64. + const valueType = this.getVariableType(lastValue); + if (valueType === "double" || !valueType) { + const converted = this.nextTemp(); + this.emit(`${converted} = fptosi double ${lastValue} to i64`); + lastValue = converted; + } else if (valueType === "i32") { + const converted = this.nextTemp(); + this.emit(`${converted} = sext i32 ${lastValue} to i64`); + lastValue = converted; + } } if (this.currentFunctionReturnType === "void") { diff --git a/tests/compiler.test.ts b/tests/compiler.test.ts index 824e63b0..df07b522 100644 --- a/tests/compiler.test.ts +++ b/tests/compiler.test.ts @@ -150,8 +150,8 @@ describe(`ChadScript Compiler (${compilerLabel})`, () => { // Check for essential LLVM IR components assert.ok( - llContent.includes("define double @_cs_add"), - "Should define add function (mangled)", + llContent.includes("define double @_cs_add") || llContent.includes("define i64 @_cs_add"), + "Should define add function (mangled, either double or i64 ABI)", ); assert.ok(llContent.includes("define i32 @main"), "Should define main function"); assert.ok(llContent.includes("ret"), "Should have return statements");