Permalink
Browse files

Implement basic support for Unix FD passing

This introduces two new types, UnixFD and UnixFDIndex, of which only the
latter is actually supported by the wire protocol. However, all methods
of Connection automatically convert between these types, provided that the
transport supports it.

This also adds the new internal transport interface, which handles the
transport-specific operations of a connection.
  • Loading branch information...
guelfey committed Feb 15, 2013
1 parent ba5dda8 commit 299651ad8c33305979fdfab9b10c958185a20ac5
Showing with 379 additions and 84 deletions.
  1. +22 −1 auth.go
  2. +60 −67 connection.go
  3. +3 −0 message.go
  4. +32 −16 proto.go
  5. +204 −0 transport_unix.go
  6. +58 −0 transport_unix_test.go
View
23 auth.go
@@ -87,7 +87,28 @@ func (conn *Connection) auth() error {
return err
}
if ok {
- return authWriteLine(conn.transport, []byte("BEGIN"))
+ if conn.SupportsUnixFDs() {
+ err = authWriteLine(conn, []byte("NEGOTIATE_UNIX_FD"))
+ if err != nil {
+ return err
+ }
+ line, err := authReadLine(in)
+ if err != nil {
+ return err
+ }
+ switch {
+ case bytes.Equal(line[0], []byte("AGREE_UNIX_FD")):
+ conn.EnableUnixFDs()
+ case bytes.Equal(line[0], []byte("ERROR")):
+ default:
+ return errors.New("authentication protocol error")
+ }
+ }
+ err = authWriteLine(conn.transport, []byte("BEGIN"))
+ if err != nil {
+ return err
+ }
+ return nil
}
}
}
View
@@ -5,7 +5,6 @@ import (
"encoding/binary"
"errors"
"io"
- "net"
"os"
"os/exec"
"reflect"
@@ -20,7 +19,7 @@ const defaultSystemBusAddress = "unix:path=/var/run/dbus/system_bus_socket"
//
// Multiple goroutines may invoke methods on a connection simultaneously.
type Connection struct {
- transport net.Conn
+ transport
uuid string
names []string
namesLck sync.RWMutex
@@ -73,27 +72,9 @@ func ConnectSystemBus() (*Connection, error) {
func NewConnection(address string) (*Connection, error) {
var err error
conn := new(Connection)
- if strings.HasPrefix(address, "unix") {
- abstract := getKey(address, "abstract")
- path := getKey(address, "path")
- switch {
- case abstract == "" && path == "":
- return nil, errors.New("bad address: neither path nor abstract set")
- case abstract != "" && path == "":
- conn.transport, err = net.Dial("unix", "@"+abstract)
- if err != nil {
- return nil, err
- }
- case abstract == "" && path != "":
- conn.transport, err = net.Dial("unix", path)
- if err != nil {
- return nil, err
- }
- case abstract != "" && path != "":
- return nil, errors.New("bad address: both path and abstract set")
- }
- } else {
- return nil, errors.New("bad address: invalid or unsupported transport")
+ conn.transport, err = getTransport(address)
+ if err != nil {
+ return nil, err
}
if err = conn.auth(); err != nil {
conn.transport.Close()
@@ -172,7 +153,7 @@ func (conn *Connection) hello() error {
// transport and dispatching them appropiately.
func (conn *Connection) inWorker() {
for {
- msg, err := conn.readMessage()
+ msg, err := conn.ReadMessage()
if err == nil {
dest, _ := msg.Headers[FieldDestination].value.(string)
found := false
@@ -279,7 +260,7 @@ func (conn *Connection) Names() []string {
// sent to conn.out.
func (conn *Connection) outWorker() {
for msg := range conn.out {
- err := msg.EncodeTo(conn.transport)
+ err := conn.SendMessage(msg)
conn.repliesLck.RLock()
if err != nil {
if conn.replies[msg.Serial] != nil {
@@ -293,41 +274,6 @@ func (conn *Connection) outWorker() {
}
}
-// readMessage reads and decodes a single message from the transport.
-func (conn *Connection) readMessage() (*Message, error) {
- // read the first 16 bytes, from which we can figure out the length of the
- // rest of the message
- var header [16]byte
- if _, err := io.ReadFull(conn.transport, header[:]); err != nil {
- return nil, err
- }
- var order binary.ByteOrder
- switch header[0] {
- case 'l':
- order = binary.LittleEndian
- case 'B':
- order = binary.BigEndian
- default:
- return nil, InvalidMessageError("invalid byte order")
- }
- // header[4:8] -> length of message body, header[12:16] -> length of header
- // fields (without alignment)
- var blen, hlen uint32
- binary.Read(bytes.NewBuffer(header[4:8]), order, &blen)
- binary.Read(bytes.NewBuffer(header[12:16]), order, &hlen)
- if hlen%8 != 0 {
- hlen += 8 - (hlen % 8)
- }
- rest := make([]byte, int(blen+hlen))
- if _, err := io.ReadFull(conn.transport, rest); err != nil {
- return nil, err
- }
- all := make([]byte, 16+len(rest))
- copy(all, header[:])
- copy(all[16:], rest)
- return DecodeMessage(bytes.NewBuffer(all))
-}
-
// sendError creates an error message corresponding to the parameters and sends
// it to conn.out.
func (conn *Connection) sendError(e Error, dest string, serial uint32) {
@@ -383,6 +329,13 @@ func (conn *Connection) serials() {
}
}
+// SupportsUnixFDs returns whether the underlying transport supports passing of
+// unix file descriptors. If this is false, method calls containing unix file
+// descriptors will return an error, emitted signals containing them will not be
+// sent and methods of exported objects that take them as a parameter will
+// behvae as if they weren't present.
+// TODO
+
// Object returns the object identified by the given destination name and path.
func (conn *Connection) Object(dest string, path ObjectPath) *Object {
if !path.IsValid() {
@@ -459,14 +412,54 @@ type Signal struct {
Body []interface{}
}
-// getKey gets a key from a server address. Returns "" on error / not found...
-func getKey(s, key string) string {
- i := strings.IndexRune(s, ':')
- if i == -1 {
- return ""
+// transport is a DBus transport.
+type transport interface {
+ // Read and Write raw data (for example, for the authentication protocol).
+ io.ReadWriteCloser
+
+ // Send the initial null byte used for the EXTERNAL mechanism.
+ SendNullByte() error
+
+ // Returns whether this transport supports passing Unix FDs.
+ SupportsUnixFDs() bool
+
+ // Signal the transport that Unix FD passing is enabled for this connection.
+ EnableUnixFDs()
+
+ // Read / send a message, handling things like Unix FDs.
+ ReadMessage() (*Message, error)
+ SendMessage(*Message) error
+}
+
+func getTransport(address string) (transport, error) {
+ var err error
+ var t transport
+
+ m := map[string]func(string) (transport, error){
+ "unix": newUnixTransport,
}
- s = s[i+1:]
- i = strings.Index(s, key)
+ addresses := strings.Split(address, ";")
+ for _, v := range addresses {
+ i := strings.IndexRune(v, ':')
+ if i == -1 {
+ err = errors.New("bad address: no transport")
+ continue
+ }
+ f := m[v[:i]]
+ if f == nil {
+ err = errors.New("bad address: invalid or unsupported transport")
+ }
+ t, err = f(v[i+1:])
+ if err == nil {
+ return t, nil
+ }
+ }
+ return nil, err
+}
+
+// getKey gets a key from a the list of keys. Returns "" on error / not found...
+func getKey(s, key string) string {
+ i := strings.Index(s, key)
if i == -1 {
return ""
}
View
@@ -279,6 +279,9 @@ func (msg *Message) String() string {
s += " to <null>"
}
s += " serial " + strconv.FormatUint(uint64(msg.Serial), 10)
+ if v, ok := msg.Headers[FieldUnixFds]; ok {
+ s += " unixfds " + strconv.FormatUint(uint64(v.value.(uint32)), 10)
+ }
if v, ok := msg.Headers[FieldPath]; ok {
s += " path " + string(v.value.(ObjectPath))
}
View
@@ -7,21 +7,23 @@ import (
)
var (
- byteType = reflect.TypeOf(byte(0))
- boolType = reflect.TypeOf(false)
- uint8Type = reflect.TypeOf(uint8(0))
- int16Type = reflect.TypeOf(int16(0))
- uint16Type = reflect.TypeOf(uint16(0))
- int32Type = reflect.TypeOf(int32(0))
- uint32Type = reflect.TypeOf(uint32(0))
- int64Type = reflect.TypeOf(int64(0))
- uint64Type = reflect.TypeOf(uint64(0))
- float64Type = reflect.TypeOf(float64(0))
- stringType = reflect.TypeOf("")
- signatureType = reflect.TypeOf(Signature{""})
- objectPathType = reflect.TypeOf(ObjectPath(""))
- variantType = reflect.TypeOf(Variant{Signature{""}, nil})
- interfacesType = reflect.TypeOf([]interface{}{})
+ byteType = reflect.TypeOf(byte(0))
+ boolType = reflect.TypeOf(false)
+ uint8Type = reflect.TypeOf(uint8(0))
+ int16Type = reflect.TypeOf(int16(0))
+ uint16Type = reflect.TypeOf(uint16(0))
+ int32Type = reflect.TypeOf(int32(0))
+ uint32Type = reflect.TypeOf(uint32(0))
+ int64Type = reflect.TypeOf(int64(0))
+ uint64Type = reflect.TypeOf(uint64(0))
+ float64Type = reflect.TypeOf(float64(0))
+ stringType = reflect.TypeOf("")
+ signatureType = reflect.TypeOf(Signature{""})
+ objectPathType = reflect.TypeOf(ObjectPath(""))
+ variantType = reflect.TypeOf(Variant{Signature{""}, nil})
+ interfacesType = reflect.TypeOf([]interface{}{})
+ unixFDType = reflect.TypeOf(UnixFD(0))
+ unixFDIndexType = reflect.TypeOf(UnixFDIndex(0))
)
type invalidTypeError struct {
@@ -46,6 +48,7 @@ var sigToType = map[byte]reflect.Type{
'g': signatureType,
'o': objectPathType,
'v': variantType,
+ 'h': unixFDIndexType,
}
// Signature represents a correct type signature as specified
@@ -84,8 +87,14 @@ func getSignature(t reflect.Type) string {
case reflect.Uint16:
return "q"
case reflect.Int32:
+ if t == unixFDType {
+ return "h"
+ }
return "i"
case reflect.Uint32:
+ if t == unixFDIndexType {
+ return "h"
+ }
return "u"
case reflect.Int64:
return "x"
@@ -230,6 +239,13 @@ func (o ObjectPath) IsValid() bool {
return true
}
+// A UnixFD is a Unix file descriptor sent over the wire. See the package-level
+// documentation for more information about Unix file descriptor passsing.
+type UnixFD int32
+
+// A UnixFDIndex is the representation of a Unix file descriptor in a message.
+type UnixFDIndex uint32
+
// Variant represents a DBus variant type.
type Variant struct {
sig Signature
@@ -300,7 +316,7 @@ func validSingle(s string, depth int) (err error, rem string) {
return SignatureError{Sig: s, Reason: "container nesting too deep"}, ""
}
switch s[0] {
- case 'y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'g', 'o', 'v':
+ case 'y', 'b', 'n', 'q', 'i', 'u', 'x', 't', 'd', 's', 'g', 'o', 'v', 'h':
return nil, s[1:]
case 'a':
if len(s) > 1 && s[1] == '{' {
Oops, something went wrong.

0 comments on commit 299651a

Please sign in to comment.