diff --git a/src/core/annotation-parser.ts b/src/core/annotation-parser.ts index d830f7c2..33c75ac8 100644 --- a/src/core/annotation-parser.ts +++ b/src/core/annotation-parser.ts @@ -1,8 +1,9 @@ -import type { PythonType } from '../types/index.js'; +import type { PythonGenericParameter, PythonType } from '../types/index.js'; export interface AnnotationParserOptions { onUnknownTypeName?: (name: string) => void; knownTypeVarNames?: Iterable; + typeParameters?: readonly PythonGenericParameter[]; } export function parseAnnotationToPythonType( @@ -11,6 +12,9 @@ export function parseAnnotationToPythonType( ): PythonType { const onUnknownTypeName = options.onUnknownTypeName; const knownTypeVarNames = new Set(options.knownTypeVarNames ?? []); + const knownTypeParameters = new Map( + (options.typeParameters ?? []).map(param => [param.name, param] as const) + ); const modulePrefixes = ['', 'typing.', 'typing_extensions.', 'collections.abc.'] as const; const unknownType = (): PythonType => ({ kind: 'custom', name: 'Any', module: 'typing' }); @@ -35,68 +39,40 @@ export function parseAnnotationToPythonType( return { kind: 'literal', value: null } as PythonType; } const num = Number(t); - // Number('') is 0, so ensure we don't treat empty literals as numeric. if (t !== '' && !Number.isNaN(num)) { return { kind: 'literal', value: num } as PythonType; } return { kind: 'custom', name: t } as PythonType; }; - const parseTypingFactoryName = (text: string, name: string): string | null => { - for (const prefix of modulePrefixes) { - const start = `${prefix}${name}(`; - if (!text.startsWith(start) || !text.endsWith(')')) { - continue; - } - - const inner = text.slice(start.length, -1).trim(); - if (inner.length < 2) { - return null; - } - - const quote = inner[0]; - if ((quote !== "'" && quote !== '"') || inner[inner.length - 1] !== quote) { - return null; - } - - const commaIndex = inner.indexOf(','); - const quoted = commaIndex === -1 ? inner : inner.slice(0, commaIndex).trimEnd(); - - if (quoted.length < 2 || quoted[quoted.length - 1] !== quote) { - return null; - } - - return quoted.slice(1, -1); + const mapKnownTypeParameter = (name: string): PythonType | null => { + const normalized = name.replace(/^~/, '').trim(); + const param = knownTypeParameters.get(normalized); + if (!param) { + return null; } - - return null; - }; - - const matchBracketedAlias = ( - text: string, - aliases: readonly string[] - ): { alias: string; inner: string } | null => { - for (const prefix of modulePrefixes) { - for (const alias of aliases) { - const start = `${prefix}${alias}[`; - if (!text.startsWith(start) || !text.endsWith(']')) { - continue; - } + switch (param.kind) { + case 'typevar': return { - alias, - inner: text.slice(start.length, -1), - }; - } + kind: 'typevar', + name: param.name, + bound: param.bound, + constraints: param.constraints, + variance: param.variance, + } satisfies PythonType; + case 'paramspec': + return { kind: 'paramspec', name: param.name } satisfies PythonType; + case 'typevartuple': + return { kind: 'typevartuple', name: param.name } satisfies PythonType; } - return null; }; const mapSimpleName = (name: string): PythonType => { - const n = name - .replace(/^~/, '') - .replace(/^(typing\.|typing_extensions\.|collections\.abc\.)/, '') - .trim(); - + const n = name.replace(/^(typing\.|typing_extensions\.|collections\.abc\.)/, '').trim(); + const known = mapKnownTypeParameter(n); + if (known) { + return known; + } if (knownTypeVarNames.has(n)) { return { kind: 'typevar', name: n }; } @@ -105,7 +81,6 @@ export function parseAnnotationToPythonType( return { kind: 'primitive', name: n }; } - // Track unknown typing-ish names for diagnostics if ( n === 'Any' || n === 'Never' || @@ -137,9 +112,51 @@ export function parseAnnotationToPythonType( if (n === 'frozenset' || n === 'FrozenSet') { return { kind: 'collection', name: 'frozenset', itemTypes: [] }; } + if (n.startsWith('~')) { + return { kind: 'typevar', name: n.slice(1) }; + } return { kind: 'custom', name: n }; }; + const parseTypingFactoryName = (text: string, name: string): string | null => { + for (const prefix of modulePrefixes) { + const start = `${prefix}${name}(`; + if (!text.startsWith(start) || !text.endsWith(')')) { + continue; + } + + const inner = text.slice(start.length, -1).trim(); + if (inner.length < 2) { + return null; + } + + const quote = inner[0]; + if ((quote !== "'" && quote !== '"') || inner[inner.length - 1] !== quote) { + return null; + } + + const commaIndex = inner.indexOf(','); + const quoted = commaIndex === -1 ? inner : inner.slice(0, commaIndex).trimEnd(); + if (quoted.length < 2 || quoted[quoted.length - 1] !== quote) { + return null; + } + + return quoted.slice(1, -1); + } + + return null; + }; + + const splitQualifiedName = (raw: string): { name: string; module?: string } => { + const trimmed = raw.trim(); + const parts = trimmed.split('.').filter(Boolean); + if (parts.length <= 1) { + return { name: trimmed }; + } + const name = parts[parts.length - 1] ?? trimmed; + return { name, module: parts.slice(0, -1).join('.') || undefined }; + }; + const normalizeCollectionName = ( raw: string ): { name: 'list' | 'dict' | 'tuple' | 'set' | 'frozenset'; inner?: string } | null => { @@ -228,6 +245,38 @@ export function parseAnnotationToPythonType( return results; }; + const splitGenericInvocation = (raw: string): { name: string; inner: string } | null => { + if (!raw.endsWith(']')) { + return null; + } + const bracketStart = raw.indexOf('['); + if (bracketStart <= 0) { + return null; + } + + let depth = 0; + for (let i = bracketStart; i < raw.length; i++) { + const ch = raw.charAt(i); + if (ch === '[') { + depth++; + } else if (ch === ']') { + depth--; + if (depth === 0 && i !== raw.length - 1) { + return null; + } + } + } + + if (depth !== 0) { + return null; + } + + return { + name: raw.slice(0, bracketStart).trim(), + inner: raw.slice(bracketStart + 1, -1), + }; + }; + const parse = (ann: unknown, depth = 0): PythonType => { if (ann === null || ann === undefined) { return unknownType(); @@ -237,33 +286,32 @@ export function parseAnnotationToPythonType( } const rawText = String(ann).trim(); const raw = rawText.startsWith('~') ? rawText.slice(1).trim() : rawText; - - if (/^[A-Za-z_][A-Za-z0-9_]*$/.test(raw) && rawText.startsWith('~')) { - return { kind: 'typevar', name: raw }; + if (raw === '') { + return unknownType(); } - if (/^[A-Za-z_][A-Za-z0-9_]*\.args$/.test(raw)) { - return { kind: 'collection', name: 'list', itemTypes: [unknownType()] }; + if (/^[A-Za-z_][A-Za-z0-9_]*$/.test(raw) && rawText.startsWith('~')) { + return mapKnownTypeParameter(raw) ?? { kind: 'typevar', name: raw }; } - if (/^[A-Za-z_][A-Za-z0-9_]*\.kwargs$/.test(raw)) { - return { - kind: 'collection', - name: 'dict', - itemTypes: [{ kind: 'primitive', name: 'str' }, unknownType()], - }; + const paramspecArgsMatch = rawText.match(/^~?([A-Za-z_][A-Za-z0-9_]*)\.(args|kwargs)$/); + if (paramspecArgsMatch?.[1] && paramspecArgsMatch[2]) { + const baseName = paramspecArgsMatch[1]; + const known = mapKnownTypeParameter(baseName); + if (known?.kind === 'paramspec' || rawText.startsWith('~')) { + return paramspecArgsMatch[2] === 'args' + ? ({ kind: 'paramspec_args', name: baseName } satisfies PythonType) + : ({ kind: 'paramspec_kwargs', name: baseName } satisfies PythonType); + } } - // Handle built-in class repr: - const classMatch = raw.match(/^$/); - if (classMatch) { + const builtInClassMatch = raw.match(/^$/); + if (builtInClassMatch) { const inner = (raw.match(/^$/) ?? [])[1] ?? ''; const name = (inner.split('.').pop() ?? '').toString(); return mapSimpleName(name); } - // PEP 604 unions: int | str | None - // Note: split at top-level only (avoid recursing forever on pipes inside quoted Literals). if (raw.includes('|')) { const parts = splitTopLevel(raw, '|'); if (parts.length > 1) { @@ -272,161 +320,179 @@ export function parseAnnotationToPythonType( } } - // typing.Union[...] - if (raw.startsWith('typing.Union[') || raw.startsWith('Union[')) { + if ( + raw.startsWith('typing.Union[') || + raw.startsWith('typing_extensions.Union[') || + raw.startsWith('Union[') + ) { const inner = raw.slice(raw.indexOf('[') + 1, raw.lastIndexOf(']')); const parts = splitTopLevel(inner, ','); const types = parts.map(p => parse(p.trim(), depth + 1)); return { kind: 'union', types }; } - // Optional[T] - const optionalAlias = matchBracketedAlias(raw, ['Optional']); - if (optionalAlias) { - const base = parse(optionalAlias.inner, depth + 1); - return { kind: 'optional', type: base }; - } - - const sequenceAlias = matchBracketedAlias(raw, ['Sequence']); - if (sequenceAlias) { - return { - kind: 'collection', - name: 'list', - itemTypes: [parse(sequenceAlias.inner, depth + 1)], - }; - } - - const mappingAlias = matchBracketedAlias(raw, ['Mapping']); - if (mappingAlias) { - const parts = splitTopLevel(mappingAlias.inner, ','); - const itemTypes = parts.map(p => parse(p.trim(), depth + 1)); - return { kind: 'collection', name: 'dict', itemTypes } as PythonType; - } - - const iteratorAlias = matchBracketedAlias(raw, [ - 'Iterator', - 'AsyncIterator', - 'Iterable', - 'AsyncIterable', - ]); - if (iteratorAlias) { - const parts = splitTopLevel(iteratorAlias.inner, ','); - return { - kind: 'generic', - name: iteratorAlias.alias, - typeArgs: parts.map(p => parse(p.trim(), depth + 1)), - }; - } - - const awaitableAlias = matchBracketedAlias(raw, ['Awaitable']); - if (awaitableAlias) { - return { - kind: 'generic', - name: 'Promise', - typeArgs: [parse(awaitableAlias.inner, depth + 1)], - }; - } - - const coroutineAlias = matchBracketedAlias(raw, ['Coroutine']); - if (coroutineAlias) { - const parts = splitTopLevel(coroutineAlias.inner, ','); - return { - kind: 'generic', - name: 'Promise', - typeArgs: [parse(parts[parts.length - 1] ?? 'Any', depth + 1)], - }; + if ( + raw.startsWith('typing.Optional[') || + raw.startsWith('typing_extensions.Optional[') || + raw.startsWith('Optional[') + ) { + const inner = raw.slice(raw.indexOf('[') + 1, raw.lastIndexOf(']')); + return { kind: 'optional', type: parse(inner, depth + 1) }; } - const literalAlias = matchBracketedAlias(raw, ['Literal']); - if (literalAlias) { - const parts = splitTopLevel(literalAlias.inner, ','); + if ( + raw.startsWith('typing.Literal[') || + raw.startsWith('typing_extensions.Literal[') || + raw.startsWith('Literal[') + ) { + const inner = raw.slice(raw.indexOf('[') + 1, raw.lastIndexOf(']')); + const parts = splitTopLevel(inner, ','); if (parts.length === 1) { return mapLiteral(String(parts[0] ?? '').trim()); } - return { kind: 'union', types: parts.map(p => mapLiteral(String(p).trim())) } as PythonType; + return { kind: 'union', types: parts.map(p => mapLiteral(String(p).trim())) }; } - // typing_extensions wrappers: ClassVar[T], Final[T], Required[T], NotRequired[T] const extMatch = raw.match( /^(typing\.|typing_extensions\.)?(ClassVar|Final|Required|NotRequired)\[(.*)\]$/ ); if (extMatch) { - const inner = extMatch[3] ?? ''; - return parse(inner, depth + 1); + return parse(extMatch[3] ?? '', depth + 1); } - const typeVarName = parseTypingFactoryName(raw, 'TypeVar'); + const typeVarName = parseTypingFactoryName(rawText, 'TypeVar'); if (typeVarName) { - return { kind: 'typevar', name: typeVarName }; + return mapKnownTypeParameter(typeVarName) ?? { kind: 'typevar', name: typeVarName }; } - const paramSpecName = parseTypingFactoryName(raw, 'ParamSpec'); + const paramSpecName = parseTypingFactoryName(rawText, 'ParamSpec'); if (paramSpecName) { - return { kind: 'custom', name: paramSpecName, module: 'typing' }; + return mapKnownTypeParameter(paramSpecName) ?? { kind: 'paramspec', name: paramSpecName }; } - const typeVarTupleName = parseTypingFactoryName(raw, 'TypeVarTuple'); + const typeVarTupleName = parseTypingFactoryName(rawText, 'TypeVarTuple'); if (typeVarTupleName) { - return { kind: 'custom', name: typeVarTupleName, module: 'typing' }; + return ( + mapKnownTypeParameter(typeVarTupleName) ?? { + kind: 'typevartuple', + name: typeVarTupleName, + } + ); } - - // LiteralString - if (raw === 'typing.LiteralString' || raw === 'LiteralString') { - return { kind: 'primitive', name: 'str' } as PythonType; + if ( + raw === 'typing.LiteralString' || + raw === 'typing_extensions.LiteralString' || + raw === 'LiteralString' + ) { + return { kind: 'primitive', name: 'str' }; } - const callableAlias = matchBracketedAlias(raw, ['Callable']); - if (callableAlias) { - const parts = splitTopLevel(callableAlias.inner, ','); + if ( + raw.startsWith('typing.Callable[') || + raw.startsWith('typing_extensions.Callable[') || + raw.startsWith('Callable[') + ) { + const inner = raw.slice(raw.indexOf('[') + 1, raw.lastIndexOf(']')); + const parts = splitTopLevel(inner, ','); if (parts.length >= 2) { const paramsPart = (parts[0] ?? '').trim(); const returnPart = parts.slice(1).join(',').trim(); + const returnType = parse(returnPart, depth + 1); + + if (paramsPart === '...' || paramsPart === 'Ellipsis') { + return { + kind: 'callable', + parameters: [{ kind: 'custom', name: '...' }], + returnType, + }; + } + + const directParamSpec = parse(paramsPart, depth + 1); + if (directParamSpec.kind === 'paramspec') { + return { + kind: 'callable', + parameters: [], + parameterSpec: directParamSpec, + returnType, + }; + } + const paramInner = paramsPart.startsWith('[') && paramsPart.endsWith(']') ? paramsPart.slice(1, -1) : ''; - const paramTypes = ((): PythonType[] => { - // Callable[..., R] uses a top-level Ellipsis. - if (paramsPart === '...' || paramsPart === 'Ellipsis') { - return [{ kind: 'custom', name: '...' } as PythonType]; - } - const trimmed = paramInner.trim(); - if (trimmed === '...' || trimmed === 'Ellipsis') { - return [{ kind: 'custom', name: '...' } as PythonType]; - } - if (!paramsPart.startsWith('[') || !paramsPart.endsWith(']')) { - return [{ kind: 'custom', name: '...' } as PythonType]; - } - return trimmed ? splitTopLevel(trimmed, ',').map(p => parse(p.trim(), depth + 1)) : []; - })(); - const returnType = parse(returnPart, depth + 1); - return { kind: 'callable', parameters: paramTypes, returnType } as PythonType; + const trimmed = paramInner.trim(); + if (trimmed === '...' || trimmed === 'Ellipsis') { + return { + kind: 'callable', + parameters: [{ kind: 'custom', name: '...' }], + returnType, + }; + } + + const parameters = trimmed + ? splitTopLevel(trimmed, ',').map(p => parse(p.trim(), depth + 1)) + : []; + return { kind: 'callable', parameters, returnType }; } } - const unpackAlias = matchBracketedAlias(raw, ['Unpack']); - if (unpackAlias) { - return unknownType(); + if ( + raw.startsWith('typing.Mapping[') || + raw.startsWith('typing_extensions.Mapping[') || + raw.startsWith('Mapping[') + ) { + const inner = raw.slice(raw.indexOf('[') + 1, raw.lastIndexOf(']')); + const parts = splitTopLevel(inner, ','); + return { + kind: 'collection', + name: 'dict', + itemTypes: parts.map(p => parse(p.trim(), depth + 1)), + }; } - const annotatedAlias = matchBracketedAlias(raw, ['Annotated']); - if (annotatedAlias) { - const parts = splitTopLevel(annotatedAlias.inner, ','); + if ( + raw.startsWith('typing.Annotated[') || + raw.startsWith('typing_extensions.Annotated[') || + raw.startsWith('Annotated[') + ) { + const inner = raw.slice(raw.indexOf('[') + 1, raw.lastIndexOf(']')); + const parts = splitTopLevel(inner, ','); if (parts.length > 0) { - const base = parse((parts[0] ?? '').trim(), depth + 1); - const metaParts = parts.slice(1).map(p => String(p).trim()); - return { kind: 'annotated', base, metadata: metaParts } as PythonType; + return { + kind: 'annotated', + base: parse(parts[0] ?? '', depth + 1), + metadata: parts.slice(1).map(p => String(p).trim()), + }; } } - // Collections: list[T], dict[K,V], tuple[...], set[T], frozenset[T] + if ( + raw.startsWith('typing.Unpack[') || + raw.startsWith('typing_extensions.Unpack[') || + raw.startsWith('Unpack[') + ) { + const inner = raw.slice(raw.indexOf('[') + 1, raw.lastIndexOf(']')); + return { kind: 'unpack', type: parse(inner, depth + 1) }; + } + const coll = normalizeCollectionName(raw); if (coll) { - const { name, inner } = coll; - const itemParts = splitTopLevel(inner ?? '', ','); - const itemTypes = (inner ? itemParts : []).map(p => parse(p.trim(), depth + 1)); - return { kind: 'collection', name, itemTypes }; + const itemParts = splitTopLevel(coll.inner ?? '', ','); + const itemTypes = (coll.inner ? itemParts : []).map(p => parse(p.trim(), depth + 1)); + return { kind: 'collection', name: coll.name, itemTypes }; + } + + const generic = splitGenericInvocation(raw); + if (generic) { + const typeArgs = splitTopLevel(generic.inner, ',').map(part => parse(part.trim(), depth + 1)); + const qualified = splitQualifiedName(generic.name); + return { + kind: 'generic', + name: qualified.name, + module: qualified.module, + typeArgs, + }; } - // Bare names like int, str, float, bool, bytes, None return mapSimpleName(raw); }; diff --git a/src/core/generator.ts b/src/core/generator.ts index 390724e4..01f2cb47 100644 --- a/src/core/generator.ts +++ b/src/core/generator.ts @@ -3,9 +3,12 @@ */ import type { + PythonGenericParameter, PythonFunction, PythonClass, PythonModule, + PythonType, + PythonTypeAlias, GeneratedCode, TypescriptType, } from '../types/index.js'; @@ -13,6 +16,19 @@ import { globalCache } from '../utils/cache.js'; import { TypeMapper } from './mapper.js'; +interface GenericRenderParam { + name: string; + declaration: string; +} + +interface GenericRenderContext { + currentModule?: string; + declaration: string; + typeArguments: string; + emittedNames: Set; + emittedParamSpecs: Set; +} + export class CodeGenerator { private readonly mapper: TypeMapper; private readonly reservedTsIdentifiers = new Set([ @@ -109,6 +125,186 @@ export class CodeGenerator { }); } + private getTypeParameters( + typeParameters?: readonly PythonGenericParameter[] + ): readonly PythonGenericParameter[] { + return typeParameters ?? []; + } + + private buildGenericRenderContext( + typeParameters: readonly PythonGenericParameter[], + types: readonly PythonType[], + currentModule?: string + ): GenericRenderContext { + const callableParamSpecs = new Set(); + types.forEach(type => this.collectCallableParamSpecs(type, callableParamSpecs)); + + const emitted: GenericRenderParam[] = []; + const emittedNames = new Set(); + const emittedParamSpecs = new Set(); + + typeParameters.forEach(param => { + if ( + param.kind === 'typevar' && + !param.bound && + !(param.constraints && param.constraints.length > 0) && + (!param.variance || param.variance === 'invariant') + ) { + emitted.push({ name: param.name, declaration: param.name }); + emittedNames.add(param.name); + return; + } + + if (param.kind === 'paramspec' && callableParamSpecs.has(param.name)) { + emitted.push({ name: param.name, declaration: `${param.name} extends unknown[]` }); + emittedNames.add(param.name); + emittedParamSpecs.add(param.name); + } + }); + + return { + currentModule, + declaration: + emitted.length > 0 ? `<${emitted.map(param => param.declaration).join(', ')}>` : '', + typeArguments: emitted.length > 0 ? `<${emitted.map(param => param.name).join(', ')}>` : '', + emittedNames, + emittedParamSpecs, + }; + } + + private mergeGenericRenderContexts( + outer: GenericRenderContext, + inner: GenericRenderContext + ): GenericRenderContext { + return { + currentModule: inner.currentModule ?? outer.currentModule, + declaration: inner.declaration, + typeArguments: inner.typeArguments, + emittedNames: new Set([...outer.emittedNames, ...inner.emittedNames]), + emittedParamSpecs: new Set([...outer.emittedParamSpecs, ...inner.emittedParamSpecs]), + }; + } + + private collectCallableParamSpecs(type: PythonType, out: Set): void { + switch (type.kind) { + case 'collection': + type.itemTypes.forEach(item => this.collectCallableParamSpecs(item, out)); + break; + case 'paramspec': + out.add(type.name); + break; + case 'union': + type.types.forEach(item => this.collectCallableParamSpecs(item, out)); + break; + case 'optional': + this.collectCallableParamSpecs(type.type, out); + break; + case 'generic': + type.typeArgs.forEach(item => this.collectCallableParamSpecs(item, out)); + break; + case 'callable': + if (type.parameterSpec?.kind === 'paramspec') { + out.add(type.parameterSpec.name); + } + type.parameters.forEach(item => this.collectCallableParamSpecs(item, out)); + this.collectCallableParamSpecs(type.returnType, out); + break; + case 'annotated': + this.collectCallableParamSpecs(type.base, out); + break; + case 'final': + case 'classvar': + this.collectCallableParamSpecs(type.type, out); + break; + case 'unpack': + this.collectCallableParamSpecs(type.type, out); + break; + default: + break; + } + } + + private sanitizeType(type: PythonType, ctx: GenericRenderContext): PythonType { + const unknownType = (): PythonType => ({ kind: 'custom', name: 'Any', module: 'typing' }); + + switch (type.kind) { + case 'primitive': + case 'literal': + return type; + case 'custom': + return type; + case 'collection': + return { + ...type, + itemTypes: type.itemTypes.map(item => this.sanitizeType(item, ctx)), + }; + case 'union': + return { + ...type, + types: type.types.map(item => this.sanitizeType(item, ctx)), + }; + case 'optional': + return { ...type, type: this.sanitizeType(type.type, ctx) }; + case 'generic': + return { + ...type, + module: type.module === ctx.currentModule ? undefined : type.module, + typeArgs: type.typeArgs.map(item => this.sanitizeType(item, ctx)), + }; + case 'callable': + if (type.parameterSpec && !ctx.emittedParamSpecs.has(type.parameterSpec.name)) { + return { + ...type, + parameters: [{ kind: 'custom', name: '...' }], + parameterSpec: undefined, + returnType: this.sanitizeType(type.returnType, ctx), + }; + } + return { + ...type, + parameters: type.parameters.map(item => this.sanitizeType(item, ctx)), + parameterSpec: + type.parameterSpec && ctx.emittedParamSpecs.has(type.parameterSpec.name) + ? type.parameterSpec + : undefined, + returnType: this.sanitizeType(type.returnType, ctx), + }; + case 'annotated': + return { ...type, base: this.sanitizeType(type.base, ctx) }; + case 'typevar': + case 'paramspec': + return ctx.emittedNames.has(type.name) ? type : unknownType(); + case 'paramspec_args': + return { + kind: 'collection', + name: 'list', + itemTypes: [unknownType()], + }; + case 'paramspec_kwargs': + return { + kind: 'collection', + name: 'dict', + itemTypes: [{ kind: 'primitive', name: 'str' }, unknownType()], + }; + case 'typevartuple': + return unknownType(); + case 'unpack': + return unknownType(); + case 'final': + return { ...type, type: this.sanitizeType(type.type, ctx) }; + case 'classvar': + return { ...type, type: this.sanitizeType(type.type, ctx) }; + } + } + + private typeToTsFromPython( + type: PythonType, + ctx: GenericRenderContext, + mappingContext: 'value' | 'return' + ): string { + return this.typeToTs(this.mapper.mapPythonType(this.sanitizeType(type, ctx), mappingContext)); + } + private renderLooksLikeKwargsExpr( valueExpr: string, options: { @@ -122,9 +318,7 @@ export class CodeGenerator { const keyCheck = (() => { if (options.requiredKwOnlyNames.length > 0) { return options.requiredKwOnlyNames - .map( - k => `Object.prototype.hasOwnProperty.call(${valueExpr}, ${JSON.stringify(k)})` - ) + .map(k => `Object.prototype.hasOwnProperty.call(${valueExpr}, ${JSON.stringify(k)})`) .join(' && '); } if (options.hasVarKwArgs) { @@ -133,9 +327,7 @@ export class CodeGenerator { } if (options.keywordOnlyNames.length > 0) { return options.keywordOnlyNames - .map( - k => `Object.prototype.hasOwnProperty.call(${valueExpr}, ${JSON.stringify(k)})` - ) + .map(k => `Object.prototype.hasOwnProperty.call(${valueExpr}, ${JSON.stringify(k)})`) .join(' || '); } return 'false'; @@ -163,9 +355,15 @@ export class CodeGenerator { const needsVarArgsArray = Boolean(varArgsParam) && needsKwargsParam; const positionalParams = filteredParams.filter(p => !p.keywordOnly && !p.varArgs && !p.kwArgs); + const genericContext = this.buildGenericRenderContext( + this.getTypeParameters(func.typeParameters), + [func.returnType, ...filteredParams.map(param => param.type)], + moduleName + ); + const typeParamDecl = genericContext.declaration; const tsTypeForValue = (p: (typeof filteredParams)[number]): string => - this.typeToTs(this.mapper.mapPythonType(p.type, 'value')); + this.typeToTsFromPython(p.type, genericContext, 'value'); const kwargsType = (() => { if (!needsKwargsParam) { @@ -225,7 +423,7 @@ export class CodeGenerator { const paramDecl = implParams.join(', '); const hasKwArgs = needsKwargsParam; - const returnType = this.typeToTs(this.mapper.mapPythonType(func.returnType, 'return')); + const returnType = this.typeToTsFromPython(func.returnType, genericContext, 'return'); const fname = this.escapeIdentifier(func.name); const moduleId = moduleName ?? '__main__'; @@ -266,12 +464,12 @@ export class CodeGenerator { rest.push(k); } overloads.push( - `export function ${fname}(${[...head, ...rest].join(', ')}): Promise<${returnType}>;` + `export function ${fname}${typeParamDecl}(${[...head, ...rest].join(', ')}): Promise<${returnType}>;` ); if (varArgsParam && needsVarArgsArray) { // Also allow callers to omit the varargs surrogate parameter entirely (i.e. `fn(kwargs)`). overloads.push( - `export function ${fname}(${[...head, renderKwargsParam(true)].join(', ')}): Promise<${returnType}>;` + `export function ${fname}${typeParamDecl}(${[...head, renderKwargsParam(true)].join(', ')}): Promise<${returnType}>;` ); } } @@ -288,7 +486,7 @@ export class CodeGenerator { rest.push(k); } overloads.push( - `export function ${fname}(${[...head, ...rest].join(', ')}): Promise<${returnType}>;` + `export function ${fname}${typeParamDecl}(${[...head, ...rest].join(', ')}): Promise<${returnType}>;` ); } } @@ -349,7 +547,9 @@ export class CodeGenerator { requiredKwOnlyNames, hasVarKwArgs, }); - callPreludeLines.push(` if (__kwargs === undefined && __args.length > ${requiredPosCount}) {`); + callPreludeLines.push( + ` if (__kwargs === undefined && __args.length > ${requiredPosCount}) {` + ); callPreludeLines.push(` const __candidate = __args[__args.length - 1];`); callPreludeLines.push(` if (${looksLikeKwargs}) {`); callPreludeLines.push(` __kwargs = __candidate as any;`); @@ -384,14 +584,19 @@ export class CodeGenerator { } const callPrelude = callPreludeLines.length > 0 ? `${callPreludeLines.join('\n')}\n` : ''; - const ts = `${jsdoc}${overloadDecl}export async function ${fname}(${paramDecl}): Promise<${returnType}> { + const ts = `${jsdoc}${overloadDecl}export async function ${fname}${typeParamDecl}(${paramDecl}): Promise<${returnType}> { ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.name}', __args${ hasKwArgs ? ', __kwargs' : '' }); } `; - return this.wrap(ts, [func.name]); + const declarationBody = + overloads.length > 0 + ? overloadDecl + : `export function ${fname}${typeParamDecl}(${paramDecl}): Promise<${returnType}>;\n`; + + return this.wrap(ts, `${jsdoc}${declarationBody}`, [func.name]); } generateClassWrapper( @@ -400,81 +605,113 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n _annotatedJSDoc = false ): GeneratedCode { const jsdoc = this.generateJsDoc(cls.docstring); - // Structural type aliases for special kinds + const classGenericContext = this.buildGenericRenderContext( + this.getTypeParameters(cls.typeParameters), + [ + ...cls.properties.map(property => property.type), + ...cls.methods.flatMap(method => [ + method.returnType, + ...method.parameters.map(p => p.type), + ]), + ], + moduleName + ); + const classTypeParamDecl = classGenericContext.declaration; + const cname = this.escapeIdentifier(cls.name); + const classSelfType = `${cname}${classGenericContext.typeArguments}`; + const tsValueType = (p: (typeof cls.methods)[number]['parameters'][number]): string => + this.typeToTsFromPython(p.type, classGenericContext, 'value'); + + const wrapAlias = (body: string): GeneratedCode => { + const ts = `${jsdoc}export type ${cname}${classTypeParamDecl} = ${body}\n`; + return this.wrap(ts, ts, [cls.name]); + }; + if (cls.decorators.includes('__typed_dict__') || cls.kind === 'typed_dict') { const props = cls.properties .map(p => { const pname = this.escapeIdentifier(p.name); const opt = (p as unknown as { optional?: boolean }).optional === true ? '?' : ''; - const t = this.typeToTs(this.mapper.mapPythonType(p.type, 'value')); + const t = this.typeToTsFromPython(p.type, classGenericContext, 'value'); return `${pname}${opt}: ${t};`; }) .join(' '); - const cname = this.escapeIdentifier(cls.name); - const ts = `${jsdoc}export type ${cname} = { ${props} }\n`; - return this.wrap(ts, [cls.name]); + return wrapAlias(`{ ${props} }`); } if (cls.kind === 'namedtuple') { - // NamedTuple -> readonly tuple type alias `[T1, T2, ...]` const elements = cls.properties.map(p => - this.typeToTs(this.mapper.mapPythonType(p.type, 'value')) + this.typeToTsFromPython(p.type, classGenericContext, 'value') ); - const cname = this.escapeIdentifier(cls.name); - const ts = `${jsdoc}export type ${cname} = readonly [${elements.join(', ')}]\n`; - return this.wrap(ts, [cls.name]); + return wrapAlias(`readonly [${elements.join(', ')}]`); } if (cls.kind === 'protocol') { - // Protocol -> structural interface-like type alias for attributes and callables (subset) const props = cls.properties .map( p => - `${this.escapeIdentifier(p.name)}: ${this.typeToTs(this.mapper.mapPythonType(p.type, 'value'))};` + `${this.escapeIdentifier(p.name)}: ${this.typeToTsFromPython(p.type, classGenericContext, 'value')};` ) .join(' '); const methods = cls.methods + .filter(m => m.name !== '__init__') .map(m => { const fparams = m.parameters.filter(p => p.name !== 'self' && p.name !== 'cls'); + const methodOwnGenericContext = this.buildGenericRenderContext( + this.getTypeParameters(m.typeParameters), + [m.returnType, ...fparams.map(param => param.type)], + moduleName + ); + const methodGenericContext = this.mergeGenericRenderContexts( + classGenericContext, + methodOwnGenericContext + ); + const methodTypeParamDecl = methodOwnGenericContext.declaration; const paramsDecl = fparams .map( p => - `${this.escapeIdentifier(p.name)}${p.optional ? '?' : ''}: ${this.typeToTs(this.mapper.mapPythonType(p.type, 'value'))}` + `${this.escapeIdentifier(p.name)}${p.optional ? '?' : ''}: ${this.typeToTsFromPython(p.type, methodGenericContext, 'value')}` ) .join(', '); - const returnType = this.typeToTs(this.mapper.mapPythonType(m.returnType, 'return')); - return `${this.escapeIdentifier(m.name)}: (${paramsDecl}) => ${returnType};`; + const returnType = this.typeToTsFromPython(m.returnType, methodGenericContext, 'return'); + return `${this.escapeIdentifier(m.name)}: ${methodTypeParamDecl}(${paramsDecl}) => ${returnType};`; }) .join(' '); - const cname = this.escapeIdentifier(cls.name); - const ts = `${jsdoc}export type ${cname} = { ${props} ${methods} }\n`; - return this.wrap(ts, [cls.name]); + return wrapAlias(`{ ${props} ${methods} }`); } if (cls.kind === 'dataclass' || cls.kind === 'pydantic') { - // Data containers -> object type alias const props = cls.properties .map(p => { const pname = this.escapeIdentifier(p.name); const opt = (p as unknown as { optional?: boolean }).optional === true ? '?' : ''; - const t = this.typeToTs(this.mapper.mapPythonType(p.type, 'value')); + const t = this.typeToTsFromPython(p.type, classGenericContext, 'value'); return `${pname}${opt}: ${t};`; }) .join(' '); - const cname = this.escapeIdentifier(cls.name); - const ts = `${jsdoc}export type ${cname} = { ${props} }\n`; - return this.wrap(ts, [cls.name]); + return wrapAlias(`{ ${props} }`); } - const cname = this.escapeIdentifier(cls.name); - const sortedMethods = [...cls.methods].sort((a, b) => a.name.localeCompare(b.name)); - const tsValueType = (p: (typeof cls.methods)[number]['parameters'][number]): string => - this.typeToTs(this.mapper.mapPythonType(p.type, 'value')); - - const methodBodies = sortedMethods - .filter(m => m.name !== '__init__') - .map(m => { - const fparams = m.parameters.filter(p => p.name !== 'self' && p.name !== 'cls'); + const sortedMethods = [...cls.methods].sort((a, b) => a.name.localeCompare(b.name)); + const methodBodies: string[] = []; + const methodDeclarations: string[] = []; + + sortedMethods + .filter(method => method.name !== '__init__') + .forEach(method => { + const fparams = method.parameters.filter(p => p.name !== 'self' && p.name !== 'cls'); + const methodOwnGenericContext = this.buildGenericRenderContext( + this.getTypeParameters(method.typeParameters), + [method.returnType, ...fparams.map(param => param.type)], + moduleName + ); + const methodGenericContext = this.mergeGenericRenderContexts( + classGenericContext, + methodOwnGenericContext + ); + const methodTypeParamDecl = methodOwnGenericContext.declaration; + const methodTsValueType = (p: (typeof fparams)[number]): string => + this.typeToTsFromPython(p.type, methodGenericContext, 'value'); const keywordOnlyParams = fparams.filter(p => p.keywordOnly); const positionalOnlyNames = fparams.filter(p => p.positionalOnly).map(p => p.name); const hasVarKwArgs = fparams.some(p => p.kwArgs); @@ -493,7 +730,7 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n ): string => { const pname = this.escapeIdentifier(p.name); const opt = !forceRequired && p.optional ? '?' : ''; - return `${pname}${opt}: ${tsValueType(p)}`; + return `${pname}${opt}: ${methodTsValueType(p)}`; }; const kwargsType = (() => { @@ -504,16 +741,16 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n return 'Record'; } const props = keywordOnlyParams - .map(p => `${JSON.stringify(p.name)}${p.optional ? '?' : ''}: ${tsValueType(p)};`) + .map(p => `${JSON.stringify(p.name)}${p.optional ? '?' : ''}: ${methodTsValueType(p)};`) .join(' '); const obj = `{ ${props} }`; return hasVarKwArgs ? `(${obj} & Record)` : obj; })(); const paramsDeclParts: string[] = []; - for (const p of positionalParams) { + positionalParams.forEach(p => { paramsDeclParts.push(renderPositionalParam(p)); - } + }); if (varArgsParam) { const vname = this.escapeIdentifier(varArgsParam.name); paramsDeclParts.push( @@ -526,8 +763,12 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n const paramsDecl = paramsDeclParts.join(', '); const requiredKwOnlyNames = keywordOnlyParams.filter(p => !p.optional).map(p => p.name); - const returnType = this.typeToTs(this.mapper.mapPythonType(m.returnType, 'return')); - const mname = this.escapeIdentifier(m.name); + const returnType = this.typeToTsFromPython( + method.returnType, + methodGenericContext, + 'return' + ); + const mname = this.escapeIdentifier(method.name); const overloads: string[] = []; if (needsKwargsParam && requiredKwOnlyNames.length > 0) { @@ -544,10 +785,12 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n ); } rest.push(`kwargs: ${kwargsType}`); - overloads.push(` ${mname}(${[...head, ...rest].join(', ')}): Promise<${returnType}>;`); + overloads.push( + ` ${mname}${methodTypeParamDecl}(${[...head, ...rest].join(', ')}): Promise<${returnType}>;` + ); if (varArgsParam && needsVarArgsArray) { overloads.push( - ` ${mname}(${[...head, `kwargs: ${kwargsType}`].join(', ')}): Promise<${returnType}>;` + ` ${mname}${methodTypeParamDecl}(${[...head, `kwargs: ${kwargsType}`].join(', ')}): Promise<${returnType}>;` ); } } @@ -595,11 +838,13 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n callPreludeLines.push(` if (${vname} !== undefined) {`); callPreludeLines.push(` if (globalThis.Array.isArray(${vname})) {`); callPreludeLines.push(` __varargs = ${vname};`); - callPreludeLines.push(` } else if (__kwargs === undefined && ${looksLikeKwargs}) {`); + callPreludeLines.push( + ` } else if (__kwargs === undefined && ${looksLikeKwargs}) {` + ); callPreludeLines.push(` __kwargs = ${vname} as any;`); callPreludeLines.push(` } else {`); callPreludeLines.push( - ` throw new Error(\`${m.name} expected ${varArgsParam.name} to be an array\`);` + ` throw new Error(\`${method.name} expected ${varArgsParam.name} to be an array\`);` ); callPreludeLines.push(` }`); callPreludeLines.push(` }`); @@ -609,6 +854,7 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n } } const callPrelude = callPreludeLines.length > 0 ? `${callPreludeLines.join('\n')}\n` : ''; + const guardLines: string[] = []; if (needsKwargsParam && positionalOnlyNames.length > 0) { guardLines.push( @@ -619,7 +865,7 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n ` if (__kwargs && Object.prototype.hasOwnProperty.call(__kwargs, key)) {` ); guardLines.push( - ` throw new Error(\`${m.name} does not accept positional-only argument "\${key}" as a keyword argument\`);` + ` throw new Error(\`${method.name} does not accept positional-only argument "\${key}" as a keyword argument\`);` ); guardLines.push(` }`); guardLines.push(` }`); @@ -638,26 +884,28 @@ ${callPrelude}${guards} return getRuntimeBridge().call('${moduleId}', '${func.n guardLines.push(` }`); guardLines.push(` if (__missing.length > 0) {`); guardLines.push( - ` throw new Error(\`Missing required keyword-only arguments for ${m.name}: \${__missing.join(', ')}\`);` + ` throw new Error(\`Missing required keyword-only arguments for ${method.name}: \${__missing.join(', ')}\`);` ); guardLines.push(` }`); } const guards = guardLines.length > 0 ? `${guardLines.join('\n')}\n` : ''; - return `${overloadDecl} async ${mname}(${paramsDecl}): Promise<${returnType}> { -${callPrelude}${guards} return getRuntimeBridge().callMethod(this.__handle, '${m.name}', __args${ - needsKwargsParam ? ', __kwargs' : '' - }); - }`; - }) - .join('\n'); + methodBodies.push(`${overloadDecl} async ${mname}${methodTypeParamDecl}(${paramsDecl}): Promise<${returnType}> { +${callPrelude}${guards} return getRuntimeBridge().callMethod(this.__handle, '${method.name}', __args${ + needsKwargsParam ? ', __kwargs' : '' + }); + }`); + methodDeclarations.push( + `${overloadDecl}${overloads.length > 0 ? '' : ` ${mname}${methodTypeParamDecl}(${paramsDecl}): Promise<${returnType}>;\n`}` + ); + }); - // Constructor typing from __init__ const init = cls.methods.find(m => m.name === '__init__'); const ctorSpec = (() => { if (!init) { return { overloadDecl: '', + declaration: ` static create${classTypeParamDecl}(...args: unknown[]): Promise<${classSelfType}>;\n`, paramsDecl: `...args: unknown[]`, callPrelude: ` const __args: unknown[] = [...args];\n`, hasKwargs: false, @@ -702,9 +950,9 @@ ${callPrelude}${guards} return getRuntimeBridge().callMethod(this.__handle, ' })(); const paramsDeclParts: string[] = []; - for (const p of positionalParams) { + positionalParams.forEach(p => { paramsDeclParts.push(renderPositionalParam(p)); - } + }); if (varArgsParam) { const vname = this.escapeIdentifier(varArgsParam.name); paramsDeclParts.push(needsVarArgsArray ? `${vname}?: unknown[]` : `...${vname}: unknown[]`); @@ -730,17 +978,23 @@ ${callPrelude}${guards} return getRuntimeBridge().callMethod(this.__handle, ' ); } rest.push(`kwargs: ${kwargsType}`); - overloads.push(` static create(${[...head, ...rest].join(', ')}): Promise<${cname}>;`); + overloads.push( + ` static create${classTypeParamDecl}(${[...head, ...rest].join(', ')}): Promise<${classSelfType}>;` + ); if (varArgsParam && needsVarArgsArray) { overloads.push( - ` static create(${[...head, `kwargs: ${kwargsType}`].join(', ')}): Promise<${cname}>;` + ` static create${classTypeParamDecl}(${[...head, `kwargs: ${kwargsType}`].join(', ')}): Promise<${classSelfType}>;` ); } } } const overloadDecl = overloads.length > 0 ? `${overloads.join('\n')}\n` : ''; - const guardLines: string[] = []; + const declaration = + overloads.length > 0 + ? overloadDecl + : ` static create${classTypeParamDecl}(${paramsDecl}): Promise<${classSelfType}>;\n`; + const guardLines: string[] = []; const callPreludeLines: string[] = []; if (needsKwargsParam) { callPreludeLines.push(` let __kwargs = kwargs;`); @@ -796,6 +1050,7 @@ ${callPrelude}${guards} return getRuntimeBridge().callMethod(this.__handle, ' } } const callPrelude = callPreludeLines.length > 0 ? `${callPreludeLines.join('\n')}\n` : ''; + if (needsKwargsParam && positionalOnlyNames.length > 0) { guardLines.push( ` const __positionalOnly = ${JSON.stringify(positionalOnlyNames)} as const;` @@ -829,26 +1084,55 @@ ${callPrelude}${guards} return getRuntimeBridge().callMethod(this.__handle, ' guardLines.push(` }`); } - return { overloadDecl, paramsDecl, callPrelude, hasKwargs: needsKwargsParam, guardLines }; + return { + overloadDecl, + declaration, + paramsDecl, + callPrelude, + hasKwargs: needsKwargsParam, + guardLines, + }; })(); const moduleId = moduleName ?? '__main__'; - const methodsSection = methodBodies ? `\n${methodBodies}\n` : '\n'; + const methodsSection = methodBodies.length > 0 ? `\n${methodBodies.join('\n')}\n` : '\n'; + const declarationMethodsSection = + methodDeclarations.length > 0 ? `\n${methodDeclarations.join('')}\n` : '\n'; const ctorGuards = ctorSpec.guardLines.length > 0 ? `${ctorSpec.guardLines.join('\n')}\n` : ''; - const ts = `${jsdoc}export class ${cname} { - private readonly __handle: string; - private constructor(handle: string) { this.__handle = handle; } - ${ctorSpec.overloadDecl} static async create(${ctorSpec.paramsDecl}): Promise<${cname}> { + const newClassExpr = `new ${cname}${classGenericContext.typeArguments}(handle)`; + const ts = `${jsdoc}export class ${cname}${classTypeParamDecl} { + private readonly __handle: string; + private constructor(handle: string) { this.__handle = handle; } +${ctorSpec.overloadDecl} static async create${classTypeParamDecl}(${ctorSpec.paramsDecl}): Promise<${classSelfType}> { ${ctorSpec.callPrelude}${ctorGuards} const handle = await getRuntimeBridge().instantiate('${moduleId}', '${cls.name}', __args${ ctorSpec.hasKwargs ? ', __kwargs' : '' }); - return new ${cname}(handle); + return ${newClassExpr}; } - static fromHandle(handle: string): ${cname} { return new ${cname}(handle); }${methodsSection} async disposeHandle(): Promise { await getRuntimeBridge().disposeInstance(this.__handle); } + static fromHandle${classTypeParamDecl}(handle: string): ${classSelfType} { return ${newClassExpr}; }${methodsSection} async disposeHandle(): Promise { await getRuntimeBridge().disposeInstance(this.__handle); } +} +`; + + const declaration = `${jsdoc}export class ${cname}${classTypeParamDecl} { + private readonly __handle: string; + private constructor(handle: string); +${ctorSpec.declaration} static fromHandle${classTypeParamDecl}(handle: string): ${classSelfType};${declarationMethodsSection} disposeHandle(): Promise; } `; - return this.wrap(ts, [cls.name]); + return this.wrap(ts, declaration, [cls.name]); + } + + generateTypeAlias(alias: PythonTypeAlias, moduleName?: string): GeneratedCode { + const genericContext = this.buildGenericRenderContext( + this.getTypeParameters(alias.typeParameters), + [alias.type], + moduleName + ); + const aliasName = this.escapeIdentifier(alias.name, { preserveCase: true }); + const body = this.typeToTsFromPython(alias.type, genericContext, 'value'); + const ts = `export type ${aliasName}${genericContext.declaration} = ${body}\n`; + return this.wrap(ts, ts, [alias.name]); } /** @@ -877,16 +1161,22 @@ ${ctorSpec.callPrelude}${ctorGuards} const handle = await getRuntimeBridge(). } generateModuleDefinition(module: PythonModule, annotatedJSDoc = false): GeneratedCode { - const functionCodes = [...module.functions] + const functionResults = [...module.functions] + .sort((a, b) => a.name.localeCompare(b.name)) + .map(f => this.generateFunctionWrapper(f, module.name, annotatedJSDoc)); + const classResults = [...module.classes] .sort((a, b) => a.name.localeCompare(b.name)) - .map(f => this.generateFunctionWrapper(f, module.name, annotatedJSDoc).typescript) - .join('\n'); - const classCodes = [...module.classes] + .map(c => this.generateClassWrapper(c, module.name, annotatedJSDoc)); + const typeAliasResults = [...(module.typeAliases ?? [])] .sort((a, b) => a.name.localeCompare(b.name)) - .map(c => this.generateClassWrapper(c, module.name, annotatedJSDoc).typescript) - .join('\n'); + .map(alias => this.generateTypeAlias(alias, module.name)); + + const functionCodes = functionResults.map(result => result.typescript).join('\n'); + const classCodes = classResults.map(result => result.typescript).join('\n'); + const typeAliasCodes = typeAliasResults.map(result => result.typescript).join('\n'); const header = `// Generated by tywrap\n// Module: ${module.name}\n// DO NOT EDIT MANUALLY\n\n`; + const declarationHeader = `// Generated by tywrap\n// Type Declarations\n// DO NOT EDIT MANUALLY\n\n`; const hasRuntimeClasses = module.classes.some(c => { const kind = c.kind ?? 'class'; return kind === 'class' && !c.decorators.includes('__typed_dict__'); @@ -894,8 +1184,13 @@ ${ctorSpec.callPrelude}${ctorGuards} const handle = await getRuntimeBridge(). const needsRuntime = module.functions.length > 0 || hasRuntimeClasses; const bridgeDecl = needsRuntime ? `import { getRuntimeBridge } from 'tywrap/runtime';\n\n` : ''; - const ts = `${header + bridgeDecl + functionCodes}\n${classCodes}`; - return this.wrap(ts, [module.name]); + const ts = `${`${header}${bridgeDecl}${functionCodes}\n${classCodes}\n${typeAliasCodes}`.trimEnd()}\n`; + const declaration = `${`${declarationHeader}${functionResults + .map(result => result.declaration) + .join('\n')}\n${classResults + .map(result => result.declaration) + .join('\n')}\n${typeAliasResults.map(result => result.declaration).join('\n')}`.trimEnd()}\n`; + return this.wrap(ts, declaration, [module.name]); } private generateJsDoc(doc?: string, paramAnnotations?: readonly string[]): string { @@ -915,10 +1210,10 @@ ${ctorSpec.callPrelude}${ctorGuards} const handle = await getRuntimeBridge(). return `/**\n${lines.join('\n')}\n */\n`; } - private wrap(typescript: string, _sources: string[]): GeneratedCode { + private wrap(typescript: string, declaration: string, _sources: string[]): GeneratedCode { return { typescript, - declaration: '', + declaration, sourceMap: undefined, metadata: { generatedAt: new Date(), diff --git a/src/core/mapper.ts b/src/core/mapper.ts index 17bf3378..c2fe7c59 100644 --- a/src/core/mapper.ts +++ b/src/core/mapper.ts @@ -9,7 +9,13 @@ import type { UnionType as PyUnionType, OptionalType as PyOptionalType, GenericType as PyGenericType, + CallableType as PyCallableType, TypeVarType as PyTypeVarType, + ParamSpecType as PyParamSpecType, + ParamSpecArgsType as PyParamSpecArgsType, + ParamSpecKwargsType as PyParamSpecKwargsType, + TypeVarTupleType as PyTypeVarTupleType, + UnpackType as PyUnpackType, FinalType as PyFinalType, ClassVarType as PyClassVarType, TypescriptType, @@ -20,6 +26,7 @@ import type { TSUnionType, TSFunctionType, TSGenericType, + TSCustomType, TSIndexSignature, TSLiteralType, TypePreset, @@ -61,6 +68,16 @@ export class TypeMapper { return this.mapPythonType(pythonType.base, context); case 'typevar': return this.mapTypeVarType(pythonType, context); + case 'paramspec': + return this.mapParamSpecType(pythonType); + case 'paramspec_args': + return this.mapParamSpecArgsType(pythonType); + case 'paramspec_kwargs': + return this.mapParamSpecKwargsType(); + case 'typevartuple': + return this.mapTypeVarTupleType(pythonType); + case 'unpack': + return this.mapUnpackType(pythonType, context); case 'final': return this.mapFinalType(pythonType, context); case 'classvar': @@ -177,11 +194,17 @@ export class TypeMapper { }; } - mapGenericType(type: PyGenericType, context: MappingContext = 'value'): TSGenericType { + mapGenericType(type: PyGenericType, context: MappingContext = 'value'): TypescriptType { + const normalized = this.normalizeCustomType({ name: type.name, module: type.module }); + const typeArgs = type.typeArgs.map(t => this.mapPythonType(t, context)); + const knownGenericType = this.mapKnownGenericType(normalized, typeArgs); + if (knownGenericType) { + return knownGenericType; + } return { kind: 'generic', - name: type.name, - typeArgs: type.typeArgs.map(t => this.mapPythonType(t, context)), + name: normalized.name, + typeArgs, }; } @@ -236,39 +259,9 @@ export class TypeMapper { }; } - // Async types - if (name === 'Awaitable' || fullName === 'typing.Awaitable') { - return { - kind: 'generic', - name: 'Promise', - typeArgs: [{ kind: 'primitive', name: 'unknown' }], - }; - } - if (name === 'Coroutine' || fullName === 'typing.Coroutine') { - return { - kind: 'generic', - name: 'Promise', - typeArgs: [{ kind: 'primitive', name: 'unknown' }], - }; - } - - // Collection types that should be generics - if (name === 'Sequence' || fullName === 'typing.Sequence') { - return { - kind: 'generic', - name: 'Array', - typeArgs: [{ kind: 'primitive', name: 'unknown' }], - }; - } - if (name === 'Mapping' || fullName === 'typing.Mapping') { - return { - kind: 'object', - properties: [], - indexSignature: { - keyType: { kind: 'primitive', name: 'string' }, - valueType: { kind: 'primitive', name: 'unknown' }, - }, - }; + const knownGenericType = this.mapKnownGenericType(this.normalizeCustomType(type)); + if (knownGenericType) { + return knownGenericType; } const presetType = this.mapPresetType(type); @@ -288,11 +281,7 @@ export class TypeMapper { return { kind: 'custom', name: normalized.name, module: normalized.module }; } - mapCallableType(type: { - kind: 'callable'; - parameters: PythonType[]; - returnType: PythonType; - }): TSFunctionType { + mapCallableType(type: PyCallableType): TSFunctionType { // Support Callable[[...], R] → (...args: unknown[]) => R const onlyEllipsis = type.parameters.length === 1 && @@ -302,21 +291,30 @@ export class TypeMapper { return { kind: 'function', isAsync: false, - parameters: onlyEllipsis + parameters: type.parameterSpec ? ([ { name: 'args', - type: { kind: 'array', elementType: { kind: 'primitive', name: 'unknown' } }, + type: this.mapParamSpecType(type.parameterSpec), optional: false, rest: true, }, ] as const satisfies TSFunctionType['parameters']) - : type.parameters.map((p, i) => ({ - name: `arg${i}`, - type: this.mapPythonType(p, 'value'), - optional: false, - rest: false, - })), + : onlyEllipsis + ? ([ + { + name: 'args', + type: { kind: 'array', elementType: { kind: 'primitive', name: 'unknown' } }, + optional: false, + rest: true, + }, + ] as const satisfies TSFunctionType['parameters']) + : type.parameters.map((p, i) => ({ + name: `arg${i}`, + type: this.mapPythonType(p, 'value'), + optional: false, + rest: false, + })), returnType: this.mapPythonType(type.returnType, 'return'), } satisfies TSFunctionType; } @@ -328,15 +326,49 @@ export class TypeMapper { return { kind: 'literal', value: type.value }; } - mapTypeVarType(_type: PyTypeVarType, _context: MappingContext = 'value'): TSPrimitiveType { - // Generated wrappers do not declare TypeScript generics that mirror Python - // TypeVar/ParamSpec scopes, so the sound fallback is unknown. + mapTypeVarType(type: PyTypeVarType, _context: MappingContext = 'value'): TSCustomType { + // TypeVar maps to a generic type parameter in TypeScript. + return { + kind: 'custom', + name: type.name, + module: 'typing', + }; + } + + mapParamSpecType(type: PyParamSpecType): TSCustomType { return { - kind: 'primitive', - name: 'unknown', + kind: 'custom', + name: type.name, + module: 'typing', + }; + } + + mapParamSpecArgsType(_type: PyParamSpecArgsType): TSArrayType { + return { + kind: 'array', + elementType: { kind: 'primitive', name: 'unknown' }, + }; + } + + mapParamSpecKwargsType(): TSObjectType { + return { + kind: 'object', + properties: [], + indexSignature: { + keyType: { kind: 'primitive', name: 'string' }, + valueType: { kind: 'primitive', name: 'unknown' }, + }, }; } + mapTypeVarTupleType(_type: PyTypeVarTupleType): TSPrimitiveType { + return { kind: 'primitive', name: 'unknown' }; + } + + mapUnpackType(_type: PyUnpackType, _context: MappingContext = 'value'): TSPrimitiveType { + return { kind: 'primitive', name: 'unknown' }; + } + mapFinalType(type: PyFinalType, context: MappingContext = 'value'): TypescriptType { // Final[T] maps to T in TypeScript (no direct Final equivalent) // The Final qualifier is more of a static analysis hint @@ -349,6 +381,58 @@ export class TypeMapper { return this.mapPythonType(type.type, context); } + private mapKnownGenericType( + type: { name: string; module?: string }, + typeArgs: TypescriptType[] = [] + ): TypescriptType | undefined { + const unknownType: TSPrimitiveType = { kind: 'primitive', name: 'unknown' }; + const isKnownTypingModule = + type.module === undefined || + type.module === 'typing' || + type.module === 'typing_extensions' || + type.module === 'collections.abc'; + + if (!isKnownTypingModule) { + return undefined; + } + + if (type.name === 'Awaitable') { + return { + kind: 'generic', + name: 'Promise', + typeArgs: [typeArgs[0] ?? unknownType], + } satisfies TSGenericType; + } + + if (type.name === 'Coroutine') { + return { + kind: 'generic', + name: 'Promise', + typeArgs: [typeArgs[typeArgs.length - 1] ?? unknownType], + } satisfies TSGenericType; + } + + if (type.name === 'Sequence') { + return { + kind: 'array', + elementType: typeArgs[0] ?? unknownType, + } satisfies TSArrayType; + } + + if (type.name === 'Mapping') { + return { + kind: 'object', + properties: [], + indexSignature: { + keyType: this.asIndexKeyType(typeArgs[0] ?? unknownType), + valueType: typeArgs[1] ?? unknownType, + }, + } satisfies TSObjectType; + } + + return undefined; + } + private asIndexKeyType(key: TypescriptType): TSPrimitiveType { if (key.kind === 'primitive' && (key.name === 'string' || key.name === 'number')) { return key; diff --git a/src/index.ts b/src/index.ts index 47aaffe5..79933a32 100644 --- a/src/index.ts +++ b/src/index.ts @@ -87,13 +87,27 @@ export type { PythonModule, PythonFunction, PythonClass, + PythonTypeAlias, PythonType, + PythonGenericParameter, + PythonGenericParameterKind, PrimitiveType, CollectionType, UnionType, OptionalType, CustomType, GenericType, + CallableType, + LiteralType, + AnnotatedType, + TypeVarType, + ParamSpecType, + ParamSpecArgsType, + ParamSpecKwargsType, + TypeVarTupleType, + UnpackType, + FinalType, + ClassVarType, Parameter, Property, PythonImport, diff --git a/src/types/index.ts b/src/types/index.ts index 1a8b9fe3..1068bb9f 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -8,6 +8,7 @@ export interface PythonModule { version?: string; functions: PythonFunction[]; classes: PythonClass[]; + typeAliases?: PythonTypeAlias[]; imports: PythonImport[]; exports: string[]; } @@ -19,6 +20,7 @@ export interface PythonFunction { decorators: string[]; isAsync: boolean; isGenerator: boolean; + typeParameters?: PythonGenericParameter[]; returnType: PythonType; parameters: Parameter[]; } @@ -31,6 +33,23 @@ export interface PythonClass { docstring?: string; decorators: string[]; kind?: 'class' | 'protocol' | 'typed_dict' | 'namedtuple' | 'dataclass' | 'pydantic'; + typeParameters?: PythonGenericParameter[]; +} + +export interface PythonTypeAlias { + name: string; + type: PythonType; + typeParameters?: PythonGenericParameter[]; +} + +export type PythonGenericParameterKind = 'typevar' | 'paramspec' | 'typevartuple'; + +export interface PythonGenericParameter { + name: string; + kind: PythonGenericParameterKind; + bound?: PythonType; + constraints?: PythonType[]; + variance?: 'covariant' | 'contravariant' | 'invariant'; } export interface PythonImport { @@ -55,6 +74,7 @@ export interface Property { name: string; type: PythonType; readonly: boolean; + optional?: boolean; setter?: boolean; getter?: boolean; } @@ -78,6 +98,11 @@ export type PythonType = | AnnotatedType | CustomType | TypeVarType + | ParamSpecType + | ParamSpecArgsType + | ParamSpecKwargsType + | TypeVarTupleType + | UnpackType | FinalType | ClassVarType; @@ -105,6 +130,7 @@ export interface OptionalType { export interface GenericType { kind: 'generic'; name: string; + module?: string; typeArgs: PythonType[]; } @@ -117,6 +143,7 @@ export interface CustomType { export interface CallableType { kind: 'callable'; parameters: PythonType[]; + parameterSpec?: ParamSpecType; returnType: PythonType; } @@ -139,6 +166,31 @@ export interface TypeVarType { variance?: 'covariant' | 'contravariant' | 'invariant'; } +export interface ParamSpecType { + kind: 'paramspec'; + name: string; +} + +export interface ParamSpecArgsType { + kind: 'paramspec_args'; + name: string; +} + +export interface ParamSpecKwargsType { + kind: 'paramspec_kwargs'; + name: string; +} + +export interface TypeVarTupleType { + kind: 'typevartuple'; + name: string; +} + +export interface UnpackType { + kind: 'unpack'; + type: PythonType; +} + export interface FinalType { kind: 'final'; type: PythonType; diff --git a/src/tywrap.ts b/src/tywrap.ts index 693e5efb..331fa172 100644 --- a/src/tywrap.ts +++ b/src/tywrap.ts @@ -8,9 +8,11 @@ import { parseAnnotationToPythonType } from './core/annotation-parser.js'; import { createConfig } from './config/index.js'; import type { TywrapOptions, + PythonGenericParameter, PythonFunction, PythonModule as TSPythonModule, PythonClass, + PythonTypeAlias, Parameter, PythonType, } from './types/index.js'; @@ -20,6 +22,8 @@ import { globalParallelProcessor } from './utils/parallel-processor.js'; import { resolvePythonExecutable } from './utils/python.js'; import { computeIrCacheFilename } from './utils/ir-cache.js'; +const TYWRAP_IR_VERSION = '0.2.0'; + // Collect unknown typing constructs encountered during annotation parsing (per-generate run) let unknownTypeNamesCollector: Map = new Map(); function recordUnknown(name: string): void { @@ -241,6 +245,10 @@ export async function generate( } else { moduleModel.classes = moduleModel.classes.filter(c => !shouldExclude(c.name, true)); } + + moduleModel.typeAliases = (moduleModel.typeAliases ?? []).filter( + alias => !shouldExclude(alias.name, false) + ); } // Generate module code @@ -256,7 +264,7 @@ export async function generate( if (resolvedOptions.output?.declaration) { filesToEmit.push({ path: pathUtils.join(outputDir, `${baseName}.generated.d.ts`), - content: renderDts(gen.typescript), + content: gen.declaration, }); } @@ -346,7 +354,7 @@ async function fetchPythonIr( try { const result = await processUtils.exec( pythonPath, - ['-m', 'tywrap_ir', '--module', moduleName, '--no-pretty'], + ['-m', 'tywrap_ir', '--module', moduleName, '--ir-version', TYWRAP_IR_VERSION, '--no-pretty'], { timeoutMs: options.timeoutMs, env } ); if (result.code === 0) { @@ -387,7 +395,7 @@ async function fetchPythonIr( const fallback = await processUtils.exec( pythonPath, - [localMain, '--module', moduleName, '--no-pretty'], + [localMain, '--module', moduleName, '--ir-version', TYWRAP_IR_VERSION, '--no-pretty'], { timeoutMs: options.timeoutMs, env } ); if (fallback.code !== 0) { @@ -443,15 +451,43 @@ function transformIrToTsModel(ir: unknown): TSPythonModule { typeof ir === 'object' && ir !== null ? (ir as Record) : {}; const functions = (obj.functions as unknown[]) ?? []; const classes = (obj.classes as unknown[]) ?? []; + const aliases = (obj.type_aliases as unknown[]) ?? []; const moduleTypeVarNames = collectModuleTypeVarNames(obj); - const parseType = (annotation: unknown): PythonType => + const parseType = ( + annotation: unknown, + typeParameters: readonly PythonGenericParameter[] = [] + ): PythonType => parseAnnotationToPythonType(annotation, { onUnknownTypeName: recordUnknown, knownTypeVarNames: moduleTypeVarNames, + typeParameters, }); - const mapParam = (p: Record): Parameter => ({ + const mapTypeParameters = (value: Record): PythonGenericParameter[] => + Array.isArray(value.type_params) + ? (value.type_params as unknown[]).map(v => { + const param = (v ?? {}) as Record; + return { + name: String(param.name ?? ''), + kind: String(param.kind ?? 'typevar') as PythonGenericParameter['kind'], + bound: param.bound ? parseType(param.bound) : undefined, + constraints: Array.isArray(param.constraints) + ? (param.constraints as unknown[]).map(item => parseType(item)) + : undefined, + variance: + param.variance === 'covariant' || + param.variance === 'contravariant' || + param.variance === 'invariant' + ? param.variance + : undefined, + } satisfies PythonGenericParameter; + }) + : []; + const mapParam = ( + p: Record, + typeParameters: readonly PythonGenericParameter[] = [] + ): Parameter => ({ name: String(p.name ?? ''), - type: parseType(p.annotation), + type: parseType(p.annotation, typeParameters), optional: Boolean(p.default), varArgs: p.kind === 'VAR_POSITIONAL', kwArgs: p.kind === 'VAR_KEYWORD', @@ -459,60 +495,86 @@ function transformIrToTsModel(ir: unknown): TSPythonModule { keywordOnly: p.kind === 'KEYWORD_ONLY', }); - const mapFunc = (f: Record): PythonFunction => ({ - name: String(f.name ?? ''), - signature: { - parameters: Array.isArray(f.parameters) - ? (f.parameters as unknown[]).map(v => mapParam((v ?? {}) as Record)) - : [], - returnType: parseType(f.returns), + const mapFunc = ( + f: Record, + inheritedTypeParameters: readonly PythonGenericParameter[] = [] + ): PythonFunction => { + const localTypeParameters = mapTypeParameters(f); + const annotationTypeParameters = [...inheritedTypeParameters, ...localTypeParameters]; + return { + name: String(f.name ?? ''), + signature: { + parameters: Array.isArray(f.parameters) + ? (f.parameters as unknown[]).map(v => + mapParam((v ?? {}) as Record, annotationTypeParameters) + ) + : [], + returnType: parseType(f.returns, annotationTypeParameters), + isAsync: Boolean(f.is_async), + isGenerator: Boolean(f.is_generator), + }, + docstring: (f.docstring as string | undefined) ?? undefined, + decorators: [], isAsync: Boolean(f.is_async), isGenerator: Boolean(f.is_generator), - }, - docstring: (f.docstring as string | undefined) ?? undefined, - decorators: [], - isAsync: Boolean(f.is_async), - isGenerator: Boolean(f.is_generator), - returnType: parseType(f.returns), - parameters: Array.isArray(f.parameters) - ? (f.parameters as unknown[]).map(v => mapParam((v ?? {}) as Record)) - : [], - }); + typeParameters: [...localTypeParameters], + returnType: parseType(f.returns, annotationTypeParameters), + parameters: Array.isArray(f.parameters) + ? (f.parameters as unknown[]).map(v => + mapParam((v ?? {}) as Record, annotationTypeParameters) + ) + : [], + }; + }; - const mapClass = (c: Record): PythonClass => ({ - name: String(c.name ?? ''), - bases: Array.isArray(c.bases) ? (c.bases as string[]) : [], - methods: Array.isArray(c.methods) - ? (c.methods as unknown[]).map(v => mapFunc((v ?? {}) as Record)) - : [], - properties: Array.isArray(c.fields) - ? ((c.fields as unknown[]).map(v => { - const p = (v ?? {}) as Record; - const optional = Boolean(p.default); - return { - name: String(p.name ?? ''), - type: parseType(p.annotation), - readonly: false, - setter: false, - getter: true, - optional, - } as unknown as never; - }) as unknown as PythonClass['properties']) - : [], - docstring: (c.docstring as string | undefined) ?? undefined, - decorators: (c.typed_dict as boolean) ? ['__typed_dict__'] : [], - kind: (c.typed_dict as boolean) - ? 'typed_dict' - : (c.is_protocol as boolean) - ? 'protocol' - : (c.is_namedtuple as boolean) - ? 'namedtuple' - : (c.is_dataclass as boolean) - ? 'dataclass' - : (c.is_pydantic as boolean) - ? 'pydantic' - : 'class', - }); + const mapClass = (c: Record): PythonClass => { + const classTypeParameters = mapTypeParameters(c); + return { + name: String(c.name ?? ''), + bases: Array.isArray(c.bases) ? (c.bases as string[]) : [], + methods: Array.isArray(c.methods) + ? (c.methods as unknown[]).map(v => + mapFunc((v ?? {}) as Record, classTypeParameters) + ) + : [], + properties: Array.isArray(c.fields) + ? (c.fields as unknown[]).map(v => { + const p = (v ?? {}) as Record; + return { + name: String(p.name ?? ''), + type: parseType(p.annotation, classTypeParameters), + readonly: false, + setter: false, + getter: true, + optional: Boolean(p.default), + }; + }) + : [], + docstring: (c.docstring as string | undefined) ?? undefined, + decorators: (c.typed_dict as boolean) ? ['__typed_dict__'] : [], + kind: (c.typed_dict as boolean) + ? 'typed_dict' + : (c.is_protocol as boolean) + ? 'protocol' + : (c.is_namedtuple as boolean) + ? 'namedtuple' + : (c.is_dataclass as boolean) + ? 'dataclass' + : (c.is_pydantic as boolean) + ? 'pydantic' + : 'class', + typeParameters: classTypeParameters, + }; + }; + + const mapTypeAlias = (alias: Record): PythonTypeAlias => { + const typeParameters = mapTypeParameters(alias); + return { + name: String(alias.name ?? ''), + type: parseType(alias.definition, typeParameters), + typeParameters, + }; + }; const moduleModel: TSPythonModule = { name: (obj.module as string) ?? 'module', @@ -523,6 +585,7 @@ function transformIrToTsModel(ir: unknown): TSPythonModule { : undefined, functions: functions.map(v => mapFunc((v ?? {}) as Record)), classes: classes.map(v => mapClass((v ?? {}) as Record)), + typeAliases: aliases.map(v => mapTypeAlias((v ?? {}) as Record)), imports: [], exports: [], }; @@ -546,6 +609,7 @@ async function computeCacheKey( const keyObject = { module: moduleName, moduleVersion: moduleConfig?.version ?? null, + irVersion: TYWRAP_IR_VERSION, pythonImportPath: options.pythonImportPath ?? [], runtime: { pythonPath: runtimePython, @@ -565,44 +629,6 @@ async function computeCacheKey( return await computeIrCacheFilename(keyObject); } -/** - * Very lightweight .d.ts emitter derived from generated TS wrappers - * This is intentionally minimal and stable for our wrappers shape. - */ -function renderDts(generatedTs: string): string { - const header = `// Generated by tywrap\n// Type Declarations\n// DO NOT EDIT MANUALLY\n\n`; - const lines: string[] = [header]; - // Extract function exports - const funcRegex = /export\s+async\s+function\s+(\w+)\s*\(([^)]*)\)\s*:\s*Promise<([^>]+)>/g; - let m: RegExpExecArray | null; - while ((m = funcRegex.exec(generatedTs)) !== null) { - const name = m[1]; - const params = m[2]; - const ret = m[3]; - lines.push(`export function ${name}(${params}): Promise<${ret}>;`); - } - // Extract class exports and methods - const classRegex = /export\s+class\s+(\w+)\s*\{([\s\S]*?)\n\}/g; - while ((m = classRegex.exec(generatedTs)) !== null) { - const className = String(m[1] ?? ''); - const body = String(m[2] ?? ''); - // Constructor is always variadic unknown[] in current generator - const methods: string[] = []; - const methodRegex = /\n\s+async\s+(\w+)\s*\(([^)]*)\)\s*:\s*Promise<([^>]+)>/g; - let mm: RegExpExecArray | null; - while ((mm = methodRegex.exec(body)) !== null) { - methods.push(` ${mm[1]}(${mm[2]}): Promise<${mm[3]}>;`); - } - lines.push(`export class ${className} {`); - lines.push(` constructor(...args: unknown[]);`); - if (methods.length > 0) { - lines.push(methods.join('\n')); - } - lines.push('}'); - } - return `${lines.join('\n')}\n`; -} - /** * Minimal source map placeholder (stable, empty mappings) */ diff --git a/src/utils/ir-cache.ts b/src/utils/ir-cache.ts index e4323798..76b3f1d3 100644 --- a/src/utils/ir-cache.ts +++ b/src/utils/ir-cache.ts @@ -3,6 +3,7 @@ import { hashUtils } from './runtime.js'; export interface IrCacheKeyObject { module: string; moduleVersion: string | null; + irVersion: string; pythonImportPath?: readonly string[]; runtime: { pythonPath: string; diff --git a/src/utils/runtime.ts b/src/utils/runtime.ts index 9e0c9ca9..cd60d013 100644 --- a/src/utils/runtime.ts +++ b/src/utils/runtime.ts @@ -715,13 +715,13 @@ export const hashUtils = { const bytes = Array.from(new Uint8Array(digest)); return bytes.map(b => b.toString(16).padStart(2, '0')).join(''); } - // Fallback to DJB2 (non-crypto) for unknown runtimes - let hash = 5381; - for (let i = 0; i < text.length; i++) { - hash = (hash << 5) + hash + text.charCodeAt(i); - hash |= 0; - } - // Match sha256 hex shape (64 chars) so callers can rely on fixed-length output. - return Math.abs(hash).toString(16).padStart(64, '0'); - }, - }; + // Fallback to DJB2 (non-crypto) for unknown runtimes + let hash = 5381; + for (let i = 0; i < text.length; i++) { + hash = (hash << 5) + hash + text.charCodeAt(i); + hash |= 0; + } + // Match sha256 hex shape (64 chars) so callers can rely on fixed-length output. + return Math.abs(hash).toString(16).padStart(64, '0'); + }, +}; diff --git a/test/annotation-parser.test.ts b/test/annotation-parser.test.ts index c3bf7d0f..bbc024bf 100644 --- a/test/annotation-parser.test.ts +++ b/test/annotation-parser.test.ts @@ -48,74 +48,23 @@ describe('annotation parser', () => { expect((t as any).parameters[0].name).toBe('...'); }); - it('parses Sequence[T] as a list-like collection', () => { - const t = parseAnnotationToPythonType('Sequence[bool]'); - expect(t.kind).toBe('collection'); - expect((t as any).name).toBe('list'); - expect((t as any).itemTypes[0].name).toBe('bool'); + it('parses known ParamSpec argument packs', () => { + const t = parseAnnotationToPythonType('P.args', { + typeParameters: [{ name: 'P', kind: 'paramspec' }], + }); + expect(t.kind).toBe('paramspec_args'); + expect((t as any).name).toBe('P'); }); - it('parses collections.abc aliases used by advanced fixtures', () => { - const sequence = parseAnnotationToPythonType('collections.abc.Sequence[str]'); - expect(sequence.kind).toBe('collection'); - expect((sequence as any).name).toBe('list'); - expect((sequence as any).itemTypes[0].name).toBe('str'); - - const mapping = parseAnnotationToPythonType('collections.abc.Mapping[str, int]'); - expect(mapping.kind).toBe('collection'); - expect((mapping as any).name).toBe('dict'); - expect((mapping as any).itemTypes).toHaveLength(2); - - const iterator = parseAnnotationToPythonType('collections.abc.AsyncIterator[bytes]'); - expect(iterator.kind).toBe('generic'); - expect((iterator as any).name).toBe('AsyncIterator'); - expect((iterator as any).typeArgs[0].name).toBe('bytes'); + it('parses explicit ParamSpec argument packs without declared type parameters', () => { + const t = parseAnnotationToPythonType('~P.args'); + expect(t.kind).toBe('paramspec_args'); + expect((t as any).name).toBe('P'); }); - it('parses TypeVar-like helpers into safe wrapper-friendly shapes', () => { - const typeVar = parseAnnotationToPythonType("TypeVar('T')"); - expect(typeVar.kind).toBe('typevar'); - expect((typeVar as any).name).toBe('T'); - - const paramSpec = parseAnnotationToPythonType("ParamSpec('P')"); - expect(paramSpec.kind).toBe('custom'); - expect((paramSpec as any).name).toBe('P'); - - const inferredTypeVar = parseAnnotationToPythonType('~T'); - expect(inferredTypeVar.kind).toBe('typevar'); - expect((inferredTypeVar as any).name).toBe('T'); - - const callable = parseAnnotationToPythonType('typing.Callable[~P, ~T]'); - expect(callable.kind).toBe('callable'); - expect((callable as any).parameters).toHaveLength(1); - expect((callable as any).parameters[0].name).toBe('...'); - expect((callable as any).returnType.kind).toBe('typevar'); - - const args = parseAnnotationToPythonType('P.args'); - expect(args.kind).toBe('collection'); - expect((args as any).name).toBe('list'); - - const kwargs = parseAnnotationToPythonType('P.kwargs'); - expect(kwargs.kind).toBe('collection'); - expect((kwargs as any).name).toBe('dict'); - - const unpack = parseAnnotationToPythonType('Unpack[Ts]'); - expect(unpack.kind).toBe('custom'); - expect((unpack as any).name).toBe('Any'); - expect((unpack as any).module).toBe('typing'); - }); - - it('treats bare module-scoped typevar names as safe placeholders when provided', () => { - const bareTypeVar = parseAnnotationToPythonType('T', { - knownTypeVarNames: ['T', 'P'], - }); - expect(bareTypeVar.kind).toBe('typevar'); - expect((bareTypeVar as any).name).toBe('T'); - - const bareParamSpec = parseAnnotationToPythonType('P', { - knownTypeVarNames: ['T', 'P'], - }); - expect(bareParamSpec.kind).toBe('typevar'); - expect((bareParamSpec as any).name).toBe('P'); + it('does not treat arbitrary dotted names as ParamSpec packs', () => { + const t = parseAnnotationToPythonType('Request.args'); + expect(t.kind).toBe('custom'); + expect((t as any).name).toBe('Request.args'); }); }); diff --git a/test/generator.test.ts b/test/generator.test.ts index 620d958e..9b7dde91 100644 --- a/test/generator.test.ts +++ b/test/generator.test.ts @@ -629,6 +629,513 @@ describe('CodeGenerator', () => { expect(code.typescript).toContain('name?: string;'); }); + it('emits generic function type parameters on overloads and declarations', () => { + const code = gen.generateFunctionWrapper( + { + name: 'coalesce', + signature: { + parameters: [ + { + name: 'x', + type: { kind: 'typevar', name: 'T' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'y', + type: { kind: 'typevar', name: 'T' }, + optional: true, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { kind: 'typevar', name: 'T' }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + typeParameters: [{ name: 'T', kind: 'typevar', variance: 'invariant' }], + returnType: { kind: 'typevar', name: 'T' }, + parameters: [ + { + name: 'x', + type: { kind: 'typevar', name: 'T' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'y', + type: { kind: 'typevar', name: 'T' }, + optional: true, + varArgs: false, + kwArgs: false, + }, + ], + } as any, + 'generic_module' + ); + + expect(code.typescript).toContain('export function coalesce(x: T): Promise;'); + expect(code.typescript).toContain('export async function coalesce(x: T, y?: T): Promise'); + expect(code.declaration).toContain('export function coalesce(x: T): Promise;'); + }); + + it('generates protocol aliases without init helpers and preserves method generics', () => { + const code = gen.generateClassWrapper( + { + name: 'Mapper', + bases: ['Protocol'], + methods: [ + { + name: '__init__', + signature: { + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { kind: 'primitive', name: 'None' }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + returnType: { kind: 'primitive', name: 'None' }, + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + }, + { + name: 'map', + signature: { + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'x', + type: { kind: 'typevar', name: 'U' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { kind: 'typevar', name: 'U' }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + typeParameters: [{ name: 'U', kind: 'typevar', variance: 'invariant' }], + returnType: { kind: 'typevar', name: 'U' }, + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'x', + type: { kind: 'typevar', name: 'U' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + }, + ], + properties: [], + docstring: undefined, + decorators: [], + kind: 'protocol', + } as any, + 'protocol_module' + ); + + expect(code.typescript).toContain('export type Mapper ='); + expect(code.typescript).toContain('map: (x: U) => U;'); + expect(code.typescript).not.toContain('__init__'); + expect(code.typescript).not.toContain('NoInitOrReplaceInit'); + }); + + it('emits generic classes and type aliases with safe fallbacks', () => { + const typeP = { name: 'P', kind: 'paramspec' } as const; + const typeT = { name: 'T', kind: 'typevar', variance: 'invariant' } as const; + const code = gen.generateModuleDefinition({ + name: 'generic_module', + functions: [ + { + name: 'forward', + signature: { + parameters: [ + { + name: 'container', + type: { + kind: 'generic', + name: 'Container', + module: 'generic_module', + typeArgs: [{ kind: 'typevar', name: 'T' }], + }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { + kind: 'generic', + name: 'Container', + module: 'generic_module', + typeArgs: [{ kind: 'typevar', name: 'T' }], + }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + typeParameters: [typeT], + returnType: { + kind: 'generic', + name: 'Container', + module: 'generic_module', + typeArgs: [{ kind: 'typevar', name: 'T' }], + }, + parameters: [ + { + name: 'container', + type: { + kind: 'generic', + name: 'Container', + module: 'generic_module', + typeArgs: [{ kind: 'typevar', name: 'T' }], + }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + }, + { + name: 'accept_transform', + signature: { + parameters: [ + { + name: 'transform', + type: { + kind: 'generic', + name: 'Transform', + module: 'generic_module', + typeArgs: [ + { kind: 'paramspec', name: 'P' }, + { kind: 'typevar', name: 'T' }, + ], + }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { + kind: 'generic', + name: 'Transform', + module: 'generic_module', + typeArgs: [ + { kind: 'paramspec', name: 'P' }, + { kind: 'typevar', name: 'T' }, + ], + }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + typeParameters: [typeP, typeT], + returnType: { + kind: 'generic', + name: 'Transform', + module: 'generic_module', + typeArgs: [ + { kind: 'paramspec', name: 'P' }, + { kind: 'typevar', name: 'T' }, + ], + }, + parameters: [ + { + name: 'transform', + type: { + kind: 'generic', + name: 'Transform', + module: 'generic_module', + typeArgs: [ + { kind: 'paramspec', name: 'P' }, + { kind: 'typevar', name: 'T' }, + ], + }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + }, + ], + classes: [ + { + name: 'Container', + bases: ['Generic'], + methods: [ + { + name: '__init__', + signature: { + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'value', + type: { kind: 'typevar', name: 'T' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { kind: 'primitive', name: 'None' }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + returnType: { kind: 'primitive', name: 'None' }, + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'value', + type: { kind: 'typevar', name: 'T' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + }, + { + name: 'clone', + signature: { + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { + kind: 'generic', + name: 'Container', + module: 'generic_module', + typeArgs: [{ kind: 'typevar', name: 'T' }], + }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + returnType: { + kind: 'generic', + name: 'Container', + module: 'generic_module', + typeArgs: [{ kind: 'typevar', name: 'T' }], + }, + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + }, + { + name: 'id', + signature: { + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'x', + type: { kind: 'typevar', name: 'U' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + returnType: { + kind: 'collection', + name: 'tuple', + itemTypes: [ + { kind: 'typevar', name: 'T' }, + { kind: 'typevar', name: 'U' }, + ], + }, + isAsync: false, + isGenerator: false, + }, + docstring: undefined, + decorators: [], + isAsync: false, + isGenerator: false, + typeParameters: [{ name: 'U', kind: 'typevar', variance: 'invariant' }], + returnType: { + kind: 'collection', + name: 'tuple', + itemTypes: [ + { kind: 'typevar', name: 'T' }, + { kind: 'typevar', name: 'U' }, + ], + }, + parameters: [ + { + name: 'self', + type: { kind: 'primitive', name: 'None' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + { + name: 'x', + type: { kind: 'typevar', name: 'U' }, + optional: false, + varArgs: false, + kwArgs: false, + }, + ], + }, + ], + properties: [], + docstring: undefined, + decorators: [], + typeParameters: [typeT], + }, + ], + typeAliases: [ + { + name: 'Pair', + type: { + kind: 'collection', + name: 'tuple', + itemTypes: [ + { kind: 'typevar', name: 'T' }, + { kind: 'typevar', name: 'T' }, + ], + }, + typeParameters: [typeT], + }, + { + name: 'Transform', + type: { + kind: 'callable', + parameters: [], + parameterSpec: { kind: 'paramspec', name: 'P' }, + returnType: { kind: 'typevar', name: 'T' }, + }, + typeParameters: [ + { name: 'P', kind: 'paramspec' }, + { name: 'T', kind: 'typevar', variance: 'invariant' }, + ], + }, + { + name: 'Variadic', + type: { + kind: 'collection', + name: 'tuple', + itemTypes: [ + { + kind: 'unpack', + type: { kind: 'typevartuple', name: 'Ts' }, + }, + ], + }, + typeParameters: [{ name: 'Ts', kind: 'typevartuple' }], + }, + ], + imports: [], + exports: [], + } as any); + + expect(code.typescript).toContain('export async function forward(container: Container)'); + expect(code.typescript).toContain( + 'export async function acceptTransform

(transform: Transform): Promise>' + ); + expect(code.typescript).toContain('Promise>'); + expect(code.typescript).toContain('export class Container'); + expect(code.typescript).toContain('static async create(value: T): Promise>'); + expect(code.typescript).toContain('static fromHandle(handle: string): Container'); + expect(code.typescript).toContain('async clone(): Promise>'); + expect(code.typescript).toContain('async id(x: U): Promise<[T, U]>'); + expect(code.typescript).toContain('export type Pair = [T, T]'); + expect(code.typescript).toContain( + 'export type Transform

= (...args: P) => T' + ); + expect(code.typescript).toContain('export type Variadic = [unknown]'); + expect(code.typescript).not.toContain('~T'); + expect(code.typescript).not.toContain('~P'); + expect(code.typescript).not.toContain('Unpack['); + expect(code.typescript).not.toMatch(/\bTs\b/); + expect(code.declaration).toContain('export type Pair = [T, T]'); + expect(code.declaration).toContain( + 'export function acceptTransform

(transform: Transform): Promise>;' + ); + expect(code.declaration).toContain('id(x: U): Promise<[T, U]>;'); + expect(code.declaration).not.toContain('getRuntimeBridge'); + }); + it('generates class wrapper', () => { const code = gen.generateClassWrapper( { diff --git a/test/integration.test.ts b/test/integration.test.ts index 4c68fd99..039a87a3 100644 --- a/test/integration.test.ts +++ b/test/integration.test.ts @@ -18,6 +18,31 @@ const nodeBridgeTimeoutMs = isCi ? 60000 : 30000; const nodeBridgeTestTimeoutMs = isCi ? 60000 : 30000; const defaultPythonPath = getDefaultPythonPath(); +async function supportsVariadicTypingFeatures(): Promise { + const result = await processUtils.exec(defaultPythonPath, [ + '-c', + ` +try: + from typing import ParamSpec, TypeVarTuple, Unpack +except ImportError: + try: + from typing_extensions import ParamSpec, TypeVarTuple, Unpack + except ImportError: + raise SystemExit(1) +raise SystemExit(0) +`, + ]); + return result.code === 0; +} + +async function supportsPep695Syntax(): Promise { + const result = await processUtils.exec(defaultPythonPath, [ + '-c', + 'import sys; raise SystemExit(0 if sys.version_info >= (3, 12) else 1)', + ]); + return result.code === 0; +} + describe('IR-only integration', () => { it('tywrap_ir emits JSON IR for math', async () => { const result = await processUtils.exec(defaultPythonPath, [ @@ -177,6 +202,292 @@ describe('IR-only integration', () => { await rm(tempDir, { recursive: true, force: true }); } }, 30_000); + + it('generates safe generic wrappers and declaration files that typecheck', async () => { + if (!(await supportsVariadicTypingFeatures())) { + return; + } + + const tempDir = await mkdtemp(join(tmpdir(), 'tywrap-generics-')); + try { + const importDir = join(tempDir, 'py'); + await mkdir(importDir, { recursive: true }); + await writeFile( + join(importDir, 'generic_module.py'), + `from __future__ import annotations + +from typing import Callable, Generic, TypeVar +try: + from typing import ParamSpec, TypeVarTuple, Unpack +except ImportError: + from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +T = TypeVar("T") +U = TypeVar("U") +P = ParamSpec("P") +Ts = TypeVarTuple("Ts") + +Pair = tuple[T, T] +Transform = Callable[P, T] +Variadic = tuple[Unpack[Ts]] + +def identity(x: T, fallback: T | None = None) -> T: + return x if fallback is None else fallback + +def forward(container: Container[T]) -> Container[T]: + return container + +def accept_transform(transform: Transform[P, T]) -> Transform[P, T]: + return transform + +def passthrough(*args: P.args, **kwargs: P.kwargs) -> None: + return None + +class Container(Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + + def get(self) -> T: + return self.value + + def clone(self) -> Container[T]: + return Container(self.value) + + def id(self, x: U) -> tuple[T, U]: + return (self.value, x) +`, + 'utf-8' + ); + + const outDir = join(tempDir, 'generated'); + const res = await generate({ + pythonModules: { generic_module: { runtime: 'node', typeHints: 'strict' } }, + pythonImportPath: [importDir], + output: { dir: outDir, format: 'esm', declaration: true, sourceMap: false }, + runtime: { node: { pythonPath: defaultPythonPath } }, + performance: { caching: false, batching: false, compression: 'none' }, + development: { hotReload: false, sourceMap: false, validation: 'none' }, + } as any); + + expect(res.written.some(p => p.endsWith('generic_module.generated.ts'))).toBe(true); + expect(res.written.some(p => p.endsWith('generic_module.generated.d.ts'))).toBe(true); + + const typescript = await fsUtils.readFile(join(outDir, 'generic_module.generated.ts')); + const declaration = await fsUtils.readFile(join(outDir, 'generic_module.generated.d.ts')); + + expect(typescript).toContain('export function identity(x: T): Promise;'); + expect(typescript).toContain( + 'export async function forward(container: Container): Promise>' + ); + expect(typescript).toContain( + 'export async function acceptTransform

(' + ); + expect(typescript).not.toContain('Transform'); + expect(typescript).toContain('export async function passthrough(args?: unknown[]'); + expect(typescript).toContain('kwargs?: Record'); + expect(typescript).toContain('export class Container'); + expect(typescript).toContain('static async create(value: T): Promise>'); + expect(typescript).toContain('static fromHandle(handle: string): Container'); + expect(typescript).toContain('async id(x: U): Promise<[T, U]>'); + expect(typescript).toContain('export type Pair = [T, T]'); + expect(typescript).toContain( + 'export type Transform

= (...args: P) => T' + ); + expect(typescript).toContain('export type Variadic = [unknown]'); + expect(typescript).not.toContain('~T'); + expect(typescript).not.toContain('~P'); + expect(typescript).not.toContain('Unpack['); + expect(declaration).toContain('export type Pair = [T, T]'); + expect(declaration).toContain('export class Container'); + expect(declaration).toContain('export function acceptTransform

('); + expect(declaration).toContain('id(x: U): Promise<[T, U]>;'); + expect(declaration).not.toContain('getRuntimeBridge'); + + const compileDir = join(tempDir, 'compile'); + await mkdir(compileDir, { recursive: true }); + await writeFile( + join(compileDir, 'runtime-stub.d.ts'), + `// Minimal stub covering only the bridge methods generated wrappers call. +// The real RuntimeExecution also has dispose() and other members; expand this if codegen starts using them. +export interface RuntimeBridge { + call(module: string, functionName: string, args: unknown[], kwargs?: Record): Promise; + instantiate(module: string, className: string, args: unknown[], kwargs?: Record): Promise; + callMethod(handle: string, methodName: string, args: unknown[], kwargs?: Record): Promise; + disposeInstance(handle: string): Promise; +} + +export declare function getRuntimeBridge(): RuntimeBridge; +`, + 'utf-8' + ); + await writeFile( + join(compileDir, 'consumer.ts'), + `import { Container, Pair, Transform, acceptTransform, forward, identity, passthrough } from '../generated/generic_module.generated.js'; + +const pair: Pair = ['a', 'b']; +const transform: Transform<[number], string> = (...args) => String(args[0]); +const container = Container.fromHandle('handle'); +const accepted: Promise> = acceptTransform<[number], string>(transform); +const forwarded: Promise> = forward(container); +const resolved: Promise = identity(1); +const cloned: Promise> = container.clone(); +const identified: Promise<[number, string]> = container.id('value'); +const passthroughResult: Promise = passthrough([1, 2], { flag: true }); + +void pair; +void transform; +void accepted; +void forwarded; +void resolved; +void cloned; +void identified; +void passthroughResult; +`, + 'utf-8' + ); + await writeFile( + join(compileDir, 'tsconfig.json'), + JSON.stringify( + { + compilerOptions: { + target: 'ES2022', + module: 'NodeNext', + moduleResolution: 'NodeNext', + strict: true, + noEmit: true, + skipLibCheck: false, + baseUrl: '.', + paths: { + 'tywrap/runtime': ['./runtime-stub'], + }, + }, + include: [ + './consumer.ts', + './runtime-stub.d.ts', + '../generated/generic_module.generated.ts', + '../generated/generic_module.generated.d.ts', + ], + }, + null, + 2 + ), + 'utf-8' + ); + + const tscPath = join(process.cwd(), 'node_modules', 'typescript', 'lib', 'tsc.js'); + const compile = await processUtils.exec(process.execPath, [ + tscPath, + '-p', + join(compileDir, 'tsconfig.json'), + '--pretty', + 'false', + ]); + expect(compile.code, compile.stdout || compile.stderr).toBe(0); + } finally { + await rm(tempDir, { recursive: true, force: true }); + } + }, 30_000); + + it('generates protocol aliases without leaking init helpers and preserves method generics', async () => { + if (!(await supportsPep695Syntax())) { + return; + } + + const tempDir = await mkdtemp(join(tmpdir(), 'tywrap-protocol-generics-')); + try { + const importDir = join(tempDir, 'py'); + await mkdir(importDir, { recursive: true }); + await writeFile( + join(importDir, 'protocol_module.py'), + `from __future__ import annotations + +from typing import Protocol + +class Mapper(Protocol): + def map[U](self, x: U) -> U: + ... +`, + 'utf-8' + ); + + const outDir = join(tempDir, 'generated'); + const res = await generate({ + pythonModules: { protocol_module: { runtime: 'node', typeHints: 'strict' } }, + pythonImportPath: [importDir], + output: { dir: outDir, format: 'esm', declaration: true, sourceMap: false }, + runtime: { node: { pythonPath: defaultPythonPath } }, + performance: { caching: false, batching: false, compression: 'none' }, + development: { hotReload: false, sourceMap: false, validation: 'none' }, + } as any); + + expect(res.written.some(p => p.endsWith('protocol_module.generated.ts'))).toBe(true); + expect(res.written.some(p => p.endsWith('protocol_module.generated.d.ts'))).toBe(true); + + const typescript = await fsUtils.readFile(join(outDir, 'protocol_module.generated.ts')); + const declaration = await fsUtils.readFile(join(outDir, 'protocol_module.generated.d.ts')); + + expect(typescript).toContain('export type Mapper ='); + expect(typescript).toContain('map: (x: U) => U;'); + expect(typescript).not.toContain('NoInitOrReplaceInit'); + expect(typescript).not.toContain('__init__'); + expect(typescript).not.toContain('getRuntimeBridge'); + expect(declaration).toContain('map: (x: U) => U;'); + expect(declaration).not.toContain('NoInitOrReplaceInit'); + + const compileDir = join(tempDir, 'compile'); + await mkdir(compileDir, { recursive: true }); + await writeFile( + join(compileDir, 'consumer.ts'), + `import type { Mapper } from '../generated/protocol_module.generated.js'; + +const mapper: Mapper = { + map: (x: U) => x, +}; + +const result: string = mapper.map('value'); + +void mapper; +void result; +`, + 'utf-8' + ); + await writeFile( + join(compileDir, 'tsconfig.json'), + JSON.stringify( + { + compilerOptions: { + target: 'ES2022', + module: 'NodeNext', + moduleResolution: 'NodeNext', + strict: true, + noEmit: true, + skipLibCheck: false, + }, + include: [ + './consumer.ts', + '../generated/protocol_module.generated.ts', + '../generated/protocol_module.generated.d.ts', + ], + }, + null, + 2 + ), + 'utf-8' + ); + + const tscPath = join(process.cwd(), 'node_modules', 'typescript', 'lib', 'tsc.js'); + const compile = await processUtils.exec(process.execPath, [ + tscPath, + '-p', + join(compileDir, 'tsconfig.json'), + '--pretty', + 'false', + ]); + expect(compile.code, compile.stdout || compile.stderr).toBe(0); + } finally { + await rm(tempDir, { recursive: true, force: true }); + } + }, 30_000); }); describe('NodeBridge smoke', () => { diff --git a/test/ir_cache_filename.test.ts b/test/ir_cache_filename.test.ts index 40ffa15f..040fb9c7 100644 --- a/test/ir_cache_filename.test.ts +++ b/test/ir_cache_filename.test.ts @@ -7,6 +7,7 @@ describe('computeIrCacheFilename', () => { const filename = await computeIrCacheFilename({ module: 'pkg.sub/../evil:a', moduleVersion: null, + irVersion: '0.2.0', runtime: { pythonPath: '/usr/bin/python3', virtualEnv: null }, output: { format: 'esm', declaration: false, sourceMap: false }, performance: { caching: true, compression: 'none' }, @@ -19,4 +20,26 @@ describe('computeIrCacheFilename', () => { expect(filename).not.toContain('..'); expect(filename).not.toContain(':'); }); + + it('changes when the IR schema version changes', async () => { + const keyBase = { + module: 'pkg.sub', + moduleVersion: null, + runtime: { pythonPath: '/usr/bin/python3', virtualEnv: null }, + output: { format: 'esm', declaration: true, sourceMap: false }, + performance: { caching: true, compression: 'none' }, + typeHints: 'strict' as const, + }; + + const before = await computeIrCacheFilename({ + ...keyBase, + irVersion: '0.1.0', + }); + const after = await computeIrCacheFilename({ + ...keyBase, + irVersion: '0.2.0', + }); + + expect(after).not.toBe(before); + }); }); diff --git a/test/ir_extraction_advanced.test.ts b/test/ir_extraction_advanced.test.ts index 06a3d0d3..a3ddb785 100644 --- a/test/ir_extraction_advanced.test.ts +++ b/test/ir_extraction_advanced.test.ts @@ -65,6 +65,39 @@ function createTempModule(name: string, content: string): void { } } +async function supportsVariadicTypingFeatures(): Promise { + const result = await processUtils.exec(PYTHON_EXECUTABLE, [ + '-c', + ` +try: + from typing import ParamSpec, TypeVarTuple, Unpack +except ImportError: + try: + from typing_extensions import ParamSpec, TypeVarTuple, Unpack + except ImportError: + raise SystemExit(1) +raise SystemExit(0) +`, + ]); + return result.code === 0; +} + +async function supportsTypingExtensionsBackports(): Promise { + const result = await processUtils.exec(PYTHON_EXECUTABLE, [ + '-c', + 'from typing_extensions import ParamSpec, TypeVarTuple, Unpack', + ]); + return result.code === 0; +} + +async function supportsPep695Syntax(): Promise { + const result = await processUtils.exec(PYTHON_EXECUTABLE, [ + '-c', + 'import sys; raise SystemExit(0 if sys.version_info >= (3, 12) else 1)', + ]); + return result.code === 0; +} + beforeAll(async () => { // Ensure test fixtures exist if (!existsSync(FIXTURES_DIR)) { @@ -175,6 +208,207 @@ describe('IR Extraction - Complex Fixture Files', () => { }); }); +describe('IR Extraction - Generic Metadata', () => { + it('extracts ordered type parameters for functions, classes, and type aliases', async () => { + if (!(await supportsVariadicTypingFeatures())) { + return; + } + + createTempModule( + 'generic_type_params', + ` +from __future__ import annotations + +from typing import Callable, Generic, TypeVar +try: + from typing import ParamSpec, TypeVarTuple, Unpack +except ImportError: + from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +T = TypeVar("T") +K = TypeVar("K") +V = TypeVar("V") +P = ParamSpec("P") +Ts = TypeVarTuple("Ts") + +Pair = tuple[T, T] +Transform = Callable[P, T] +Variadic = tuple[Unpack[Ts]] + +def identity(x: T) -> T: + return x + +class Container(Generic[T]): + def __init__(self, value: T) -> None: + self.value = value + + def get(self) -> T: + return self.value + +class KeyValueStore(Generic[K, V]): + def __init__(self) -> None: + self._data: dict[K, V] = {} + + def put(self, key: K, value: V) -> None: + self._data[key] = value +` + ); + + const ir = await extractIR('generic_type_params'); + const summarizeParams = (params: Array<{ name: string; kind: string }>) => + params.map(param => ({ name: param.name, kind: param.kind })); + + const identity = ir.functions.find((f: any) => f.name === 'identity'); + expect(identity).toBeDefined(); + expect(summarizeParams(identity.type_params)).toEqual([{ name: 'T', kind: 'typevar' }]); + + const container = ir.classes.find((c: any) => c.name === 'Container'); + expect(container).toBeDefined(); + expect(summarizeParams(container.type_params)).toEqual([{ name: 'T', kind: 'typevar' }]); + + const keyValueStore = ir.classes.find((c: any) => c.name === 'KeyValueStore'); + expect(keyValueStore).toBeDefined(); + expect(summarizeParams(keyValueStore.type_params)).toEqual([ + { name: 'K', kind: 'typevar' }, + { name: 'V', kind: 'typevar' }, + ]); + + const pair = ir.type_aliases.find((alias: any) => alias.name === 'Pair'); + expect(pair).toBeDefined(); + expect(summarizeParams(pair.type_params)).toEqual([{ name: 'T', kind: 'typevar' }]); + + const transform = ir.type_aliases.find((alias: any) => alias.name === 'Transform'); + expect(transform).toBeDefined(); + expect(summarizeParams(transform.type_params)).toEqual([ + { name: 'P', kind: 'paramspec' }, + { name: 'T', kind: 'typevar' }, + ]); + + const variadic = ir.type_aliases.find((alias: any) => alias.name === 'Variadic'); + expect(variadic).toBeDefined(); + expect(summarizeParams(variadic.type_params)).toEqual([{ name: 'Ts', kind: 'typevartuple' }]); + }); + + it('extracts backported generic markers from typing_extensions when available', async () => { + if (!(await supportsTypingExtensionsBackports())) { + return; + } + + createTempModule( + 'generic_type_params_backport', + ` +from __future__ import annotations + +from typing import Callable, TypeVar +from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +T = TypeVar("T") +P = ParamSpec("P") +Ts = TypeVarTuple("Ts") + +Transform = Callable[P, T] +Variadic = tuple[Unpack[Ts]] +` + ); + + const ir = await extractIR('generic_type_params_backport'); + const summarizeParams = (params: Array<{ name: string; kind: string }>) => + params.map(param => ({ name: param.name, kind: param.kind })); + + const transform = ir.type_aliases.find((alias: any) => alias.name === 'Transform'); + expect(transform).toBeDefined(); + expect(summarizeParams(transform.type_params)).toEqual([ + { name: 'P', kind: 'paramspec' }, + { name: 'T', kind: 'typevar' }, + ]); + + const variadic = ir.type_aliases.find((alias: any) => alias.name === 'Variadic'); + expect(variadic).toBeDefined(); + expect(summarizeParams(variadic.type_params)).toEqual([{ name: 'Ts', kind: 'typevartuple' }]); + }); + + it('extracts PEP 695 type parameters from functions, classes, and type aliases on Python 3.12+', async () => { + if (!(await supportsPep695Syntax())) { + return; + } + + createTempModule( + 'generic_type_params_pep695', + ` +from __future__ import annotations + +type Pair[T] = tuple[T, T] + +def identity[T](x: T) -> T: + return x + +class Box[T]: + def __init__(self, value: T) -> None: + self.value = value + + def id[U](self, x: U) -> tuple[T, U]: + return (self.value, x) +` + ); + + const ir = await extractIR('generic_type_params_pep695'); + const summarizeParams = (params: Array<{ name: string; kind: string }>) => + params.map(param => ({ name: param.name, kind: param.kind })); + + const identity = ir.functions.find((f: any) => f.name === 'identity'); + expect(identity).toBeDefined(); + expect(summarizeParams(identity.type_params)).toEqual([{ name: 'T', kind: 'typevar' }]); + + const box = ir.classes.find((c: any) => c.name === 'Box'); + expect(box).toBeDefined(); + expect(summarizeParams(box.type_params)).toEqual([{ name: 'T', kind: 'typevar' }]); + const boxId = box.methods.find((method: any) => method.name === 'id'); + expect(boxId).toBeDefined(); + expect(summarizeParams(boxId.type_params)).toEqual([{ name: 'U', kind: 'typevar' }]); + + const pair = ir.type_aliases.find((alias: any) => alias.name === 'Pair'); + expect(pair).toBeDefined(); + expect(pair.definition).toBe('tuple[T, T]'); + expect(summarizeParams(pair.type_params)).toEqual([{ name: 'T', kind: 'typevar' }]); + }); + + it('preserves protocol method names and generic metadata on Python 3.12+', async () => { + if (!(await supportsPep695Syntax())) { + return; + } + + createTempModule( + 'protocol_method_generics_pep695', + ` +from __future__ import annotations + +from typing import Protocol + +class Mapper(Protocol): + def map[U](self, x: U) -> U: + ... +` + ); + + const ir = await extractIR('protocol_method_generics_pep695'); + const summarizeParams = (params: Array<{ name: string; kind: string }>) => + params.map(param => ({ name: param.name, kind: param.kind })); + + const mapper = ir.classes.find((c: any) => c.name === 'Mapper'); + expect(mapper).toBeDefined(); + expect(mapper.is_protocol).toBe(true); + + const mapMethod = mapper.methods.find((method: any) => method.name === 'map'); + expect(mapMethod).toBeDefined(); + expect(summarizeParams(mapMethod.type_params)).toEqual([{ name: 'U', kind: 'typevar' }]); + + const initMethod = mapper.methods.find((method: any) => method.qualname.endsWith('.__init__')); + if (initMethod) { + expect(initMethod.name).toBe('__init__'); + } + }); +}); + describe('IR Extraction - Metadata and Version Info', () => { it('should include correct metadata in IR output', async () => { const ir = await extractIR('math'); // Use built-in module diff --git a/test/mapper.test.ts b/test/mapper.test.ts index b551e660..07025221 100644 --- a/test/mapper.test.ts +++ b/test/mapper.test.ts @@ -118,26 +118,37 @@ describe('TypeMapper', () => { expect(mapped.kind).toBe('object'); }); - it('maps TypeVar-style placeholders to unknown instead of undeclared TS identifiers', () => { + it('preserves dedicated generic placeholders for later sanitization', () => { const mappedTypeVar = mapper.mapPythonType({ kind: 'typevar', name: 'T', } as any); expect(mappedTypeVar).toEqual({ - kind: 'primitive', - name: 'unknown', + kind: 'custom', + name: 'T', + module: 'typing', }); - const mappedParamSpec = mapper.mapPythonType({ + const mappedRawTypingPlaceholder = mapper.mapPythonType({ kind: 'custom', name: 'P', module: 'typing', } as any); - expect(mappedParamSpec).toEqual({ + expect(mappedRawTypingPlaceholder).toEqual({ kind: 'primitive', name: 'unknown', }); + const mappedParamSpec = mapper.mapPythonType({ + kind: 'paramspec', + name: 'P', + } as any); + expect(mappedParamSpec).toEqual({ + kind: 'custom', + name: 'P', + module: 'typing', + }); + const mappedCallable = mapper.mapPythonType({ kind: 'callable', parameters: [{ kind: 'custom', name: '...' }], @@ -146,8 +157,9 @@ describe('TypeMapper', () => { expect(mappedCallable.kind).toBe('function'); if (mappedCallable.kind === 'function') { expect(mappedCallable.returnType).toEqual({ - kind: 'primitive', - name: 'unknown', + kind: 'custom', + name: 'T', + module: 'typing', }); } }); diff --git a/test/type_mapping_enhanced.test.ts b/test/type_mapping_enhanced.test.ts index 3e71faff..e4fb5a63 100644 --- a/test/type_mapping_enhanced.test.ts +++ b/test/type_mapping_enhanced.test.ts @@ -3,7 +3,9 @@ import { TypeMapper } from '../src/core/mapper.js'; import type { PythonType, TypescriptType, + TSArrayType, TSPrimitiveType, + TSCustomType, TSGenericType, TSObjectType, } from '../src/types/index.js'; @@ -12,30 +14,34 @@ describe('TypeMapper - Enhanced Type Support', () => { const mapper = new TypeMapper(); describe('TypeVar Support', () => { - test('maps basic TypeVar to unknown', () => { + test('maps basic TypeVar to custom type', () => { const typeVar: PythonType = { kind: 'typevar', name: 'T', }; - const result = mapper.mapPythonType(typeVar); + const result = mapper.mapPythonType(typeVar) as TSCustomType; - expect(result).toEqual({ kind: 'primitive', name: 'unknown' }); + expect(result.kind).toBe('custom'); + expect(result.name).toBe('T'); + expect(result.module).toBe('typing'); }); - test('maps bounded TypeVar to unknown', () => { + test('maps bounded TypeVar preserving name', () => { const boundedTypeVar: PythonType = { kind: 'typevar', name: 'T', bound: { kind: 'custom', name: 'BaseClass' }, }; - const result = mapper.mapPythonType(boundedTypeVar); + const result = mapper.mapPythonType(boundedTypeVar) as TSCustomType; - expect(result).toEqual({ kind: 'primitive', name: 'unknown' }); + expect(result.kind).toBe('custom'); + expect(result.name).toBe('T'); + expect(result.module).toBe('typing'); }); - test('maps constrained TypeVar to unknown', () => { + test('maps constrained TypeVar preserving name', () => { const constrainedTypeVar: PythonType = { kind: 'typevar', name: 'T', @@ -45,21 +51,23 @@ describe('TypeMapper - Enhanced Type Support', () => { ], }; - const result = mapper.mapPythonType(constrainedTypeVar); + const result = mapper.mapPythonType(constrainedTypeVar) as TSCustomType; - expect(result).toEqual({ kind: 'primitive', name: 'unknown' }); + expect(result.kind).toBe('custom'); + expect(result.name).toBe('T'); }); - test('maps covariant TypeVar to unknown', () => { + test('maps covariant TypeVar', () => { const covariantTypeVar: PythonType = { kind: 'typevar', name: 'T_co', variance: 'covariant', }; - const result = mapper.mapPythonType(covariantTypeVar); + const result = mapper.mapPythonType(covariantTypeVar) as TSCustomType; - expect(result).toEqual({ kind: 'primitive', name: 'unknown' }); + expect(result.kind).toBe('custom'); + expect(result.name).toBe('T_co'); }); }); @@ -191,18 +199,17 @@ describe('TypeMapper - Enhanced Type Support', () => { expect(result.typeArgs).toEqual([{ kind: 'primitive', name: 'unknown' }]); }); - test('maps typing.Sequence to Array generic', () => { + test('maps typing.Sequence to an array shape', () => { const sequenceType: PythonType = { kind: 'custom', name: 'Sequence', module: 'typing', }; - const result = mapper.mapPythonType(sequenceType) as TSGenericType; + const result = mapper.mapPythonType(sequenceType) as TSArrayType; - expect(result.kind).toBe('generic'); - expect(result.name).toBe('Array'); - expect(result.typeArgs).toEqual([{ kind: 'primitive', name: 'unknown' }]); + expect(result.kind).toBe('array'); + expect(result.elementType).toEqual({ kind: 'primitive', name: 'unknown' }); }); test('maps typing.Mapping to object with index signature', () => { @@ -221,6 +228,144 @@ describe('TypeMapper - Enhanced Type Support', () => { valueType: { kind: 'primitive', name: 'unknown' }, }); }); + + test('maps generic typing.Awaitable[T] to Promise', () => { + const awaitableType: PythonType = { + kind: 'generic', + name: 'Awaitable', + module: 'typing', + typeArgs: [{ kind: 'primitive', name: 'str' }], + }; + + const result = mapper.mapPythonType(awaitableType) as TSGenericType; + + expect(result.kind).toBe('generic'); + expect(result.name).toBe('Promise'); + expect(result.typeArgs).toEqual([{ kind: 'primitive', name: 'string' }]); + }); + + test('maps generic typing.Sequence[T] to T[]', () => { + const sequenceType: PythonType = { + kind: 'generic', + name: 'Sequence', + module: 'typing', + typeArgs: [{ kind: 'primitive', name: 'int' }], + }; + + const result = mapper.mapPythonType(sequenceType) as TSArrayType; + + expect(result.kind).toBe('array'); + expect(result.elementType).toEqual({ kind: 'primitive', name: 'number' }); + }); + + test('maps generic collections.abc.Sequence[T] to T[]', () => { + const sequenceType: PythonType = { + kind: 'generic', + name: 'Sequence', + module: 'collections.abc', + typeArgs: [{ kind: 'primitive', name: 'bool' }], + }; + + const result = mapper.mapPythonType(sequenceType) as TSArrayType; + + expect(result.kind).toBe('array'); + expect(result.elementType).toEqual({ kind: 'primitive', name: 'boolean' }); + }); + + test('maps generic typing_extensions.Mapping[K, V] to an object index signature', () => { + const mappingType: PythonType = { + kind: 'generic', + name: 'Mapping', + module: 'typing_extensions', + typeArgs: [ + { kind: 'primitive', name: 'str' }, + { kind: 'primitive', name: 'int' }, + ], + }; + + const result = mapper.mapPythonType(mappingType) as TSObjectType; + + expect(result.kind).toBe('object'); + expect(result.properties).toEqual([]); + expect(result.indexSignature).toEqual({ + keyType: { kind: 'primitive', name: 'string' }, + valueType: { kind: 'primitive', name: 'number' }, + }); + }); + + test('preserves third-party Sequence generics', () => { + const sequenceType: PythonType = { + kind: 'generic', + name: 'Sequence', + module: 'pkg', + typeArgs: [{ kind: 'primitive', name: 'int' }], + }; + + const result = mapper.mapPythonType(sequenceType) as TSGenericType; + + expect(result.kind).toBe('generic'); + expect(result.name).toBe('Sequence'); + expect(result.typeArgs).toEqual([{ kind: 'primitive', name: 'number' }]); + }); + + test('preserves third-party Awaitable generics', () => { + const awaitableType: PythonType = { + kind: 'generic', + name: 'Awaitable', + module: 'pkg', + typeArgs: [{ kind: 'primitive', name: 'str' }], + }; + + const result = mapper.mapPythonType(awaitableType) as TSGenericType; + + expect(result.kind).toBe('generic'); + expect(result.name).toBe('Awaitable'); + expect(result.typeArgs).toEqual([{ kind: 'primitive', name: 'string' }]); + }); + + test('preserves third-party Coroutine generics', () => { + const coroutineType: PythonType = { + kind: 'generic', + name: 'Coroutine', + module: 'vendor', + typeArgs: [ + { kind: 'primitive', name: 'int' }, + { kind: 'primitive', name: 'int' }, + { kind: 'primitive', name: 'str' }, + ], + }; + + const result = mapper.mapPythonType(coroutineType) as TSGenericType; + + expect(result.kind).toBe('generic'); + expect(result.name).toBe('Coroutine'); + expect(result.typeArgs).toEqual([ + { kind: 'primitive', name: 'number' }, + { kind: 'primitive', name: 'number' }, + { kind: 'primitive', name: 'string' }, + ]); + }); + + test('preserves third-party Mapping generics', () => { + const mappingType: PythonType = { + kind: 'generic', + name: 'Mapping', + module: 'vendor', + typeArgs: [ + { kind: 'primitive', name: 'str' }, + { kind: 'primitive', name: 'int' }, + ], + }; + + const result = mapper.mapPythonType(mappingType) as TSGenericType; + + expect(result.kind).toBe('generic'); + expect(result.name).toBe('Mapping'); + expect(result.typeArgs).toEqual([ + { kind: 'primitive', name: 'string' }, + { kind: 'primitive', name: 'number' }, + ]); + }); }); describe('Module-qualified Type Names', () => { @@ -256,11 +401,11 @@ describe('TypeMapper - Enhanced Type Support', () => { module: 'my.module', }; - const result = mapper.mapPythonType(unknownType); + const result = mapper.mapPythonType(unknownType) as TSCustomType; expect(result.kind).toBe('custom'); - expect((result as any).name).toBe('MyCustomClass'); - expect((result as any).module).toBe('my.module'); + expect(result.name).toBe('MyCustomClass'); + expect(result.module).toBe('my.module'); }); }); @@ -274,9 +419,11 @@ describe('TypeMapper - Enhanced Type Support', () => { }, }; - const result = mapper.mapPythonType(finalTypeVar); + const result = mapper.mapPythonType(finalTypeVar) as TSCustomType; - expect(result).toEqual({ kind: 'primitive', name: 'unknown' }); + expect(result.kind).toBe('custom'); + expect(result.name).toBe('T'); + expect(result.module).toBe('typing'); }); test('maps ClassVar[Final[int]] correctly', () => { @@ -310,10 +457,11 @@ describe('TypeMapper - Enhanced Type Support', () => { const unionResult = result as any; expect(unionResult.types).toHaveLength(3); - // TypeVar becomes unknown + // TypeVar becomes custom type expect(unionResult.types[0]).toEqual({ - kind: 'primitive', - name: 'unknown', + kind: 'custom', + name: 'T', + module: 'typing', }); // Final[None] becomes null @@ -342,8 +490,9 @@ describe('TypeMapper - Enhanced Type Support', () => { expect(valueResult).toEqual(returnResult); expect(valueResult).toEqual({ - kind: 'primitive', - name: 'unknown', + kind: 'custom', + name: 'T', + module: 'typing', }); }); diff --git a/tywrap_ir/tywrap_ir/__main__.py b/tywrap_ir/tywrap_ir/__main__.py index 8e7cb6d9..5c8fe803 100644 --- a/tywrap_ir/tywrap_ir/__main__.py +++ b/tywrap_ir/tywrap_ir/__main__.py @@ -10,7 +10,7 @@ def main() -> None: parser.add_argument("--module", help="Python module name, e.g. math or pandas") parser.add_argument("--package", help="Python package name (alias of --module for now)") parser.add_argument("--output", help="Write JSON IR to file instead of stdout") - parser.add_argument("--ir-version", default="0.1.0", help="IR schema version") + parser.add_argument("--ir-version", default="0.2.0", help="IR schema version") parser.add_argument("--include-private", action="store_true", help="Include private members (leading _)") parser.add_argument("--no-pretty", action="store_true", help="Disable pretty JSON formatting") args = parser.parse_args() diff --git a/tywrap_ir/tywrap_ir/ir.py b/tywrap_ir/tywrap_ir/ir.py index 7ffca575..71bddd51 100644 --- a/tywrap_ir/tywrap_ir/ir.py +++ b/tywrap_ir/tywrap_ir/ir.py @@ -1,14 +1,16 @@ from __future__ import annotations +import ast +import dataclasses as _dataclasses import importlib import inspect import json import platform import sys -from dataclasses import dataclass, asdict -from typing import Any, Dict, List, Optional, get_type_hints -import dataclasses as _dataclasses +import types import typing +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, get_type_hints try: from importlib import metadata as importlib_metadata # py3.8+ @@ -16,6 +18,15 @@ import importlib_metadata # type: ignore +@dataclass +class IRTypeParam: + name: str + kind: str + bound: str | None = None + constraints: List[str] | None = None + variance: str | None = None + + @dataclass class IRParam: name: str @@ -33,6 +44,7 @@ class IRFunction: returns: Optional[str] is_async: bool is_generator: bool + type_params: List[IRTypeParam] @dataclass @@ -49,6 +61,7 @@ class IRClass: is_namedtuple: bool is_dataclass: bool is_pydantic: bool + type_params: List[IRTypeParam] @dataclass @@ -64,6 +77,7 @@ class IRTypeAlias: name: str definition: str is_generic: bool + type_params: List[IRTypeParam] @dataclass @@ -78,23 +92,16 @@ class IRModule: warnings: List[str] -# Minimal, stringified annotation representation - def _stringify_annotation(annotation: Any) -> Optional[str]: if annotation is inspect._empty: # type: ignore[attr-defined] return None try: - # Handle forward references more elegantly str_repr = str(annotation) - - # Clean up class references to show just the class name if str_repr.startswith(""): - # Extract class name from - class_path = str_repr[8:-2] # Remove - if '.' in class_path: - return class_path.split('.')[-1] # Just the class name + class_path = str_repr[8:-2] + if "." in class_path: + return class_path.split(".")[-1] return class_path - return str_repr except Exception: return None @@ -111,116 +118,319 @@ def _param_kind_to_str(kind: inspect._ParameterKind) -> str: return mapping.get(kind, str(kind)) +def _type_param_kind(value: Any) -> str | None: + cls = type(value) + name = getattr(cls, "__name__", "") + module = getattr(cls, "__module__", "") + if module in {"typing", "typing_extensions"} and name in {"TypeVar", "ParamSpec", "TypeVarTuple"}: + return name.lower() + return None + + +def _serialize_type_param(value: Any) -> IRTypeParam | None: + kind = _type_param_kind(value) + if kind is None: + return None + + if kind == "typevar": + constraints = [_stringify_annotation(item) or str(item) for item in getattr(value, "__constraints__", ())] or None + variance = "invariant" + if getattr(value, "__covariant__", False): + variance = "covariant" + elif getattr(value, "__contravariant__", False): + variance = "contravariant" + bound_value = getattr(value, "__bound__", None) + return IRTypeParam( + name=str(getattr(value, "__name__", str(value)).replace("~", "")), + kind="typevar", + bound=_stringify_annotation(bound_value) if bound_value is not None else None, + constraints=constraints, + variance=variance, + ) + + if kind == "paramspec": + return IRTypeParam( + name=str(getattr(value, "__name__", str(value)).replace("~", "")), + kind="paramspec", + ) + + if kind == "typevartuple": + return IRTypeParam( + name=str(getattr(value, "__name__", str(value)).replace("~", "")), + kind="typevartuple", + ) + + return None + + +def _append_type_param(value: Any, seen: set[str], out: List[IRTypeParam]) -> None: + param = _serialize_type_param(value) + if param is None: + return + key = _type_param_key(param) + if key in seen: + return + seen.add(key) + out.append(param) + + +def _type_param_key(param: IRTypeParam) -> str: + return f"{param.kind}:{param.name}" + + +def _append_declared_type_params(obj: Any, seen: set[str], out: List[IRTypeParam]) -> None: + for attr_name in ("__type_params__", "__parameters__"): + try: + params = getattr(obj, attr_name, ()) + except Exception: + continue + for param in params or (): + _append_type_param(param, seen, out) + + +def _collect_declared_type_params(obj: Any) -> List[IRTypeParam]: + seen: set[str] = set() + out: List[IRTypeParam] = [] + _append_declared_type_params(obj, seen, out) + return out + + +def _collect_type_params_from_annotation( + annotation: Any, + seen: set[str], + out: List[IRTypeParam], +) -> None: + _append_type_param(annotation, seen, out) + + try: + origin = typing.get_origin(annotation) + except Exception: + origin = None + if origin is not None: + _append_type_param(origin, seen, out) + + try: + args = typing.get_args(annotation) + except Exception: + args = () + for arg in args: + _collect_type_params_from_annotation(arg, seen, out) + + text = _stringify_annotation(annotation) or "" + paramspec_match = text.split(".", 1)[0] if text.endswith((".args", ".kwargs")) else None + if paramspec_match: + inferred = IRTypeParam(name=paramspec_match.replace("~", ""), kind="paramspec") + key = _type_param_key(inferred) + if key not in seen: + seen.add(key) + out.append(inferred) + + +def _collect_annotation_type_params(*annotations: Any) -> List[IRTypeParam]: + seen: set[str] = set() + out: List[IRTypeParam] = [] + for annotation in annotations: + _collect_type_params_from_annotation(annotation, seen, out) + return out + + +def _merge_type_params(*groups: List[IRTypeParam]) -> List[IRTypeParam]: + seen: set[str] = set() + out: List[IRTypeParam] = [] + for group in groups: + for param in group: + key = _type_param_key(param) + if key in seen: + continue + seen.add(key) + out.append(param) + return out + + +def _collect_type_params_from_object_and_annotations(obj: Any, *annotations: Any) -> List[IRTypeParam]: + return _merge_type_params( + _collect_declared_type_params(obj), + _collect_annotation_type_params(*annotations), + ) + + +def _collect_scoped_type_params( + obj: Any, + *annotations: Any, + inherited_type_params: List[IRTypeParam] | None = None, +) -> List[IRTypeParam]: + declared = _collect_declared_type_params(obj) + annotation_params = _collect_annotation_type_params(*annotations) + if not inherited_type_params: + return _merge_type_params(declared, annotation_params) + + inherited_keys = {_type_param_key(param) for param in inherited_type_params} + scoped_annotation_params = [ + param for param in annotation_params if _type_param_key(param) not in inherited_keys + ] + return _merge_type_params(declared, scoped_annotation_params) + + +def _unwrap_type_alias_value(value: Any) -> Any: + type_alias_type = getattr(typing, "TypeAliasType", None) + if type_alias_type is None or not isinstance(value, type_alias_type): + return value + try: + return value.__value__ + except Exception: + return value + + +def _top_level_assigned_names(module: Any) -> set[str]: + try: + source = inspect.getsource(module) + except Exception: + return set() + + try: + tree = ast.parse(source) + except Exception: + return set() + + names: set[str] = set() + for stmt in tree.body: + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + if isinstance(target, ast.Name): + names.add(target.id) + elif isinstance(stmt, ast.AnnAssign): + if isinstance(stmt.target, ast.Name): + names.add(stmt.target.id) + elif hasattr(ast, "TypeAlias") and isinstance(stmt, getattr(ast, "TypeAlias")): + name_node = getattr(stmt, "name", None) + if isinstance(name_node, ast.Name): + names.add(name_node.id) + return names + + +def _is_type_alias_value(value: Any) -> bool: + if inspect.isfunction(value) or inspect.isbuiltin(value) or inspect.isclass(value) or inspect.ismodule(value): + return False + type_alias_type = getattr(typing, "TypeAliasType", None) + if type_alias_type is not None and isinstance(value, type_alias_type): + return True + if isinstance(value, types.GenericAlias): + return True + try: + if typing.get_origin(value) is not None: + return True + except Exception: + pass + return hasattr(value, "__parameters__") and bool(getattr(value, "__parameters__", ())) + + def _extract_constants(module: Any, module_name: str, include_private: bool) -> List[IRConstant]: """Extract module-level constants and Final variables.""" constants: List[IRConstant] = [] - - # Check module annotations for type hints - annotations = getattr(module, '__annotations__', {}) - + annotations = getattr(module, "__annotations__", {}) + for name in dir(module): if not include_private and name.startswith("_"): continue - + try: value = getattr(module, name) except Exception: continue - - # Skip functions, classes, modules, and other non-constant items - if (inspect.isfunction(value) or inspect.isclass(value) or - inspect.ismodule(value) or callable(value)): + + if inspect.isfunction(value) or inspect.isclass(value) or inspect.ismodule(value) or callable(value): continue - - # Check if it's a constant (uppercase naming convention or Final) + is_constant = name.isupper() or name in annotations - - if is_constant: - annotation = annotations.get(name) - annotation_str = _stringify_annotation(annotation) if annotation else None - - # Check if it's Final - is_final = False - if annotation_str and ('Final[' in annotation_str or annotation_str == 'Final'): - is_final = True - - # Get string representation of value (truncated for large objects) - try: - value_repr = repr(value) - if len(value_repr) > 200: - value_repr = value_repr[:197] + "..." - except Exception: - value_repr = "" - - constants.append(IRConstant( + if not is_constant: + continue + + annotation = annotations.get(name) + annotation_str = _stringify_annotation(annotation) if annotation else None + is_final = bool(annotation_str and ("Final[" in annotation_str or annotation_str == "Final")) + + try: + value_repr = repr(value) + if len(value_repr) > 200: + value_repr = value_repr[:197] + "..." + except Exception: + value_repr = "" + + constants.append( + IRConstant( name=name, annotation=annotation_str, value_repr=value_repr, - is_final=is_final - )) - + is_final=is_final, + ) + ) + return constants def _extract_type_aliases(module: Any, module_name: str, include_private: bool) -> List[IRTypeAlias]: - """Extract type aliases from module.""" + """Extract top-level type aliases from module assignments.""" type_aliases: List[IRTypeAlias] = [] - - # Check module annotations for type aliases - annotations = getattr(module, '__annotations__', {}) - - for name, annotation in annotations.items(): + annotations = getattr(module, "__annotations__", {}) + assigned_names = _top_level_assigned_names(module) + candidate_names = sorted(assigned_names | set(annotations.keys())) + + for name in candidate_names: if not include_private and name.startswith("_"): continue - + try: - value = getattr(module, name, None) + value = getattr(module, name) except Exception: continue - - # Type aliases are typically annotated but check for specific patterns - annotation_str = _stringify_annotation(annotation) - - # Check if it's a type alias (has TypeAlias annotation or follows patterns) - is_type_alias = ( - annotation_str and ('TypeAlias' in annotation_str or - 'typing.Union' in annotation_str or - 'typing.Optional' in annotation_str or - 'typing.List' in annotation_str or - 'typing.Dict' in annotation_str or - '|' in annotation_str) # Modern union syntax - ) - - if is_type_alias: - # Check if it's generic (contains type parameters) - is_generic = bool(annotation_str and any( - marker in annotation_str for marker in ['[', '~', 'TypeVar', 'Generic'] - )) - - type_aliases.append(IRTypeAlias( + + annotation = annotations.get(name) + annotation_str = _stringify_annotation(annotation) if annotation is not None else None + is_type_alias_annotation = bool(annotation_str and "TypeAlias" in annotation_str) + if not _is_type_alias_value(value) and not is_type_alias_annotation: + continue + + unwrapped_value = _unwrap_type_alias_value(value) + type_params = _collect_type_params_from_object_and_annotations(value, unwrapped_value) + definition = _stringify_annotation(unwrapped_value) or annotation_str or str(unwrapped_value) + type_aliases.append( + IRTypeAlias( name=name, - definition=annotation_str or str(annotation), - is_generic=is_generic - )) - + definition=definition, + is_generic=bool(type_params), + type_params=type_params, + ) + ) + return type_aliases -def _extract_function(obj: Any, qualname: str) -> Optional[IRFunction]: +def _extract_function( + obj: Any, + qualname: str, + *, + include_type_params: bool = True, + inherited_type_params: List[IRTypeParam] | None = None, + display_name: str | None = None, +) -> Optional[IRFunction]: try: sig = inspect.signature(obj) except Exception: return None - # Use get_type_hints to resolve ForwardRefs where possible try: - hints = get_type_hints(obj) + hints = get_type_hints(obj, include_extras=True) except Exception: - hints = {} + try: + hints = get_type_hints(obj) + except Exception: + hints = {} params: List[IRParam] = [] + annotations_for_params: List[Any] = [] for name, p in sig.parameters.items(): ann = hints.get(name, p.annotation) + annotations_for_params.append(ann) params.append( IRParam( name=name, @@ -235,15 +445,26 @@ def _extract_function(obj: Any, qualname: str) -> Optional[IRFunction]: # Async generators must be marked as generators so callers can distinguish # them from plain coroutines. is_generator = inspect.isgeneratorfunction(obj) or inspect.isasyncgenfunction(obj) + type_params = ( + _collect_scoped_type_params( + obj, + *annotations_for_params, + returns, + inherited_type_params=inherited_type_params, + ) + if include_type_params + else [] + ) return IRFunction( - name=getattr(obj, "__name__", qualname.split(".")[-1]), + name=display_name or getattr(obj, "__name__", qualname.split(".")[-1]), qualname=qualname, docstring=inspect.getdoc(obj), parameters=params, returns=_stringify_annotation(returns), is_async=is_async, is_generator=is_generator, + type_params=type_params, ) @@ -255,31 +476,38 @@ def _extract_class(cls: type, module_name: str, include_private: bool) -> Option return None bases = [b.__name__ for b in getattr(cls, "__bases__", []) if hasattr(b, "__name__")] + class_type_params = _collect_type_params_from_object_and_annotations(cls) methods: List[IRFunction] = [] for meth_name, value in inspect.getmembers( cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethoddescriptor(x) or inspect.isbuiltin(x), ): - if not include_private and meth_name.startswith("_"): + if not include_private and meth_name.startswith("_") and meth_name != "__init__": continue - fn = _extract_function(value, f"{module_name}.{cls.__name__}.{meth_name}") + fn = _extract_function( + value, + f"{module_name}.{cls.__name__}.{meth_name}", + inherited_type_params=class_type_params, + display_name=meth_name, + ) if fn is not None: methods.append(fn) - # TypedDict detection and fields typed_dict = False total: Optional[bool] = None fields: List[IRParam] = [] try: - # Heuristic: TypedDict classes have __annotations__ and __total__ if hasattr(cls, "__annotations__") and hasattr(cls, "__total__"): typed_dict = True total = bool(getattr(cls, "__total__", True)) - ann = get_type_hints(cls, include_extras=True) if hasattr(typing, "get_origin") else getattr(cls, "__annotations__", {}) + ann = ( + get_type_hints(cls, include_extras=True) + if hasattr(typing, "get_origin") + else getattr(cls, "__annotations__", {}) + ) for fname, ftype in ann.items(): text = _stringify_annotation(ftype) - # Determine optionality from NotRequired/Required wrappers if present s = str(ftype) is_not_required = "NotRequired[" in s or "typing.NotRequired[" in s is_required = "Required[" in s or "typing.Required[" in s @@ -288,7 +516,6 @@ def _extract_class(cls: type, module_name: str, include_private: bool) -> Option except Exception: pass - # Protocol detection is_protocol = False try: for b in getattr(cls, "__mro__", []): @@ -298,18 +525,27 @@ def _extract_class(cls: type, module_name: str, include_private: bool) -> Option except Exception: is_protocol = False - # NamedTuple detection is_namedtuple = hasattr(cls, "_fields") and isinstance(getattr(cls, "_fields", None), (list, tuple)) if is_namedtuple and not fields: try: - ann = get_type_hints(cls, include_extras=True) if hasattr(typing, "get_origin") else getattr(cls, "__annotations__", {}) + ann = ( + get_type_hints(cls, include_extras=True) + if hasattr(typing, "get_origin") + else getattr(cls, "__annotations__", {}) + ) for fname in getattr(cls, "_fields", []): ftype = ann.get(fname, None) - fields.append(IRParam(name=str(fname), kind="FIELD", annotation=_stringify_annotation(ftype), default=False)) + fields.append( + IRParam( + name=str(fname), + kind="FIELD", + annotation=_stringify_annotation(ftype), + default=False, + ) + ) except Exception: pass - # Dataclass detection is_dataclass = False try: is_dataclass = _dataclasses.is_dataclass(cls) @@ -318,12 +554,20 @@ def _extract_class(cls: type, module_name: str, include_private: bool) -> Option if is_dataclass and not fields: try: for f in _dataclasses.fields(cls): # type: ignore[attr-defined] - defaulted = not (f.default is _dataclasses.MISSING and f.default_factory is _dataclasses.MISSING) # type: ignore[attr-defined] - fields.append(IRParam(name=f.name, kind="FIELD", annotation=_stringify_annotation(f.type), default=defaulted)) + defaulted = not ( + f.default is _dataclasses.MISSING and f.default_factory is _dataclasses.MISSING + ) # type: ignore[attr-defined] + fields.append( + IRParam( + name=f.name, + kind="FIELD", + annotation=_stringify_annotation(f.type), + default=defaulted, + ) + ) except Exception: pass - # Pydantic detection is_pydantic = False try: import pydantic @@ -341,28 +585,40 @@ def _extract_class(cls: type, module_name: str, include_private: bool) -> Option is_pydantic = False if is_pydantic and not fields: try: - # v2 model_fields = getattr(cls, "model_fields", None) if isinstance(model_fields, dict): for fname, finfo in model_fields.items(): ann = getattr(finfo, "annotation", None) required = getattr(finfo, "is_required", False) - fields.append(IRParam(name=str(fname), kind="FIELD", annotation=_stringify_annotation(ann), default=(not required))) + fields.append( + IRParam( + name=str(fname), + kind="FIELD", + annotation=_stringify_annotation(ann), + default=(not required), + ) + ) else: - # v1 __fields__ = getattr(cls, "__fields__", None) if isinstance(__fields__, dict): for fname, finfo in __fields__.items(): ann = getattr(finfo, "type_", None) required = getattr(finfo, "required", False) - fields.append(IRParam(name=str(fname), kind="FIELD", annotation=_stringify_annotation(ann), default=(not required))) + fields.append( + IRParam( + name=str(fname), + kind="FIELD", + annotation=_stringify_annotation(ann), + default=(not required), + ) + ) except Exception: pass return IRClass( name=name, qualname=f"{module_name}.{name}", - docstring=inspect.getdoc(cls), + docstring=inspect.getdoc(cls) if getattr(cls, "__doc__", None) else None, bases=bases, methods=methods, typed_dict=typed_dict, @@ -372,6 +628,7 @@ def _extract_class(cls: type, module_name: str, include_private: bool) -> Option is_namedtuple=is_namedtuple, is_dataclass=is_dataclass, is_pydantic=is_pydantic, + type_params=class_type_params, ) @@ -403,19 +660,15 @@ def _collect_metadata(module_name: str, ir_version: str) -> Dict[str, Any]: def extract_module_ir( module_name: str, *, - ir_version: str = "0.1.0", + ir_version: str = "0.2.0", include_private: bool = False, ) -> Dict[str, Any]: - """ - Extract a minimal IR for a Python module: top-level callables with signature info. - """ module = importlib.import_module(module_name) functions: List[IRFunction] = [] classes: List[IRClass] = [] warnings: List[str] = [] - - # Extract constants and type aliases + constants = _extract_constants(module, module_name, include_private) type_aliases = _extract_type_aliases(module, module_name, include_private) @@ -426,12 +679,10 @@ def extract_module_ir( continue if not include_private and name.startswith("_"): continue - # Include plain functions and builtins (e.g., math.sqrt) if inspect.isfunction(value) or inspect.isbuiltin(value): fn = _extract_function(value, f"{module_name}.{name}") if fn is not None: functions.append(fn) - # Include classes defined in this module if inspect.isclass(value) and getattr(value, "__module__", None) == module.__name__: cls_ir = _extract_class(value, module_name, include_private) if cls_ir is not None: @@ -447,13 +698,13 @@ def extract_module_ir( metadata=_collect_metadata(module_name, ir_version), warnings=warnings, ) - # Return as plain dicts ready for JSON emitting return asdict(ir) + def emit_ir_json( module_name: str, *, - ir_version: str = "0.1.0", + ir_version: str = "0.2.0", include_private: bool = False, pretty: bool = True, ) -> str: diff --git a/tywrap_ir/tywrap_ir/optimized_ir.py b/tywrap_ir/tywrap_ir/optimized_ir.py index 7999ad93..9346b416 100644 --- a/tywrap_ir/tywrap_ir/optimized_ir.py +++ b/tywrap_ir/tywrap_ir/optimized_ir.py @@ -158,7 +158,7 @@ def __init__(self, def extract_module_ir_optimized(self, module_name: str, *, - ir_version: str = "0.1.0", + ir_version: str = "0.2.0", include_private: bool = False) -> Dict[str, Any]: """ Extract IR with performance optimizations @@ -411,7 +411,7 @@ def clear_cache(self) -> None: def extract_module_ir_optimized(module_name: str, *, - ir_version: str = "0.1.0", + ir_version: str = "0.2.0", include_private: bool = False, enable_caching: bool = True, enable_parallel: bool = True) -> Dict[str, Any]: