Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/3]: Preparatory work for Forwarding Blinded Routes #8159

Merged
merged 14 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 66 additions & 1 deletion channeldb/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ const (
// begins to be interpreted as an absolute block height, rather than a
// relative one.
AbsoluteThawHeightThreshold uint32 = 500000

// HTLCBlindingPointTLV is the tlv type used for storing blinding
// points with HTLCs.
HTLCBlindingPointTLV tlv.Type = 0
carlaKC marked this conversation as resolved.
Show resolved Hide resolved
)

var (
Expand Down Expand Up @@ -2316,7 +2320,56 @@ type HTLC struct {
// Note that this extra data is stored inline with the OnionBlob for
// legacy reasons, see serialization/deserialization functions for
// detail.
ExtraData []byte
ExtraData lnwire.ExtraOpaqueData

// BlindingPoint is an optional blinding point included with the HTLC.
//
// Note: this field is not a part of on-disk representation of the
// HTLC. It is stored in the ExtraData field, which is used to store
// a TLV stream of additional information associated with the HTLC.
carlaKC marked this conversation as resolved.
Show resolved Hide resolved
BlindingPoint lnwire.BlindingPointRecord
}

// serializeExtraData encodes a TLV stream of extra data to be stored with a
// HTLC. It uses the update_add_htlc TLV types, because this is where extra
// data is passed with a HTLC. At present blinding points are the only extra
// data that we will store, and the function is a no-op if a nil blinding
// point is provided.
//
// This function MUST be called to persist all HTLC values when they are
// serialized.
func (h *HTLC) serializeExtraData() error {
var records []tlv.RecordProducer
h.BlindingPoint.WhenSome(func(b tlv.RecordT[lnwire.BlindingPointTlvType,
*btcec.PublicKey]) {

records = append(records, &b)
})

return h.ExtraData.PackRecords(records...)
}

// deserializeExtraData extracts TLVs from the extra data persisted for the
// htlc and populates values in the struct accordingly.
//
// This function MUST be called to populate the struct properly when HTLCs
// are deserialized.
func (h *HTLC) deserializeExtraData() error {
if len(h.ExtraData) == 0 {
return nil
}

blindingPoint := h.BlindingPoint.Zero()
tlvMap, err := h.ExtraData.ExtractRecords(&blindingPoint)
if err != nil {
return err
}

if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil {
h.BlindingPoint = tlv.SomeRecordT(blindingPoint)
}

return nil
}

// SerializeHtlcs writes out the passed set of HTLC's into the passed writer
Expand All @@ -2340,6 +2393,12 @@ func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
}

for _, htlc := range htlcs {
// Populate TLV stream for any additional fields contained
// in the TLV.
if err := htlc.serializeExtraData(); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this information already be stored in ExtraData if you just take the opaque bytes from the update add HTLC message? Since we code, but then keep the blob as is.

return err
}

// The onion blob and hltc data are stored as a single var
// bytes blob.
onionAndExtraData := make(
Expand Down Expand Up @@ -2425,6 +2484,12 @@ func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
onionAndExtraData[lnwire.OnionPacketSize:],
)
}

// Finally, deserialize any TLVs contained in that extra data
// if they are present.
if err := htlcs[i].deserializeExtraData(); err != nil {
return nil, err
}
}

return htlcs, nil
Expand Down
53 changes: 30 additions & 23 deletions channeldb/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/lightningnetwork/lnd/lntest/channels"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1606,9 +1607,25 @@ func TestHTLCsExtraData(t *testing.T) {
OnionBlob: lnmock.MockOnion(),
}

// Add a blinding point to a htlc.
blindingPointHTLC := HTLC{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: lnmock.MockOnion(),
BlindingPoint: tlv.SomeRecordT(
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pubKey,
),
),
}

testCases := []struct {
name string
htlcs []HTLC
name string
htlcs []HTLC
blindingIdx int
}{
{
// Serialize multiple HLTCs with no extra data to
Expand All @@ -1620,30 +1637,12 @@ func TestHTLCsExtraData(t *testing.T) {
},
},
{
// Some HTLCs with extra data, some without.
name: "mixed extra data",
htlcs: []HTLC{
mockHtlc,
{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: lnmock.MockOnion(),
ExtraData: []byte{1, 2, 3},
},
blindingPointHTLC,
mockHtlc,
{
Signature: testSig.Serialize(),
Incoming: false,
Amt: 10,
RHash: key,
RefundTimeout: 1,
OnionBlob: lnmock.MockOnion(),
ExtraData: bytes.Repeat(
[]byte{9}, 999,
),
},
},
},
}
Expand All @@ -1661,7 +1660,15 @@ func TestHTLCsExtraData(t *testing.T) {
r := bytes.NewReader(b.Bytes())
htlcs, err := DeserializeHtlcs(r)
require.NoError(t, err)
require.Equal(t, testCase.htlcs, htlcs)

require.EqualValues(t, len(testCase.htlcs), len(htlcs))
for i, htlc := range htlcs {
// We use the extra data field when we
// serialize, so we set to nil to be able to
// assert on equal for the test.
htlc.ExtraData = nil
require.Equal(t, testCase.htlcs[i], htlc)
}
})
}
}
Expand Down
84 changes: 71 additions & 13 deletions lnwallet/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain"
"github.com/lightningnetwork/lnd/tlv"
)

