Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions db/sqlc/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions db/sqlc/queries/sessions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ WHERE id = $2;
DELETE FROM sessions
WHERE state = $1;

-- name: DeleteSession :exec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could maybe in future make things more defensive (both here and above by restricting this to only allowing it for "where status=reserved" - but i guess it's ok for now since your store interface does this protection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, I see your point enforcing that on the db layer would be safer, one downside of this is that we'd need to hard code an integer here for the type (i.e. 4), so we'd lose some type safety, right? for this reason I kept the current approach, but I'll definitely consider that approach in future work 🙏

DELETE FROM sessions
WHERE id = $1;

-- name: GetSessionByLocalPublicKey :one
SELECT * FROM sessions
WHERE local_public_key = $1;
Expand Down
10 changes: 10 additions & 0 deletions db/sqlc/sessions.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ type Store interface {
// StateReserved state.
DeleteReservedSessions(ctx context.Context) error

// DeleteReservedSession deletes the session with the given ID if it is
// in the StateReserved state.
DeleteReservedSession(ctx context.Context, id ID) error

// ShiftState updates the state of the session with the given ID to the
// "dest" state.
ShiftState(ctx context.Context, id ID, dest State) error
Expand Down
154 changes: 104 additions & 50 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,10 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
return err
}

return sessionBucket.ForEach(func(k, v []byte) error {
// We create a copy of the sessions to delete so that we are
// not iterating and modifying the bucket at the same time.
var sessionsToDelete []*Session
err = sessionBucket.ForEach(func(k, v []byte) error {
// We'll also get buckets here, skip those (identified
// by nil value).
if v == nil {
Expand All @@ -458,69 +461,120 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
return nil
}

err = sessionBucket.Delete(k)
if err != nil {
return err
}
sessionsToDelete = append(sessionsToDelete, session)

idIndexBkt := sessionBucket.Bucket(idIndexKey)
if idIndexBkt == nil {
return ErrDBInitErr
}
return nil
})
if err != nil {
return err
}

// Delete the entire session ID bucket.
err = idIndexBkt.DeleteBucket(session.ID[:])
if err != nil {
for _, session := range sessionsToDelete {
if err := deleteSession(sessionBucket,
session); err != nil {
return err
}
}

groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey)
if groupIdIndexBkt == nil {
return ErrDBInitErr
}
return nil
})
}

groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:])
if groupBkt == nil {
return ErrDBInitErr
}
// deleteSession deletes all the parts of a session from the database. This
// assumes that the session has already been fetched from the db.
func deleteSession(sessionBucket *bbolt.Bucket, session *Session) error {
sessionKey := getSessionKey(session)
err := sessionBucket.Delete(sessionKey)
if err != nil {
return err
}

sessionIDsBkt := groupBkt.Bucket(sessionIDKey)
if sessionIDsBkt == nil {
return ErrDBInitErr
}
idIndexBkt := sessionBucket.Bucket(idIndexKey)
if idIndexBkt == nil {
return ErrDBInitErr
}

var (
seqKey []byte
numSessions int
)
err = sessionIDsBkt.ForEach(func(k, v []byte) error {
numSessions++
// Delete the entire session ID bucket.
err = idIndexBkt.DeleteBucket(session.ID[:])
if err != nil {
return err
}

if !bytes.Equal(v, session.ID[:]) {
return nil
}
groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey)
if groupIdIndexBkt == nil {
return ErrDBInitErr
}

seqKey = k
groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:])
if groupBkt == nil {
return ErrDBInitErr
}

return nil
})
if err != nil {
return err
}
sessionIDsBkt := groupBkt.Bucket(sessionIDKey)
if sessionIDsBkt == nil {
return ErrDBInitErr
}

if numSessions == 0 {
return fmt.Errorf("no sessions found for "+
"group ID %x", session.GroupID)
}
var (
seqKey []byte
numSessions int
)
err = sessionIDsBkt.ForEach(func(k, v []byte) error {
numSessions++

if numSessions == 1 {
// Delete the whole group bucket.
return groupBkt.DeleteBucket(sessionIDKey)
}
if !bytes.Equal(v, session.ID[:]) {
return nil
}

// Else, delete just the session ID entry.
return sessionIDsBkt.Delete(seqKey)
})
seqKey = k

return nil
})
if err != nil {
return err
}

if numSessions == 0 {
return fmt.Errorf("no sessions found for "+
"group ID %x", session.GroupID)
}

if numSessions == 1 {
// If this is the last session in the group, we can delete the
// whole group bucket.
return groupIdIndexBkt.DeleteBucket(session.GroupID[:])
}

