diff --git a/pkg/code/data/balance/checkpoint.go b/pkg/code/data/balance/checkpoint.go new file mode 100644 index 00000000..ca87a7e2 --- /dev/null +++ b/pkg/code/data/balance/checkpoint.go @@ -0,0 +1,47 @@ +package balance + +import ( + "errors" + "time" +) + +// Note: Only supports external balances +type Record struct { + Id uint64 + + TokenAccount string + Quarks uint64 + SlotCheckpoint uint64 + + LastUpdatedAt time.Time +} + +func (r *Record) Validate() error { + if len(r.TokenAccount) == 0 { + return errors.New("token account is required") + } + + return nil +} + +func (r *Record) Clone() Record { + return Record{ + Id: r.Id, + + TokenAccount: r.TokenAccount, + Quarks: r.Quarks, + SlotCheckpoint: r.SlotCheckpoint, + + LastUpdatedAt: r.LastUpdatedAt, + } +} + +func (r *Record) CopyTo(dst *Record) { + dst.Id = r.Id + + dst.TokenAccount = r.TokenAccount + dst.Quarks = r.Quarks + dst.SlotCheckpoint = r.SlotCheckpoint + + dst.LastUpdatedAt = r.LastUpdatedAt +} diff --git a/pkg/code/data/balance/memory/store.go b/pkg/code/data/balance/memory/store.go new file mode 100644 index 00000000..706f28ad --- /dev/null +++ b/pkg/code/data/balance/memory/store.go @@ -0,0 +1,92 @@ +package memory + +import ( + "context" + "sync" + "time" + + "github.com/code-payments/code-server/pkg/code/data/balance" +) + +type store struct { + mu sync.Mutex + records []*balance.Record + last uint64 +} + +// New returns a new in memory balance.Store +func New() balance.Store { + return &store{} +} + +// SaveCheckpoint implements balance.Store.SaveCheckpoint +func (s *store) SaveCheckpoint(_ context.Context, data *balance.Record) error { + if err := data.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.last++ + if item := s.find(data); item != nil { + if data.SlotCheckpoint <= item.SlotCheckpoint { + return balance.ErrStaleCheckpoint + } + + item.SlotCheckpoint = data.SlotCheckpoint + item.Quarks = data.Quarks + item.LastUpdatedAt = time.Now() + item.CopyTo(data) + } else { + if data.Id == 0 { + data.Id = s.last + } + data.LastUpdatedAt = time.Now() + c := data.Clone() + s.records = append(s.records, &c) + } + + return nil +} + +// GetCheckpoint implements balance.Store.GetCheckpoint +func (s *store) GetCheckpoint(_ context.Context, account string) (*balance.Record, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if item := s.findByTokenAccount(account); item != nil { + cloned := item.Clone() + return &cloned, nil + } + return nil, balance.ErrCheckpointNotFound +} + +func (s *store) find(data *balance.Record) *balance.Record { + for _, item := range s.records { + if item.Id == data.Id { + return item + } + if data.TokenAccount == item.TokenAccount { + return item + } + } + return nil +} + +func (s *store) findByTokenAccount(account string) *balance.Record { + for _, item := range s.records { + if account == item.TokenAccount { + return item + } + } + return nil +} + +func (s *store) reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.records = nil + s.last = 0 +} diff --git a/pkg/code/data/balance/memory/store_test.go b/pkg/code/data/balance/memory/store_test.go new file mode 100644 index 00000000..078464a2 --- /dev/null +++ b/pkg/code/data/balance/memory/store_test.go @@ -0,0 +1,15 @@ +package memory + +import ( + "testing" + + "github.com/code-payments/code-server/pkg/code/data/balance/tests" +) + +func TestBalanceMemoryStore(t *testing.T) { + testStore := New() + teardown := func() { + testStore.(*store).reset() + } + tests.RunTests(t, testStore, teardown) +} diff --git a/pkg/code/data/balance/postgres/model.go b/pkg/code/data/balance/postgres/model.go new file mode 100644 index 00000000..9458d77a --- /dev/null +++ b/pkg/code/data/balance/postgres/model.go @@ -0,0 +1,94 @@ +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/jmoiron/sqlx" + + "github.com/code-payments/code-server/pkg/code/data/balance" + pgutil "github.com/code-payments/code-server/pkg/database/postgres" +) + +const ( + tableName = "codewallet__core_balancecheckpoint" +) + +type model struct { + Id sql.NullInt64 `db:"id"` + + TokenAccount string `db:"token_account"` + Quarks uint64 `db:"quarks"` + SlotCheckpoint uint64 `db:"slot_checkpoint"` + + LastUpdatedAt time.Time `db:"last_updated_at"` +} + +func toModel(obj *balance.Record) (*model, error) { + if err := obj.Validate(); err != nil { + return nil, err + } + + return &model{ + TokenAccount: obj.TokenAccount, + Quarks: obj.Quarks, + SlotCheckpoint: obj.SlotCheckpoint, + LastUpdatedAt: obj.LastUpdatedAt, + }, nil +} + +func fromModel(obj *model) *balance.Record { + return &balance.Record{ + Id: uint64(obj.Id.Int64), + TokenAccount: obj.TokenAccount, + Quarks: obj.Quarks, + SlotCheckpoint: obj.SlotCheckpoint, + LastUpdatedAt: obj.LastUpdatedAt, + } +} + +func (m *model) dbSave(ctx context.Context, db *sqlx.DB) error { + return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { + query := `INSERT INTO ` + tableName + ` + (token_account, quarks, slot_checkpoint, last_updated_at) + VALUES ($1, $2, $3, $4) + + ON CONFLICT (token_account) + DO UPDATE + SET quarks = $2, slot_checkpoint = $3, last_updated_at = $4 + WHERE ` + tableName + `.token_account = $1 AND ` + tableName + `.slot_checkpoint < $3 + + RETURNING + id, token_account, quarks, slot_checkpoint, last_updated_at` + + m.LastUpdatedAt = time.Now() + + err := tx.QueryRowxContext( + ctx, + query, + m.TokenAccount, + m.Quarks, + m.SlotCheckpoint, + m.LastUpdatedAt.UTC(), + ).StructScan(m) + + return pgutil.CheckNoRows(err, balance.ErrStaleCheckpoint) + }) +} + +func dbGetCheckpoint(ctx context.Context, db *sqlx.DB, account string) (*model, error) { + res := &model{} + + query := `SELECT + id, token_account, quarks, slot_checkpoint, last_updated_at + FROM ` + tableName + ` + WHERE token_account = $1 + LIMIT 1` + + err := db.GetContext(ctx, res, query, account) + if err != nil { + return nil, pgutil.CheckNoRows(err, balance.ErrCheckpointNotFound) + } + return res, nil +} diff --git a/pkg/code/data/balance/postgres/store.go b/pkg/code/data/balance/postgres/store.go new file mode 100644 index 00000000..fa6d095c --- /dev/null +++ b/pkg/code/data/balance/postgres/store.go @@ -0,0 +1,47 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/jmoiron/sqlx" + + "github.com/code-payments/code-server/pkg/code/data/balance" +) + +type store struct { + db *sqlx.DB +} + +// New returns a new postgres balance.Store +func New(db *sql.DB) balance.Store { + return &store{ + db: sqlx.NewDb(db, "pgx"), + } +} + +// SaveCheckpoint implements balance.Store.SaveCheckpoint +func (s *store) SaveCheckpoint(ctx context.Context, record *balance.Record) error { + model, err := toModel(record) + if err != nil { + return err + } + + if err := model.dbSave(ctx, s.db); err != nil { + return err + } + + res := fromModel(model) + res.CopyTo(record) + + return nil +} + +// GetCheckpoint implements balance.Store.GetCheckpoint +func (s *store) GetCheckpoint(ctx context.Context, account string) (*balance.Record, error) { + model, err := dbGetCheckpoint(ctx, s.db, account) + if err != nil { + return nil, err + } + return fromModel(model), nil +} diff --git a/pkg/code/data/balance/postgres/store_test.go b/pkg/code/data/balance/postgres/store_test.go new file mode 100644 index 00000000..4800b140 --- /dev/null +++ b/pkg/code/data/balance/postgres/store_test.go @@ -0,0 +1,109 @@ +package postgres + +import ( + "database/sql" + "os" + "testing" + + "github.com/ory/dockertest/v3" + "github.com/sirupsen/logrus" + + "github.com/code-payments/code-server/pkg/code/data/balance" + "github.com/code-payments/code-server/pkg/code/data/balance/tests" + + postgrestest "github.com/code-payments/code-server/pkg/database/postgres/test" + + _ "github.com/jackc/pgx/v4/stdlib" +) + +var ( + testStore balance.Store + teardown func() +) + +const ( + // Used for testing ONLY, the table and migrations are external to this repository + tableCreate = ` + CREATE TABLE codewallet__core_balancecheckpoint ( + id SERIAL NOT NULL PRIMARY KEY, + + token_account TEXT NOT NULL, + quarks INTEGER NOT NULL, + slot_checkpoint INTEGER NOT NULL, + + last_updated_at TIMESTAMP WITH TIME ZONE, + + CONSTRAINT codewallet__core_balancecheckpoint__uniq__token_account UNIQUE (token_account) + ); + ` + + // Used for testing ONLY, the table and migrations are external to this repository + tableDestroy = ` + DROP TABLE codewallet__core_balancecheckpoint; + ` +) + +func TestMain(m *testing.M) { + log := logrus.StandardLogger() + + testPool, err := dockertest.NewPool("") + if err != nil { + log.WithError(err).Error("Error creating docker pool") + os.Exit(1) + } + + var cleanUpFunc func() + db, cleanUpFunc, err := postgrestest.StartPostgresDB(testPool) + if err != nil { + log.WithError(err).Error("Error starting postgres image") + os.Exit(1) + } + defer db.Close() + + if err := createTestTables(db); err != nil { + logrus.StandardLogger().WithError(err).Error("Error creating test tables") + cleanUpFunc() + os.Exit(1) + } + + testStore = New(db) + teardown = func() { + if pc := recover(); pc != nil { + cleanUpFunc() + panic(pc) + } + + if err := resetTestTables(db); err != nil { + logrus.StandardLogger().WithError(err).Error("Error resetting test tables") + cleanUpFunc() + os.Exit(1) + } + } + + code := m.Run() + cleanUpFunc() + os.Exit(code) +} + +func TestBalancePostgresStore(t *testing.T) { + tests.RunTests(t, testStore, teardown) +} + +func createTestTables(db *sql.DB) error { + _, err := db.Exec(tableCreate) + if err != nil { + logrus.StandardLogger().WithError(err).Error("could not create test tables") + return err + } + return nil +} + +func resetTestTables(db *sql.DB) error { + _, err := db.Exec(tableDestroy) + if err != nil { + logrus.StandardLogger().WithError(err).Error("could not drop test tables") + return err + } + + return createTestTables(db) +} diff --git a/pkg/code/data/balance/store.go b/pkg/code/data/balance/store.go new file mode 100644 index 00000000..ce5a65bc --- /dev/null +++ b/pkg/code/data/balance/store.go @@ -0,0 +1,22 @@ +package balance + +import ( + "context" + "errors" +) + +var ( + ErrCheckpointNotFound = errors.New("checkpoint not found") + + ErrStaleCheckpoint = errors.New("checkpoint is stale") +) + +type Store interface { + // SaveCheckpoint saves a balance at a checkpoint. ErrStaleCheckpoint is returned + // if the checkpoint is outdated + SaveCheckpoint(ctx context.Context, record *Record) error + + // GetCheckpoint gets a balance checkpoint for a given account. ErrCheckpointNotFound + // is returend if no DB record exists. + GetCheckpoint(ctx context.Context, account string) (*Record, error) +} diff --git a/pkg/code/data/balance/tests/tests.go b/pkg/code/data/balance/tests/tests.go new file mode 100644 index 00000000..a354f134 --- /dev/null +++ b/pkg/code/data/balance/tests/tests.go @@ -0,0 +1,81 @@ +package tests + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/code-payments/code-server/pkg/code/data/balance" +) + +func RunTests(t *testing.T, s balance.Store, teardown func()) { + for _, tf := range []func(t *testing.T, s balance.Store){ + testHappyPath, + } { + tf(t, s) + teardown() + } +} + +func testHappyPath(t *testing.T, s balance.Store) { + t.Run("testHappyPath", func(t *testing.T) { + ctx := context.Background() + + _, err := s.GetCheckpoint(ctx, "token_account") + assert.Equal(t, balance.ErrCheckpointNotFound, err) + + start := time.Now() + + expected := &balance.Record{ + TokenAccount: "token_account", + Quarks: 0, + SlotCheckpoint: 0, + } + cloned := expected.Clone() + + require.NoError(t, s.SaveCheckpoint(ctx, expected)) + assert.EqualValues(t, 1, expected.Id) + assert.True(t, expected.LastUpdatedAt.After(start)) + + actual, err := s.GetCheckpoint(ctx, "token_account") + require.NoError(t, err) + assertEquivalentRecords(t, actual, &cloned) + + start = time.Now() + + expected.Quarks = 12345 + expected.SlotCheckpoint = 10 + cloned = expected.Clone() + + require.NoError(t, s.SaveCheckpoint(ctx, expected)) + assert.EqualValues(t, 1, expected.Id) + assert.True(t, expected.LastUpdatedAt.After(start)) + + actual, err = s.GetCheckpoint(ctx, "token_account") + require.NoError(t, err) + assertEquivalentRecords(t, actual, &cloned) + + expected.Quarks = 67890 + assert.Equal(t, balance.ErrStaleCheckpoint, s.SaveCheckpoint(ctx, expected)) + + actual, err = s.GetCheckpoint(ctx, "token_account") + require.NoError(t, err) + assertEquivalentRecords(t, actual, &cloned) + + expected.SlotCheckpoint -= 1 + assert.Equal(t, balance.ErrStaleCheckpoint, s.SaveCheckpoint(ctx, expected)) + + actual, err = s.GetCheckpoint(ctx, "token_account") + require.NoError(t, err) + assertEquivalentRecords(t, actual, &cloned) + }) +} + +func assertEquivalentRecords(t *testing.T, obj1, obj2 *balance.Record) { + assert.Equal(t, obj1.TokenAccount, obj2.TokenAccount) + assert.Equal(t, obj1.Quarks, obj2.Quarks) + assert.Equal(t, obj1.SlotCheckpoint, obj2.SlotCheckpoint) +} diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index ce188690..6858cca4 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -21,6 +21,7 @@ import ( "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/badgecount" + "github.com/code-payments/code-server/pkg/code/data/balance" "github.com/code-payments/code-server/pkg/code/data/chat" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/contact" @@ -50,6 +51,7 @@ import ( account_memory_client "github.com/code-payments/code-server/pkg/code/data/account/memory" action_memory_client "github.com/code-payments/code-server/pkg/code/data/action/memory" badgecount_memory_client "github.com/code-payments/code-server/pkg/code/data/badgecount/memory" + balance_memory_client "github.com/code-payments/code-server/pkg/code/data/balance/memory" chat_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/memory" commitment_memory_client "github.com/code-payments/code-server/pkg/code/data/commitment/memory" contact_memory_client "github.com/code-payments/code-server/pkg/code/data/contact/memory" @@ -80,6 +82,7 @@ import ( account_postgres_client "github.com/code-payments/code-server/pkg/code/data/account/postgres" action_postgres_client "github.com/code-payments/code-server/pkg/code/data/action/postgres" badgecount_postgres_client "github.com/code-payments/code-server/pkg/code/data/badgecount/postgres" + balance_postgres_client "github.com/code-payments/code-server/pkg/code/data/balance/postgres" chat_postgres_client "github.com/code-payments/code-server/pkg/code/data/chat/postgres" commitment_postgres_client "github.com/code-payments/code-server/pkg/code/data/commitment/postgres" contact_postgres_client "github.com/code-payments/code-server/pkg/code/data/contact/postgres" @@ -385,6 +388,11 @@ type DatabaseData interface { GetLoginsByAppInstall(ctx context.Context, appInstallId string) (*login.MultiRecord, error) GetLatestLoginByOwner(ctx context.Context, owner string) (*login.Record, error) + // Balance + // -------------------------------------------------------------------------------- + SaveBalanceCheckpoint(ctx context.Context, record *balance.Record) error + GetBalanceCheckpoint(ctx context.Context, account string) (*balance.Record, error) + // ExecuteInTx executes fn with a single DB transaction that is scoped to the call. // This enables more complex transactions that can span many calls across the provider. // @@ -423,6 +431,7 @@ type DatabaseProvider struct { chat chat.Store badgecount badgecount.Store login login.Store + balance balance.Store exchangeCache cache.Cache timelockCache cache.Cache @@ -480,6 +489,7 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { chat: chat_postgres_client.New(db), badgecount: badgecount_postgres_client.New(db), login: login_postgres_client.New(db), + balance: balance_postgres_client.New(db), exchangeCache: cache.NewCache(maxExchangeRateCacheBudget), timelockCache: cache.NewCache(maxTimelockCacheBudget), @@ -518,6 +528,7 @@ func NewTestDatabaseProvider() DatabaseData { chat: chat_memory_client.New(), badgecount: badgecount_memory_client.New(), login: login_memory_client.New(), + balance: balance_memory_client.New(), exchangeCache: cache.NewCache(maxExchangeRateCacheBudget), timelockCache: nil, // Shouldn't be used for tests @@ -1395,3 +1406,12 @@ func (dp *DatabaseProvider) GetLoginsByAppInstall(ctx context.Context, appInstal func (dp *DatabaseProvider) GetLatestLoginByOwner(ctx context.Context, owner string) (*login.Record, error) { return dp.login.GetLatestByOwner(ctx, owner) } + +// Balance +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) SaveBalanceCheckpoint(ctx context.Context, record *balance.Record) error { + return dp.balance.SaveCheckpoint(ctx, record) +} +func (dp *DatabaseProvider) GetBalanceCheckpoint(ctx context.Context, account string) (*balance.Record, error) { + return dp.balance.GetCheckpoint(ctx, account) +}