Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions firewall/privacy_mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
uri string, req proto.Message, sessionID session.ID) (proto.Message,
error) {

session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
session, err := p.sessionDB.GetSession(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string,
resp proto.Message, sessionID session.ID) (proto.Message, error) {

session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
session, err := p.sessionDB.GetSession(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion firewall/rule_enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string,
return nil, err
}

session, err := r.sessionDB.GetSessionByID(ctx, sessionID)
session, err := r.sessionDB.GetSession(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions firewalldb/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ import (
type SessionDB interface {
session.IDToGroupIndex

// GetSessionByID returns the session for a specific id.
GetSessionByID(context.Context, session.ID) (*session.Session, error)
// GetSession returns the session for a specific id.
GetSession(context.Context, session.ID) (*session.Session, error)
}
10 changes: 5 additions & 5 deletions firewalldb/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) (
return ids, nil
}

// GetSessionByID returns the session for a specific id.
func (m *mockSessionDB) GetSessionByID(_ context.Context,
sessionID session.ID) (*session.Session, error) {
// GetSession returns the session for a specific id.
func (m *mockSessionDB) GetSession(_ context.Context,
id session.ID) (*session.Session, error) {

s, ok := m.sessionToGroupID[sessionID]
s, ok := m.sessionToGroupID[id]
if !ok {
return nil, fmt.Errorf("no session found for session ID")
}

f, ok := m.privacyFlags[sessionID]
f, ok := m.privacyFlags[id]
if !ok {
return nil, fmt.Errorf("no privacy flags found for session ID")
}
Expand Down
14 changes: 8 additions & 6 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,10 @@ type Store interface {
expiry time.Time, serverAddr string, opts ...Option) (*Session,
error)

// GetSession fetches the session with the given key.
GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error)
// GetSessionByLocalPub fetches the session with the given local pub
// key.
GetSessionByLocalPub(ctx context.Context,
key *btcec.PublicKey) (*Session, error)

// ListAllSessions returns all sessions currently known to the store.
ListAllSessions(ctx context.Context) ([]*Session, error)
Expand All @@ -309,12 +311,12 @@ type Store interface {
error)

// UpdateSessionRemotePubKey can be used to add the given remote pub key
// to the session with the given local pub key.
UpdateSessionRemotePubKey(ctx context.Context, localPubKey,
// to the session with the given ID.
UpdateSessionRemotePubKey(ctx context.Context, id ID,
remotePubKey *btcec.PublicKey) error

// GetSessionByID fetches the session with the given ID.
GetSessionByID(ctx context.Context, id ID) (*Session, error)
// GetSession fetches the session with the given ID.
GetSession(ctx context.Context, id ID) (*Session, error)

// DeleteReservedSessions deletes all sessions that are in the
// StateReserved state.
Expand Down
28 changes: 8 additions & 20 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,29 +298,19 @@ func (db *BoltStore) NewSession(ctx context.Context, label string, typ Type,
}

// UpdateSessionRemotePubKey can be used to add the given remote pub key
// to the session with the given local pub key.
// to the session with the given ID.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey,
func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, id ID,
remotePubKey *btcec.PublicKey) error {

key := localPubKey.SerializeCompressed()

return db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
return err
}

serialisedSession := sessionBucket.Get(key)

if len(serialisedSession) == 0 {
return ErrSessionNotFound
}

session, err := DeserializeSession(
bytes.NewReader(serialisedSession),
)
session, err := getSessionByID(sessionBucket, id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe rename the internal getSessionByID to getSession as well :)?

Copy link
Member Author

@ellemouton ellemouton Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is verryyyy nitty yall 😅 i think exposed/exported methods are the main focus. especially since we will delete all this code kvdb code by the end of the series

if err != nil {
return err
}
Expand All @@ -331,11 +321,11 @@ func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey,
})
}

// GetSession fetches the session with the given key.
// GetSessionByLocalPub fetches the session with the given local pub key.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) (
*Session, error) {
func (db *BoltStore) GetSessionByLocalPub(_ context.Context,
key *btcec.PublicKey) (*Session, error) {

var session *Session
err := db.View(func(tx *bbolt.Tx) error {
Expand Down Expand Up @@ -575,12 +565,10 @@ func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error {
})
}

// GetSessionByID fetches the session with the given ID.
// GetSession fetches the session with the given ID.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) GetSessionByID(_ context.Context, id ID) (*Session,
error) {

func (db *BoltStore) GetSession(_ context.Context, id ID) (*Session, error) {
var session *Session
err := db.View(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
Expand Down
6 changes: 3 additions & 3 deletions session/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func newMailboxSession() *mailboxSession {

func (m *mailboxSession) start(session *Session,
serverCreator GRPCServerCreator, authData []byte,
onUpdate func(ctx context.Context, local,
onUpdate func(ctx context.Context, id ID,
remote *btcec.PublicKey) error,
onNewStatus func(s mailbox.ServerStatus)) error {

Expand All @@ -53,7 +53,7 @@ func (m *mailboxSession) start(session *Session,
keys := mailbox.NewConnData(
ecdh, session.RemotePublicKey, session.PairingSecret[:],
authData, func(key *btcec.PublicKey) error {
return onUpdate(ctx, session.LocalPublicKey, key)
return onUpdate(ctx, session.ID, key)
}, nil,
)

Expand Down Expand Up @@ -112,7 +112,7 @@ func NewServer(serverCreator GRPCServerCreator) *Server {
}

func (s *Server) StartSession(session *Session, authData []byte,
onUpdate func(ctx context.Context, local,
onUpdate func(ctx context.Context, id ID,
remote *btcec.PublicKey) error,
onNewStatus func(s mailbox.ServerStatus)) (chan struct{}, error) {

Expand Down
28 changes: 13 additions & 15 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ func TestBasicSessionStore(t *testing.T) {
db := NewTestDB(t, clock)

// Try fetch a session that doesn't exist yet.
_, err := db.GetSessionByID(ctx, ID{1, 3, 4, 4})
_, err := db.GetSession(ctx, ID{1, 3, 4, 4})
require.ErrorIs(t, err, ErrSessionNotFound)

// Reserve a session. This should succeed.
s1, err := reserveSession(db, "session 1")
require.NoError(t, err)

// Show that the session starts in the reserved state.
s1, err = db.GetSessionByID(ctx, s1.ID)
s1, err = db.GetSession(ctx, s1.ID)
require.NoError(t, err)
require.Equal(t, StateReserved, s1.State)

Expand All @@ -46,7 +46,7 @@ func TestBasicSessionStore(t *testing.T) {
require.NoError(t, err)

// Show that the session is now in the created state.
s1, err = db.GetSessionByID(ctx, s1.ID)
s1, err = db.GetSession(ctx, s1.ID)
require.NoError(t, err)
require.Equal(t, StateCreated, s1.State)

Expand Down Expand Up @@ -82,17 +82,17 @@ func TestBasicSessionStore(t *testing.T) {
// Ensure that we can retrieve each session by both its local pub key
// and by its ID.
for _, s := range []*Session{s1, s2, s3} {
session, err := db.GetSession(ctx, s.LocalPublicKey)
session, err := db.GetSessionByLocalPub(ctx, s.LocalPublicKey)
require.NoError(t, err)
assertEqualSessions(t, s, session)

session, err = db.GetSessionByID(ctx, s.ID)
session, err = db.GetSession(ctx, s.ID)
require.NoError(t, err)
assertEqualSessions(t, s, session)
}

// Fetch session 1 and assert that it currently has no remote pub key.
session1, err := db.GetSession(ctx, s1.LocalPublicKey)
session1, err := db.GetSessionByLocalPub(ctx, s1.LocalPublicKey)
require.NoError(t, err)
require.Nil(t, session1.RemotePublicKey)

Expand All @@ -101,13 +101,11 @@ func TestBasicSessionStore(t *testing.T) {
require.NoError(t, err)
remotePub := remotePriv.PubKey()

err = db.UpdateSessionRemotePubKey(
ctx, session1.LocalPublicKey, remotePub,
)
err = db.UpdateSessionRemotePubKey(ctx, session1.ID, remotePub)
require.NoError(t, err)

// Assert that the session now does have the remote pub key.
session1, err = db.GetSession(ctx, s1.LocalPublicKey)
session1, err = db.GetSessionByLocalPub(ctx, s1.LocalPublicKey)
require.NoError(t, err)
require.True(t, remotePub.IsEqual(session1.RemotePublicKey))

Expand All @@ -116,7 +114,7 @@ func TestBasicSessionStore(t *testing.T) {

// Now revoke the session and assert that the state is revoked.
require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked))
s1, err = db.GetSession(ctx, s1.LocalPublicKey)
s1, err = db.GetSessionByLocalPub(ctx, s1.LocalPublicKey)
require.NoError(t, err)
require.Equal(t, s1.State, StateRevoked)

Expand Down Expand Up @@ -299,7 +297,7 @@ func TestStateShift(t *testing.T) {

// Check that the session is in the StateCreated state. Also check that
// the "RevokedAt" time has not yet been set.
s1, err := db.GetSession(ctx, s1.LocalPublicKey)
s1, err := db.GetSessionByLocalPub(ctx, s1.LocalPublicKey)
require.NoError(t, err)
require.Equal(t, StateCreated, s1.State)
require.Equal(t, time.Time{}, s1.RevokedAt)
Expand All @@ -310,7 +308,7 @@ func TestStateShift(t *testing.T) {

// This should have worked. Since it is now in a terminal state, the
// "RevokedAt" time should be set.
s1, err = db.GetSession(ctx, s1.LocalPublicKey)
s1, err = db.GetSessionByLocalPub(ctx, s1.LocalPublicKey)
require.NoError(t, err)
require.Equal(t, StateRevoked, s1.State)
require.True(t, clock.Now().Equal(s1.RevokedAt))
Expand Down Expand Up @@ -361,7 +359,7 @@ func TestLinkedAccount(t *testing.T) {
})

// Make sure that a fetched session includes the account ID.
s1, err = db.GetSessionByID(ctx, s1.ID)
s1, err = db.GetSession(ctx, s1.ID)
require.NoError(t, err)
require.True(t, s1.AccountID.IsSome())
s1.AccountID.WhenSome(func(id accounts.AccountID) {
Expand Down Expand Up @@ -453,7 +451,7 @@ func createSession(t *testing.T, db Store, label string,
err = db.ShiftState(context.Background(), s.ID, StateCreated)
require.NoError(t, err)

s, err = db.GetSessionByID(context.Background(), s.ID)
s, err = db.GetSession(context.Background(), s.ID)
require.NoError(t, err)

return s
Expand Down
14 changes: 6 additions & 8 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context,

// Re-fetch the session to get the latest state of it before marshaling
// it.
sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID)
sess, err = s.cfg.db.GetSession(ctx, sess.ID)
if err != nil {
return nil, fmt.Errorf("error fetching session: %v", err)
}
Expand Down Expand Up @@ -577,7 +577,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
return nil, fmt.Errorf("error parsing public key: %v", err)
}

sess, err := s.cfg.db.GetSession(ctx, pubKey)
sess, err := s.cfg.db.GetSessionByLocalPub(ctx, pubKey)
if err != nil {
return nil, fmt.Errorf("error fetching session: %v", err)
}
Expand Down Expand Up @@ -882,7 +882,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
copy(groupID[:], req.LinkedGroupId)

// Check that the group actually does exist.
groupSess, err := s.cfg.db.GetSessionByID(ctx, groupID)
groupSess, err := s.cfg.db.GetSession(ctx, groupID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1245,9 +1245,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
"autopilot server: %v", err)
}

err = s.cfg.db.UpdateSessionRemotePubKey(
ctx, sess.LocalPublicKey, remoteKey,
)
err = s.cfg.db.UpdateSessionRemotePubKey(ctx, sess.ID, remoteKey)
if err != nil {
return nil, fmt.Errorf("error setting remote pubkey: %v", err)
}
Expand All @@ -1269,7 +1267,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,

// Re-fetch the session to get the latest state of it before marshaling
// it.
sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID)
sess, err = s.cfg.db.GetSession(ctx, sess.ID)
if err != nil {
return nil, fmt.Errorf("error fetching session: %v", err)
}
Expand Down Expand Up @@ -1319,7 +1317,7 @@ func (s *sessionRpcServer) RevokeAutopilotSession(ctx context.Context,
return nil, fmt.Errorf("error parsing public key: %v", err)
}

sess, err := s.cfg.db.GetSession(ctx, pubKey)
sess, err := s.cfg.db.GetSessionByLocalPub(ctx, pubKey)
if err != nil {
return nil, err
}
Expand Down
Loading