diff --git a/.changeset/shaky-dryers-refuse.md b/.changeset/shaky-dryers-refuse.md new file mode 100644 index 0000000..3f3679e --- /dev/null +++ b/.changeset/shaky-dryers-refuse.md @@ -0,0 +1,5 @@ +--- +'@codama/renderers-rust': patch +--- + +fix: make serde field attributes dynamic and optional diff --git a/src/getRenderMapVisitor.ts b/src/getRenderMapVisitor.ts index 3f8e1da..5423e21 100644 --- a/src/getRenderMapVisitor.ts +++ b/src/getRenderMapVisitor.ts @@ -56,7 +56,11 @@ export function getRenderMapVisitor(options: GetRenderMapOptions = {}) { const dependencyMap = options.dependencyMap ?? {}; const getImportFrom = getImportFromFactory(options.linkOverrides ?? {}); const getTraitsFromNode = getTraitsFromNodeFactory(options.traitOptions); - const typeManifestVisitor = getTypeManifestVisitor({ getImportFrom, getTraitsFromNode }); + const typeManifestVisitor = getTypeManifestVisitor({ + getImportFrom, + getTraitsFromNode, + traitOptions: options.traitOptions, + }); const anchorTraits = options.anchorTraits ?? true; return pipe( diff --git a/src/getTypeManifestVisitor.ts b/src/getTypeManifestVisitor.ts index c3a87ab..0d7b022 100644 --- a/src/getTypeManifestVisitor.ts +++ b/src/getTypeManifestVisitor.ts @@ -1,9 +1,12 @@ import { CODAMA_ERROR__RENDERERS__UNSUPPORTED_NODE, CodamaError } from '@codama/errors'; import { + AccountNode, arrayTypeNode, CountNode, + DefinedTypeNode, definedTypeNode, fixedCountNode, + InstructionNode, isNode, NumberTypeNode, numberTypeNode, @@ -18,7 +21,13 @@ import { import { extendVisitor, mergeVisitor, pipe, visit } from '@codama/visitors-core'; import { ImportMap } from './ImportMap'; -import { GetImportFromFunction, GetTraitsFromNodeFunction, rustDocblock } from './utils'; +import { + GetImportFromFunction, + getSerdeFieldAttribute, + GetTraitsFromNodeFunction, + rustDocblock, + TraitOptions, +} from './utils'; export type TypeManifest = { imports: ImportMap; @@ -31,12 +40,14 @@ export function getTypeManifestVisitor(options: { getTraitsFromNode: GetTraitsFromNodeFunction; nestedStruct?: boolean; parentName?: string | null; + traitOptions?: TraitOptions; }) { - const { getImportFrom, getTraitsFromNode } = options; + const { getImportFrom, getTraitsFromNode, traitOptions } = options; let parentName: string | null = options.parentName ?? null; let nestedStruct: boolean = options.nestedStruct ?? false; let inlineStruct: boolean = false; let parentSize: NumberTypeNode | number | null = null; + let parentNode: AccountNode | DefinedTypeNode | InstructionNode | null = null; return pipe( mergeVisitor( @@ -51,10 +62,12 @@ export function getTypeManifestVisitor(options: { extendVisitor(v, { visitAccount(account, { self }) { parentName = pascalCase(account.name); + parentNode = account; const manifest = visit(account.data, self); const traits = getTraitsFromNode(account); manifest.imports.mergeWith(traits.imports); parentName = null; + parentNode = null; return { ...manifest, type: traits.render + manifest.type, @@ -140,10 +153,12 @@ export function getTypeManifestVisitor(options: { visitDefinedType(definedType, { self }) { parentName = pascalCase(definedType.name); + parentNode = definedType; const manifest = visit(definedType.type, self); const traits = getTraitsFromNode(definedType); manifest.imports.mergeWith(traits.imports); parentName = null; + parentNode = null; const renderedType = isNode(definedType.type, ['enumTypeNode', 'structTypeNode']) ? manifest.type @@ -204,12 +219,18 @@ export function getTypeManifestVisitor(options: { parentName = originalParentName; let derive = ''; - if (childManifest.type === '(Pubkey)') { - derive = - '#[cfg_attr(feature = "serde", serde(with = "serde_with::As::"))]\n'; - } else if (childManifest.type === '(Vec)') { - derive = - '#[cfg_attr(feature = "serde", serde(with = "serde_with::As::>"))]\n'; + if (parentNode && childManifest.type === '(Pubkey)') { + derive = getSerdeFieldAttribute( + 'serde_with::As::', + parentNode, + traitOptions, + ); + } else if (parentNode && childManifest.type === '(Vec)') { + derive = getSerdeFieldAttribute( + 'serde_with::As::>', + parentNode, + traitOptions, + ); } return { @@ -385,25 +406,36 @@ export function getTypeManifestVisitor(options: { const resolvedNestedType = resolveNestedTypeNode(structFieldType.type); let derive = ''; - if (fieldManifest.type === 'Pubkey') { - derive = - '#[cfg_attr(feature = "serde", serde(with = "serde_with::As::"))]\n'; - } else if (fieldManifest.type === 'Vec') { - derive = - '#[cfg_attr(feature = "serde", serde(with = "serde_with::As::>"))]\n'; - } else if ( - isNode(resolvedNestedType, 'arrayTypeNode') && - isNode(resolvedNestedType.count, 'fixedCountNode') && - resolvedNestedType.count.value > 32 - ) { - derive = '#[cfg_attr(feature = "serde", serde(with = "serde_big_array::BigArray"))]\n'; - } else if ( - isNode(resolvedNestedType, ['bytesTypeNode', 'stringTypeNode']) && - isNode(structFieldType.type, 'fixedSizeTypeNode') && - structFieldType.type.size > 32 - ) { - derive = - '#[cfg_attr(feature = "serde", serde(with = "serde_with::As::"))]\n'; + if (parentNode) { + if (fieldManifest.type === 'Pubkey') { + derive = getSerdeFieldAttribute( + 'serde_with::As::', + parentNode, + traitOptions, + ); + } else if (fieldManifest.type === 'Vec') { + derive = getSerdeFieldAttribute( + 'serde_with::As::>', + parentNode, + traitOptions, + ); + } else if ( + isNode(resolvedNestedType, 'arrayTypeNode') && + isNode(resolvedNestedType.count, 'fixedCountNode') && + resolvedNestedType.count.value > 32 + ) { + derive = getSerdeFieldAttribute('serde_big_array::BigArray', parentNode, traitOptions); + } else if ( + isNode(resolvedNestedType, ['bytesTypeNode', 'stringTypeNode']) && + isNode(structFieldType.type, 'fixedSizeTypeNode') && + structFieldType.type.size > 32 + ) { + derive = getSerdeFieldAttribute( + 'serde_with::As::', + parentNode, + traitOptions, + ); + } } return { diff --git a/src/utils/traitOptions.ts b/src/utils/traitOptions.ts index 170ddac..90998d3 100644 --- a/src/utils/traitOptions.ts +++ b/src/utils/traitOptions.ts @@ -174,3 +174,60 @@ function extractFullyQualifiedNames(traits: string[], imports: ImportMap): strin return trait.slice(index + 2); }); } + +/** + * Helper function to get the serde field attribute format based on trait configuration. + * Returns the appropriate attribute string for serde field customization, or empty string if no serde traits. + */ +export function getSerdeFieldAttribute( + serdeWith: string, + node: AccountNode | DefinedTypeNode | InstructionNode, + userOptions: TraitOptions = {}, +): string { + assertIsNode(node, ['accountNode', 'definedTypeNode', 'instructionNode']); + const options: Required = { ...DEFAULT_TRAIT_OPTIONS, ...userOptions }; + + // Get the node type and return early if it's a type alias. + const nodeType = getNodeType(node); + if (nodeType === 'alias') { + return ''; + } + + // Find all the traits for the node. + const sanitizedOverrides = Object.fromEntries( + Object.entries(options.overrides).map(([key, value]) => [camelCase(key), value]), + ); + const nodeOverrides: string[] | undefined = sanitizedOverrides[node.name]; + const allTraits = nodeOverrides === undefined ? getDefaultTraits(nodeType, options) : nodeOverrides; + + // Check if serde traits are present. + const hasSerdeSerialize = allTraits.some(t => t === 'serde::Serialize' || t === 'Serialize'); + const hasSerdeDeserialize = allTraits.some(t => t === 'serde::Deserialize' || t === 'Deserialize'); + + if (!hasSerdeSerialize && !hasSerdeDeserialize) { + return ''; + } + + // Check if serde is feature-flagged. + const partitionedTraits = partitionTraitsInFeatures(allTraits, options.featureFlags); + const featuredTraits = partitionedTraits[1]; + + // Find which feature flag contains serde traits. + let serdeFeatureName: string | undefined; + for (const [feature, traits] of Object.entries(featuredTraits)) { + if ( + traits.some( + t => t === 'serde::Serialize' || t === 'serde::Deserialize' || t === 'Serialize' || t === 'Deserialize', + ) + ) { + serdeFeatureName = feature; + break; + } + } + + if (serdeFeatureName) { + return `#[cfg_attr(feature = "${serdeFeatureName}", serde(with = "${serdeWith}"))]\n`; + } else { + return `#[serde(with = "${serdeWith}")]\n`; + } +} diff --git a/test/utils/traitOptions.test.ts b/test/utils/traitOptions.test.ts index 3532411..46c3e24 100644 --- a/test/utils/traitOptions.test.ts +++ b/test/utils/traitOptions.test.ts @@ -1,14 +1,20 @@ import { accountNode, + arrayTypeNode, + bytesTypeNode, camelCase, definedTypeNode, enumEmptyVariantTypeNode, enumStructVariantTypeNode, enumTypeNode, + fixedCountNode, + fixedSizeTypeNode, instructionArgumentNode, instructionNode, numberTypeNode, + prefixedCountNode, programNode, + publicKeyTypeNode, rootNode, structFieldTypeNode, structTypeNode, @@ -480,3 +486,212 @@ describe('conditional try_to_vec generation', () => { expect(instruction).toContain('#[derive(Clone, Debug)]'); }); }); + +describe('conditional serde field attributes', () => { + test('it generates cfg_attr serde field attributes when serde is feature-flagged (default)', () => { + // Given an account with a Pubkey field. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ name: 'authority', type: publicKeyTypeNode() }), + structFieldTypeNode({ + name: 'tokens', + type: arrayTypeNode(publicKeyTypeNode(), prefixedCountNode(numberTypeNode('u32'))), + }), + ]), + name: 'myAccount', + }); + + // When we render with default traits (serde is feature-flagged). + const renderMap = visit( + rootNode(programNode({ accounts: [node], name: 'myProgram', publicKey: '1111' })), + getRenderMapVisitor(), + ); + + // Then we expect field attributes to be wrapped in cfg_attr. + const account = renderMap.get('accounts/my_account.rs') as string; + expect(account).toContain( + '#[cfg_attr(feature = "serde", serde(with = "serde_with::As::"))]', + ); + expect(account).toContain( + '#[cfg_attr(feature = "serde", serde(with = "serde_with::As::>"))]', + ); + }); + + test('it generates plain serde field attributes when serde is not feature-flagged', () => { + // Given an account with a Pubkey field. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ name: 'authority', type: publicKeyTypeNode() }), + structFieldTypeNode({ + name: 'tokens', + type: arrayTypeNode(publicKeyTypeNode(), prefixedCountNode(numberTypeNode('u32'))), + }), + ]), + name: 'myAccount', + }); + + // When we render with serde in base defaults but no feature flags. + const renderMap = visit( + rootNode(programNode({ accounts: [node], name: 'myProgram', publicKey: '1111' })), + getRenderMapVisitor({ + traitOptions: { + baseDefaults: [ + 'BorshSerialize', + 'BorshDeserialize', + 'serde::Serialize', + 'serde::Deserialize', + 'Clone', + 'Debug', + 'Eq', + 'PartialEq', + ], + featureFlags: {}, // No feature flags + }, + }), + ); + + // Then we expect field attributes without cfg_attr. + const account = renderMap.get('accounts/my_account.rs') as string; + expect(account).toContain('#[serde(with = "serde_with::As::")]'); + expect(account).toContain('#[serde(with = "serde_with::As::>")]'); + expect(account).not.toContain('#[cfg_attr(feature = "serde"'); + }); + + test('it omits serde field attributes when serde traits are removed', () => { + // Given an account with a Pubkey field. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ name: 'authority', type: publicKeyTypeNode() }), + structFieldTypeNode({ + name: 'tokens', + type: arrayTypeNode(publicKeyTypeNode(), prefixedCountNode(numberTypeNode('u32'))), + }), + ]), + name: 'myAccount', + }); + + // When we render without serde traits. + const renderMap = visit( + rootNode(programNode({ accounts: [node], name: 'myProgram', publicKey: '1111' })), + getRenderMapVisitor({ + traitOptions: { + baseDefaults: ['BorshSerialize', 'BorshDeserialize', 'Clone', 'Debug'], + featureFlags: {}, + }, + }), + ); + + // Then we expect no serde field attributes at all. + const account = renderMap.get('accounts/my_account.rs') as string; + expect(account).not.toContain('serde(with'); + expect(account).not.toContain('serde_with::As'); + expect(account).not.toContain('DisplayFromStr'); + }); + + test('it handles large array serde attributes conditionally', () => { + // Given an account with a large fixed array field. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ + name: 'data', + type: fixedSizeTypeNode(arrayTypeNode(numberTypeNode('u8'), fixedCountNode(64)), 64), + }), + ]), + name: 'myAccount', + }); + + // When we render with default traits (serde is feature-flagged). + const renderMap = visit( + rootNode(programNode({ accounts: [node], name: 'myProgram', publicKey: '1111' })), + getRenderMapVisitor(), + ); + + // Then we expect the big array attribute to be feature-flagged. + const account = renderMap.get('accounts/my_account.rs') as string; + expect(account).toContain('#[cfg_attr(feature = "serde", serde(with = "serde_big_array::BigArray"))]'); + }); + + test('it handles bytes serde attributes conditionally', () => { + // Given an account with a fixed-size bytes field. + const node = accountNode({ + data: structTypeNode([ + structFieldTypeNode({ + name: 'signature', + type: fixedSizeTypeNode(bytesTypeNode(), 64), + }), + ]), + name: 'myAccount', + }); + + // When we render with serde not feature-flagged. + const renderMap = visit( + rootNode(programNode({ accounts: [node], name: 'myProgram', publicKey: '1111' })), + getRenderMapVisitor({ + traitOptions: { + baseDefaults: ['serde::Serialize', 'serde::Deserialize', 'Clone', 'Debug'], + featureFlags: {}, + }, + }), + ); + + // Then we expect the bytes attribute without cfg_attr. + const account = renderMap.get('accounts/my_account.rs') as string; + expect(account).toContain('#[serde(with = "serde_with::As::")]'); + expect(account).not.toContain('#[cfg_attr(feature = "serde"'); + }); + + test('it uses custom feature names for serde field attributes', () => { + // Given an account with a Pubkey field. + const node = accountNode({ + data: structTypeNode([structFieldTypeNode({ name: 'authority', type: publicKeyTypeNode() })]), + name: 'myAccount', + }); + + // When we render with serde under a custom feature name. + const renderMap = visit( + rootNode(programNode({ accounts: [node], name: 'myProgram', publicKey: '1111' })), + getRenderMapVisitor({ + traitOptions: { + baseDefaults: ['BorshSerialize', 'BorshDeserialize', 'serde::Serialize', 'serde::Deserialize'], + featureFlags: { + json_support: ['serde::Serialize', 'serde::Deserialize'], + }, + }, + }), + ); + + // Then we expect field attributes with the custom feature name. + const account = renderMap.get('accounts/my_account.rs') as string; + expect(account).toContain( + '#[cfg_attr(feature = "json_support", serde(with = "serde_with::As::"))]', + ); + expect(account).not.toContain('#[cfg_attr(feature = "serde"'); + }); + + test('it respects overrides for serde field attributes', () => { + // Given an account with a Pubkey field. + const node = accountNode({ + data: structTypeNode([structFieldTypeNode({ name: 'authority', type: publicKeyTypeNode() })]), + name: 'myAccount', + }); + + // When we render with overrides that include serde but not feature-flagged. + const renderMap = visit( + rootNode(programNode({ accounts: [node], name: 'myProgram', publicKey: '1111' })), + getRenderMapVisitor({ + traitOptions: { + baseDefaults: ['Clone', 'Debug'], + featureFlags: {}, + overrides: { + myAccount: ['serde::Serialize', 'serde::Deserialize', 'Clone', 'Debug'], + }, + }, + }), + ); + + // Then we expect field attributes without cfg_attr since override has serde but no feature flag. + const account = renderMap.get('accounts/my_account.rs') as string; + expect(account).toContain('#[serde(with = "serde_with::As::")]'); + expect(account).not.toContain('#[cfg_attr(feature = "serde"'); + }); +});