@@ -20,8 +20,8 @@ import (
2020 "bytes"
2121 "context"
2222 "encoding/hex"
23+ "github.com/google/btree"
2324 "math"
24- "sort"
2525 "strconv"
2626 "sync"
2727 "sync/atomic"
@@ -245,6 +245,13 @@ func (o *oracle) doneCommit(cts uint64) {
245245 o .txnMark .Done (cts )
246246}
247247
248+ type EntryItem Entry
249+
250+ func (e * EntryItem ) Less (than btree.Item ) bool {
251+ other := than .(* EntryItem )
252+ return bytes .Compare (e .Key , other .Key ) < 0
253+ }
254+
248255// Txn represents a Badger transaction.
249256type Txn struct {
250257 readTs uint64
@@ -258,8 +265,8 @@ type Txn struct {
258265 conflictKeys map [uint64 ]struct {}
259266 readsLock sync.Mutex // guards the reads slice. See addReadKey.
260267
261- pendingWrites map [ string ] * Entry // cache stores any writes done by txn.
262- duplicateWrites []* Entry // Used in managed mode to store duplicate entries.
268+ pendingWrites * btree. BTree // cache stores any writes done by txn.
269+ duplicateWrites []* Entry // Used in managed mode to store duplicate entries.
263270
264271 numIterators int32
265272 discarded bool
@@ -268,40 +275,74 @@ type Txn struct {
268275}
269276
270277type pendingWritesIterator struct {
271- entries []* Entry
272- nextIdx int
278+ entries * btree.BTree
279+ itClose chan <- interface {}
280+ it <- chan * Entry
281+ current * Entry
273282 readTs uint64
274283 reversed bool
275284}
276285
277286func (pi * pendingWritesIterator ) Next () {
278- pi .nextIdx ++
287+ pi .current = <- pi .it
288+ }
289+
290+ func (pi * pendingWritesIterator ) reset (itConsumer func (it btree.ItemIterator )) {
291+ if pi .it != nil {
292+ close (pi .itClose )
293+ }
294+ it := make (chan * Entry , 0 )
295+ itClose := make (chan interface {}, 0 )
296+ go func () {
297+ itFunc := func (i btree.Item ) bool {
298+ select {
299+ case it <- (* Entry )(i .(* EntryItem )):
300+ return true
301+ case <- itClose :
302+ return false
303+ }
304+ }
305+ itConsumer (itFunc )
306+ close (it )
307+ }()
308+
309+ pi .it = it
310+ pi .itClose = itClose
311+
312+ pi .Next ()
279313}
280314
281315func (pi * pendingWritesIterator ) Rewind () {
282- pi .nextIdx = 0
316+ pi .reset (func (itFunc btree.ItemIterator ) {
317+ if pi .reversed {
318+ pi .entries .Descend (itFunc )
319+ } else {
320+ pi .entries .Ascend (itFunc )
321+ }
322+ })
283323}
284324
285325func (pi * pendingWritesIterator ) Seek (key []byte ) {
286326 key = y .ParseKey (key )
287- pi .nextIdx = sort .Search (len (pi .entries ), func (idx int ) bool {
288- cmp := bytes .Compare (pi .entries [idx ].Key , key )
289- if ! pi .reversed {
290- return cmp >= 0
327+ pivot := & EntryItem {Key : key }
328+
329+ pi .reset (func (itFunc btree.ItemIterator ) {
330+ if pi .reversed {
331+ pi .entries .DescendLessOrEqual (pivot , itFunc )
332+ } else {
333+ pi .entries .AscendGreaterOrEqual (pivot , itFunc )
291334 }
292- return cmp <= 0
293335 })
294336}
295337
296338func (pi * pendingWritesIterator ) Key () []byte {
297339 y .AssertTrue (pi .Valid ())
298- entry := pi .entries [pi .nextIdx ]
299- return y .KeyWithTs (entry .Key , pi .readTs )
340+ return y .KeyWithTs (pi .current .Key , pi .readTs )
300341}
301342
302343func (pi * pendingWritesIterator ) Value () y.ValueStruct {
303344 y .AssertTrue (pi .Valid ())
304- entry := pi .entries [ pi . nextIdx ]
345+ entry := pi .current
305346 return y.ValueStruct {
306347 Value : entry .Value ,
307348 Meta : entry .meta ,
@@ -312,32 +353,36 @@ func (pi *pendingWritesIterator) Value() y.ValueStruct {
312353}
313354
314355func (pi * pendingWritesIterator ) Valid () bool {
315- return pi .nextIdx < len ( pi . entries )
356+ return pi .current != nil
316357}
317358
318359func (pi * pendingWritesIterator ) Close () error {
319360 return nil
320361}
321362
322363func (txn * Txn ) newPendingWritesIterator (reversed bool ) * pendingWritesIterator {
323- if ! txn .update || len ( txn .pendingWrites ) == 0 {
364+ if ! txn .update || txn .pendingWrites . Len ( ) == 0 {
324365 return nil
325366 }
326- entries := make ([]* Entry , 0 , len (txn .pendingWrites ))
327- for _ , e := range txn .pendingWrites {
328- entries = append (entries , e )
329- }
330- // Number of pending writes per transaction shouldn't be too big in general.
331- sort .Slice (entries , func (i , j int ) bool {
332- cmp := bytes .Compare (entries [i ].Key , entries [j ].Key )
333- if ! reversed {
334- return cmp < 0
335- }
336- return cmp > 0
337- })
367+ //entries := make([]*Entry, 0, txn.pendingWrites.Len())
368+ //
369+ //txn.pendingWrites.Ascend(func(i btree.Item) bool {
370+ // entries = append(entries, (*Entry)(i.(*EntryItem)))
371+ // return true
372+ //})
373+ //
374+ //// Number of pending writes per transaction shouldn't be too big in general.
375+ //sort.Slice(entries, func(i, j int) bool {
376+ // cmp := bytes.Compare(entries[i].Key, entries[j].Key)
377+ // if !reversed {
378+ // return cmp < 0
379+ // }
380+ // return cmp > 0
381+ //})
382+
338383 return & pendingWritesIterator {
339384 readTs : txn .readTs ,
340- entries : entries ,
385+ entries : txn . pendingWrites . Clone () ,
341386 reversed : reversed ,
342387 }
343388}
@@ -381,6 +426,14 @@ func ValidEntry(db *DB, key, val []byte) error {
381426 return nil
382427}
383428
429+ func (txn * Txn ) getFromPendingWrites (key []byte ) (* Entry , bool ) {
430+ result := txn .pendingWrites .Get (& EntryItem {Key : key })
431+ if result == nil {
432+ return nil , false
433+ }
434+ return (* Entry )(result .(* EntryItem )), true
435+ }
436+
384437func (txn * Txn ) modify (e * Entry ) error {
385438 switch {
386439 case ! txn .update :
@@ -418,10 +471,10 @@ func (txn *Txn) modify(e *Entry) error {
418471 // If a duplicate entry was inserted in managed mode, move it to the duplicate writes slice.
419472 // Add the entry to duplicateWrites only if both the entries have different versions. For
420473 // same versions, we will overwrite the existing entry.
421- if oldEntry , ok := txn .pendingWrites [ string (e .Key )] ; ok && oldEntry .version != e .version {
474+ if oldEntry , ok := txn .getFromPendingWrites (e .Key ); ok && oldEntry .version != e .version {
422475 txn .duplicateWrites = append (txn .duplicateWrites , oldEntry )
423476 }
424- txn .pendingWrites [ string ( e . Key )] = e
477+ txn .pendingWrites . ReplaceOrInsert (( * EntryItem )( e ))
425478 return nil
426479}
427480
@@ -474,7 +527,7 @@ func (txn *Txn) Get(key []byte) (item *Item, rerr error) {
474527
475528 item = new (Item )
476529 if txn .update {
477- if e , has := txn .pendingWrites [ string (key )] ; has && bytes .Equal (key , e .Key ) {
530+ if e , has := txn .getFromPendingWrites (key ); has && bytes .Equal (key , e .Key ) {
478531 if isDeletedOrExpired (e .meta , e .ExpiresAt ) {
479532 return nil , ErrKeyNotFound
480533 }
@@ -570,16 +623,18 @@ func (txn *Txn) commitAndSend() (func() error, error) {
570623 keepTogether = false
571624 }
572625 }
573- for _ , e := range txn .pendingWrites {
574- setVersion (e )
575- }
626+ txn .pendingWrites .Ascend (func (i btree.Item ) bool {
627+ setVersion ((* Entry )(i .(* EntryItem )))
628+ return true
629+ })
630+
576631 // The duplicateWrites slice will be non-empty only if there are duplicate
577632 // entries with different versions.
578633 for _ , e := range txn .duplicateWrites {
579634 setVersion (e )
580635 }
581636
582- entries := make ([]* Entry , 0 , len ( txn .pendingWrites )+ len (txn .duplicateWrites )+ 1 )
637+ entries := make ([]* Entry , 0 , txn .pendingWrites . Len ( )+ len (txn .duplicateWrites )+ 1 )
583638
584639 processEntry := func (e * Entry ) {
585640 // Suffix the keys with commit ts, so the key versions are sorted in
@@ -602,9 +657,10 @@ func (txn *Txn) commitAndSend() (func() error, error) {
602657 // var b strings.Builder
603658 // fmt.Fprintf(&b, "Read: %d. Commit: %d. reads: %v. writes: %v. Keys: ",
604659 // txn.readTs, commitTs, txn.reads, txn.conflictKeys)
605- for _ , e := range txn .pendingWrites {
606- processEntry (e )
607- }
660+ txn .pendingWrites .Ascend (func (i btree.Item ) bool {
661+ processEntry ((* Entry )(i .(* EntryItem )))
662+ return true
663+ })
608664 for _ , e := range txn .duplicateWrites {
609665 processEntry (e )
610666 }
@@ -641,11 +697,14 @@ func (txn *Txn) commitPrecheck() error {
641697 return errors .New ("Trying to commit a discarded txn" )
642698 }
643699 keepTogether := true
644- for _ , e := range txn .pendingWrites {
700+ txn .pendingWrites .Ascend (func (i btree.Item ) bool {
701+ e := (* Entry )(i .(* EntryItem ))
645702 if e .version != 0 {
646703 keepTogether = false
704+ return false
647705 }
648- }
706+ return true
707+ })
649708
650709 // If keepTogether is True, it implies transaction markers will be added.
651710 // In that case, commitTs should not be never be zero. This might happen if
@@ -679,7 +738,7 @@ func (txn *Txn) commitPrecheck() error {
679738func (txn * Txn ) Commit () error {
680739 // txn.conflictKeys can be zero if conflict detection is turned off. So we
681740 // should check txn.pendingWrites.
682- if len ( txn .pendingWrites ) == 0 {
741+ if txn .pendingWrites . Len ( ) == 0 {
683742 return nil // Nothing to do.
684743 }
685744 // Precheck before discarding txn.
@@ -730,7 +789,7 @@ func (txn *Txn) CommitWith(cb func(error)) {
730789 panic ("Nil callback provided to CommitWith" )
731790 }
732791
733- if len ( txn .pendingWrites ) == 0 {
792+ if txn .pendingWrites . Len ( ) == 0 {
734793 // Do not run these callbacks from here, because the CommitWith and the
735794 // callback might be acquiring the same locks. Instead run the callback
736795 // from another goroutine.
@@ -800,7 +859,7 @@ func (db *DB) newTransaction(update, isManaged bool) *Txn {
800859 if db .opt .DetectConflicts {
801860 txn .conflictKeys = make (map [uint64 ]struct {})
802861 }
803- txn .pendingWrites = make ( map [ string ] * Entry )
862+ txn .pendingWrites = btree . New ( 5 )
804863 }
805864 if ! isManaged {
806865 txn .readTs = db .orc .readTs ()
0 commit comments