/
record_type.go
161 lines (135 loc) · 4.72 KB
/
record_type.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package tlv
import (
"testing"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/fn"
"golang.org/x/exp/constraints"
)
// RecordT is a high-order type makes it easy to encode known "primitive" types
// as TLV records.
type RecordT[T TlvType, V any] struct {
// recordType is the type of the TLV record.
recordType T
// val is the value of the underlying record. Go doesn't let us just
// embed the type param as a struct field, so we need an intermediate
// variable.
Val V
}
// RecordProducerT is a type-aware wrapper around the normal RecordProducer
// interface.
type RecordProducerT[T any] interface {
RecordProducer
// This is a non-interface type constraint that allows us to pass a
// concrete type as a type parameter rather than a pointer to the type
// that satisfies the Record interface.
*T
}
// NewRecordT creates a new RecordT type from a given RecordProducer type. This
// is useful to wrap a given record in this utility type, which also serves as
// an extra type annotation. The underlying type of the record is retained.
func NewRecordT[T TlvType, K any, V RecordProducerT[K]](
record K,
) RecordT[T, K] {
return RecordT[T, K]{
Val: record,
}
}
// Primitive is a type constraint that capture the set of "primitive" types,
// which are the built in stdlib types, and type defs of those types
type Primitive interface {
constraints.Unsigned | ~[]byte | ~[32]byte | ~[33]byte | ~bool |
~*btcec.PublicKey | ~[64]byte
}
// NewPrimitiveRecord creates a new RecordT type from a given primitive type.
func NewPrimitiveRecord[T TlvType, V Primitive](val V) RecordT[T, V] {
return RecordT[T, V]{
Val: val,
}
}
// Record returns the underlying record interface for the record type.
func (t *RecordT[T, V]) Record() Record {
// Go doesn't allow type assertions on a type param, so to work around
// this, we'll convert to any, then do our type assertion.
tlvRecord, ok := any(&t.Val).(RecordProducer)
if !ok {
return MakePrimitiveRecord(
t.recordType.TypeVal(), &t.Val,
)
}
// To enforce proper usage of the RecordT type, we'll make a wrapper
// record that uses the proper internal type value.
ogRecord := tlvRecord.Record()
return Record{
value: ogRecord.value,
typ: t.recordType.TypeVal(),
staticSize: ogRecord.staticSize,
sizeFunc: ogRecord.sizeFunc,
encoder: ogRecord.encoder,
decoder: ogRecord.decoder,
}
}
// TlvType returns the type of the record. This is the value used to identify
// this type on the wire. This value is bound to the specified TlvType type
// param.
func (t *RecordT[T, V]) TlvType() Type {
return t.recordType.TypeVal()
}
// Zero returns a zero value of the record type.
func (t *RecordT[T, V]) Zero() RecordT[T, V] {
return ZeroRecordT[T, V]()
}
// OptionalRecordT is a high-order type that represents an optional TLV record.
// This can be used when a TLV record doesn't always need to be present (ok to
// be odd).
type OptionalRecordT[T TlvType, V any] struct {
fn.Option[RecordT[T, V]]
}
// TlvType returns the type of the record. This is the value used to identify
// this type on the wire. This value is bound to the specified TlvType type
// param.
func (t *OptionalRecordT[T, V]) TlvType() Type {
zeroRecord := ZeroRecordT[T, V]()
return zeroRecord.TlvType()
}
// WhenSomeV executes the given function if the optional record is present.
// This operates on the inner most type, V, which is the value of the record.
func (t *OptionalRecordT[T, V]) WhenSomeV(f func(V)) {
t.Option.WhenSome(func(r RecordT[T, V]) {
f(r.Val)
})
}
// UnwrapOrFailV is used to extract a value from an option within a test
// context. If the option is None, then the test fails. This gives the
// underlying value of the record, rather then the record itself.
func (o *OptionalRecordT[T, V]) UnwrapOrFailV(t *testing.T) V {
inner := o.Option.UnwrapOrFail(t)
return inner.Val
}
// UnwrapOrErr is used to extract a value from an option, if the option is
// empty, then the specified error is returned directly. This gives the
// underlying value of the record, instead of the record itself.
func (o *OptionalRecordT[T, V]) UnwrapOrErrV(err error) (V, error) {
var zero V
inner, err := o.Option.UnwrapOrErr(err)
if err != nil {
return zero, err
}
return inner.Val, nil
}
// Zero returns a zero value of the record type.
func (t *OptionalRecordT[T, V]) Zero() RecordT[T, V] {
return ZeroRecordT[T, V]()
}
// SomeRecordT creates a new OptionalRecordT type from a given RecordT type.
func SomeRecordT[T TlvType, V any](record RecordT[T, V]) OptionalRecordT[T, V] {
return OptionalRecordT[T, V]{
Option: fn.Some(record),
}
}
// ZeroRecordT returns a zero value of the RecordT type.
func ZeroRecordT[T TlvType, V any]() RecordT[T, V] {
var v V
return RecordT[T, V]{
Val: v,
}
}