Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Use proper error in packing and unpacking

All the relevant functions now return an error instead of
a simple boolean. This greatly approves the feedback to coders.

Spotted some fishy error handling along the way and fix that too.
  • Loading branch information...
commit 570bf8dc69acf2fe3f3b9bf81f11108428632eb0 1 parent 099c19d
@miekg authored
Showing with 211 additions and 252 deletions.
  1. +10 −11 client.go
  2. +14 −14 dnssec.go
  3. +153 −199 msg.go
  4. +4 −4 nsecx.go
  5. +6 −9 server.go
  6. +24 −15 tsig.go
View
21 client.go
@@ -101,9 +101,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, err error) {
func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
var n int
var w *reply
- out, ok := m.Pack()
- if !ok {
- return nil, 0, ErrPack
+ out, err := m.Pack()
+ if err != nil {
+ return nil, 0, err
}
var in []byte
switch c.Net {
@@ -123,8 +123,8 @@ func (c *Client) ExchangeRtt(m *Msg, a string) (r *Msg, rtt time.Duration, err e
}
r = new(Msg)
r.Size = n
- if ok := r.Unpack(in[:n]); !ok {
- return nil, w.rtt, ErrUnpack
+ if err := r.Unpack(in[:n]); err != nil {
+ return nil, w.rtt, err
}
return r, w.rtt, nil
}
@@ -158,8 +158,8 @@ func (w *reply) receive() (*Msg, error) {
return nil, err
}
p = p[:n]
- if ok := m.Unpack(p); !ok {
- return nil, ErrUnpack
+ if err := m.Unpack(p); err != nil {
+ return nil, err
}
w.rtt = time.Since(w.t)
m.Size = n
@@ -260,10 +260,9 @@ func (w *reply) send(m *Msg) (err error) {
}
w.tsigRequestMAC = mac
} else {
- ok := false
- out, ok = m.Pack()
- if !ok {
- return ErrPack
+ out, err = m.Pack()
+ if err != nil {
+ return err
}
}
w.t = time.Now()
View
28 dnssec.go
@@ -119,8 +119,8 @@ func (k *RR_DNSKEY) KeyTag() uint16 {
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
- n, ok := PackStruct(keywire, wire, 0)
- if !ok {
+ n, err := PackStruct(keywire, wire, 0)
+ if err != nil {
return 0
}
wire = wire[:n]
@@ -157,15 +157,15 @@ func (k *RR_DNSKEY) ToDS(h int) *RR_DS {
keywire.Algorithm = k.Algorithm
keywire.PublicKey = k.PublicKey
wire := make([]byte, DefaultMsgSize)
- n, ok := PackStruct(keywire, wire, 0)
- if !ok {
+ n, err := PackStruct(keywire, wire, 0)
+ if err != nil {
return nil
}
wire = wire[:n]
owner := make([]byte, 255)
- off, ok1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false)
- if !ok1 {
+ off, err1 := PackDomainName(k.Hdr.Name, owner, 0, nil, false)
+ if err1 != nil {
return nil
}
owner = owner[:off]
@@ -237,9 +237,9 @@ func (rr *RR_RRSIG) Sign(k PrivateKey, rrset []RR) error {
// Create the desired binary blob
signdata := make([]byte, DefaultMsgSize)
- n, ok := PackStruct(sigwire, signdata, 0)
- if !ok {
- return ErrPack
+ n, err := PackStruct(sigwire, signdata, 0)
+ if err != nil {
+ return err
}
signdata = signdata[:n]
wire := rawSignatureData(rrset, rr)
@@ -349,9 +349,9 @@ func (rr *RR_RRSIG) Verify(k *RR_DNSKEY, rrset []RR) error {
sigwire.SignerName = strings.ToLower(rr.SignerName)
// Create the desired binary blob
signeddata := make([]byte, DefaultMsgSize)
- n, ok := PackStruct(sigwire, signeddata, 0)
- if !ok {
- return ErrPack
+ n, err := PackStruct(sigwire, signeddata, 0)
+ if err != nil {
+ return err
}
signeddata = signeddata[:n]
wire := rawSignatureData(rrset, rr)
@@ -684,8 +684,8 @@ func rawSignatureData(rrset []RR, s *RR_RRSIG) (buf []byte) {
}
// 6.2. Canonical RR Form. (5) - origTTL
wire := make([]byte, r.Len()*2)
- off, ok1 := PackRR(r1, wire, 0, nil, false)
- if !ok1 {
+ off, err1 := PackRR(r1, wire, 0, nil, false)
+ if err1 != nil {
return nil
}
wire = wire[:off]
View
352 msg.go
@@ -28,8 +28,10 @@ const maxCompressionOffset = 2 << 13 // We have 14 bits for the compression poin
var (
ErrUnpack error = &Error{Err: "unpacking failed"}
ErrPack error = &Error{Err: "packing failed"}
+ ErrFqdn error = &Error{Err: "domain must be fully qualified"}
ErrId error = &Error{Err: "id mismatch"}
- ErrBuf error = &Error{Err: "buffer size too large"}
+ ErrRdata error = &Error{Err: "bad rdata"}
+ ErrBuf error = &Error{Err: "buffer size too small"}
ErrShortRead error = &Error{Err: "short read"}
ErrConn error = &Error{Err: "conn holds both UDP and TCP connection"}
ErrConnEmpty error = &Error{Err: "conn has no connection"}
@@ -46,10 +48,7 @@ var (
ErrSigGen error = &Error{Err: "bad signature generation"}
ErrAuth error = &Error{Err: "bad authentication"}
ErrSoa error = &Error{Err: "no SOA"}
- ErrHandle error = &Error{Err: "handle is nil"}
- ErrChan error = &Error{Err: "channel is nil"}
- ErrName error = &Error{Err: "type not found for name"}
- ErrRRset error = &Error{Err: "invalid rrset"}
+ ErrRRset error = &Error{Err: "bad rrset"}
ErrDenialNsec3 error = &Error{Err: "no NSEC3 records"}
ErrDenialCe error = &Error{Err: "no matching closest encloser found"}
ErrDenialNc error = &Error{Err: "no covering NSEC3 found for next closer"}
@@ -200,13 +199,12 @@ var Rcode_str = map[int]string{
// If compression is wanted compress must be true and the compression
// map needs to hold a mapping between domain names and offsets
// pointing into msg[].
-func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
- // Add trailing dot to canonicalize name.
+func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
lenmsg := len(msg)
ls := len(s)
+ // If not fully qualified, error out
if ls == 0 || s[ls-1] != '.' {
- //println("dns: name not fully qualified")
- return lenmsg, false
+ return lenmsg, ErrFqdn
}
// Each dot ends a segment of the name.
@@ -234,30 +232,22 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
if bs[i] == '.' {
if i-begin >= 1<<6 { // top two bits of length must be clear
- return lenmsg, false
+ return lenmsg, ErrRdata
}
// off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified
if off+1 > len(msg) {
- return lenmsg, false
+ return lenmsg, ErrBuf
}
msg[off] = byte(i - begin)
offset := off
off++
- // TODO(mg): because of the new check above, this can go. But
- // just leave it as is for the moment.
- // if off > lenmsg {
- // return lenmsg, false
- // }
for j := begin; j < i; j++ {
if off+1 > len(msg) {
- return lenmsg, false
+ return lenmsg, ErrBuf
}
msg[off] = bs[j]
off++
- // if off > lenmsg {
- // return lenmsg, false
- // }
}
// Dont try to compress '.'
if compression != nil && string(bs[begin:]) != ".'" {
@@ -285,7 +275,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
}
// Root label is special
if string(bs) == "." {
- return off, true
+ return off, nil
}
// If we did compression and we find something at the pointer here
if pointer != -1 {
@@ -297,7 +287,7 @@ func PackDomainName(s string, msg []byte, off int, compression map[string]int, c
msg[off] = 0
End:
off++
- return off, true
+ return off, nil
}
// Unpack a domain name.
@@ -315,14 +305,14 @@ End:
// We let them jump anywhere and stop jumping after a while.
// UnpackDomainName unpacks a domain name into a string.
-func UnpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
+func UnpackDomainName(msg []byte, off int) (s string, off1 int, err error) {
s = ""
lenmsg := len(msg)
ptr := 0 // number of pointers followed
Loop:
for {
if off >= lenmsg {
- return "", lenmsg, false
+ return "", lenmsg, ErrBuf
}
c := int(msg[off])
off++
@@ -331,13 +321,13 @@ Loop:
if c == 0x00 {
// end of name
if s == "" {
- return ".", off, true
+ return ".", off, nil
}
break Loop
}
// literal string
if off+c > lenmsg {
- return "", lenmsg, false
+ return "", lenmsg, ErrBuf
}
for j := off; j < off+c; j++ {
if msg[j] == '.' {
@@ -356,7 +346,7 @@ Loop:
// also, don't follow too many pointers --
// maybe there's a loop.
if off >= lenmsg {
- return "", lenmsg, false
+ return "", lenmsg, ErrBuf
}
c1 := msg[off]
off++
@@ -364,41 +354,38 @@ Loop:
off1 = off
}
if ptr++; ptr > 10 {
- return "", lenmsg, false
+ return "", lenmsg, &Error{Err: "too many compression pointers"}
}
off = (c^0xC0)<<8 | int(c1)
default:
// 0x80 and 0x40 are reserved
- return "", lenmsg, false
+ return "", lenmsg, ErrRdata
}
}
if ptr == 0 {
off1 = off
}
- return s, off1, true
+ return s, off1, nil
}
// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string,
// slices and other (often anonymous) structs.
-func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
+func packStructValue(val reflect.Value, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
+ lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
- // f := val.Type().Field(i)
- lenmsg := len(msg)
switch fv := val.Field(i); fv.Kind() {
default:
- return lenmsg, false
+ return lenmsg, &Error{Err: "bad kind packing"}
case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") {
default:
- // println("dns: unknown tag packing slice", val.Type().Field(i).Tag.Get("dns"), '"', val.Type().Field(i).Tag, '"')
- return lenmsg, false
+ return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing slice"}
case "domain-name":
for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).String()
- off, ok = PackDomainName(element, msg, off, compression, false && compress)
- if !ok {
- // println("dns: overflow packing domain-name", off)
- return lenmsg, false
+ off, err = PackDomainName(element, msg, off, compression, false && compress)
+ if err != nil {
+ return lenmsg, err
}
}
case "txt":
@@ -406,8 +393,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).String()
// Counted string: 1 byte length.
if len(element) > 255 || off+1+len(element) > lenmsg {
- // println("dns: overflow packing TXT string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing txt"}
}
msg[off] = byte(len(element))
off++
@@ -421,8 +407,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).Interface()
b, e := element.(EDNS0).pack()
if e != nil {
- // println("dns: failure packing OPT")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing opt"}
}
// Option code
msg[off], msg[off+1] = packUint16(element.(EDNS0).Option())
@@ -436,22 +421,17 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "a":
// It must be a slice of 4, even if it is 16, we encode
// only the first 4
+ if off+net.IPv4len > lenmsg {
+ return lenmsg, &Error{Err: "overflow packing a"}
+ }
switch fv.Len() {
case net.IPv6len:
- if off+net.IPv4len > lenmsg {
- // println("dns: overflow packing A", off, lenmsg)
- return lenmsg, false
- }
msg[off] = byte(fv.Index(12).Uint())
msg[off+1] = byte(fv.Index(13).Uint())
msg[off+2] = byte(fv.Index(14).Uint())
msg[off+3] = byte(fv.Index(15).Uint())
off += net.IPv4len
case net.IPv4len:
- if off+net.IPv4len > lenmsg {
- // println("dns: overflow packing A", off, lenmsg)
- return lenmsg, false
- }
msg[off] = byte(fv.Index(0).Uint())
msg[off+1] = byte(fv.Index(1).Uint())
msg[off+2] = byte(fv.Index(2).Uint())
@@ -460,13 +440,11 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case 0:
// Allowed, for dynamic updates
default:
- // println("dns: overflow packing A")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing a"}
}
case "aaaa":
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
- // println("dns: overflow packing AAAA")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing aaaa"}
}
for j := 0; j < net.IPv6len; j++ {
msg[off] = byte(fv.Index(j).Uint())
@@ -481,8 +459,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
serv := uint16((fv.Index(j).Uint()))
bitmapbyte = uint16(serv / 8)
if int(bitmapbyte) > lenmsg {
- // println("dns: overflow packing WKS")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing wks"}
}
bit := uint16(serv) - bitmapbyte*8
msg[bitmapbyte] = byte(1 << (7 - bit))
@@ -498,8 +475,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
lastwindow := uint16(0)
length := uint16(0)
if off+2 > lenmsg {
- // println("dns: overflow packing NSECx bitmap")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing nsecx"}
}
for j := 0; j < val.Field(i).Len(); j++ {
t := uint16((fv.Index(j).Uint()))
@@ -508,15 +484,13 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// New window, jump to the new offset
off += int(length) + 3
if off > lenmsg {
- // println("dns: overflow packing NSECx bitmap")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
}
}
length = (t - window*256) / 8
bit := t - (window * 256) - (length * 8)
if off+2+int(length) > lenmsg {
- // println("dns: overflow packing NSECx bitmap")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
}
// Setting the window #
@@ -530,23 +504,20 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2 + int(length)
off++
if off > lenmsg {
- // println("dns: overflow packing NSECx bitmap")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing nsecx bitmap"}
}
}
case reflect.Struct:
- off, ok = packStructValue(fv, msg, off, compression, compress)
+ off, err = packStructValue(fv, msg, off, compression, compress)
case reflect.Uint8:
if off+1 > lenmsg {
- // println("dns: overflow packing uint8")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing uint8"}
}
msg[off] = byte(fv.Uint())
off++
case reflect.Uint16:
if off+2 > lenmsg {
- // println("dns: overflow packing uint16")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing uint16"}
}
i := fv.Uint()
msg[off] = byte(i >> 8)
@@ -554,8 +525,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2
case reflect.Uint32:
if off+4 > lenmsg {
- // println("dns: overflow packing uint32")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing uint32"}
}
i := fv.Uint()
msg[off] = byte(i >> 24)
@@ -566,8 +536,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case reflect.Uint64:
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > lenmsg {
- // println("dns: overflow packing uint64")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing uint64 as uint48"}
}
i := fv.Uint()
msg[off] = byte(i >> 40)
@@ -583,24 +552,21 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
s := fv.String()
switch val.Type().Field(i).Tag.Get("dns") {
default:
- return lenmsg, false
+ return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag packing string"}
case "base64":
b64, err := packBase64([]byte(s))
if err != nil {
- // println("dns: overflow packing base64")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing base64"}
}
copy(msg[off:off+len(b64)], b64)
off += len(b64)
case "domain-name":
- if off, ok = PackDomainName(s, msg, off, compression, false && compress); !ok {
- // println("dns: overflow packing domain-name", off)
- return lenmsg, false
+ if off, err = PackDomainName(s, msg, off, compression, false && compress); err != nil {
+ return lenmsg, err
}
case "cdomain-name":
- if off, ok = PackDomainName(s, msg, off, compression, true && compress); !ok {
- // println("dns: overflow packing domain-name", off)
- return lenmsg, false
+ if off, err = PackDomainName(s, msg, off, compression, true && compress); err != nil {
+ return lenmsg, err
}
case "size-base32":
// This is purely for NSEC3 atm, the previous byte must
@@ -611,8 +577,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "base32":
b32, err := packBase32([]byte(s))
if err != nil {
- // println("dns: overflow packing base32")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing base32"}
}
copy(msg[off:off+len(b32)], b32)
off += len(b32)
@@ -622,12 +587,10 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// There is no length encoded here
h, e := hex.DecodeString(s)
if e != nil {
- // println("dns: overflow packing (size-)hex string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing hex"}
}
if off+hex.DecodedLen(len(s)) > lenmsg {
- // Overflow
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing hex"}
}
copy(msg[off:off+hex.DecodedLen(len(s))], h)
off += hex.DecodedLen(len(s))
@@ -641,8 +604,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "":
// Counted string: 1 byte length.
if len(s) > 255 || off+1+len(s) > lenmsg {
- // println("dns: overflow packing string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow packing string"}
}
msg[off] = byte(len(s))
off++
@@ -653,48 +615,44 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
}
}
}
- return off, true
+ return off, nil
}
func structValue(any interface{}) reflect.Value {
return reflect.ValueOf(any).Elem()
}
-func PackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
- off, ok = packStructValue(structValue(any), msg, off, nil, false)
- return off, ok
+func PackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
+ off, err = packStructValue(structValue(any), msg, off, nil, false)
+ return off, err
}
-func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
- off, ok = packStructValue(structValue(any), msg, off, compression, compress)
- return off, ok
+func packStructCompress(any interface{}, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
+ off, err = packStructValue(structValue(any), msg, off, compression, compress)
+ return off, err
}
// Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue.
-func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
+func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err error) {
var rdstart int
+ lenmsg := len(msg)
for i := 0; i < val.NumField(); i++ {
- // f := val.Type().Field(i)
- lenmsg := len(msg)
switch fv := val.Field(i); fv.Kind() {
default:
- // println("dns: unknown case unpacking struct")
- return lenmsg, false
+ return lenmsg, &Error{Err: "bad kind unpacking"}
case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") {
default:
- // println("dns: unknown tag unpacking slice", val.Type().Field(i).Tag)
- return lenmsg, false
+ return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking slice"}
case "domain-name":
// HIP record slice of name (or none)
servers := make([]string, 0)
var s string
for off < lenmsg {
- s, off, ok = UnpackDomainName(msg, off)
- if !ok {
- // println("dns: failure unpacking domain-name")
- return lenmsg, false
+ s, off, err = UnpackDomainName(msg, off)
+ if err != nil {
+ return lenmsg, err
}
servers = append(servers, s)
}
@@ -705,8 +663,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
Txts:
l := int(msg[off])
if off+l+1 > lenmsg {
- // println("dns: failure unpacking txt strings")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking txt"}
}
txt = append(txt, string(msg[off+1:off+l+1]))
off += l + 1
@@ -724,14 +681,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
break
}
edns := make([]EDNS0, 0)
- // Goto to this place, when there is a goto
code := uint16(0)
-
- code, off = unpackUint16(msg, off) // Overflow? TODO
+ if off+2 > lenmsg {
+ return lenmsg, &Error{Err: "overflow unpacking opt"}
+ }
+ code, off = unpackUint16(msg, off)
optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > off+rdlength {
- // println("dns: overflow unpacking OPT")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking opt"}
}
switch code {
case EDNS0NSID:
@@ -746,18 +703,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
off = off1 + int(optlen)
}
fv.Set(reflect.ValueOf(edns))
- // goto ??
+ // multiple EDNS codes?
case "a":
if off+net.IPv4len > len(msg) {
- // println("dns: overflow unpacking A")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking a"}
}
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len
case "aaaa":
if off+net.IPv6len > lenmsg {
- // println("dns: overflow unpacking AAAA")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking aaaa"}
}
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
msg[off+5], msg[off+6], msg[off+7], msg[off+8], msg[off+9], msg[off+10],
@@ -806,8 +761,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
endrr := rdstart + rdlength
if off+2 > lenmsg {
- // println("dns: overflow unpacking NSEC")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking nsecx"}
}
nsec := make([]uint16, 0)
length := 0
@@ -820,15 +774,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// A length window of zero is strange. If there
// the window should not have been specified. Bail out
// println("dns: length == 0 when unpacking NSEC")
- return lenmsg, false
+ return lenmsg, ErrRdata
}
if length > 32 {
- // println("dns: length > 32 when unpacking NSEC")
- return lenmsg, false
+ return lenmsg, ErrRdata
}
- // Walk the bytes in the window - and check the bit
- // setting..
+ // Walk the bytes in the window - and check the bit settings...
off += 2
for j := 0; j < length; j++ {
b := msg[off+j]
@@ -863,29 +815,26 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
fv.Set(reflect.ValueOf(nsec))
}
case reflect.Struct:
- off, ok = unpackStructValue(fv, msg, off)
+ off, err = unpackStructValue(fv, msg, off)
if val.Type().Field(i).Name == "Hdr" {
rdstart = off
}
case reflect.Uint8:
if off+1 > lenmsg {
- // println("dns: overflow unpacking uint8")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking uint8"}
}
fv.SetUint(uint64(uint8(msg[off])))
off++
case reflect.Uint16:
var i uint16
if off+2 > lenmsg {
- // println("dns: overflow unpacking uint16")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking uint16"}
}
i, off = unpackUint16(msg, off)
fv.SetUint(uint64(i))
case reflect.Uint32:
if off+4 > lenmsg {
- // println("dns: overflow unpacking uint32")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking uint32"}
}
fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])))
off += 4
@@ -893,8 +842,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes)
if off+6 > lenmsg {
- // println("dns: overflow unpacking uint64")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking uint64 as uint48"}
}
fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5])))
@@ -903,15 +851,13 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
var s string
switch val.Type().Field(i).Tag.Get("dns") {
default:
- // println("dns: unknown tag unpacking string")
- return lenmsg, false
+ return lenmsg, &Error{Name: val.Type().Field(i).Tag.Get("dns"), Err: "bad tag unpacking string"}
case "hex":
// Rest of the RR is hex encoded, network order an issue here?
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength
if endrr > lenmsg {
- // println("dns: overflow when unpacking hex string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking hex"}
}
s = hex.EncodeToString(msg[off:endrr])
off = endrr
@@ -920,18 +866,16 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength
if endrr > lenmsg {
- // println("dns: failure unpacking base64")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking base64"}
}
s = unpackBase64(msg[off:endrr])
off = endrr
case "cdomain-name":
fallthrough
case "domain-name":
- s, off, ok = UnpackDomainName(msg, off)
- if !ok {
- // println("dns: failure unpacking domain-name")
- return lenmsg, false
+ s, off, err = UnpackDomainName(msg, off)
+ if err != nil {
+ return lenmsg, err
}
case "size-base32":
var size int
@@ -944,8 +888,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
}
if off+size > lenmsg {
- // println("dns: failure unpacking size-base32 string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking base32"}
}
s = unpackBase32(msg[off : off+size])
off += size
@@ -973,8 +916,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
}
if off+size > lenmsg {
- // println("dns: failure unpacking size-hex string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking hex"}
}
s = hex.EncodeToString(msg[off : off+size])
off += size
@@ -983,8 +925,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txt:
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
- // println("dns: failure unpacking txt string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking txt"}
}
n := int(msg[off])
off++
@@ -998,8 +939,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
case "":
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
- // println("dns: failure unpacking string")
- return lenmsg, false
+ return lenmsg, &Error{Err: "overflow unpacking string"}
}
n := int(msg[off])
off++
@@ -1011,7 +951,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
fv.SetString(s)
}
}
- return off, true
+ return off, nil
}
// Helper function for unpacking
@@ -1021,9 +961,9 @@ func unpackUint16(msg []byte, off int) (v uint16, off1 int) {
return
}
-func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
- off, ok = unpackStructValue(structValue(any), msg, off)
- return off, ok
+func UnpackStruct(any interface{}, msg []byte, off int) (off1 int, err error) {
+ off, err = unpackStructValue(structValue(any), msg, off)
+ return off, err
}
func unpackBase32(b []byte) string {
@@ -1067,28 +1007,26 @@ func packBase32(s []byte) ([]byte, error) {
}
// Resource record packer.
-func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, ok bool) {
+func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
if rr == nil {
- return len(msg), false
+ return len(msg), &Error{Err: "nil rr"}
}
- off1, ok = packStructCompress(rr, msg, off, compression, compress)
- if !ok {
- return len(msg), false
- }
- if !rawSetRdlength(msg, off, off1) {
- return len(msg), false
+ off1, err = packStructCompress(rr, msg, off, compression, compress)
+ if err != nil {
+ return len(msg), err
}
- return off1, true
+ rawSetRdlength(msg, off, off1)
+ return off1, nil
}
// Resource record unpacker.
-func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
+func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
// unpack just the header, to find the rr type and length
var h RR_Header
off0 := off
- if off, ok = UnpackStruct(&h, msg, off); !ok {
- return nil, len(msg), false
+ if off, err = UnpackStruct(&h, msg, off); err != nil {
+ return nil, len(msg), err
}
end := off + int(h.Rdlength)
// make an rr of that type and re-unpack.
@@ -1098,11 +1036,11 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) {
} else {
rr = mk()
}
- off, ok = UnpackStruct(rr, msg, off0)
+ off, err = UnpackStruct(rr, msg, off0)
if off != end {
- return &h, end, true
+ return &h, end, nil
}
- return rr, off, ok
+ return rr, off, err
}
// Reverse a map
@@ -1176,9 +1114,9 @@ func (h *MsgHdr) String() string {
// Pack packs a Msg: it is converted to to wire format.
// If the dns.Compress is true the message will be in compressed wire format.
-func (dns *Msg) Pack() (msg []byte, ok bool) {
+func (dns *Msg) Pack() (msg []byte, err error) {
if dns == nil {
- return nil, false
+ return nil, &Error{Err: "nil message"}
}
var dh Header
compression := make(map[string]int) // Compression pointer mappings
@@ -1227,34 +1165,41 @@ func (dns *Msg) Pack() (msg []byte, ok bool) {
// Pack it in: header and then the pieces.
off := 0
- off, ok = packStructCompress(&dh, msg, off, compression, dns.Compress)
+ off, err = packStructCompress(&dh, msg, off, compression, dns.Compress)
for i := 0; i < len(question); i++ {
- off, ok = packStructCompress(&question[i], msg, off, compression, dns.Compress)
+ off, err = packStructCompress(&question[i], msg, off, compression, dns.Compress)
+ if err != nil {
+ return nil, err
+ }
}
for i := 0; i < len(answer); i++ {
- off, ok = PackRR(answer[i], msg, off, compression, dns.Compress)
+ off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
+ if err != nil {
+ return nil, err
+ }
}
for i := 0; i < len(ns); i++ {
- off, ok = PackRR(ns[i], msg, off, compression, dns.Compress)
+ off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
+ if err != nil {
+ return nil, err
+ }
}
for i := 0; i < len(extra); i++ {
- off, ok = PackRR(extra[i], msg, off, compression, dns.Compress)
- }
- if !ok {
- return nil, false
+ off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
+ if err != nil {
+ return nil, err
+ }
}
- //println("allocated", dns.Len()+1, "used", off)
- return msg[:off], true
+ return msg[:off], nil
}
// Unpack unpacks a binary message to a Msg structure.
-func (dns *Msg) Unpack(msg []byte) bool {
+func (dns *Msg) Unpack(msg []byte) (err error) {
// Header.
var dh Header
off := 0
- var ok bool
- if off, ok = UnpackStruct(&dh, msg, off); !ok {
- return false
+ if off, err = UnpackStruct(&dh, msg, off); err != nil {
+ return err
}
dns.Id = dh.Id
dns.Response = (dh.Bits & _QR) != 0
@@ -1275,25 +1220,34 @@ func (dns *Msg) Unpack(msg []byte) bool {
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ {
- off, ok = UnpackStruct(&dns.Question[i], msg, off)
+ off, err = UnpackStruct(&dns.Question[i], msg, off)
+ if err != nil {
+ return err
+ }
}
for i := 0; i < len(dns.Answer); i++ {
- dns.Answer[i], off, ok = UnpackRR(msg, off)
+ dns.Answer[i], off, err = UnpackRR(msg, off)
+ if err != nil {
+ return err
+ }
}
for i := 0; i < len(dns.Ns); i++ {
- dns.Ns[i], off, ok = UnpackRR(msg, off)
+ dns.Ns[i], off, err = UnpackRR(msg, off)
+ if err != nil {
+ return err
+ }
}
for i := 0; i < len(dns.Extra); i++ {
- dns.Extra[i], off, ok = UnpackRR(msg, off)
- }
- if !ok {
- return false
+ dns.Extra[i], off, err = UnpackRR(msg, off)
+ if err != nil {
+ return err
+ }
}
if off != len(msg) {
// TODO(mg) remove eventually
// println("extra bytes in dns packet", off, "<", len(msg))
}
- return true
+ return nil
}
// Convert a complete message to a string with dig-like output.
View
8 nsecx.go
@@ -38,14 +38,14 @@ func HashName(label string, ha uint8, iter uint16, salt string) string {
saltwire := new(saltWireFmt)
saltwire.Salt = salt
wire := make([]byte, DefaultMsgSize)
- n, ok := PackStruct(saltwire, wire, 0)
- if !ok {
+ n, err := PackStruct(saltwire, wire, 0)
+ if err != nil {
return ""
}
wire = wire[:n]
name := make([]byte, 255)
- off, ok1 := PackDomainName(strings.ToLower(label), name, 0, nil, false)
- if !ok1 {
+ off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false)
+ if err != nil {
return ""
}
name = name[:off]
View
15 server.go
@@ -396,7 +396,7 @@ func (c *conn) serve() {
w := new(response)
w.conn = c
req := new(Msg)
- if !req.Unpack(c.request) {
+ if req.Unpack(c.request) != nil {
// Send a format error back
x := new(Msg)
x.SetRcodeFormatError(req)
@@ -436,10 +436,7 @@ func (c *conn) serve() {
// Write implements the ResponseWriter.Write method.
func (w *response) Write(m *Msg) (err error) {
- var (
- data []byte
- ok bool
- )
+ var data []byte
if m == nil {
return &Error{Err: "nil message"}
}
@@ -449,9 +446,9 @@ func (w *response) Write(m *Msg) (err error) {
return err
}
} else {
- data, ok = m.Pack()
- if !ok {
- return ErrPack
+ data, err = m.Pack()
+ if err != nil {
+ return err
}
}
return w.WriteBuf(data)
@@ -470,7 +467,7 @@ func (w *response) WriteBuf(m []byte) (err error) {
}
case w.conn._TCP != nil:
if len(m) > MaxMsgSize {
- return ErrBuf
+ return &Error{Err: "message too large"}
}
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(m)))
View
39 tsig.go
@@ -165,9 +165,9 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
rr := m.Extra[len(m.Extra)-1].(*RR_TSIG)
m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
- mbuf, ok := m.Pack()
- if !ok {
- return nil, "", ErrPack
+ mbuf, err := m.Pack()
+ if err != nil {
+ return nil, "", err
}
buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly)
@@ -194,10 +194,10 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
t.OrigId = m.Id
tbuf := make([]byte, t.Len())
- if off, ok := PackRR(t, tbuf, 0, nil, false); ok {
+ if off, err := PackRR(t, tbuf, 0, nil, false); err != nil {
tbuf = tbuf[:off] // reset to actual size used
} else {
- return nil, "", ErrPack
+ return nil, "", err
}
mbuf = append(mbuf, tbuf...)
rawSetExtraLen(mbuf, uint16(len(m.Extra)+1))
@@ -298,13 +298,13 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
// Copied from msg.go's Unpack()
// Header.
var dh Header
+ var err error
dns := new(Msg)
rr := new(RR_TSIG)
off := 0
tsigoff := 0
- var ok bool
- if off, ok = UnpackStruct(&dh, msg, off); !ok {
- return nil, nil, ErrUnpack
+ if off, err = UnpackStruct(&dh, msg, off); err !=nil {
+ return nil, nil, err
}
if dh.Arcount == 0 {
return nil, nil, ErrNoSig
@@ -321,17 +321,29 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ {
- off, ok = UnpackStruct(&dns.Question[i], msg, off)
+ off, err = UnpackStruct(&dns.Question[i], msg, off)
+ if err != nil {
+ return nil, nil, err
+ }
}
for i := 0; i < len(dns.Answer); i++ {
- dns.Answer[i], off, ok = UnpackRR(msg, off)
+ dns.Answer[i], off, err = UnpackRR(msg, off)
+ if err != nil {
+ return nil, nil, err
+ }
}
for i := 0; i < len(dns.Ns); i++ {
- dns.Ns[i], off, ok = UnpackRR(msg, off)
+ dns.Ns[i], off, err = UnpackRR(msg, off)
+ if err != nil {
+ return nil, nil, err
+ }
}
for i := 0; i < len(dns.Extra); i++ {
tsigoff = off
- dns.Extra[i], off, ok = UnpackRR(msg, off)
+ dns.Extra[i], off, err = UnpackRR(msg, off)
+ if err != nil {
+ return nil, nil, err
+ }
if dns.Extra[i].Header().Rrtype == TypeTSIG {
rr = dns.Extra[i].(*RR_TSIG)
// Adjust Arcount.
@@ -340,9 +352,6 @@ func stripTsig(msg []byte) ([]byte, *RR_TSIG, error) {
break
}
}
- if !ok {
- return nil, nil, ErrUnpack
- }
if rr == nil {
return nil, nil, ErrNoSig
}
Please sign in to comment.
Something went wrong with that request. Please try again.