diff --git a/pkg/code/data/nonce/memory/store.go b/pkg/code/data/nonce/memory/store.go index 5193689b..d2c2d1ce 100644 --- a/pkg/code/data/nonce/memory/store.go +++ b/pkg/code/data/nonce/memory/store.go @@ -6,8 +6,8 @@ import ( "sort" "sync" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/database/query" ) type store struct { @@ -58,27 +58,23 @@ func (s *store) findAddress(address string) *nonce.Record { } func (s *store) findByState(state nonce.State) []*nonce.Record { - res := make([]*nonce.Record, 0) - for _, item := range s.records { - if item.State == state { - res = append(res, item) - } - } - return res + return s.findFn(func(nonce *nonce.Record) bool { + return nonce.State == state + }) } func (s *store) findByStateAndPurpose(state nonce.State, purpose nonce.Purpose) []*nonce.Record { + return s.findFn(func(record *nonce.Record) bool { + return record.State == state && record.Purpose == purpose + }) +} + +func (s *store) findFn(f func(nonce *nonce.Record) bool) []*nonce.Record { res := make([]*nonce.Record, 0) for _, item := range s.records { - if item.State != state { - continue - } - - if item.Purpose != purpose { - continue + if f(item) { + res = append(res, item) } - - res = append(res, item) } return res } @@ -194,7 +190,9 @@ func (s *store) GetRandomAvailableByPurpose(ctx context.Context, purpose nonce.P s.mu.Lock() defer s.mu.Unlock() - items := s.findByStateAndPurpose(nonce.StateAvailable, purpose) + items := s.findFn(func(n *nonce.Record) bool { + return n.Purpose == purpose && n.IsAvailable() + }) if len(items) == 0 { return nil, nonce.ErrNonceNotFound } diff --git a/pkg/code/data/nonce/nonce.go b/pkg/code/data/nonce/nonce.go index 69cec6d6..97f39378 100644 --- a/pkg/code/data/nonce/nonce.go +++ b/pkg/code/data/nonce/nonce.go @@ -3,6 +3,7 @@ package nonce import ( "crypto/ed25519" "errors" + "time" "github.com/mr-tron/base58" ) @@ -20,14 +21,11 @@ const ( StateAvailable // The nonce is available to be used by a payment intent, subscription, or other nonce-related transaction. StateReserved // The nonce is reserved by a payment intent, subscription, or other nonce-related transaction. StateInvalid // The nonce account is invalid (e.g. insufficient funds, etc). + StateClaimed // The nonce is claimed for future use by a process (identified by Node ID). ) -// Split nonce pool across different use cases. This has an added benefit of: -// - Solving for race conditions without distributed locks. -// - Avoiding different use cases from starving each other and ending up in a -// deadlocked state. Concretely, it would be really bad if clients could starve -// internal processes from creating transactions that would allow us to progress -// and submit existing transactions. +// Purpose indicates the intended use purpose of the nonce. By partitioning nonce's by +// purpose, we help prevent various use cases from starving each other. type Purpose uint8 const ( @@ -46,6 +44,17 @@ type Record struct { Purpose Purpose State State + // Contains the NodeId that transitioned the state into StateClaimed. + // + // Should be ignored if State != StateClaimed. + ClaimNodeId string + + // The time at which StateClaimed is no longer valid, and the state should + // be considered StateAvailable. + // + // Should be ignored if State != StateClaimed. + ClaimExpiresAt time.Time + Signature string } @@ -53,15 +62,28 @@ func (r *Record) GetPublicKey() (ed25519.PublicKey, error) { return base58.Decode(r.Address) } +func (r *Record) IsAvailable() bool { + if r.State == StateAvailable { + return true + } + if r.State != StateClaimed { + return false + } + + return time.Now().After(r.ClaimExpiresAt) +} + func (r *Record) Clone() Record { return Record{ - Id: r.Id, - Address: r.Address, - Authority: r.Authority, - Blockhash: r.Blockhash, - Purpose: r.Purpose, - State: r.State, - Signature: r.Signature, + Id: r.Id, + Address: r.Address, + Authority: r.Authority, + Blockhash: r.Blockhash, + Purpose: r.Purpose, + State: r.State, + ClaimNodeId: r.ClaimNodeId, + ClaimExpiresAt: r.ClaimExpiresAt, + Signature: r.Signature, } } @@ -72,21 +94,33 @@ func (r *Record) CopyTo(dst *Record) { dst.Blockhash = r.Blockhash dst.Purpose = r.Purpose dst.State = r.State + dst.ClaimNodeId = r.ClaimNodeId + dst.ClaimExpiresAt = r.ClaimExpiresAt dst.Signature = r.Signature } -func (v *Record) Validate() error { - if len(v.Address) == 0 { +func (r *Record) Validate() error { + if len(r.Address) == 0 { return errors.New("nonce account address is required") } - if len(v.Authority) == 0 { + if len(r.Authority) == 0 { return errors.New("authority address is required") } - if v.Purpose == PurposeUnknown { + if r.Purpose == PurposeUnknown { return errors.New("nonce purpose must be set") } + + if r.State == StateClaimed { + if r.ClaimNodeId == "" { + return errors.New("missing claim node id") + } + if r.ClaimExpiresAt == (time.Time{}) || r.ClaimExpiresAt.IsZero() { + return errors.New("missing claim expiry date") + } + } + return nil } @@ -102,6 +136,8 @@ func (s State) String() string { return "reserved" case StateInvalid: return "invalid" + case StateClaimed: + return "claimed" } return "unknown" diff --git a/pkg/code/data/nonce/postgres/model.go b/pkg/code/data/nonce/postgres/model.go index b7a990b7..87044248 100644 --- a/pkg/code/data/nonce/postgres/model.go +++ b/pkg/code/data/nonce/postgres/model.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "time" "github.com/jmoiron/sqlx" @@ -17,13 +18,15 @@ const ( ) type nonceModel struct { - Id sql.NullInt64 `db:"id"` - Address string `db:"address"` - Authority string `db:"authority"` - Blockhash string `db:"blockhash"` - Purpose uint `db:"purpose"` - State uint `db:"state"` - Signature string `db:"signature"` + Id sql.NullInt64 `db:"id"` + Address string `db:"address"` + Authority string `db:"authority"` + Blockhash string `db:"blockhash"` + Purpose uint `db:"purpose"` + State uint `db:"state"` + Signature string `db:"signature"` + ClaimNodeId string `db:"claim_node_id"` + ClaimExpiresAtMs int64 `db:"claim_expires_at"` } func toNonceModel(obj *nonce.Record) (*nonceModel, error) { @@ -32,33 +35,37 @@ func toNonceModel(obj *nonce.Record) (*nonceModel, error) { } return &nonceModel{ - Id: sql.NullInt64{Int64: int64(obj.Id), Valid: true}, - Address: obj.Address, - Authority: obj.Authority, - Blockhash: obj.Blockhash, - Purpose: uint(obj.Purpose), - State: uint(obj.State), - Signature: obj.Signature, + Id: sql.NullInt64{Int64: int64(obj.Id), Valid: true}, + Address: obj.Address, + Authority: obj.Authority, + Blockhash: obj.Blockhash, + Purpose: uint(obj.Purpose), + State: uint(obj.State), + Signature: obj.Signature, + ClaimNodeId: obj.ClaimNodeId, + ClaimExpiresAtMs: obj.ClaimExpiresAt.UnixMilli(), }, nil } func fromNonceModel(obj *nonceModel) *nonce.Record { return &nonce.Record{ - Id: uint64(obj.Id.Int64), - Address: obj.Address, - Authority: obj.Authority, - Blockhash: obj.Blockhash, - Purpose: nonce.Purpose(obj.Purpose), - State: nonce.State(obj.State), - Signature: obj.Signature, + Id: uint64(obj.Id.Int64), + Address: obj.Address, + Authority: obj.Authority, + Blockhash: obj.Blockhash, + Purpose: nonce.Purpose(obj.Purpose), + State: nonce.State(obj.State), + Signature: obj.Signature, + ClaimNodeId: obj.ClaimNodeId, + ClaimExpiresAt: time.UnixMilli(obj.ClaimExpiresAtMs), } } func (m *nonceModel) dbSave(ctx context.Context, db *sqlx.DB) error { return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { query := `INSERT INTO ` + nonceTableName + ` - (address, authority, blockhash, purpose, state, signature) - VALUES ($1,$2,$3,$4,$5,$6) + (address, authority, blockhash, purpose, state, signature, claim_node_id, claim_expires_at) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8) ON CONFLICT (address) DO UPDATE SET blockhash = $3, state = $5, signature = $6 @@ -75,6 +82,8 @@ func (m *nonceModel) dbSave(ctx context.Context, db *sqlx.DB) error { m.Purpose, m.State, m.Signature, + m.ClaimNodeId, + m.ClaimExpiresAtMs, ).StructScan(m) return pgutil.CheckNoRows(err, nonce.ErrInvalidNonce) @@ -162,48 +171,77 @@ func dbGetAllByState(ctx context.Context, db *sqlx.DB, state nonce.State, cursor return res, nil } -// todo: Implementation still isn't perfect, but better than no randomness. It's -// sufficiently efficient, as long as our nonce pool is larger than the max offset. -// todo: We may need to tune the offset based on pool size and environment, but it -// should be sufficiently good enough for now. +// We query a random nonce by first selecting any available candidate from the +// total set, applying an upper limit of 100, and _then_ randomly shuffling the +// results and selecting the first. By bounding the size before ORDER BY random(), +// we avoid having to shuffle large sets of results. +// +// Previously, we would use OFFSET FLOOR(RANDOM() * 100). However, if the pool +// (post filter) size was less than 100, any selection > pool size would result +// in the OFFSET being set to zero. This meant random() disappeared for a subset +// of values. In practice, this would result in a bias, and increased contention. +// +// For example, 50 Available nonce's, 25 Claimed (expired), 25 Reserved. With Offset: +// +// 1. 50% of the time would be a random Available. +// 2. 25% of the time would be a random expired Claimed. +// 3. 25% of the time would be _the first_ Available. +// +// This meant that 25% of the time would not be random. As we pull from the pool, +// this % only increases, further causing contention. +// +// Performance wise, this approach is slightly worse, but the vast majority of the +// time is spent on the scan and filter. Below are two example query plans (from a +// small dataset in an online editor). +// +// QUERY PLAN (OFFSET): +// +// Limit (cost=17.80..35.60 rows=1 width=140) (actual time=0.019..0.019 rows=0 loops=1) +// -> Seq Scan on codewallet__core_nonce (cost=0.00..17.80 rows=1 width=140) (actual time=0.016..0.017 rows=0 loops=1) +// Filter: ((signature IS NOT NULL) AND (purpose = 1) AND ((state = 0) OR ((state = 2) AND (claim_expires_at < 200)))) +// Rows Removed by Filter: 100 +// +// Planning Time: 0.046 ms +// Execution Time: 0.031 ms +// +// QUERY PLAN (ORDER BY): +// +// Limit (cost=17.82..17.83 rows=1 width=148) (actual time=0.018..0.019 rows=0 loops=1) +// -> Sort (cost=17.82..17.83 rows=1 width=148) (actual time=0.018..0.018 rows=0 loops=1) +// Sort Key: (random()) +// Sort Method: quicksort Memory: 25kB +// -> Subquery Scan on sub (cost=0.00..17.81 rows=1 width=148) (actual time=0.015..0.016 rows=0 loops=1) +// -> Limit (cost=0.00..17.80 rows=1 width=140) (actual time=0.015..0.015 rows=0 loops=1) +// -> Seq Scan on codewallet__core_nonce (cost=0.00..17.80 rows=1 width=140) (actual time=0.015..0.015 rows=0 loops=1) +// Filter: ((signature IS NOT NULL) AND (purpose = 1) AND ((state = 0) OR ((state = 2) AND (claim_expires_at < 200)))) +// Rows Removed by Filter: 100 +// +// Planning Time: 0.068 ms +// Execution Time: 0.037 ms +// +// Overall, the Seq Scan and Filter is the bulk of the work, with the ORDER BY RANDOM() +// adding a small (fixed) amount of overhead. The trade-off is negligible time complexity +// for more reliable semantics. func dbGetRandomAvailableByPurpose(ctx context.Context, db *sqlx.DB, purpose nonce.Purpose) (*nonceModel, error) { res := &nonceModel{} + nowMs := time.Now().UnixMilli() // Signature null check is required because some legacy records didn't have this // set and causes this call to fail. This is a result of the field not being // defined at the time of record creation. // // todo: Fix said nonce records - query := `SELECT - id, address, authority, blockhash, purpose, state, signature - FROM ` + nonceTableName + ` - WHERE state = $1 AND purpose = $2 AND signature IS NOT NULL - OFFSET FLOOR(RANDOM() * 100) - LIMIT 1 - ` - fallbackQuery := `SELECT - id, address, authority, blockhash, purpose, state, signature - FROM ` + nonceTableName + ` - WHERE state = $1 AND purpose = $2 AND signature IS NOT NULL + query := ` + SELECT id, address, authority, blockhash, purpose, state, signature FROM ( + SELECT id, address, authority, blockhash, purpose, state, signature + FROM ` + nonceTableName + ` + WHERE ((state = $1) OR (state = $2 AND claim_expires_at < $3)) AND purpose = $4 AND signature IS NOT NULL + LIMIT 100 + ) sub + ORDER BY random() LIMIT 1 ` - err := db.GetContext(ctx, res, query, nonce.StateAvailable, purpose) - if err != nil { - err = pgutil.CheckNoRows(err, nonce.ErrNonceNotFound) - - // No nonces found. Because our query isn't perfect, fall back to a - // strategy that will guarantee to select something if an available - // nonce exists. - if err == nonce.ErrNonceNotFound { - err := db.GetContext(ctx, res, fallbackQuery, nonce.StateAvailable, purpose) - if err != nil { - return nil, pgutil.CheckNoRows(err, nonce.ErrNonceNotFound) - } - return res, nil - } - - return nil, err - } - return res, nil + err := db.GetContext(ctx, res, query, nonce.StateAvailable, nonce.StateClaimed, nowMs, purpose) + return res, pgutil.CheckNoRows(err, nonce.ErrNonceNotFound) } diff --git a/pkg/code/data/nonce/postgres/store_test.go b/pkg/code/data/nonce/postgres/store_test.go index bfc3b46f..70ef5ee6 100644 --- a/pkg/code/data/nonce/postgres/store_test.go +++ b/pkg/code/data/nonce/postgres/store_test.go @@ -22,13 +22,16 @@ const ( CREATE TABLE codewallet__core_nonce( id SERIAL NOT NULL PRIMARY KEY, - address text NOT NULL UNIQUE, + address text NOT NULL UNIQUE, authority text NOT NULL, blockhash text NULL, purpose integer NOT NULL, state integer NOT NULL, - signature text NULL + signature text NULL, + + claim_node_id text NULL, + claim_expires_at bigint NULL ); ` diff --git a/pkg/code/data/nonce/tests/tests.go b/pkg/code/data/nonce/tests/tests.go index 8eee5cff..3f6bbbb1 100644 --- a/pkg/code/data/nonce/tests/tests.go +++ b/pkg/code/data/nonce/tests/tests.go @@ -3,10 +3,13 @@ package tests import ( "context" "fmt" + "strconv" + "strings" "testing" + "time" - "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/database/query" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,6 +18,7 @@ func RunTests(t *testing.T, s nonce.Store, teardown func()) { for _, tf := range []func(t *testing.T, s nonce.Store){ testRoundTrip, testUpdate, + testUpdateInvalid, testGetAllByState, testGetCount, testGetRandomAvailableByPurpose, @@ -80,6 +84,52 @@ func testUpdate(t *testing.T, s nonce.Store) { assert.EqualValues(t, 1, actual.Id) } +func testUpdateInvalid(t *testing.T, s nonce.Store) { + ctx := context.Background() + + for _, invalid := range []*nonce.Record{ + {}, + { + Address: "test_address", + }, + { + Address: "test_address", + Authority: "test_authority", + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "block_hash", + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "test_blockhash", + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "test_blockhash", + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + ClaimNodeId: "my-node", + }, + { + Address: "test_address", + Authority: "test_authority", + Blockhash: "test_blockhash", + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + ClaimExpiresAt: time.Now().Add(time.Hour), + }, + } { + require.Error(t, invalid.Validate()) + assert.Error(t, s.Save(ctx, invalid)) + } +} + func testGetAllByState(t *testing.T, s nonce.Store) { ctx := context.Background() @@ -260,8 +310,9 @@ func testGetRandomAvailableByPurpose(t *testing.T, s nonce.Store) { nonce.StateUnknown, nonce.StateAvailable, nonce.StateReserved, + nonce.StateClaimed, } { - for i := 0; i < 500; i++ { + for i := 0; i < 50; i++ { record := &nonce.Record{ Address: fmt.Sprintf("nonce_%s_%s_%d", purpose, state, i), Authority: "authority", @@ -270,27 +321,83 @@ func testGetRandomAvailableByPurpose(t *testing.T, s nonce.Store) { State: state, Signature: "", } + if state == nonce.StateClaimed { + record.ClaimNodeId = "my-node-id" + + if i < 25 { + record.ClaimExpiresAt = time.Now().Add(-time.Hour) + } else { + record.ClaimExpiresAt = time.Now().Add(time.Hour) + } + } + require.NoError(t, s.Save(ctx, record)) } } } + var sequentialLoads int + var availableState, claimedState int + var lastNonce *nonce.Record selectedByAddress := make(map[string]struct{}) - for i := 0; i < 100; i++ { + for i := 0; i < 1000; i++ { actual, err := s.GetRandomAvailableByPurpose(ctx, nonce.PurposeClientTransaction) require.NoError(t, err) assert.Equal(t, nonce.PurposeClientTransaction, actual.Purpose) - assert.Equal(t, nonce.StateAvailable, actual.State) + assert.True(t, actual.IsAvailable()) + + switch actual.State { + case nonce.StateAvailable: + availableState++ + case nonce.StateClaimed: + claimedState++ + assert.True(t, time.Now().After(actual.ClaimExpiresAt)) + default: + } + + // We test for randomness by ensuring we're not loading nonce's sequentially. + if lastNonce != nil && lastNonce.Purpose == actual.Purpose { + lastID, err := strconv.ParseInt(strings.Split(lastNonce.Address, "_")[4], 10, 64) + require.NoError(t, err) + currentID, _ := strconv.ParseInt(strings.Split(actual.Address, "_")[4], 10, 64) + require.NoError(t, err) + + if currentID == lastID+1 { + sequentialLoads++ + } + } + selectedByAddress[actual.Address] = struct{}{} + lastNonce = actual } - assert.True(t, len(selectedByAddress) > 10) + assert.Greater(t, len(selectedByAddress), 10) + assert.NotZero(t, availableState) + assert.NotZero(t, claimedState) + + // We allocated 50 available nonce's, and 25 expired claim nonces. Given that + // we randomly select out of the first available 100 nonces, we expect a ratio + // of 2:1 Available vs Expired Claimed nonces. + assert.InDelta(t, 2.0, float64(availableState)/float64(claimedState), 0.5) + + assert.Less(t, sequentialLoads, 100) + availableState, claimedState = 0, 0 selectedByAddress = make(map[string]struct{}) for i := 0; i < 100; i++ { actual, err := s.GetRandomAvailableByPurpose(ctx, nonce.PurposeInternalServerProcess) require.NoError(t, err) assert.Equal(t, nonce.PurposeInternalServerProcess, actual.Purpose) - assert.Equal(t, nonce.StateAvailable, actual.State) + assert.True(t, actual.IsAvailable()) + + switch actual.State { + case nonce.StateAvailable: + availableState++ + case nonce.StateClaimed: + claimedState++ + assert.True(t, time.Now().After(actual.ClaimExpiresAt)) + default: + } + selectedByAddress[actual.Address] = struct{}{} } assert.True(t, len(selectedByAddress) > 10) diff --git a/pkg/code/transaction/nonce.go b/pkg/code/transaction/nonce.go index 1f6fc7f3..6bfca0a1 100644 --- a/pkg/code/transaction/nonce.go +++ b/pkg/code/transaction/nonce.go @@ -8,12 +8,12 @@ import ( "github.com/mr-tron/base58" - "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/solana" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" + "github.com/code-payments/code-server/pkg/retry" + "github.com/code-payments/code-server/pkg/solana" ) var ( @@ -66,7 +66,7 @@ func SelectAvailableNonce(ctx context.Context, data code_data.Provider, useCase defer globalNonceLock.Unlock() randomRecord, err := data.GetRandomAvailableNonceByPurpose(ctx, useCase) - if err == nonce.ErrNonceNotFound { + if errors.Is(err, nonce.ErrNonceNotFound) { return ErrNoAvailableNonces } else if err != nil { return err @@ -77,14 +77,14 @@ func SelectAvailableNonce(ctx context.Context, data code_data.Provider, useCase lock = getNonceLock(record.Address) lock.Lock() - // Refetch because the state could have changed by the time we got the lock + // Re-fetch because the state could have changed by the time we got the lock record, err = data.GetNonce(ctx, record.Address) if err != nil { lock.Unlock() return err } - if record.State != nonce.StateAvailable { + if !record.IsAvailable() { // Unlock and try again lock.Unlock() return errors.New("selected nonce that became unavailable") @@ -105,6 +105,8 @@ func SelectAvailableNonce(ctx context.Context, data code_data.Provider, useCase // Reserve the nonce for use with a fulfillment record.State = nonce.StateReserved + record.ClaimNodeId = "" + record.ClaimExpiresAt = time.UnixMilli(0) err = data.SaveNonce(ctx, record) if err != nil { lock.Unlock() @@ -212,12 +214,15 @@ func (n *SelectedNonce) MarkReservedWithSignature(ctx context.Context, sig strin return n.data.SaveNonce(ctx, n.record) } - if n.record.State != nonce.StateAvailable { + if !n.record.IsAvailable() { return errors.New("nonce must be available to reserve") } n.record.State = nonce.StateReserved n.record.Signature = sig + n.record.ClaimNodeId = "" + n.record.ClaimExpiresAt = time.UnixMilli(0) + return n.data.SaveNonce(ctx, n.record) } @@ -250,8 +255,8 @@ func (n *SelectedNonce) UpdateSignature(ctx context.Context, sig string) error { // ReleaseIfNotReserved makes a nonce available if it hasn't been reserved with // a signature. It's recommended to call this in tandem with Unlock when the -// caller knows it's safe to go from the reserved to available state (ie. don't -// use this in uprade flows!). +// caller knows it's safe to go from the reserved to available state (i.e. don't +// use this in upgrade flows!). func (n *SelectedNonce) ReleaseIfNotReserved() error { n.localLock.Lock() defer n.localLock.Unlock() @@ -264,6 +269,12 @@ func (n *SelectedNonce) ReleaseIfNotReserved() error { return nil } + if n.record.State == nonce.StateClaimed { + n.record.State = nonce.StateAvailable + n.record.ClaimNodeId = "" + n.record.ClaimExpiresAt = time.UnixMilli(0) + } + // A nonce is not fully reserved if it's state is reserved, but there is no // assigned signature. if n.record.State == nonce.StateReserved && len(n.record.Signature) == 0 { diff --git a/pkg/code/transaction/nonce_test.go b/pkg/code/transaction/nonce_test.go index ca8d7f28..86c19e70 100644 --- a/pkg/code/transaction/nonce_test.go +++ b/pkg/code/transaction/nonce_test.go @@ -10,14 +10,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/solana" - "github.com/code-payments/code-server/pkg/testutil" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/nonce" "github.com/code-payments/code-server/pkg/code/data/vault" + "github.com/code-payments/code-server/pkg/pointer" + "github.com/code-payments/code-server/pkg/solana" + "github.com/code-payments/code-server/pkg/testutil" ) func TestNonce_SelectAvailableNonce(t *testing.T) { @@ -62,6 +62,32 @@ func TestNonce_SelectAvailableNonce(t *testing.T) { assert.Equal(t, ErrNoAvailableNonces, err) } +func TestNonce_SelectAvailableNonceClaimed(t *testing.T) { + env := setupNonceTestEnv(t) + + // Should ignore (non-expired) claimed nonces. + expiredNonces := map[string]*nonce.Record{} + for i := 0; i < 10; i++ { + n := generateClaimedNonce(t, env, true) + expiredNonces[n.Address] = n + + generateClaimedNonce(t, env, false) + } + + for i := 0; i < 10; i++ { + nonce, err := SelectAvailableNonce(env.ctx, env.data, nonce.PurposeClientTransaction) + require.NoError(t, err) + + _, ok := expiredNonces[nonce.Account.PublicKey().ToBase58()] + require.True(t, ok) + require.True(t, nonce.record.ClaimExpiresAt.Before(time.Now())) + delete(expiredNonces, nonce.Account.PublicKey().ToBase58()) + } + + _, err := SelectAvailableNonce(env.ctx, env.data, nonce.PurposeInternalServerProcess) + require.ErrorIs(t, ErrNoAvailableNonces, err) +} + func TestNonce_SelectNonceFromFulfillmentToUpgrade_HappyPath(t *testing.T) { env := setupNonceTestEnv(t) @@ -238,7 +264,7 @@ func generateAvailableNonce(t *testing.T, env nonceTestEnv, useCase nonce.Purpos Address: nonceAccount.PublicKey().ToBase58(), Authority: common.GetSubsidizer().PublicKey().ToBase58(), Blockhash: base58.Encode(bh[:]), - Purpose: nonce.PurposeClientTransaction, + Purpose: useCase, State: nonce.StateAvailable, } require.NoError(t, env.data.SaveKey(env.ctx, nonceKey)) @@ -246,6 +272,37 @@ func generateAvailableNonce(t *testing.T, env nonceTestEnv, useCase nonce.Purpos return nonceRecord } +func generateClaimedNonce(t *testing.T, env nonceTestEnv, expired bool) *nonce.Record { + nonceAccount := testutil.NewRandomAccount(t) + + var bh solana.Blockhash + rand.Read(bh[:]) + + nonceKey := &vault.Record{ + PublicKey: nonceAccount.PublicKey().ToBase58(), + PrivateKey: nonceAccount.PrivateKey().ToBase58(), + State: vault.StateAvailable, + CreatedAt: time.Now(), + } + nonceRecord := &nonce.Record{ + Address: nonceAccount.PublicKey().ToBase58(), + Authority: common.GetSubsidizer().PublicKey().ToBase58(), + Blockhash: base58.Encode(bh[:]), + Purpose: nonce.PurposeClientTransaction, + State: nonce.StateClaimed, + ClaimNodeId: "my-node-id", + } + if expired { + nonceRecord.ClaimExpiresAt = time.Now().Add(-time.Hour) + } else { + nonceRecord.ClaimExpiresAt = time.Now().Add(time.Hour) + } + + require.NoError(t, env.data.SaveKey(env.ctx, nonceKey)) + require.NoError(t, env.data.SaveNonce(env.ctx, nonceRecord)) + return nonceRecord +} + func generateAvailableNonces(t *testing.T, env nonceTestEnv, useCase nonce.Purpose, count int) []*nonce.Record { var nonces []*nonce.Record for i := 0; i < count; i++ { diff --git a/pkg/database/postgres/errors.go b/pkg/database/postgres/errors.go index 38f8155a..a61e286c 100644 --- a/pkg/database/postgres/errors.go +++ b/pkg/database/postgres/errors.go @@ -2,14 +2,14 @@ package pg import ( "database/sql" + "errors" "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" - "github.com/pkg/errors" ) func CheckNoRows(inErr, outErr error) error { - if inErr == sql.ErrNoRows { + if errors.Is(inErr, sql.ErrNoRows) { return outErr } return inErr