var (
Expand Down Expand Up @@ -371,6 +372,12 @@ type PaymentDescriptor struct {
// isForwarded denotes if an incoming HTLC has been forwarded to any
// possible upstream peers in the route.
isForwarded bool

// BlindingPoint is an optional ephemeral key used in route blinding.
// This value is set for nodes that are relaying payments inside of a
// blinded route (ie, not the introduction node) from update_add_htlc's
// TLVs.
BlindingPoint *btcec.PublicKey
}

// PayDescsFromRemoteLogUpdates converts a slice of LogUpdates received from the
Expand Down Expand Up @@ -411,6 +418,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64,
Height: height,
Index: uint16(i),
},
BlindingPoint: wireMsg.BlingingPointOrNil(),
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
Expand Down Expand Up @@ -736,6 +744,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment {
Incoming: false,
}
copy(h.OnionBlob[:], htlc.OnionBlob)
if htlc.BlindingPoint != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re above (and also related to future changes like storing the endorsement bit in an opaque manner), if we copied the extra records from the pay desc into this HTLC, then it's a more generic way to handle storing any future TLV data.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we copied the extra records from the pay desc into this HTLC

def, but iirc from the original PR we wanted to be more intentional about what we store (ie, only things we care about) rather than including the full ExtraBytes and wasting space if people send us random junk along with the HTLC.

h.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
htlc.BlindingPoint,
),
)
}

if ourCommit && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
Expand All @@ -760,7 +776,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment {
Incoming: true,
}
copy(h.OnionBlob[:], htlc.OnionBlob)

if htlc.BlindingPoint != nil {
h.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
htlc.BlindingPoint,
),
)
}
if ourCommit && htlc.sig != nil {
h.Signature = htlc.sig.Serialize()
}
Expand Down Expand Up @@ -859,6 +882,12 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight,
theirWitnessScript: theirWitnessScript,
}

htlc.BlindingPoint.WhenSome(func(b tlv.RecordT[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WhenSomeV is useful here again as it gives you just *btcec.PublicKey in this case.

lnwire.BlindingPointTlvType, *btcec.PublicKey]) {

pd.BlindingPoint = b.Val
})

return pd, nil
}

Expand Down Expand Up @@ -1548,6 +1577,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
addCommitHeightRemote: commitHeight,
BlindingPoint: wireMsg.BlingingPointOrNil(),
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
Expand Down Expand Up @@ -1745,6 +1775,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd
HtlcIndex: wireMsg.ID,
LogIndex: logUpdate.LogIndex,
addCommitHeightLocal: commitHeight,
BlindingPoint: wireMsg.BlingingPointOrNil(),
}
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob, wireMsg.OnionBlob[:])
Expand Down Expand Up @@ -3607,6 +3638,14 @@ func (lc *LightningChannel) createCommitDiff(
PaymentHash: pd.RHash,
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
if pd.BlindingPoint != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar point here re just copying over the raw bytes so you don't need to be concerned about the record mapping at this state.

htlc.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pd.BlindingPoint,
),
)
}
logUpdate.UpdateMsg = htlc

// Gather any references for circuits opened by this Add
Expand Down Expand Up @@ -3736,12 +3775,21 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate {
// four messages that it corresponds to.
switch pd.EntryType {
case Add:
var b lnwire.BlindingPointRecord
if pd.BlindingPoint != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the pay desc just stores the record directly, then we also don't need to handle this shuffling.

tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](pd.BlindingPoint),
)
}

htlc := &lnwire.UpdateAddHTLC{
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
ChanID: chanID,
ID: pd.HtlcIndex,
Amount: pd.Amount,
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
BlindingPoint: b,
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc
Expand Down Expand Up @@ -5742,6 +5790,14 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
Expiry: pd.Timeout,
PaymentHash: pd.RHash,
}
if pd.BlindingPoint != nil {
htlc.BlindingPoint = tlv.SomeRecordT(
//nolint:lll
tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
pd.BlindingPoint,
),
)
}
copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc
addUpdates = append(addUpdates, logUpdate)
Expand Down Expand Up @@ -6079,6 +6135,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC,
HtlcIndex: lc.localUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
OpenCircuitKey: openKey,
BlindingPoint: htlc.BlingingPointOrNil(),
}
}

Expand Down Expand Up @@ -6129,13 +6186,14 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err
}

pd := &PaymentDescriptor{
EntryType: Add,
RHash: PaymentHash(htlc.PaymentHash),
Timeout: htlc.Expiry,
Amount: htlc.Amount,
LogIndex: lc.remoteUpdateLog.logIndex,
HtlcIndex: lc.remoteUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
EntryType: Add,
RHash: PaymentHash(htlc.PaymentHash),
Timeout: htlc.Expiry,
Amount: htlc.Amount,
LogIndex: lc.remoteUpdateLog.logIndex,
HtlcIndex: lc.remoteUpdateLog.htlcCounter,
OnionBlob: htlc.OnionBlob[:],
BlindingPoint: htlc.BlingingPointOrNil(),
}

localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex
Expand Down