Skip to content

Commit

Permalink
channeldb: consolidate root bucket TLVs into new struct
Browse files Browse the repository at this point in the history
In this commit, we consolidate the root bucket TLVs into a new struct.
This makes it easier to see all the new TLV fields at a glance. We also
convert TLV usage to use the new type param based APis.
  • Loading branch information
Roasbeef committed Mar 12, 2024
1 parent 007f968 commit 784d236
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 83 deletions.
179 changes: 98 additions & 81 deletions channeldb/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,27 +221,60 @@ const (
// A tlv type definition used to serialize an outpoint's indexStatus
// for use in the outpoint index.
indexStatusType tlv.Type = 0
)

// A tlv type definition used to serialize and deserialize a KeyLocator
// from the database.
keyLocType tlv.Type = 1
// chanAuxData houses the auxiliary data that is stored for each channel in a
// TLV stream within the root bucket. This is stored as a TLV stream appended
// to the existing hard-coded fields in the channel's root bucket.
type chanAuxData struct {
revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord]

// A tlv type used to serialize and deserialize the
// `InitialLocalBalance` field.
initialLocalBalanceType tlv.Type = 2
initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64]

// A tlv type used to serialize and deserialize the
// `InitialRemoteBalance` field.
initialRemoteBalanceType tlv.Type = 3
initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64]

// A tlv type definition used to serialize and deserialize the
// confirmed ShortChannelID for a zero-conf channel.
realScidType tlv.Type = 4
realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID]

// A tlv type definition used to serialize and deserialize the
// Memo for the channel channel.
channelMemoType tlv.Type = 5
)
memo tlv.OptionalRecordT[tlv.TlvType5, []byte]
}

// toOpeChan converts the chanAuxData to an OpenChannel by setting the relevant
// fields in the OpenChannel struct.
func (c *chanAuxData) toOpenChan(o *OpenChannel) {
o.RevocationKeyLocator = c.revokeKeyLoc.Val.KeyLocator
o.InitialLocalBalance = lnwire.MilliSatoshi(c.initialLocalBalance.Val)
o.InitialRemoteBalance = lnwire.MilliSatoshi(c.initialRemoteBalance.Val)
o.confirmedScid = c.realScid.Val
c.memo.WhenSomeV(func(memo []byte) {
o.Memo = memo
})
}

// newChanAuxDataFromChan creates a new chanAuxData from the given channel.
func newChanAuxDataFromChan(openChan *OpenChannel) *chanAuxData {
c := &chanAuxData{
revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1, keyLocRecord](
keyLocRecord{openChan.RevocationKeyLocator},
),
initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2, uint64](
uint64(openChan.InitialLocalBalance),
),
initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3, uint64](
uint64(openChan.InitialRemoteBalance),
),
realScid: tlv.NewRecordT[tlv.TlvType4, lnwire.ShortChannelID](
openChan.confirmedScid,
),
}

if len(openChan.Memo) == 0 {
c.memo = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5](openChan.Memo),
)
}

return c
}

// indexStatus is an enum-like type that describes what state the
// outpoint is in. Currently only two possible values.
Expand Down Expand Up @@ -852,6 +885,10 @@ type OpenChannel struct {
// channel that will be useful to our future selves.
Memo []byte

// TapscriptRoot is an optional tapscript root used to derive the
// musig2 funding output.
TapscriptRoot fn.Option[chainhash.Hash]

// TODO(roasbeef): eww
Db *ChannelStateDB

Expand Down Expand Up @@ -3932,26 +3969,20 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
return err
}

// Convert balance fields into uint64.
localBalance := uint64(channel.InitialLocalBalance)
remoteBalance := uint64(channel.InitialRemoteBalance)
auxData := newChanAuxDataFromChan(channel)

tlvRecords := []tlv.Record{
auxData.revokeKeyLoc.Record(),
auxData.initialLocalBalance.Record(),
auxData.initialRemoteBalance.Record(),
auxData.realScid.Record(),
}
auxData.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) {
tlvRecords = append(tlvRecords, memo.Record())
})

// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo),
)
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
Expand Down Expand Up @@ -4146,28 +4177,16 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
}
}

// Create balance fields in uint64, and Memo field as byte slice.
var (
localBalance uint64
remoteBalance uint64
memo []byte
)
var auxData chanAuxData
zeroMemo := auxData.memo.Zero()

// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &memo),
auxData.revokeKeyLoc.Record(),
auxData.initialLocalBalance.Record(),
auxData.initialRemoteBalance.Record(),
auxData.realScid.Record(),
zeroMemo.Record(),
)
if err != nil {
return err
Expand All @@ -4177,14 +4196,9 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
return err
}

// Attach the balance fields.
channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance)
channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance)

// Attach the memo field if non-empty.
if len(memo) > 0 {
channel.Memo = memo
}
// Assign all the relevant fields from the aux data into the actual
// open channel.
auxData.toOpenChan(channel)

channel.Packager = NewChannelPackager(channel.ShortChannelID)

Expand Down Expand Up @@ -4342,8 +4356,27 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error {
return chanBucket.Delete(frozenChanKey)
}

// EKeyLocator is an encoder for keychain.KeyLocator.
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the
// tlv.RecordProducer interface.
type keyLocRecord struct {
keychain.KeyLocator
}

// Record creates a Record out of a KeyLocator using the passed Type and the
// EKeyLocator and DKeyLocator functions. The size will always be 8 as
// KeyFamily is uint32 and the Index is uint32.
//
// NOTE: This is part of the tlv.RecordProducer interface.
func (k *keyLocRecord) Record() tlv.Record {
// Note that we set the type here as zero, as when used with a
// tlv.RecordT, the type param will be used as the type.
return tlv.MakeStaticRecord(
0, &k.KeyLocator, 8, eKeyLocator, dKeyLocator,
)
}

// eKeyLocator is an encoder for keychain.KeyLocator.
func eKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*keychain.KeyLocator); ok {
err := tlv.EUint32T(w, uint32(v.Family), buf)
if err != nil {
Expand All @@ -4355,8 +4388,8 @@ func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator")
}

// DKeyLocator is a decoder for keychain.KeyLocator.
func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
// dKeyLocator is a decoder for keychain.KeyLocator.
func dKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*keychain.KeyLocator); ok {
var family uint32
err := tlv.DUint32(r, &family, buf, 4)
Expand All @@ -4370,22 +4403,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
}

// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed
// Type and the EKeyLocator and DKeyLocator functions. The size will always be
// 8 as KeyFamily is uint32 and the Index is uint32.
func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record {
return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator)
}

// MakeScidRecord creates a Record out of a ShortChannelID using the passed
// Type and the EShortChannelID and DShortChannelID functions. The size will
// always be 8 for the ShortChannelID.
func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record {
return tlv.MakeStaticRecord(
typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID,
)
}

// ShutdownInfo contains various info about the shutdown initiation of a
// channel.
type ShutdownInfo struct {
Expand Down
4 changes: 2 additions & 2 deletions channeldb/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1523,14 +1523,14 @@ func TestKeyLocatorEncoding(t *testing.T) {
buf [8]byte
)

err := EKeyLocator(&b, &keyLoc, &buf)
err := eKeyLocator(&b, &keyLoc, &buf)
require.NoError(t, err, "unable to encode key locator")

// Next, we'll attempt to decode the bytes into a new KeyLocator.
r := bytes.NewReader(b.Bytes())
var decodedKeyLoc keychain.KeyLocator

err = DKeyLocator(r, &decodedKeyLoc, &buf, 8)
err = dKeyLocator(r, &decodedKeyLoc, &buf, 8)
require.NoError(t, err, "unable to decode key locator")

// Finally, we'll compare that the original KeyLocator and the decoded
Expand Down

0 comments on commit 784d236

Please sign in to comment.