From 5ed5fd25fa8188f48fb4c02960d2052c604b41b6 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 27 Feb 2025 08:05:03 +0200 Subject: [PATCH 1/2] session: use error variables In preparation for having the unit tests pass against a different Store implementation, we standardize some of the errors that get returned. --- itest/litd_firewall_test.go | 4 +++- session/errors.go | 12 ++++++++++++ session/kvdb_store.go | 29 +++++++++++++++++------------ session/store_test.go | 10 +++++++--- 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/itest/litd_firewall_test.go b/itest/litd_firewall_test.go index 98b583132..5c82bd4b3 100644 --- a/itest/litd_firewall_test.go +++ b/itest/litd_firewall_test.go @@ -866,7 +866,9 @@ func testSessionLinking(net *NetworkHarness, t *harnessTest) { LinkedGroupId: sessResp.Session.GroupId, }, ) - require.ErrorContains(t.t, err, "is still active") + require.ErrorContains( + t.t, err, session.ErrSessionsInGroupStillActive.Error(), + ) // Revoke the previous one and repeat. _, err = litAutopilotClient.RevokeAutopilotSession( diff --git a/session/errors.go b/session/errors.go index 560a6c2bc..1cb97a8f2 100644 --- a/session/errors.go +++ b/session/errors.go @@ -6,4 +6,16 @@ var ( // ErrSessionNotFound is an error returned when we attempt to retrieve // information about a session but it is not found. ErrSessionNotFound = errors.New("session not found") + + // ErrUnknownGroup is returned when an attempt is made to insert a + // session and link it to an existing group where the group is not + // known. + ErrUnknownGroup = errors.New("unknown group") + + // ErrSessionsInGroupStillActive is returned when an attempt is made to + // insert a session and link it to a group that still has other active + // sessions. + ErrSessionsInGroupStillActive = errors.New( + "group has active sessions", + ) ) diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 84f4dce06..96c98d29c 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -229,8 +229,9 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, if session.ID != session.GroupID { _, err = getKeyForID(sessionBucket, session.GroupID) if err != nil { - return fmt.Errorf("unknown linked session "+ - "%x: %w", session.GroupID, err) + return fmt.Errorf("%w: unknown linked "+ + "session %x: %w", ErrUnknownGroup, + session.GroupID, err) } // Fetch all the session IDs for this group. This will @@ -242,18 +243,22 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, return err } + // Ensure that the all the linked sessions are no longer + // active. for _, id := range sessionIDs { sess, err := getSessionByID(sessionBucket, id) if err != nil { return err } - // Ensure that the session is no longer active. - if !sess.State.Terminal() { - return fmt.Errorf("session (id=%x) "+ - "in group %x is still active", - sess.ID, sess.GroupID) + if sess.State.Terminal() { + continue } + + return fmt.Errorf("%w: session (id=%x) in "+ + "group %x is still active", + ErrSessionsInGroupStillActive, sess.ID, + sess.GroupID) } } @@ -630,14 +635,14 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) { sessionIDBkt := idIndex.Bucket(sessionID[:]) if sessionIDBkt == nil { - return fmt.Errorf("no index entry for session ID: %x", - sessionID) + return fmt.Errorf("%w: no index entry for session "+ + "ID: %x", ErrUnknownGroup, sessionID) } groupIDBytes := sessionIDBkt.Get(groupIDKey) if len(groupIDBytes) == 0 { - return fmt.Errorf("group ID not found for session "+ - "ID %x", sessionID) + return fmt.Errorf("%w: group ID not found for "+ + "session ID %x", ErrUnknownGroup, sessionID) } copy(groupID[:], groupIDBytes) @@ -806,7 +811,7 @@ func addIDToGroupIDPair(sessionBkt *bbolt.Bucket, id, groupID ID) error { func getSessionByID(bucket *bbolt.Bucket, id ID) (*Session, error) { keyBytes, err := getKeyForID(bucket, id) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %w", ErrSessionNotFound, err) } v := bucket.Get(keyBytes) diff --git a/session/store_test.go b/session/store_test.go index aa7afe20f..54a8c6139 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -22,6 +22,10 @@ func TestBasicSessionStore(t *testing.T) { _ = db.Close() }) + // Try fetch a session that doesn't exist yet. + _, err = db.GetSessionByID(ID{1, 3, 4, 4}) + require.ErrorIs(t, err, ErrSessionNotFound) + // Reserve a session. This should succeed. s1, err := reserveSession(db, "session 1") require.NoError(t, err) @@ -183,7 +187,7 @@ func TestBasicSessionStore(t *testing.T) { require.Empty(t, sessions) _, err = db.GetGroupID(s4.ID) - require.ErrorContains(t, err, "no index entry") + require.ErrorIs(t, err, ErrUnknownGroup) // Only session 1 should remain in this group. sessIDs, err = db.GetSessionIDs(s4.GroupID) @@ -211,7 +215,7 @@ func TestLinkingSessions(t *testing.T) { _, err = reserveSession( db, "session 2", withLinkedGroupID(&groupID), ) - require.ErrorContains(t, err, "unknown linked session") + require.ErrorIs(t, err, ErrUnknownGroup) // Create a new session with no previous link. s1 := createSession(t, db, "session 1") @@ -220,7 +224,7 @@ func TestLinkingSessions(t *testing.T) { // session. This should fail due to the first session still being // active. _, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID)) - require.ErrorContains(t, err, "is still active") + require.ErrorIs(t, err, ErrSessionsInGroupStillActive) // Revoke the first session. require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) From 44625c32e1aafa3f4a6d37765bf4cc9d06090838 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 27 Feb 2025 08:10:35 +0200 Subject: [PATCH 2/2] session: add DB constructor helpers for tests In preparation for adding helpers with the same names but that will compile under different build flags, we add the helper DB constructors to use when testing the session store logic against a KVDB backend. --- session/store_test.go | 28 ++++++---------------------- session/test_kvdb.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 22 deletions(-) create mode 100644 session/test_kvdb.go diff --git a/session/store_test.go b/session/store_test.go index 54a8c6139..966b03962 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -16,14 +16,10 @@ var testTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) func TestBasicSessionStore(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) // Try fetch a session that doesn't exist yet. - _, err = db.GetSessionByID(ID{1, 3, 4, 4}) + _, err := db.GetSessionByID(ID{1, 3, 4, 4}) require.ErrorIs(t, err, ErrSessionNotFound) // Reserve a session. This should succeed. @@ -201,11 +197,7 @@ func TestLinkingSessions(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) groupID, err := IDFromBytes([]byte{1, 2, 3, 4}) require.NoError(t, err) @@ -242,11 +234,7 @@ func TestLinkedSessions(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) // Create a few sessions. The first one is a new session and the two // after are all linked to the prior one. All these sessions belong to @@ -298,18 +286,14 @@ func TestLinkedSessions(t *testing.T) { func TestStateShift(t *testing.T) { // Set up a new DB. clock := clock.NewTestClock(testTime) - db, err := NewDB(t.TempDir(), "test.db", clock) - require.NoError(t, err) - t.Cleanup(func() { - _ = db.Close() - }) + db := NewTestDB(t, clock) // Add a new session to the DB. s1 := createSession(t, db, "label 1") // Check that the session is in the StateCreated state. Also check that // the "RevokedAt" time has not yet been set. - s1, err = db.GetSession(s1.LocalPublicKey) + s1, err := db.GetSession(s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, StateCreated, s1.State) require.Equal(t, time.Time{}, s1.RevokedAt) diff --git a/session/test_kvdb.go b/session/test_kvdb.go new file mode 100644 index 000000000..6f270d617 --- /dev/null +++ b/session/test_kvdb.go @@ -0,0 +1,28 @@ +package session + +import ( + "testing" + + "github.com/lightningnetwork/lnd/clock" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T, clock clock.Clock) *BoltStore { + return NewTestDBFromPath(t, t.TempDir(), clock) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) *BoltStore { + + store, err := NewDB(dbPath, DBFilename, clock) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, store.DB.Close()) + }) + + return store +}