Skip to content

Commit

Permalink
client: introduce RMQDB
Browse files Browse the repository at this point in the history
This DB will be necessary for the RMQ to track unacked RMs that have
been paid for across client restarts.
  • Loading branch information
miki committed Jan 19, 2023
1 parent 35dbcbf commit 6caca46
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 10 deletions.
4 changes: 3 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ func New(cfg Config) (*Client, error) {
}
ck := lowlevel.NewConnKeeper(ckCfg)

q := lowlevel.NewRMQ(cfg.logger("RMQU"), cfg.PayClient, id)
rmqdb := &rmqDBAdapter{}
q := lowlevel.NewRMQ(cfg.logger("RMQU"), cfg.PayClient, id, rmqdb)
ctx, cancel := context.WithCancel(context.Background())

dbCtx, dbCtxCancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -335,6 +336,7 @@ func New(cfg Config) (*Client, error) {
c.gcmq = gcmcacher.New(gcmqDelay, gcmqMaxDelay, slog.Disabled, c.handleDelayedGCMessages)

rmgrdb.c = c
rmqdb.c = c
kxl.kxCompleted = c.kxCompleted

return c, nil
Expand Down
21 changes: 20 additions & 1 deletion client/internal/lowlevel/rmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ func (r rmmsg) sendReply(err error) {
r.replyChan <- err
}

// RMQDB is the interface required of a DB to persist RMQ-related data.
type RMQDB interface {
// StoreRVPaymentAttempt should store that an attempt to pay to push
// to the given RV is being made with the given invoice.
StoreRVPaymentAttempt(RVID, string, time.Time) error

// RVHasPaymentAttempt should return the invoice and time that an
// attempt to pay to push to the RV was made (i.e. it returns the
// invoice and time saved on a prior call to StoreRVPaymentAttempt).
RVHasPaymentAttempt(RVID) (string, time.Time, error)

// DeleteRVPaymentAttempt removes the prior attempt to pay for the given
// RV.
DeleteRVPaymentAttempt(RVID) error
}

// RMQ is a queue for sending RoutedMessages (RMs) to the server. The rmq
// supports a flickering server connection: any unsent RMs are queued (FIFO
// style) until a new server session is bound via `bindToSession`.
Expand All @@ -48,6 +64,7 @@ type RMQ struct {
enqueueDone chan struct{}
lenChan chan chan int
timingStat timestats.Tracker
db RMQDB

nextSendChan chan *rmmsg
sendDoneChan chan struct{}
Expand All @@ -58,7 +75,8 @@ type RMQ struct {
nextInvoice string
}

func NewRMQ(log slog.Logger, payClient clientintf.PaymentClient, localID *zkidentity.FullIdentity) *RMQ {
func NewRMQ(log slog.Logger, payClient clientintf.PaymentClient,
localID *zkidentity.FullIdentity, db RMQDB) *RMQ {
if log == nil {
log = slog.Disabled
}
Expand All @@ -67,6 +85,7 @@ func NewRMQ(log slog.Logger, payClient clientintf.PaymentClient, localID *zkiden
localID: localID,
log: log,
rmChan: make(chan rmmsg),
db: db,
enqueueDone: make(chan struct{}),
lenChan: make(chan chan int),
nextSendChan: make(chan *rmmsg),
Expand Down
16 changes: 8 additions & 8 deletions client/internal/lowlevel/rmq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestRMQSuccessRM(t *testing.T) {
t.Parallel()

mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
runErr := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() { runErr <- q.Run(ctx) }()
Expand Down Expand Up @@ -71,7 +71,7 @@ func TestRMQAckErrors(t *testing.T) {
t.Parallel()

mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
runErr := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() { runErr <- q.Run(ctx) }()
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestRMQMultipleRM(t *testing.T) {

nb := 10
mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
runErr := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -229,7 +229,7 @@ func TestCanceledRMQErrorsRM(t *testing.T) {
t.Parallel()

mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
runErr := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() { runErr <- q.Run(ctx) }()
Expand Down Expand Up @@ -274,7 +274,7 @@ func TestCanceledRMQAfterQueuedErrorsRM(t *testing.T) {
t.Parallel()

mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
runErr := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() { runErr <- q.Run(ctx) }()
Expand Down Expand Up @@ -328,7 +328,7 @@ func TestEnqueueRMBeforeSession(t *testing.T) {
t.Parallel()

mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
runErr := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -395,7 +395,7 @@ func TestRMQEncryptErrorFailsRM(t *testing.T) {
t.Parallel()

mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
runErr := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -458,7 +458,7 @@ func TestRMQMaxMsgSizeErrors(t *testing.T) {
t.Parallel()
const maxMsgSize = rpc.MaxMsgSize
mockID := &zkidentity.FullIdentity{}
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID)
q := NewRMQ(nil, clientintf.FreePaymentClient{}, mockID, newMockRMQDB())
rm := mockRM(strings.Repeat(" ", maxMsgSize+1))
err := q.SendRM(rm)
if !errors.Is(err, errORMTooLarge) {
Expand Down
40 changes: 40 additions & 0 deletions client/internal/lowlevel/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"math/rand"
"net"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -446,3 +447,42 @@ func (db *mockRvMgrDB) MarkRVUnpaid(rv RVID) error {
delete(db.paid, rv)
return nil
}

type mockRMQDBEntry struct {
invoice string
date time.Time
}

type mockRMQDB struct {
mtx sync.Mutex
store map[RVID]mockRMQDBEntry
}

func newMockRMQDB() *mockRMQDB {
return &mockRMQDB{
store: make(map[RVID]mockRMQDBEntry),
}
}

func (m *mockRMQDB) RVHasPaymentAttempt(rv RVID) (string, time.Time, error) {
m.mtx.Lock()
defer m.mtx.Unlock()
if e, ok := m.store[rv]; ok {
return e.invoice, e.date, nil
}
return "", time.Time{}, nil
}

func (m *mockRMQDB) StoreRVPaymentAttempt(rv RVID, invoice string, date time.Time) error {
m.mtx.Lock()
defer m.mtx.Unlock()
m.store[rv] = mockRMQDBEntry{invoice: invoice, date: date}
return nil
}

func (m *mockRMQDB) DeleteRVPaymentAttempt(rv RVID) error {
m.mtx.Lock()
defer m.mtx.Unlock()
delete(m.store, rv)
return nil
}
30 changes: 30 additions & 0 deletions client/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/binary"
"sort"
"time"

"github.com/companyzero/bisonrelay/client/clientdb"
"github.com/companyzero/bisonrelay/client/internal/lowlevel"
Expand Down Expand Up @@ -82,6 +83,35 @@ func (rvdb *rvManagerDBAdapter) MarkRVUnpaid(rv lowlevel.RVID) error {
return err
}

// rmqDBAdapter is an adapter structure that satisfies the RMQDB interface using
// a client's db as backing storage.
type rmqDBAdapter struct {
c *Client
}

func (rmqdb *rmqDBAdapter) RVHasPaymentAttempt(rv lowlevel.RVID) (string, time.Time, error) {
var invoice string
var ts time.Time
err := rmqdb.c.dbView(func(tx clientdb.ReadTx) error {
var err error
invoice, ts, err = rmqdb.c.db.HasPushPaymentAttempt(tx, rv)
return err
})
return invoice, ts, err
}

func (rmqdb *rmqDBAdapter) StoreRVPaymentAttempt(rv lowlevel.RVID, invoice string, ts time.Time) error {
return rmqdb.c.dbUpdate(func(tx clientdb.ReadWriteTx) error {
return rmqdb.c.db.StorePushPaymentAttempt(tx, rv, invoice, ts)
})
}

func (rmqdb *rmqDBAdapter) DeleteRVPaymentAttempt(rv lowlevel.RVID) error {
return rmqdb.c.dbUpdate(func(tx clientdb.ReadWriteTx) error {
return rmqdb.c.db.DeletePushPaymentAttempt(tx, rv)
})
}

// SortedUserPayStatsIDs returns a sorted list of IDs from the passed stats
// map, ordered by largest total payments.
func SortedUserPayStatsIDs(stats map[UserID]clientdb.UserPayStats) []UserID {
Expand Down

0 comments on commit 6caca46

Please sign in to comment.