-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
client_db.go
145 lines (118 loc) · 3.82 KB
/
client_db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package migration1
import (
"bytes"
"errors"
"fmt"
"github.com/lightningnetwork/lnd/kvdb"
)
var (
// cSessionBkt is a top-level bucket storing:
// session-id => cSessionBody -> encoded ClientSessionBody
// => cSessionCommits => seqnum -> encoded CommittedUpdate
// => cSessionAcks => seqnum -> encoded BackupID
cSessionBkt = []byte("client-session-bucket")
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
// the ClientSession.
cSessionBody = []byte("client-session-body")
// cTowerIDToSessionIDIndexBkt is a top-level bucket storing:
// tower-id -> session-id -> 1
cTowerIDToSessionIDIndexBkt = []byte(
"client-tower-to-session-index-bucket",
)
// ErrUninitializedDB signals that top-level buckets for the database
// have not been initialized.
ErrUninitializedDB = errors.New("db not initialized")
// ErrClientSessionNotFound signals that the requested client session
// was not found in the database.
ErrClientSessionNotFound = errors.New("client session not found")
// ErrCorruptClientSession signals that the client session's on-disk
// structure deviates from what is expected.
ErrCorruptClientSession = errors.New("client session corrupted")
)
// MigrateTowerToSessionIndex constructs a new towerID-to-sessionID for the
// watchtower client DB.
func MigrateTowerToSessionIndex(tx kvdb.RwTx) error {
log.Infof("Migrating the tower client db to add a " +
"towerID-to-sessionID index")
// First, we collect all the entries we want to add to the index.
entries, err := getIndexEntries(tx)
if err != nil {
return err
}
// Then we create a new top-level bucket for the index.
indexBkt, err := tx.CreateTopLevelBucket(cTowerIDToSessionIDIndexBkt)
if err != nil {
return err
}
// Finally, we add all the collected entries to the index.
for towerID, sessions := range entries {
// Create a sub-bucket using the tower ID.
towerBkt, err := indexBkt.CreateBucketIfNotExists(
towerID.Bytes(),
)
if err != nil {
return err
}
for sessionID := range sessions {
err := addIndex(towerBkt, sessionID)
if err != nil {
return err
}
}
}
return nil
}
// addIndex adds a new towerID-sessionID pair to the given bucket. The
// session ID is used as a key within the bucket and a value of []byte{1} is
// used for each session ID key.
func addIndex(towerBkt kvdb.RwBucket, sessionID SessionID) error {
session := towerBkt.Get(sessionID[:])
if session != nil {
return fmt.Errorf("session %x duplicated", sessionID)
}
return towerBkt.Put(sessionID[:], []byte{1})
}
// getIndexEntries collects all the towerID-sessionID entries that need to be
// added to the new index.
func getIndexEntries(tx kvdb.RwTx) (map[TowerID]map[SessionID]bool, error) {
sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil {
return nil, ErrUninitializedDB
}
index := make(map[TowerID]map[SessionID]bool)
err := sessions.ForEach(func(k, _ []byte) error {
session, err := getClientSession(sessions, k)
if err != nil {
return err
}
if index[session.TowerID] == nil {
index[session.TowerID] = make(map[SessionID]bool)
}
index[session.TowerID][session.ID] = true
return nil
})
if err != nil {
return nil, err
}
return index, nil
}
// getClientSession fetches the session with the given ID from the db.
func getClientSession(sessions kvdb.RBucket, idBytes []byte) (*ClientSession,
error) {
sessionBkt := sessions.NestedReadBucket(idBytes)
if sessionBkt == nil {
return nil, ErrClientSessionNotFound
}
// Should never have a sessionBkt without also having its body.
sessionBody := sessionBkt.Get(cSessionBody)
if sessionBody == nil {
return nil, ErrCorruptClientSession
}
var session ClientSession
copy(session.ID[:], idBytes)
err := session.Decode(bytes.NewReader(sessionBody))
if err != nil {
return nil, err
}
return &session, nil
}