diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index af4f3b0af..26e053500 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -190,7 +190,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context, uri string, req proto.Message, sessionID session.ID) (proto.Message, error) { - session, err := p.sessionDB.GetSessionByID(ctx, sessionID) + session, err := p.sessionDB.GetSession(ctx, sessionID) if err != nil { return nil, err } @@ -220,7 +220,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context, func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string, resp proto.Message, sessionID session.ID) (proto.Message, error) { - session, err := p.sessionDB.GetSessionByID(ctx, sessionID) + session, err := p.sessionDB.GetSession(ctx, sessionID) if err != nil { return nil, err } diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index 008af72c5..c99671cdf 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -386,7 +386,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string, return nil, err } - session, err := r.sessionDB.GetSessionByID(ctx, sessionID) + session, err := r.sessionDB.GetSession(ctx, sessionID) if err != nil { return nil, err } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 86e638b53..ff82eab68 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -11,6 +11,6 @@ import ( type SessionDB interface { session.IDToGroupIndex - // GetSessionByID returns the session for a specific id. - GetSessionByID(context.Context, session.ID) (*session.Session, error) + // GetSession returns the session for a specific id. + GetSession(context.Context, session.ID) (*session.Session, error) } diff --git a/firewalldb/mock.go b/firewalldb/mock.go index 0213de864..81905736d 100644 --- a/firewalldb/mock.go +++ b/firewalldb/mock.go @@ -58,16 +58,16 @@ func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) ( return ids, nil } -// GetSessionByID returns the session for a specific id. -func (m *mockSessionDB) GetSessionByID(_ context.Context, - sessionID session.ID) (*session.Session, error) { +// GetSession returns the session for a specific id. +func (m *mockSessionDB) GetSession(_ context.Context, + id session.ID) (*session.Session, error) { - s, ok := m.sessionToGroupID[sessionID] + s, ok := m.sessionToGroupID[id] if !ok { return nil, fmt.Errorf("no session found for session ID") } - f, ok := m.privacyFlags[sessionID] + f, ok := m.privacyFlags[id] if !ok { return nil, fmt.Errorf("no privacy flags found for session ID") } diff --git a/session/interface.go b/session/interface.go index d1e25b648..c34264c0d 100644 --- a/session/interface.go +++ b/session/interface.go @@ -294,8 +294,10 @@ type Store interface { expiry time.Time, serverAddr string, opts ...Option) (*Session, error) - // GetSession fetches the session with the given key. - GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error) + // GetSessionByLocalPub fetches the session with the given local pub + // key. + GetSessionByLocalPub(ctx context.Context, + key *btcec.PublicKey) (*Session, error) // ListAllSessions returns all sessions currently known to the store. ListAllSessions(ctx context.Context) ([]*Session, error) @@ -309,12 +311,12 @@ type Store interface { error) // UpdateSessionRemotePubKey can be used to add the given remote pub key - // to the session with the given local pub key. - UpdateSessionRemotePubKey(ctx context.Context, localPubKey, + // to the session with the given ID. + UpdateSessionRemotePubKey(ctx context.Context, id ID, remotePubKey *btcec.PublicKey) error - // GetSessionByID fetches the session with the given ID. - GetSessionByID(ctx context.Context, id ID) (*Session, error) + // GetSession fetches the session with the given ID. + GetSession(ctx context.Context, id ID) (*Session, error) // DeleteReservedSessions deletes all sessions that are in the // StateReserved state. diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 829caff3a..71099192a 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -298,29 +298,19 @@ func (db *BoltStore) NewSession(ctx context.Context, label string, typ Type, } // UpdateSessionRemotePubKey can be used to add the given remote pub key -// to the session with the given local pub key. +// to the session with the given ID. // // NOTE: this is part of the Store interface. -func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey, +func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, id ID, remotePubKey *btcec.PublicKey) error { - key := localPubKey.SerializeCompressed() - return db.Update(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) if err != nil { return err } - serialisedSession := sessionBucket.Get(key) - - if len(serialisedSession) == 0 { - return ErrSessionNotFound - } - - session, err := DeserializeSession( - bytes.NewReader(serialisedSession), - ) + session, err := getSessionByID(sessionBucket, id) if err != nil { return err } @@ -331,11 +321,11 @@ func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey, }) } -// GetSession fetches the session with the given key. +// GetSessionByLocalPub fetches the session with the given local pub key. // // NOTE: this is part of the Store interface. -func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) ( - *Session, error) { +func (db *BoltStore) GetSessionByLocalPub(_ context.Context, + key *btcec.PublicKey) (*Session, error) { var session *Session err := db.View(func(tx *bbolt.Tx) error { @@ -575,12 +565,10 @@ func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error { }) } -// GetSessionByID fetches the session with the given ID. +// GetSession fetches the session with the given ID. // // NOTE: this is part of the Store interface. -func (db *BoltStore) GetSessionByID(_ context.Context, id ID) (*Session, - error) { - +func (db *BoltStore) GetSession(_ context.Context, id ID) (*Session, error) { var session *Session err := db.View(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) diff --git a/session/server.go b/session/server.go index 22de3bf8d..75c0e3edf 100644 --- a/session/server.go +++ b/session/server.go @@ -36,7 +36,7 @@ func newMailboxSession() *mailboxSession { func (m *mailboxSession) start(session *Session, serverCreator GRPCServerCreator, authData []byte, - onUpdate func(ctx context.Context, local, + onUpdate func(ctx context.Context, id ID, remote *btcec.PublicKey) error, onNewStatus func(s mailbox.ServerStatus)) error { @@ -53,7 +53,7 @@ func (m *mailboxSession) start(session *Session, keys := mailbox.NewConnData( ecdh, session.RemotePublicKey, session.PairingSecret[:], authData, func(key *btcec.PublicKey) error { - return onUpdate(ctx, session.LocalPublicKey, key) + return onUpdate(ctx, session.ID, key) }, nil, ) @@ -112,7 +112,7 @@ func NewServer(serverCreator GRPCServerCreator) *Server { } func (s *Server) StartSession(session *Session, authData []byte, - onUpdate func(ctx context.Context, local, + onUpdate func(ctx context.Context, id ID, remote *btcec.PublicKey) error, onNewStatus func(s mailbox.ServerStatus)) (chan struct{}, error) { diff --git a/session/store_test.go b/session/store_test.go index 846dc4e07..6c7ec7176 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -29,7 +29,7 @@ func TestBasicSessionStore(t *testing.T) { db := NewTestDB(t, clock) // Try fetch a session that doesn't exist yet. - _, err := db.GetSessionByID(ctx, ID{1, 3, 4, 4}) + _, err := db.GetSession(ctx, ID{1, 3, 4, 4}) require.ErrorIs(t, err, ErrSessionNotFound) // Reserve a session. This should succeed. @@ -37,7 +37,7 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) // Show that the session starts in the reserved state. - s1, err = db.GetSessionByID(ctx, s1.ID) + s1, err = db.GetSession(ctx, s1.ID) require.NoError(t, err) require.Equal(t, StateReserved, s1.State) @@ -46,7 +46,7 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) // Show that the session is now in the created state. - s1, err = db.GetSessionByID(ctx, s1.ID) + s1, err = db.GetSession(ctx, s1.ID) require.NoError(t, err) require.Equal(t, StateCreated, s1.State) @@ -82,17 +82,17 @@ func TestBasicSessionStore(t *testing.T) { // Ensure that we can retrieve each session by both its local pub key // and by its ID. for _, s := range []*Session{s1, s2, s3} { - session, err := db.GetSession(ctx, s.LocalPublicKey) + session, err := db.GetSessionByLocalPub(ctx, s.LocalPublicKey) require.NoError(t, err) assertEqualSessions(t, s, session) - session, err = db.GetSessionByID(ctx, s.ID) + session, err = db.GetSession(ctx, s.ID) require.NoError(t, err) assertEqualSessions(t, s, session) } // Fetch session 1 and assert that it currently has no remote pub key. - session1, err := db.GetSession(ctx, s1.LocalPublicKey) + session1, err := db.GetSessionByLocalPub(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Nil(t, session1.RemotePublicKey) @@ -101,13 +101,11 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) remotePub := remotePriv.PubKey() - err = db.UpdateSessionRemotePubKey( - ctx, session1.LocalPublicKey, remotePub, - ) + err = db.UpdateSessionRemotePubKey(ctx, session1.ID, remotePub) require.NoError(t, err) // Assert that the session now does have the remote pub key. - session1, err = db.GetSession(ctx, s1.LocalPublicKey) + session1, err = db.GetSessionByLocalPub(ctx, s1.LocalPublicKey) require.NoError(t, err) require.True(t, remotePub.IsEqual(session1.RemotePublicKey)) @@ -116,7 +114,7 @@ func TestBasicSessionStore(t *testing.T) { // Now revoke the session and assert that the state is revoked. require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked)) - s1, err = db.GetSession(ctx, s1.LocalPublicKey) + s1, err = db.GetSessionByLocalPub(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, s1.State, StateRevoked) @@ -299,7 +297,7 @@ func TestStateShift(t *testing.T) { // Check that the session is in the StateCreated state. Also check that // the "RevokedAt" time has not yet been set. - s1, err := db.GetSession(ctx, s1.LocalPublicKey) + s1, err := db.GetSessionByLocalPub(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, StateCreated, s1.State) require.Equal(t, time.Time{}, s1.RevokedAt) @@ -310,7 +308,7 @@ func TestStateShift(t *testing.T) { // This should have worked. Since it is now in a terminal state, the // "RevokedAt" time should be set. - s1, err = db.GetSession(ctx, s1.LocalPublicKey) + s1, err = db.GetSessionByLocalPub(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, StateRevoked, s1.State) require.True(t, clock.Now().Equal(s1.RevokedAt)) @@ -361,7 +359,7 @@ func TestLinkedAccount(t *testing.T) { }) // Make sure that a fetched session includes the account ID. - s1, err = db.GetSessionByID(ctx, s1.ID) + s1, err = db.GetSession(ctx, s1.ID) require.NoError(t, err) require.True(t, s1.AccountID.IsSome()) s1.AccountID.WhenSome(func(id accounts.AccountID) { @@ -453,7 +451,7 @@ func createSession(t *testing.T, db Store, label string, err = db.ShiftState(context.Background(), s.ID, StateCreated) require.NoError(t, err) - s, err = db.GetSessionByID(context.Background(), s.ID) + s, err = db.GetSession(context.Background(), s.ID) require.NoError(t, err) return s diff --git a/session_rpcserver.go b/session_rpcserver.go index a007a9560..092ca2e7e 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -350,7 +350,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, // Re-fetch the session to get the latest state of it before marshaling // it. - sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID) + sess, err = s.cfg.db.GetSession(ctx, sess.ID) if err != nil { return nil, fmt.Errorf("error fetching session: %v", err) } @@ -577,7 +577,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context, return nil, fmt.Errorf("error parsing public key: %v", err) } - sess, err := s.cfg.db.GetSession(ctx, pubKey) + sess, err := s.cfg.db.GetSessionByLocalPub(ctx, pubKey) if err != nil { return nil, fmt.Errorf("error fetching session: %v", err) } @@ -882,7 +882,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, copy(groupID[:], req.LinkedGroupId) // Check that the group actually does exist. - groupSess, err := s.cfg.db.GetSessionByID(ctx, groupID) + groupSess, err := s.cfg.db.GetSession(ctx, groupID) if err != nil { return nil, err } @@ -1245,9 +1245,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, "autopilot server: %v", err) } - err = s.cfg.db.UpdateSessionRemotePubKey( - ctx, sess.LocalPublicKey, remoteKey, - ) + err = s.cfg.db.UpdateSessionRemotePubKey(ctx, sess.ID, remoteKey) if err != nil { return nil, fmt.Errorf("error setting remote pubkey: %v", err) } @@ -1269,7 +1267,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, // Re-fetch the session to get the latest state of it before marshaling // it. - sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID) + sess, err = s.cfg.db.GetSession(ctx, sess.ID) if err != nil { return nil, fmt.Errorf("error fetching session: %v", err) } @@ -1319,7 +1317,7 @@ func (s *sessionRpcServer) RevokeAutopilotSession(ctx context.Context, return nil, fmt.Errorf("error parsing public key: %v", err) } - sess, err := s.cfg.db.GetSession(ctx, pubKey) + sess, err := s.cfg.db.GetSessionByLocalPub(ctx, pubKey) if err != nil { return nil, err }