Skip to content

Commit

Permalink
Reduce allocs in ReadMessage (unix transport)
Browse files Browse the repository at this point in the history
  • Loading branch information
marselester committed Dec 5, 2022
1 parent a852926 commit d76dc35
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 35 deletions.
17 changes: 17 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ type header struct {
Variant
}

func DecodeMessageBody(msg *Message, body io.Reader, order binary.ByteOrder, fds []int) error {
if err := msg.IsValid(); err != nil {
return err
}
sig, _ := msg.Headers[FieldSignature].value.(Signature)
if sig.str != "" {
dec := newDecoder(body, order, fds)
vs, err := dec.Decode(sig)
if err != nil {
return err
}
msg.Body = vs
}

return nil
}

func DecodeMessageWithFDs(rd io.Reader, fds []int) (msg *Message, err error) {
var order binary.ByteOrder
var hlength, length uint32
Expand Down
109 changes: 74 additions & 35 deletions transport_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ type oobReader struct {
conn *net.UnixConn
oob []byte
buf [4096]byte

// The following fields are used to reduce memory allocs.
headers []header
b []byte
r *bytes.Reader
blength uint32
hlength uint32
proto byte
}

func (o *oobReader) Read(b []byte) (n int, err error) {
Expand Down Expand Up @@ -71,74 +79,102 @@ func (t *unixTransport) EnableUnixFDs() {
}

func (t *unixTransport) ReadMessage() (*Message, error) {
var (
blen, hlen uint32
csheader [16]byte
headers []header
order binary.ByteOrder
unixfds uint32
)
// To be sure that all bytes of out-of-band data are read, we use a special
// reader that uses ReadUnix on the underlying connection instead of Read
// and gathers the out-of-band data in a buffer.
if t.rdr == nil {
t.rdr = &oobReader{conn: t.UnixConn}
t.rdr = &oobReader{
conn: t.UnixConn,
// This buffer is used to decode headers and the body.
// 16 bytes is enough to read the part of the header that has a constant size.
b: make([]byte, 16),
// The reader helps to read from the buffer several times.
r: &bytes.Reader{},
}
} else {
t.rdr.oob = nil
t.rdr.oob = t.rdr.oob[:0]
t.rdr.headers = t.rdr.headers[:0]
}

// read the first 16 bytes (the part of the header that has a constant size),
// from which we can figure out the length of the rest of the message
if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil {
var (
b = t.rdr.b[:1]
r = t.rdr.r
)
if _, err := t.rdr.Read(b); err != nil {
return nil, err
}
switch csheader[0] {
var order binary.ByteOrder
switch b[0] {
case 'l':
order = binary.LittleEndian
case 'B':
order = binary.BigEndian
default:
return nil, InvalidMessageError("invalid byte order")
}
// csheader[4:8] -> length of message body, csheader[12:16] -> length of
// header fields (without alignment)
if err := binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen); err != nil {

// [4:8] is a length of message body,
// [12:16] is a length of header fields (without alignment)
dec := newDecoder(t.rdr, order, nil)
dec.pos = 1
vs, err := dec.Decode(Signature{"yyyuu"})
if err != nil {
return nil, err
}
msg := &Message{}
if err = Store(vs, &msg.Type, &msg.Flags, &t.rdr.proto, &t.rdr.blength, &msg.serial); err != nil {
return nil, err
}
if err := binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen); err != nil {

// Get the header length.
b = t.rdr.b[:4]
if _, err = io.ReadFull(t.rdr, b); err != nil {
return nil, err
}
if hlen%8 != 0 {
hlen += 8 - (hlen % 8)
r.Reset(b)
if err = binary.Read(r, order, &t.rdr.hlength); err != nil {
return nil, err
}
if t.rdr.hlength+t.rdr.blength+16 > 1<<27 {
return nil, InvalidMessageError("message is too long")
}

// decode headers and look for unix fds
headerdata := make([]byte, hlen+4)
copy(headerdata, csheader[12:])
if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil {
// Decode headers and look for unix fds.
if _, err = r.Seek(0, io.SeekStart); err != nil {
return nil, err
}
dec := newDecoder(bytes.NewBuffer(headerdata), order, make([]int, 0))
dec = newDecoder(io.MultiReader(r, t.rdr), order, nil)
dec.pos = 12
vs, err := dec.Decode(Signature{"a(yv)"})
vs, err = dec.Decode(Signature{"a(yv)"})
if err != nil {
return nil, err
}
err = Store(vs, &headers)
if err != nil {
if err = Store(vs, &t.rdr.headers); err != nil {
return nil, err
}
for _, v := range headers {
var unixfds uint32
for _, v := range t.rdr.headers {
if v.Field == byte(FieldUnixFDs) {
unixfds, _ = v.Variant.value.(uint32)
}
}
all := make([]byte, 16+hlen+blen)
copy(all, csheader[:])
copy(all[16:], headerdata[4:])
if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil {

msg.Headers = make(map[HeaderField]Variant)
for _, v := range t.rdr.headers {
msg.Headers[HeaderField(v.Field)] = v.Variant
}

dec.align(8)
// Grow the buffer to accomodate for message body.
if int(t.rdr.blength) > cap(t.rdr.b) {
t.rdr.b = make([]byte, t.rdr.blength)
}
b = t.rdr.b[:t.rdr.blength]
if _, err = io.ReadFull(t.rdr, b); err != nil {
return nil, err
}
r.Reset(b)

if unixfds != 0 {
if !t.hasUnixFDs {
return nil, errors.New("dbus: got unix fds on unsupported transport")
Expand All @@ -155,8 +191,7 @@ func (t *unixTransport) ReadMessage() (*Message, error) {
if err != nil {
return nil, err
}
msg, err := DecodeMessageWithFDs(bytes.NewBuffer(all), fds)
if err != nil {
if err = DecodeMessageBody(msg, r, order, fds); err != nil {
return nil, err
}
// substitute the values in the message body (which are indices for the
Expand All @@ -181,7 +216,11 @@ func (t *unixTransport) ReadMessage() (*Message, error) {
}
return msg, nil
}
return DecodeMessage(bytes.NewBuffer(all))

if err = DecodeMessageBody(msg, r, order, nil); err != nil {
return nil, err
}
return msg, nil
}

func (t *unixTransport) SendMessage(msg *Message) error {
Expand Down
59 changes: 59 additions & 0 deletions transport_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ func (t unixFDTest) Teststructvariant(sv variantContainer) (string, *Error) {
return string(b[:n]), nil
}

type unixFDTestBench struct {
b *testing.B
}

func (bm unixFDTestBench) Testfd(fd UnixFD) (string, *Error) {
var b [4096]byte
file := os.NewFile(uintptr(fd), "testfile")
defer file.Close()
n, err := file.Read(b[:])
if err != nil {
return "", &Error{"com.github.guelfey.test.Error", nil}
}
return string(b[:n]), nil
}

func TestUnixFDs(t *testing.T) {
conn, err := ConnectSessionBus()
if err != nil {
Expand Down Expand Up @@ -155,3 +170,47 @@ func TestUnixFDs(t *testing.T) {
t.Fatal("got", s, "wanted", testString)
}
}

func BenchmarkUnixFDs(b *testing.B) {
conn, err := ConnectSessionBus()
if err != nil {
b.Fatal(err)
}
b.Cleanup(func() {
if err := conn.Close(); err != nil {
b.Error(err)
}
})
r, w, err := os.Pipe()
if err != nil {
b.Fatal(err)
}
b.Cleanup(func() {
if err := w.Close(); err != nil {
b.Error(err)
}
})
name := conn.Names()[0]
test := unixFDTestBench{b}
err = conn.Export(test, "/com/github/guelfey/test", "com.github.guelfey.test")
if err != nil {
b.Fatal(err)
}

var s string
obj := conn.Object(name, "/com/github/guelfey/test")

b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := w.Write([]byte(testString)); err != nil {
b.Fatal(err)
}
err = obj.Call("com.github.guelfey.test.Testfd", 0, UnixFD(r.Fd())).Store(&s)
if err != nil {
b.Fatal(err)
}
if s != testString {
b.Fatal("got", s, "wanted", testString)
}
}
}

0 comments on commit d76dc35

Please sign in to comment.