Skip to content

Commit

Permalink
tlv: add new RecordT[T] utility type
Browse files Browse the repository at this point in the history
In this commit, we add a new type, `Record[T]` to reduce some of the
common boiler plate for TLV types. This time lets you take either a
primitive type, or an existing Record, and gain common methods used to
create tlv streams.

It also serves as extra type annotation as well, since wire structs can
use this to wrap any existing type and gain the relevant record methods.
  • Loading branch information
Roasbeef committed Oct 28, 2023
1 parent 3b7cda9 commit d7adb79
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tlv/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/davecgh/go-spew v1.1.1
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1
github.com/stretchr/testify v1.8.2
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
)

require (
Expand All @@ -14,7 +15,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.9.0 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/sys v0.13.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Expand Down
6 changes: 4 additions & 2 deletions tlv/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
Expand Down
59 changes: 59 additions & 0 deletions tlv/record_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package tlv

import (
"github.com/btcsuite/btcd/btcec/v2"
"golang.org/x/exp/constraints"
)

// Primitive is a type constraint that capture the set of "primitive" types,
// which are the built in stdlib types, and type defs of those types
type Primitive interface {
constraints.Integer | ~[]byte | ~[32]byte | ~[33]byte | ~bool |
~*btcec.PublicKey | ~[64]byte
}

// RecordT is a high-order type makes it easy to encode known "primitive" types
// as TLV records.
type RecordT[T Record | Primitive] struct {
// Val is the value of the underlying record. Go doesn't let us just
// embed the type param as a struct field, so we need an intermediate
// variable.
Val T

// Type is the type of the TLV record.
Type Type
}

// NewPrimitiveRecord creates a new RecordT type from a given primitive type.
func NewPrimitiveRecord[T Primitive](val T, tlvType Type) RecordT[T] {
return RecordT[T]{
Val: val,
Type: tlvType,
}
}

// NewRecordT creates a new RecordT type from a given Record type. This is
// useful to wrap a given record in this utility type, which also serves as an
// extra type annotation. The underlying type of the record is retained.
func NewRecordT[T Record](record T) RecordT[T] {
// Go doesn't yet allow interfaces for union type constraints. So we'll
// cast to any, then get a concrete record so we can call the method on
// it.
tlvRecord := any(record).(Record)

return RecordT[T]{
Val: record,
Type: tlvRecord.Type(),
}
}

// Record returns the underlying record interface for the record type.
func (t *RecordT[T]) Record() Record {
// Go doesn't allow type assertions on a type param, so to work around
// this, we'll convert to any, then do our type assertion.
if tlvRecord, ok := any(t.Val).(Record); ok {
return tlvRecord
}

return MakePrimitiveRecord(t.Type, &t.Val)
}
83 changes: 83 additions & 0 deletions tlv/record_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package tlv

import (
"bytes"
"testing"

"github.com/stretchr/testify/require"
)

const (
fakeCsvDelayType = 1
fakeIsCoolType = 2
)

type fakeWireMsg struct {
CsvDelay RecordT[uint16]

IsCool RecordT[bool]
}

// TestRecordTFromPrimitive tests the RecordT type. We should be able to create
// types of both record types, and also primitive types, and encode/decode them
// as normal.
func TestRecordTFromPrimitive(t *testing.T) {
t.Parallel()

wireMsg := fakeWireMsg{
CsvDelay: NewPrimitiveRecord(uint16(5), fakeCsvDelayType),
IsCool: NewPrimitiveRecord(true, fakeIsCoolType),
}

encodeStream, err := NewStream(
wireMsg.CsvDelay.Record(), wireMsg.IsCool.Record(),
)
require.NoError(t, err)

var b bytes.Buffer
err = encodeStream.Encode(&b)
require.NoError(t, err)

newWireMsg := fakeWireMsg{
CsvDelay: NewPrimitiveRecord(uint16(5), fakeCsvDelayType),
IsCool: NewPrimitiveRecord(true, fakeIsCoolType),
}

decodeStream, err := NewStream(
newWireMsg.CsvDelay.Record(),
newWireMsg.IsCool.Record(),
)
require.NoError(t, err)

err = decodeStream.Decode(&b)
require.NoError(t, err)

require.Equal(t, wireMsg, newWireMsg)
}

// TestRecordTFromRecord tests that we can create a RecordT type from an
// existing record type and encode/decode as normal.
func TestRecordTFromRecord(t *testing.T) {
t.Parallel()

val := uint16(5)
record := MakeStaticRecord(fakeCsvDelayType, &val, 2, EUint16, DUint16)

encodeStream, err := NewStream(record)
require.NoError(t, err)

var b bytes.Buffer
err = encodeStream.Encode(&b)
require.NoError(t, err)

var val2 uint16
record2 := MakeStaticRecord(fakeCsvDelayType, &val2, 2, EUint16, DUint16)

decodeStream, err := NewStream(record2)
require.NoError(t, err)

err = decodeStream.Decode(&b)
require.NoError(t, err)

require.Equal(t, val, val2)
}

0 comments on commit d7adb79

Please sign in to comment.