Skip to content

Commit

Permalink
feat: Recursively derive seeds and add custom account resolver (#2194)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChewingGlass authored Sep 21, 2022
1 parent 2a07d84 commit 436791b
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 35 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ The minor version will be incremented upon a breaking change and the patch versi
* lang: Add parsing for consts from impl blocks for IDL PDA seeds generation ([#2128](https://github.com/coral-xyz/anchor/pull/2014))
* lang: Account closing reassigns to system program and reallocates ([#2169](https://github.com/coral-xyz/anchor/pull/2169)).
* ts: Add coders for SPL programs ([#2143](https://github.com/coral-xyz/anchor/pull/2143)).
* ts: Add `has_one` relations inference so accounts mapped via has_one relationships no longer need to be provided
* ts: Add ability to set args after setting accounts and retriving pubkyes
* ts: Add `.prepare()` to builder pattern
* ts: Add `has_one` relations inference so accounts mapped via has_one relationships no longer need to be provided ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
* ts: Add ability to set args after setting accounts and retrieving pubkyes ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
* ts: Add `.prepare()` to builder pattern ([#2160](https://github.com/coral-xyz/anchor/pull/2160))
* spl: Add `freeze_delegated_account` and `thaw_delegated_account` wrappers ([#2164](https://github.com/coral-xyz/anchor/pull/2164))
* ts: Add nested PDA inference ([#2194](https://github.com/coral-xyz/anchor/pull/2194))
* ts: Add ability to resolve missing accounts with a custom resolver ([#2194](https://github.com/coral-xyz/anchor/pull/2194))

### Fixes

Expand Down
19 changes: 19 additions & 0 deletions tests/pda-derivation/programs/pda-derivation/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,30 @@ pub struct InitMyAccount<'info> {
bump,
)]
account: Account<'info, MyAccount>,
nested: Nested<'info>,
#[account(mut)]
payer: Signer<'info>,
system_program: Program<'info, System>,
}

#[derive(Accounts)]
pub struct Nested<'info> {
#[account(
seeds = [
"nested-seed".as_bytes(),
b"test".as_ref(),
MY_SEED.as_ref(),
MY_SEED_STR.as_bytes(),
MY_SEED_U8.to_le_bytes().as_ref(),
&MY_SEED_U32.to_le_bytes(),
&MY_SEED_U64.to_le_bytes(),
],
bump,
)]
/// CHECK: Not needed
account_nested: AccountInfo<'info>,
}

#[account]
pub struct MyAccount {
data: u64,
Expand Down
27 changes: 27 additions & 0 deletions tests/pda-derivation/tests/typescript.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,31 @@ describe("typescript", () => {
.data;
expect(actualData.toNumber()).is.equal(1337);
});

it("should allow custom resolvers", async () => {
let called = false;
const customProgram = new Program<PdaDerivation>(
program.idl,
program.programId,
program.provider,
program.coder,
(instruction) => {
if (instruction.name === "initMyAccount") {
return async ({ accounts }) => {
called = true;
return accounts;
};
}
}
);
await customProgram.methods
.initMyAccount(seedA)
.accounts({
base: base.publicKey,
base2: base.publicKey,
})
.pubkeys();

expect(called).is.true;
});
});
78 changes: 57 additions & 21 deletions ts/packages/anchor/src/program/accounts-resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ import { BorshAccountsCoder } from "src/coder/index.js";

type Accounts = { [name: string]: PublicKey | Accounts };

export type CustomAccountResolver<IDL extends Idl> = (params: {
args: Array<any>;
accounts: Accounts;
provider: Provider;
programId: PublicKey;
idlIx: AllInstructions<IDL>;
}) => Promise<Accounts>;

// Populates a given accounts context with PDAs and common missing accounts.
export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
_args: Array<any>;
Expand All @@ -35,7 +43,8 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
private _provider: Provider,
private _programId: PublicKey,
private _idlIx: AllInstructions<IDL>,
_accountNamespace: AccountNamespace<IDL>
_accountNamespace: AccountNamespace<IDL>,
private _customResolver?: CustomAccountResolver<IDL>
) {
this._args = _args;
this._accountStore = new AccountStore(_provider, _accountNamespace);
Expand Down Expand Up @@ -84,25 +93,22 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
}
}

for (let k = 0; k < this._idlIx.accounts.length; k += 1) {
// Cast is ok because only a non-nested IdlAccount can have a seeds
// cosntraint.
const accountDesc = this._idlIx.accounts[k] as IdlAccount;
const accountDescName = camelCase(accountDesc.name);

// PDA derived from IDL seeds.
if (
accountDesc.pda &&
accountDesc.pda.seeds.length > 0 &&
!this._accounts[accountDescName]
) {
await this.autoPopulatePda(accountDesc);
continue;
}
// Auto populate pdas and relations until we stop finding new accounts
while (
(await this.resolvePdas(this._idlIx.accounts)) +
(await this.resolveRelations(this._idlIx.accounts)) >
0
) {}

if (this._customResolver) {
this._accounts = await this._customResolver({
args: this._args,
accounts: this._accounts,
provider: this._provider,
programId: this._programId,
idlIx: this._idlIx,
});
}

// Auto populate has_one relationships until we stop finding new accounts
while ((await this.resolveRelations(this._idlIx.accounts)) > 0) {}
}

private get(path: string[]): PublicKey | undefined {
Expand Down Expand Up @@ -130,6 +136,36 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
});
}

private async resolvePdas(
accounts: IdlAccountItem[],
path: string[] = []
): Promise<number> {
let found = 0;
for (let k = 0; k < accounts.length; k += 1) {
const accountDesc = accounts[k];
const subAccounts = (accountDesc as IdlAccounts).accounts;
if (subAccounts) {
found += await this.resolvePdas(subAccounts, [
...path,
accountDesc.name,
]);
}

const accountDescCasted: IdlAccount = accountDesc as IdlAccount;
const accountDescName = camelCase(accountDesc.name);
// PDA derived from IDL seeds.
if (
accountDescCasted.pda &&
accountDescCasted.pda.seeds.length > 0 &&
!this.get([...path, accountDescName])
) {
await this.autoPopulatePda(accountDescCasted, path);
found += 1;
}
}
return found;
}

private async resolveRelations(
accounts: IdlAccountItem[],
path: string[] = []
Expand Down Expand Up @@ -172,7 +208,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
return found;
}

private async autoPopulatePda(accountDesc: IdlAccount) {
private async autoPopulatePda(accountDesc: IdlAccount, path: string[] = []) {
if (!accountDesc.pda || !accountDesc.pda.seeds)
throw new Error("Must have seeds");

Expand All @@ -183,7 +219,7 @@ export class AccountsResolver<IDL extends Idl, I extends AllInstructions<IDL>> {
const programId = await this.parseProgramId(accountDesc);
const [pubkey] = await PublicKey.findProgramAddress(seeds, programId);

this._accounts[camelCase(accountDesc.name)] = pubkey;
this.set([...path, camelCase(accountDesc.name)], pubkey);
}

private async parseProgramId(accountDesc: IdlAccount): Promise<PublicKey> {
Expand Down
19 changes: 16 additions & 3 deletions ts/packages/anchor/src/program/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { inflate } from "pako";
import { PublicKey } from "@solana/web3.js";
import Provider, { getProvider } from "../provider.js";
import { Idl, idlAddress, decodeIdlAccount } from "../idl.js";
import { Idl, idlAddress, decodeIdlAccount, IdlInstruction } from "../idl.js";
import { Coder, BorshCoder } from "../coder/index.js";
import NamespaceFactory, {
RpcNamespace,
Expand All @@ -16,6 +16,7 @@ import NamespaceFactory, {
import { utf8 } from "../utils/bytes/index.js";
import { EventManager } from "./event.js";
import { Address, translateAddress } from "./common.js";
import { CustomAccountResolver } from "./accounts-resolver.js";

export * from "./common.js";
export * from "./context.js";
Expand Down Expand Up @@ -263,12 +264,18 @@ export class Program<IDL extends Idl = Idl> {
* @param programId The on-chain address of the program.
* @param provider The network and wallet context to use. If not provided
* then uses [[getProvider]].
* @param getCustomResolver A function that returns a custom account resolver
* for the given instruction. This is useful for resolving
* public keys of missing accounts when building instructions
*/
public constructor(
idl: IDL,
programId: Address,
provider?: Provider,
coder?: Coder
coder?: Coder,
getCustomResolver?: (
instruction: IdlInstruction
) => CustomAccountResolver<IDL> | undefined
) {
programId = translateAddress(programId);

Expand All @@ -293,7 +300,13 @@ export class Program<IDL extends Idl = Idl> {
methods,
state,
views,
] = NamespaceFactory.build(idl, this._coder, programId, provider);
] = NamespaceFactory.build(
idl,
this._coder,
programId,
provider,
getCustomResolver ?? (() => undefined)
);
this.rpc = rpc;
this.instruction = instruction;
this.transaction = transaction;
Expand Down
11 changes: 8 additions & 3 deletions ts/packages/anchor/src/program/namespace/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import camelCase from "camelcase";
import { PublicKey } from "@solana/web3.js";
import { Coder } from "../../coder/index.js";
import Provider from "../../provider.js";
import { Idl } from "../../idl.js";
import { Idl, IdlInstruction } from "../../idl.js";
import StateFactory, { StateClient } from "./state.js";
import InstructionFactory, { InstructionNamespace } from "./instruction.js";
import TransactionFactory, { TransactionNamespace } from "./transaction.js";
Expand All @@ -12,6 +12,7 @@ import SimulateFactory, { SimulateNamespace } from "./simulate.js";
import { parseIdlErrors } from "../common.js";
import { MethodsBuilderFactory, MethodsNamespace } from "./methods";
import ViewFactory, { ViewNamespace } from "./views";
import { CustomAccountResolver } from "../accounts-resolver.js";

// Re-exports.
export { StateClient } from "./state.js";
Expand All @@ -32,7 +33,10 @@ export default class NamespaceFactory {
idl: IDL,
coder: Coder,
programId: PublicKey,
provider: Provider
provider: Provider,
getCustomResolver?: (
instruction: IdlInstruction
) => CustomAccountResolver<IDL> | undefined
): [
RpcNamespace<IDL>,
InstructionNamespace<IDL>,
Expand Down Expand Up @@ -85,7 +89,8 @@ export default class NamespaceFactory {
rpcItem,
simulateItem,
viewItem,
account
account,
getCustomResolver && getCustomResolver(idlIx)
);
const name = camelCase(idlIx.name);

Expand Down
17 changes: 12 additions & 5 deletions ts/packages/anchor/src/program/namespace/methods.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ import { SimulateFn } from "./simulate.js";
import { ViewFn } from "./views.js";
import Provider from "../../provider.js";
import { AccountNamespace } from "./account.js";
import { AccountsResolver } from "../accounts-resolver.js";
import {
AccountsResolver,
CustomAccountResolver,
} from "../accounts-resolver.js";
import { Accounts } from "../context.js";

export type MethodsNamespace<
Expand All @@ -40,7 +43,8 @@ export class MethodsBuilderFactory {
rpcFn: RpcFn<IDL>,
simulateFn: SimulateFn<IDL>,
viewFn: ViewFn<IDL> | undefined,
accountNamespace: AccountNamespace<IDL>
accountNamespace: AccountNamespace<IDL>,
customResolver?: CustomAccountResolver<IDL>
): MethodsFn<IDL, I, MethodsBuilder<IDL, I>> {
return (...args) =>
new MethodsBuilder(
Expand All @@ -53,7 +57,8 @@ export class MethodsBuilderFactory {
provider,
programId,
idlIx,
accountNamespace
accountNamespace,
customResolver
);
}
}
Expand All @@ -78,7 +83,8 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
_provider: Provider,
_programId: PublicKey,
_idlIx: AllInstructions<IDL>,
_accountNamespace: AccountNamespace<IDL>
_accountNamespace: AccountNamespace<IDL>,
_customResolver?: CustomAccountResolver<IDL>
) {
this._args = _args;
this._accountsResolver = new AccountsResolver(
Expand All @@ -87,7 +93,8 @@ export class MethodsBuilder<IDL extends Idl, I extends AllInstructions<IDL>> {
_provider,
_programId,
_idlIx,
_accountNamespace
_accountNamespace,
_customResolver
);
}

Expand Down

0 comments on commit 436791b

Please sign in to comment.