// Else, delete just the session ID entry from the group.
return sessionIDsBkt.Delete(seqKey)
}

// DeleteReservedSession removes a given session that is in the reserved state
// from the database.
//
// NOTE: This is part of the Store interface.
func (db *BoltStore) DeleteReservedSession(_ context.Context, id ID) error {
return db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
return err
}

// We'll first get the session to make sure it's actually in the
// reserved state before deleting. This gives us a slightly
// better error message than just trying to delete and getting a
// "not found" if the session was in another state.
session, err := getSessionByID(sessionBucket, id)
if err != nil {
return err
}

if session.State != StateReserved {
return fmt.Errorf("session not in reserved state, is "+
"%v", session.State)
}

return deleteSession(sessionBucket, session)
Comment on lines +572 to +577
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as mentioned above, could just make the db query enforce this constraint

})
}

Expand Down
25 changes: 25 additions & 0 deletions session/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type SQLQueries interface {
SetSessionGroupID(ctx context.Context, arg sqlc.SetSessionGroupIDParams) error
UpdateSessionState(ctx context.Context, arg sqlc.UpdateSessionStateParams) error
DeleteSessionsWithState(ctx context.Context, state int16) error
DeleteSession(ctx context.Context, id int64) error
GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error)
GetAccount(ctx context.Context, id int64) (sqlc.Account, error)
}
Expand Down Expand Up @@ -431,6 +432,30 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error {
})
}

// DeleteReservedSession removes a given session that is in the reserved state
// from the database.
//
// NOTE: This is part of the Store interface.
func (s *SQLStore) DeleteReservedSession(ctx context.Context, id ID) error {
var writeTxOpts db.QueriesTxOptions
return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error {
session, err := db.GetSessionByAlias(ctx, id[:])
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("%w: unable to get session: %w",
ErrSessionNotFound, err)
} else if err != nil {
return fmt.Errorf("unable to get session: %w", err)
}

if State(session.State) != StateReserved {
return fmt.Errorf("session not in reserved state, is "+
"%v", State(session.State))
}

return db.DeleteSession(ctx, session.ID)
})
}

// GetSessionByLocalPub fetches the session with the given local pub key.
//
// NOTE: This is part of the Store interface.
Expand Down
26 changes: 26 additions & 0 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ func TestBasicSessionStore(t *testing.T) {
// of the sessions are reserved.
require.NoError(t, db.DeleteReservedSessions(ctx))

// Explicitly trying to delete session 1 should fail as it's not
// reserved.
require.Error(t, db.DeleteReservedSession(ctx, s1.ID))

sessions, err = db.ListSessionsByState(ctx, StateReserved)
require.NoError(t, err)
require.Empty(t, sessions)
Expand Down Expand Up @@ -192,6 +196,28 @@ func TestBasicSessionStore(t *testing.T) {
_, err = db.GetGroupID(ctx, s4.ID)
require.ErrorIs(t, err, ErrSessionNotFound)

// Reserve a new session and link it to session 1.
s5, err := reserveSession(
db, "session 5", withLinkedGroupID(&session1.GroupID),
)
require.NoError(t, err)
sessions, err = db.ListSessionsByState(ctx, StateReserved)
require.NoError(t, err)
require.Equal(t, 1, len(sessions))
assertEqualSessions(t, s5, sessions[0])

// Now delete the reserved session by its ID and show that it is no
// longer in the database and no longer in the group ID/session ID
// index.
require.NoError(t, db.DeleteReservedSession(ctx, s5.ID))

sessions, err = db.ListSessionsByState(ctx, StateReserved)
require.NoError(t, err)
require.Empty(t, sessions)

_, err = db.GetGroupID(ctx, s5.ID)
require.ErrorIs(t, err, ErrSessionNotFound)

// Only session 1 should remain in this group.
sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID)
require.NoError(t, err)
Expand Down
17 changes: 17 additions & 0 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,23 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
privacyFlags.Serialize(),
)
if err != nil {
// If we tried to link to a previous session, we delete the
// newly created session in the case of errors to avoid having
// non-revoked sessions lying around.
if len(req.LinkedGroupId) != 0 {
log.Infof("Session registration with autopilot " +
"server failed, deleting the newly created " +
"session")

deleteErr := s.cfg.db.DeleteReservedSession(
ctx, sess.ID,
)
if deleteErr != nil {
log.Errorf("Error deleting session after "+
"failed linking attempt: %v", deleteErr)
}
}

return nil, fmt.Errorf("error registering session with "+
"autopilot server: %v", err)
}
Expand Down
Loading