-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
tlv: add new RecordT[T] utility type
In this commit, we add a new type, `Record[T]` to reduce some of the common boiler plate for TLV types. This time lets you take either a primitive type, or an existing Record, and gain common methods used to create tlv streams. It also serves as extra type annotation as well, since wire structs can use this to wrap any existing type and gain the relevant record methods.
- Loading branch information
Showing
4 changed files
with
148 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package tlv | ||
|
||
import ( | ||
"github.com/btcsuite/btcd/btcec/v2" | ||
"golang.org/x/exp/constraints" | ||
) | ||
|
||
// Primitive is a type constraint that capture the set of "primitive" types, | ||
// which are the built in stdlib types, and type defs of those types | ||
type Primitive interface { | ||
constraints.Integer | ~[]byte | ~[32]byte | ~[33]byte | ~bool | | ||
~*btcec.PublicKey | ~[64]byte | ||
} | ||
|
||
// RecordT is a high-order type makes it easy to encode known "primitive" types | ||
// as TLV records. | ||
type RecordT[T Record | Primitive] struct { | ||
// Val is the value of the underlying record. Go doesn't let us just | ||
// embed the type param as a struct field, so we need an intermediate | ||
// variable. | ||
Val T | ||
|
||
// Type is the type of the TLV record. | ||
Type Type | ||
} | ||
|
||
// NewPrimitiveRecord creates a new RecordT type from a given primitive type. | ||
func NewPrimitiveRecord[T Primitive](val T, tlvType Type) RecordT[T] { | ||
return RecordT[T]{ | ||
Val: val, | ||
Type: tlvType, | ||
} | ||
} | ||
|
||
// NewRecordT creates a new RecordT type from a given Record type. This is | ||
// useful to wrap a given record in this utility type, which also serves as an | ||
// extra type annotation. The underlying type of the record is retained. | ||
func NewRecordT[T Record](record T) RecordT[T] { | ||
// Go doesn't yet allow interfaces for union type constraints. So we'll | ||
// cast to any, then get a concrete record so we can call the method on | ||
// it. | ||
tlvRecord := any(record).(Record) | ||
|
||
return RecordT[T]{ | ||
Val: record, | ||
Type: tlvRecord.Type(), | ||
} | ||
} | ||
|
||
// Record returns the underlying record interface for the record type. | ||
func (t *RecordT[T]) Record() Record { | ||
// Go doesn't allow type assertions on a type param, so to work around | ||
// this, we'll convert to any, then do our type assertion. | ||
if tlvRecord, ok := any(t.Val).(Record); ok { | ||
return tlvRecord | ||
} | ||
|
||
return MakePrimitiveRecord(t.Type, &t.Val) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package tlv | ||
|
||
import ( | ||
"bytes" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const ( | ||
fakeCsvDelayType = 1 | ||
fakeIsCoolType = 2 | ||
) | ||
|
||
type fakeWireMsg struct { | ||
CsvDelay RecordT[uint16] | ||
|
||
IsCool RecordT[bool] | ||
} | ||
|
||
// TestRecordTFromPrimitive tests the RecordT type. We should be able to create | ||
// types of both record types, and also primitive types, and encode/decode them | ||
// as normal. | ||
func TestRecordTFromPrimitive(t *testing.T) { | ||
t.Parallel() | ||
|
||
wireMsg := fakeWireMsg{ | ||
CsvDelay: NewPrimitiveRecord(uint16(5), fakeCsvDelayType), | ||
IsCool: NewPrimitiveRecord(true, fakeIsCoolType), | ||
} | ||
|
||
encodeStream, err := NewStream( | ||
wireMsg.CsvDelay.Record(), wireMsg.IsCool.Record(), | ||
) | ||
require.NoError(t, err) | ||
|
||
var b bytes.Buffer | ||
err = encodeStream.Encode(&b) | ||
require.NoError(t, err) | ||
|
||
newWireMsg := fakeWireMsg{ | ||
CsvDelay: NewPrimitiveRecord(uint16(5), fakeCsvDelayType), | ||
IsCool: NewPrimitiveRecord(true, fakeIsCoolType), | ||
} | ||
|
||
decodeStream, err := NewStream( | ||
newWireMsg.CsvDelay.Record(), | ||
newWireMsg.IsCool.Record(), | ||
) | ||
require.NoError(t, err) | ||
|
||
err = decodeStream.Decode(&b) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, wireMsg, newWireMsg) | ||
} | ||
|
||
// TestRecordTFromRecord tests that we can create a RecordT type from an | ||
// existing record type and encode/decode as normal. | ||
func TestRecordTFromRecord(t *testing.T) { | ||
t.Parallel() | ||
|
||
val := uint16(5) | ||
record := MakeStaticRecord(fakeCsvDelayType, &val, 2, EUint16, DUint16) | ||
|
||
encodeStream, err := NewStream(record) | ||
require.NoError(t, err) | ||
|
||
var b bytes.Buffer | ||
err = encodeStream.Encode(&b) | ||
require.NoError(t, err) | ||
|
||
var val2 uint16 | ||
record2 := MakeStaticRecord(fakeCsvDelayType, &val2, 2, EUint16, DUint16) | ||
|
||
decodeStream, err := NewStream(record2) | ||
require.NoError(t, err) | ||
|
||
err = decodeStream.Decode(&b) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, val, val2) | ||
} |