diff --git a/internal/fulfiller/fulfiller.go b/internal/fulfiller/fulfiller.go index 65deb6dc..a9c9d838 100644 --- a/internal/fulfiller/fulfiller.go +++ b/internal/fulfiller/fulfiller.go @@ -178,53 +178,49 @@ type pcall struct { ecall } -// embargoClient is a client that flushes a queue of calls. -type embargoClient struct { +// EmbargoClient is a client that flushes a queue of calls. +// Fulfiller will create these automatically when pipelined calls are +// made on unresolved answers. EmbargoClient is exported so that rpc +// can avoid making calls on its own Conn. +type EmbargoClient struct { client capnp.Client - mu sync.RWMutex - q queue.Queue + mu sync.RWMutex + q queue.Queue + calls ecallList } func newEmbargoClient(client capnp.Client, queue []ecall) capnp.Client { - ec := &embargoClient{client: client} - qq := make(ecallList, callQueueSize) - n := copy(qq, queue) - ec.q.Init(qq, n) + ec := &EmbargoClient{ + client: client, + calls: make(ecallList, callQueueSize), + } + ec.q.Init(ec.calls, copy(ec.calls, queue)) go ec.flushQueue() return ec } -func (ec *embargoClient) push(cl *capnp.Call) capnp.Answer { +func (ec *EmbargoClient) push(cl *capnp.Call) capnp.Answer { f := new(Fulfiller) cl, err := cl.Copy(nil) if err != nil { return capnp.ErrorAnswer(err) } - if ok := ec.q.Push(ecall{cl, f}); !ok { + i := ec.q.Push() + if i == -1 { return capnp.ErrorAnswer(errCallQueueFull) } + ec.calls[i] = ecall{cl, f} return f } -func (ec *embargoClient) peek() ecall { - if ec.q.Len() == 0 { - return ecall{} - } - return ec.q.Peek().(ecall) -} - -func (ec *embargoClient) pop() ecall { - if ec.q.Len() == 0 { - return ecall{} - } - return ec.q.Pop().(ecall) -} - // flushQueue is run in its own goroutine. -func (ec *embargoClient) flushQueue() { +func (ec *EmbargoClient) flushQueue() { + var c ecall ec.mu.Lock() - c := ec.peek() + if i := ec.q.Front(); i != -1 { + c = ec.calls[i] + } ec.mu.Unlock() for c.call != nil { ans := ec.client.Call(c.call) @@ -237,13 +233,19 @@ func (ec *embargoClient) flushQueue() { } }(c.f, ans) ec.mu.Lock() - ec.pop() - c = ec.peek() + ec.q.Pop() + if i := ec.q.Front(); i != -1 { + c = ec.calls[i] + } else { + c = ecall{} + } ec.mu.Unlock() } } -func (ec *embargoClient) WrappedClient() capnp.Client { +// Client returns the underlying client if the embargo has been lifted +// and nil otherwise. +func (ec *EmbargoClient) Client() capnp.Client { ec.mu.RLock() ok := ec.isPassthrough() ec.mu.RUnlock() @@ -253,11 +255,13 @@ func (ec *embargoClient) WrappedClient() capnp.Client { return ec.client } -func (ec *embargoClient) isPassthrough() bool { +func (ec *EmbargoClient) isPassthrough() bool { return ec.q.Len() == 0 } -func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer { +// Call either queues a call to the underlying client or starts a call +// if the embargo has been lifted. +func (ec *EmbargoClient) Call(cl *capnp.Call) capnp.Answer { // Fast path: queue is flushed. ec.mu.RLock() ok := ec.isPassthrough() @@ -278,12 +282,26 @@ func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer { return ans } -func (ec *embargoClient) Close() error { +// TryQueue will attempt to queue a call or return nil if the embargo +// has been lifted. +func (ec *EmbargoClient) TryQueue(cl *capnp.Call) capnp.Answer { + ec.mu.Lock() + if ec.isPassthrough() { + ec.mu.Unlock() + return nil + } + ans := ec.push(cl) + ec.mu.Unlock() + return ans +} + +// Close closes the underlying client, rejecting any queued calls. +func (ec *EmbargoClient) Close() error { ec.mu.Lock() // reject all queued calls for ec.q.Len() > 0 { - c := ec.pop() - c.f.Reject(errQueueCallCancel) + ec.calls[ec.q.Front()].f.Reject(errQueueCallCancel) + ec.q.Pop() } ec.mu.Unlock() return ec.client.Close() @@ -301,16 +319,8 @@ func (el ecallList) Len() int { return len(el) } -func (el ecallList) At(i int) interface{} { - return el[i] -} - -func (el ecallList) Set(i int, x interface{}) { - if x == nil { - el[i] = ecall{} - } else { - el[i] = x.(ecall) - } +func (el ecallList) Clear(i int) { + el[i] = ecall{} } var ( diff --git a/internal/queue/queue.go b/internal/queue/queue.go index 3ac93fd7..3f0b95e4 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -6,9 +6,11 @@ type Queue struct { q Interface start int n int + cap int } -// New creates a new queue that starts with n elements. +// New creates a new queue that starts with n elements. The interface's +// length must not change over the course of the queue's usage. func New(q Interface, n int) *Queue { qq := new(Queue) qq.Init(q, n) @@ -18,54 +20,51 @@ func New(q Interface, n int) *Queue { // Init initializes a queue. The old queue is untouched. func (q *Queue) Init(r Interface, n int) { q.q = r - q.start, q.n = 0, n + q.start = 0 + q.n = n + q.cap = r.Len() } // Len returns the length of the queue. This is different from the -// underlying interface's length. +// underlying interface's length, which is the queue's capacity. func (q *Queue) Len() int { return q.n } -// Push pushes an element on the queue. If the queue is full, -// Push returns false. If x is nil, Push panics. -func (q *Queue) Push(x interface{}) bool { - n := q.q.Len() - if q.n >= n { - return false +// Push reserves space for an element on the queue, returning its index. +// If the queue is full, Push returns -1. +func (q *Queue) Push() int { + if q.n >= q.cap { + return -1 } - i := (q.start + q.n) % n - q.q.Set(i, x) + i := (q.start + q.n) % q.cap q.n++ - return true + return i } -// Peek returns the element at the front of the queue. -// If the queue is empty, Peek panics. -func (q *Queue) Peek() interface{} { +// Front returns the index of the front of the queue, or -1 if the queue is empty. +func (q *Queue) Front() int { if q.n == 0 { - panic("Queue.Pop called on empty queue") + return -1 } - return q.q.At(q.start) + return q.start } -// Pop pops an element from the queue. -// If the queue is empty, Pop panics. -func (q *Queue) Pop() interface{} { - x := q.Peek() - q.q.Set(q.start, nil) - q.start = (q.start + 1) % q.q.Len() +// Pop pops an element from the queue, returning whether it succeeded. +func (q *Queue) Pop() bool { + if q.n == 0 { + return false + } + q.q.Clear(q.start) + q.start = (q.start + 1) % q.cap q.n-- - return x + return true } // A type implementing Interface can be used to store elements in a Queue. type Interface interface { // Len returns the number of elements available. Len() int - // At returns the element at i. - At(i int) interface{} - // Set sets the element at i to x. - // If x is nil, that element should be cleared. - Set(i int, x interface{}) + // Clear removes the element at i. + Clear(i int) } diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go index cc62debe..03b24dba 100644 --- a/internal/queue/queue_test.go +++ b/internal/queue/queue_test.go @@ -1,8 +1,6 @@ package queue -import ( - "testing" -) +import "testing" func TestNew(t *testing.T) { qi := make(ints, 5) @@ -23,8 +21,8 @@ func TestPrepush(t *testing.T) { if n := q.Len(); n != 1 { t.Fatalf("New(qi, 1).Len() = %d; want 1", n) } - if x := q.Pop().(int); x != 42 { - t.Errorf("Pop() = %d; want 42", x) + if i := q.Front(); i != 0 { + t.Errorf("q.Front() = %d; want 0") } } @@ -32,14 +30,18 @@ func TestPush(t *testing.T) { qi := make(ints, 5) q := New(qi, 0) - ok := q.Push(42) - - if !ok { - t.Error("q.Push(42) returned false") + i := q.Push() + if i == -1 { + t.Error("q.Push() returned -1") } + qi[i] = 42 + if n := q.Len(); n != 1 { t.Errorf("q.Len() after push = %d; want 1", n) } + if front := q.Front(); front != i { + t.Errorf("q.Front() after push = %d; want %d", front, i) + } } func TestPushFull(t *testing.T) { @@ -47,20 +49,28 @@ func TestPushFull(t *testing.T) { q := New(qi, 0) var ok [6]bool - ok[0] = q.Push(10) - ok[1] = q.Push(11) - ok[2] = q.Push(12) - ok[3] = q.Push(13) - ok[4] = q.Push(14) - ok[5] = q.Push(15) + push := func(n int, val int) { + i := q.Push() + if i == -1 { + return + } + ok[n] = true + qi[i] = val + } + push(0, 10) + push(1, 11) + push(2, 12) + push(3, 13) + push(4, 14) + push(5, 15) for i := 0; i < 5; i++ { if !ok[i] { - t.Errorf("q.Push(%d) returned false", 10+i) + t.Errorf("q.Push() #%d returned -1", i, 10+i) } } if ok[5] { - t.Error("q.Push(15) returned true") + t.Error("q.Push() #5 returned true") } if n := q.Len(); n != 5 { t.Errorf("q.Len() after full = %d; want 5", n) @@ -70,26 +80,33 @@ func TestPushFull(t *testing.T) { func TestPop(t *testing.T) { qi := make(ints, 5) q := New(qi, 0) - q.Push(1) - q.Push(2) - q.Push(3) - - outs := make([]int, 0, len(qi)) - outs = append(outs, q.Pop().(int)) - outs = append(outs, q.Pop().(int)) - outs = append(outs, q.Pop().(int)) + qi[q.Push()] = 1 + qi[q.Push()] = 2 + qi[q.Push()] = 3 + + outs := make([]int, 3) + for n := range outs { + i := q.Front() + if i == -1 { + t.Fatalf("before q.Pop() #%d, Front == -1", n) + } + outs[n] = qi[i] + if !q.Pop() { + t.Fatalf("q.Pop() #%d = false", n) + } + } if n := q.Len(); n != 0 { t.Errorf("q.Len() after pops = %d; want 0", n) } if outs[0] != 1 { - t.Errorf("first pop = %d; want 1", outs[0]) + t.Errorf("pop #0 = %d; want 1", outs[0]) } if outs[1] != 2 { - t.Errorf("first pop = %d; want 2", outs[1]) + t.Errorf("pop #1 = %d; want 2", outs[1]) } if outs[2] != 3 { - t.Errorf("first pop = %d; want 3", outs[2]) + t.Errorf("pop #2 = %d; want 3", outs[2]) } for i := range qi { if qi[i] != 0 { @@ -101,29 +118,27 @@ func TestPop(t *testing.T) { func TestWrap(t *testing.T) { qi := make(ints, 5) q := New(qi, 0) - var ok [7]bool - ok[0] = q.Push(10) - ok[1] = q.Push(11) - ok[2] = q.Push(12) + qi[q.Push()] = 10 + qi[q.Push()] = 11 + qi[q.Push()] = 12 q.Pop() q.Pop() - ok[3] = q.Push(13) - ok[4] = q.Push(14) - ok[5] = q.Push(15) - ok[6] = q.Push(16) + qi[q.Push()] = 13 + qi[q.Push()] = 14 + qi[q.Push()] = 15 + qi[q.Push()] = 16 - for i := 0; i < 6; i++ { - if !ok[i] { - t.Errorf("q.Push(%d) returned false", 10+i) - } - } if n := q.Len(); n != 5 { t.Errorf("q.Len() = %d; want 5", n) } for i := 12; q.Len() > 0; i++ { - if x := q.Pop().(int); x != i { - t.Errorf("q.Pop() = %d; want %d", x, i) + if x := qi[q.Front()]; x != i { + t.Errorf("qi[q.Front()] = %d; want %d", x, i) + } + if !q.Pop() { + t.Error("q.Pop() returned false") + break } } } @@ -134,14 +149,6 @@ func (is ints) Len() int { return len(is) } -func (is ints) At(i int) interface{} { - return is[i] -} - -func (is ints) Set(i int, x interface{}) { - if x == nil { - is[i] = 0 - } else { - is[i] = x.(int) - } +func (is ints) Clear(i int) { + is[i] = 0 } diff --git a/rpc/answer.go b/rpc/answer.go index e7689712..548a7e62 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -15,63 +15,40 @@ import ( // TODO(light): make this a ConnOption const callQueueSize = 64 -type answerTable struct { - tab map[answerID]*answer - manager *manager - out chan<- rpccapnp.Message - returns chan<- *outgoingReturn - queueCloses chan<- queueClientClose -} - -func (at *answerTable) get(id answerID) *answer { - var a *answer - if at.tab != nil { - a = at.tab[id] - } - return a -} - -// insert creates a new question with the given ID, returning nil +// insertAnswer creates a new answer with the given ID, returning nil // if the ID is already in use. -func (at *answerTable) insert(id answerID, cancel context.CancelFunc) *answer { - if at.tab == nil { - at.tab = make(map[answerID]*answer) - } - var a *answer - if _, ok := at.tab[id]; !ok { - a = &answer{ - id: id, - cancel: cancel, - manager: at.manager, - out: at.out, - returns: at.returns, - queueCloses: at.queueCloses, - resolved: make(chan struct{}), - queue: make([]pcall, 0, callQueueSize), - } - at.tab[id] = a +func (c *Conn) insertAnswer(id answerID, cancel context.CancelFunc) *answer { + if c.answers == nil { + c.answers = make(map[answerID]*answer) + } else if _, exists := c.answers[id]; exists { + return nil } + a := &answer{ + id: id, + cancel: cancel, + conn: c, + resolved: make(chan struct{}), + queue: make([]pcall, 0, callQueueSize), + } + c.answers[id] = a return a } -func (at *answerTable) pop(id answerID) *answer { - var a *answer - if at.tab != nil { - a = at.tab[id] - delete(at.tab, id) +func (c *Conn) popAnswer(id answerID) *answer { + if c.answers == nil { + return nil } + a := c.answers[id] + delete(c.answers, id) return a } type answer struct { - id answerID - cancel context.CancelFunc - resultCaps []exportID - manager *manager - out chan<- rpccapnp.Message - returns chan<- *outgoingReturn - queueCloses chan<- queueClientClose - resolved chan struct{} + id answerID + cancel context.CancelFunc + resultCaps []exportID + conn *Conn + resolved chan struct{} mu sync.RWMutex obj capnp.Ptr @@ -80,12 +57,11 @@ type answer struct { queue []pcall } -// fulfill is called to resolve an answer successfully and returns a list -// of return messages to send. -// It must be called from the coordinate goroutine. -func (a *answer) fulfill(msgs []rpccapnp.Message, obj capnp.Ptr, makeCapTable capTableMaker) []rpccapnp.Message { +// fulfill is called to resolve an answer successfully. It returns an +// error if its connection is shut down while sending messages. The +// caller must be holding onto a.conn.mu. +func (a *answer) fulfill(obj capnp.Ptr) error { a.mu.Lock() - defer a.mu.Unlock() if a.done { panic("answer.fulfill called more than once") } @@ -96,32 +72,37 @@ func (a *answer) fulfill(msgs []rpccapnp.Message, obj capnp.Ptr, makeCapTable ca ret, _ := retmsg.Return() payload, _ := ret.NewResults() payload.SetContentPtr(obj) - payloadTab, err := makeCapTable(ret.Segment()) - if err != nil { - // TODO(light): handle this more gracefully - panic(err) + var firstErr error + if payloadTab, err := a.conn.makeCapTable(ret.Segment()); err == nil { + payload.SetCapTable(payloadTab) + if err := a.conn.sendMessage(retmsg); err != nil { + firstErr = err + } + } else { + firstErr = err } - payload.SetCapTable(payloadTab) - msgs = append(msgs, retmsg) - queues, msgs := a.emptyQueue(msgs, obj) + queues, err := a.emptyQueue(obj) + if err != nil && firstErr == nil { + firstErr = err + } ctab := obj.Segment().Message().CapTable for capIdx, q := range queues { - ctab[capIdx] = newQueueClient(a.manager, ctab[capIdx], q, a.out, a.queueCloses) + ctab[capIdx] = newQueueClient(a.conn, ctab[capIdx], q) } close(a.resolved) - return msgs + a.mu.Unlock() + return firstErr } -// reject is called to resolve an answer with failure and returns a list -// of return messages to send. -// It must be called from the coordinate goroutine. -func (a *answer) reject(msgs []rpccapnp.Message, err error) []rpccapnp.Message { +// reject is called to resolve an answer with failure. It returns an +// error if its connection is shut down while sending messages. The +// caller must be holding onto a.conn.mu. +func (a *answer) reject(err error) error { if err == nil { panic("answer.reject called with nil") } a.mu.Lock() - defer a.mu.Unlock() if a.done { panic("answer.reject called more than once") } @@ -129,29 +110,40 @@ func (a *answer) reject(msgs []rpccapnp.Message, err error) []rpccapnp.Message { m := newReturnMessage(nil, a.id) mret, _ := m.Return() setReturnException(mret, err) - msgs = append(msgs, m) + var firstErr error + if err := a.conn.sendMessage(m); err != nil { + firstErr = err + } for i := range a.queue { - msgs = a.queue[i].a.reject(msgs, err) - a.queue[i] = pcall{} + if err := a.queue[i].a.reject(err); err != nil && firstErr == nil { + firstErr = err + } } + a.queue = nil close(a.resolved) - return msgs + a.mu.Unlock() + return firstErr } // emptyQueue splits the queue by which capability it targets // and drops any invalid calls. Once this function returns, a.queue // will be nil. -func (a *answer) emptyQueue(msgs []rpccapnp.Message, obj capnp.Ptr) (map[capnp.CapabilityID][]qcall, []rpccapnp.Message) { +func (a *answer) emptyQueue(obj capnp.Ptr) (map[capnp.CapabilityID][]qcall, error) { + var firstErr error qs := make(map[capnp.CapabilityID][]qcall, len(a.queue)) for i, pc := range a.queue { c, err := capnp.TransformPtr(obj, pc.transform) if err != nil { - msgs = pc.a.reject(msgs, err) + if err := pc.a.reject(err); err != nil && firstErr == nil { + firstErr = err + } continue } ci := c.Interface() if !ci.IsValid() { - msgs = pc.a.reject(msgs, capnp.ErrNullClient) + if err := pc.a.reject(capnp.ErrNullClient); err != nil && firstErr == nil { + firstErr = err + } continue } cn := ci.Capability() @@ -161,43 +153,27 @@ func (a *answer) emptyQueue(msgs []rpccapnp.Message, obj capnp.Ptr) (map[capnp.C qs[cn] = append(qs[cn], pc.qcall) } a.queue = nil - return qs, msgs + return qs, firstErr } -func (a *answer) peek() (obj capnp.Ptr, err error, ok bool) { - a.mu.RLock() - obj, err, ok = a.obj, a.err, a.done - a.mu.RUnlock() - return -} - -// queueCall is called from the coordinate goroutine to add a call to -// the queue. -func (a *answer) queueCall(result *answer, transform []capnp.PipelineOp, call *capnp.Call) error { - a.mu.Lock() - defer a.mu.Unlock() - if a.done { - panic("answer.queueCall called on resolved answer") - } +// queueCallLocked enqueues a call to be made after the answer has been +// resolved. The answer must not be resolved yet. pc should have +// transform and one of pc.a or pc.f to be set. The caller must be +// holding onto a.mu. +func (a *answer) queueCallLocked(call *capnp.Call, pc pcall) error { if len(a.queue) == cap(a.queue) { return errQueueFull } - cc, err := call.Copy(nil) + var err error + pc.call, err = call.Copy(nil) if err != nil { return err } - a.queue = append(a.queue, pcall{ - transform: transform, - qcall: qcall{ - a: result, - call: cc, - }, - }) + a.queue = append(a.queue, pc) return nil } -// queueDisembargo is called from the coordinate goroutine to add a -// disembargo message to the queue. +// queueDisembargo enqueues a disembargo message. func (a *answer) queueDisembargo(transform []capnp.PipelineOp, id embargoID, target rpccapnp.MessageTarget) (queued bool, err error) { a.mu.Lock() defer a.mu.Unlock() @@ -217,12 +193,12 @@ func (a *answer) queueDisembargo(transform []capnp.PipelineOp, id embargoID, tar // No need to embargo, disembargo immediately. return false, nil } - if ic, ok := extractRPCClient(qc.client).(*importClient); !(ok && a.manager == ic.manager) { + if ic := isImport(qc.client); ic == nil || a.conn != ic.conn { return false, errDisembargoNonImport } qc.mu.Lock() if !qc.isPassthrough() { - err = qc.pushEmbargo(id, target) + err = qc.pushEmbargoLocked(id, target) if err == nil { queued = true } @@ -236,24 +212,24 @@ func (a *answer) pipelineClient(transform []capnp.PipelineOp) capnp.Client { } // joinAnswer resolves an RPC answer by waiting on a generic answer. -// It waits until the generic answer is finished, so it should be run -// in its own goroutine. +// The caller must not be holding onto a.conn.mu. func joinAnswer(a *answer, ca capnp.Answer) { s, err := ca.Struct() - r := &outgoingReturn{ - a: a, - obj: s.ToPtr(), - err: err, - } select { - case a.returns <- r: - case <-a.manager.finish: + case <-a.conn.mu: + // Locked. + case <-a.conn.bg.Done(): + return + } + if err == nil { + a.fulfill(s.ToPtr()) + } else { + a.reject(err) } + a.conn.mu.Unlock() } // joinFulfiller resolves a fulfiller by waiting on a generic answer. -// It waits until the generic answer is finished, so it should be run -// in its own goroutine. func joinFulfiller(f *fulfiller.Fulfiller, ca capnp.Answer) { s, err := ca.Struct() if err != nil { @@ -263,85 +239,67 @@ func joinFulfiller(f *fulfiller.Fulfiller, ca capnp.Answer) { } } -// outgoingReturn is a message sent to the coordinate goroutine to -// indicate that a call started by an answer has completed. A simple -// message is insufficient, since the connection needs to populate the -// return message's capability table. -type outgoingReturn struct { - a *answer - obj capnp.Ptr - err error -} - type queueClient struct { - manager *manager - client capnp.Client - out chan<- rpccapnp.Message - closes chan<- queueClientClose + client capnp.Client + conn *Conn - mu sync.RWMutex - q queue.Queue + mu sync.RWMutex + q queue.Queue + calls qcallList } -func newQueueClient(m *manager, client capnp.Client, queue []qcall, out chan<- rpccapnp.Message, closes chan<- queueClientClose) *queueClient { +func newQueueClient(c *Conn, client capnp.Client, queue []qcall) *queueClient { qc := &queueClient{ - manager: m, - client: client, - out: out, - closes: closes, - } - qq := make(qcallList, callQueueSize) - n := copy(qq, queue) - qc.q.Init(qq, n) + client: client, + conn: c, + calls: make(qcallList, callQueueSize), + } + qc.q.Init(qc.calls, copy(qc.calls, queue)) go qc.flushQueue() return qc } -func (qc *queueClient) pushCall(cl *capnp.Call) capnp.Answer { +func (qc *queueClient) pushCallLocked(cl *capnp.Call) capnp.Answer { f := new(fulfiller.Fulfiller) cl, err := cl.Copy(nil) if err != nil { return capnp.ErrorAnswer(err) } - if ok := qc.q.Push(qcall{call: cl, f: f}); !ok { + i := qc.q.Push() + if i == -1 { return capnp.ErrorAnswer(errQueueFull) } + qc.calls[i] = qcall{call: cl, f: f} return f } -func (qc *queueClient) pushEmbargo(id embargoID, tgt rpccapnp.MessageTarget) error { - ok := qc.q.Push(qcall{embargoID: id, embargoTarget: tgt}) - if !ok { +func (qc *queueClient) pushEmbargoLocked(id embargoID, tgt rpccapnp.MessageTarget) error { + i := qc.q.Push() + if i == -1 { return errQueueFull } + qc.calls[i] = qcall{embargoID: id, embargoTarget: tgt} return nil } -func (qc *queueClient) peek() qcall { - if qc.q.Len() == 0 { - return qcall{} - } - return qc.q.Peek().(qcall) -} - -func (qc *queueClient) pop() qcall { - if qc.q.Len() == 0 { - return qcall{} - } - return qc.q.Pop().(qcall) -} - // flushQueue is run in its own goroutine. func (qc *queueClient) flushQueue() { + var c qcall qc.mu.RLock() - c := qc.peek() + if i := qc.q.Front(); i != -1 { + c = qc.calls[i] + } qc.mu.RUnlock() for c.which() != qcallInvalid { qc.handle(&c) qc.mu.Lock() - qc.pop() - c = qc.peek() + qc.q.Pop() + if i := qc.q.Front(); i != -1 { + c = qc.calls[i] + } else { + c = qcall{} + } qc.mu.Unlock() } } @@ -358,7 +316,7 @@ func (qc *queueClient) handle(c *qcall) { msg := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID) d, _ := msg.Disembargo() d.SetTarget(c.embargoTarget) - sendMessage(qc.manager, qc.out, msg) + qc.conn.sendMessage(msg) } } @@ -382,63 +340,63 @@ func (qc *queueClient) Call(cl *capnp.Call) capnp.Answer { qc.mu.Unlock() return qc.client.Call(cl) } - ans := qc.pushCall(cl) + ans := qc.pushCallLocked(cl) qc.mu.Unlock() return ans } -func (qc *queueClient) WrappedClient() capnp.Client { - qc.mu.RLock() - ok := qc.isPassthrough() - qc.mu.RUnlock() - if !ok { +func (qc *queueClient) tryQueue(cl *capnp.Call) capnp.Answer { + qc.mu.Lock() + if qc.isPassthrough() { + qc.mu.Unlock() return nil } - return qc.client + ans := qc.pushCallLocked(cl) + qc.mu.Unlock() + return ans } func (qc *queueClient) Close() error { - done := make(chan struct{}) - select { - case qc.closes <- queueClientClose{qc, done}: - case <-qc.manager.finish: - return qc.manager.err() - } select { - case <-done: - case <-qc.manager.finish: - return qc.manager.err() + case <-qc.conn.mu: + // Locked. + case <-qc.conn.bg.Done(): + return ErrConnClosed + } + rejErr := qc.rejectQueue() + qc.conn.mu.Unlock() + if err := qc.client.Close(); err != nil { + return err } - return qc.client.Close() + return rejErr } -// rejectQueue is called from the coordinate goroutine to close out a queueClient. -func (qc *queueClient) rejectQueue(msgs []rpccapnp.Message) []rpccapnp.Message { +// rejectQueue drains the client's queue. It returns an error if the +// connection was shut down while messages are sent. The caller must be +// holding onto qc.conn.mu. +func (qc *queueClient) rejectQueue() error { + var firstErr error qc.mu.Lock() - for { - c := qc.pop() - if w := c.which(); w == qcallRemoteCall { - msgs = c.a.reject(msgs, errQueueCallCancel) - } else if w == qcallLocalCall { + for ; qc.q.Len() > 0; qc.q.Pop() { + c := qc.calls[qc.q.Front()] + switch c.which() { + case qcallRemoteCall: + if err := c.a.reject(errQueueCallCancel); err != nil && firstErr == nil { + firstErr = err + } + case qcallLocalCall: c.f.Reject(errQueueCallCancel) - } else if w == qcallDisembargo { + case qcallDisembargo: m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, c.embargoID) d, _ := m.Disembargo() d.SetTarget(c.embargoTarget) - msgs = append(msgs, m) - } else { - break + if err := qc.conn.sendMessage(m); err != nil && firstErr == nil { + firstErr = err + } } } qc.mu.Unlock() - return msgs -} - -// queueClientClose is a message sent to the coordinate goroutine to -// handle rejecting a queue. -type queueClientClose struct { - qc *queueClient - done chan<- struct{} + return firstErr } // pcall is a queued pipeline call. @@ -486,16 +444,8 @@ func (ql qcallList) Len() int { return len(ql) } -func (ql qcallList) At(i int) interface{} { - return ql[i] -} - -func (ql qcallList) Set(i int, x interface{}) { - if x == nil { - ql[i] = qcall{} - } else { - ql[i] = x.(qcall) - } +func (ql qcallList) Clear(i int) { + ql[i] = qcall{} } // A localAnswerClient is used to provide a pipelined client of an answer. @@ -511,45 +461,29 @@ func (lac *localAnswerClient) Call(call *capnp.Call) capnp.Answer { lac.a.mu.Unlock() return clientFromResolution(lac.transform, obj, err).Call(call) } - defer lac.a.mu.Unlock() - if len(lac.a.queue) == cap(lac.a.queue) { - return capnp.ErrorAnswer(errQueueFull) - } f := new(fulfiller.Fulfiller) - cc, err := call.Copy(nil) - if err != nil { - return capnp.ErrorAnswer(err) - } - lac.a.queue = append(lac.a.queue, pcall{ + err := lac.a.queueCallLocked(call, pcall{ transform: lac.transform, - qcall: qcall{ - f: f, - call: cc, - }, + qcall: qcall{f: f}, }) - return f -} - -func (lac *localAnswerClient) WrappedClient() capnp.Client { - obj, err, ok := lac.a.peek() - if !ok { - return nil + lac.a.mu.Unlock() + if err != nil { + return capnp.ErrorAnswer(errQueueFull) } - return clientFromResolution(lac.transform, obj, err) + return f } func (lac *localAnswerClient) Close() error { - obj, err, ok := lac.a.peek() - if !ok { + lac.a.mu.RLock() + obj, err, done := lac.a.obj, lac.a.err, lac.a.done + lac.a.mu.RUnlock() + if !done { return nil } client := clientFromResolution(lac.transform, obj, err) return client.Close() } -// A capTableMaker converts the clients in a segment's message into capability descriptors. -type capTableMaker func(*capnp.Segment) (rpccapnp.CapDescriptor_List, error) - var ( errQueueFull = errors.New("rpc: pipeline queue full") errQueueCallCancel = errors.New("rpc: queued call canceled") diff --git a/rpc/bench_test.go b/rpc/bench_test.go index ba23e897..386b24c7 100644 --- a/rpc/bench_test.go +++ b/rpc/bench_test.go @@ -16,10 +16,9 @@ func BenchmarkPingPong(b *testing.B) { if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p) - d := rpc.NewConn(q, rpc.BootstrapFunc(func(ctx context.Context) (capnp.Client, error) { - return testcapnp.PingPong_ServerToClient(pingPongServer{}).Client, nil - })) + log := testLogger{b} + c := rpc.NewConn(p, rpc.ConnLog(log)) + d := rpc.NewConn(q, rpc.ConnLog(log), rpc.BootstrapFunc(bootstrapPingPong)) defer d.Wait() defer c.Close() @@ -44,6 +43,10 @@ func BenchmarkPingPong(b *testing.B) { } } +func bootstrapPingPong(ctx context.Context) (capnp.Client, error) { + return testcapnp.PingPong_ServerToClient(pingPongServer{}).Client, nil +} + type pingPongServer struct{} func (pingPongServer) EchoNum(call testcapnp.PingPong_echoNum) error { diff --git a/rpc/cancel_test.go b/rpc/cancel_test.go index 5a8ecc75..dd609ef4 100644 --- a/rpc/cancel_test.go +++ b/rpc/cancel_test.go @@ -14,14 +14,15 @@ import ( func TestCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + log := testLogger{t} p, q := pipetransport.New() if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p) + c := rpc.NewConn(p, rpc.ConnLog(log)) notify := make(chan struct{}) hanger := testcapnp.Hanger_ServerToClient(Hanger{notify: notify}) - d := rpc.NewConn(q, rpc.MainInterface(hanger.Client)) + d := rpc.NewConn(q, rpc.MainInterface(hanger.Client), rpc.ConnLog(log)) defer d.Wait() defer c.Close() client := testcapnp.Hanger{Client: c.Bootstrap(ctx)} diff --git a/rpc/embargo_test.go b/rpc/embargo_test.go index b83728fc..63e17ee8 100644 --- a/rpc/embargo_test.go +++ b/rpc/embargo_test.go @@ -18,9 +18,10 @@ func TestEmbargo(t *testing.T) { if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p) + log := testLogger{t} + c := rpc.NewConn(p, rpc.ConnLog(log)) echoSrv := testcapnp.Echoer_ServerToClient(new(Echoer)) - d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client)) + d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client), rpc.ConnLog(log)) defer d.Wait() defer c.Close() client := testcapnp.Echoer{Client: c.Bootstrap(ctx)} diff --git a/rpc/errors.go b/rpc/errors.go index 1244798d..26bb1420 100644 --- a/rpc/errors.go +++ b/rpc/errors.go @@ -25,6 +25,22 @@ func (e Exception) Error() string { // An Abort is a hang-up by a remote vat. type Abort Exception +func copyAbort(m rpccapnp.Message) (Abort, error) { + ma, err := m.Abort() + if err != nil { + return Abort{}, err + } + msg, _, _ := capnp.NewMessage(capnp.SingleSegment(nil)) + if err := msg.SetRootPtr(ma.ToPtr()); err != nil { + return Abort{}, err + } + p, err := msg.RootPtr() + if err != nil { + return Abort{}, err + } + return Abort{rpccapnp.Exception{Struct: p.Struct()}}, nil +} + // Error returns the exception's reason. func (a Abort) Error() string { r, err := a.Reason() diff --git a/rpc/internal/refcount/refcount.go b/rpc/internal/refcount/refcount.go index 5e357fa8..6a17cd57 100644 --- a/rpc/internal/refcount/refcount.go +++ b/rpc/internal/refcount/refcount.go @@ -18,20 +18,30 @@ type RefCount struct { } // New creates a reference counter and the first client reference. -func New(c capnp.Client) (rc *RefCount, ref capnp.Client) { - rc = &RefCount{Client: c} - ref = rc.Ref() +func New(c capnp.Client) (rc *RefCount, ref1 capnp.Client) { + if rr, ok := c.(*Ref); ok { + return rr.rc, rr.rc.Ref() + } + rc = &RefCount{Client: c, refs: 1} + ref1 = rc.newRef() return } // Ref makes a new client reference. func (rc *RefCount) Ref() capnp.Client { - // TODO(light): what if someone calls Ref() after refs hits zero? rc.mu.Lock() + if rc.refs <= 0 { + rc.mu.Unlock() + return capnp.ErrorClient(errZeroRef) + } rc.refs++ rc.mu.Unlock() - r := &ref{rc: rc} - runtime.SetFinalizer(r, (*ref).Close) + return rc.newRef() +} + +func (rc *RefCount) newRef() *Ref { + r := &Ref{rc: rc} + runtime.SetFinalizer(r, (*Ref).Close) return r } @@ -39,11 +49,13 @@ func (rc *RefCount) call(cl *capnp.Call) capnp.Answer { // We lock here so that we can prevent the client from being closed // while we start the call. rc.mu.Lock() - defer rc.mu.Unlock() if rc.refs <= 0 { + rc.mu.Unlock() return capnp.ErrorAnswer(errClosed) } - return rc.Client.Call(cl) + ans := rc.Client.Call(cl) + rc.mu.Unlock() + return ans } // decref decreases the reference count by one, closing the Client if it reaches zero. @@ -67,22 +79,30 @@ func (rc *RefCount) decref() error { return nil } -var errClosed = errors.New("rpc: Close() called on closed client") +var ( + errZeroRef = errors.New("rpc: Ref() called on zeroed refcount") + errClosed = errors.New("rpc: Close() called on closed client") +) -type ref struct { +// A Ref is a single reference to a client wrapped by RefCount. +type Ref struct { rc *RefCount once sync.Once } -func (r *ref) Call(cl *capnp.Call) capnp.Answer { +// Call makes a call on the underlying client. +func (r *Ref) Call(cl *capnp.Call) capnp.Answer { return r.rc.call(cl) } -func (r *ref) WrappedClient() capnp.Client { +// Client returns the underlying client. +func (r *Ref) Client() capnp.Client { return r.rc.Client } -func (r *ref) Close() error { +// Close decrements the reference count. Close will be called on +// finalization (i.e. garbage collection). +func (r *Ref) Close() error { var err error closed := false r.once.Do(func() { diff --git a/rpc/introspect.go b/rpc/introspect.go index 6603f40b..993d966d 100644 --- a/rpc/introspect.go +++ b/rpc/introspect.go @@ -1,26 +1,96 @@ package rpc import ( - "log" - "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/internal/fulfiller" + "zombiezen.com/go/capnproto2/rpc/internal/refcount" rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc" ) -// nestedCall is called from the coordinate goroutine to make a client call. -// Since the client may point -func (c *Conn) nestedCall(client capnp.Client, cl *capnp.Call) capnp.Answer { - client = extractRPCClient(client) - ac := appCallFromClientCall(c, client, cl) - if ac != nil { - ans, err := c.handleCall(ac) - if err != nil { - log.Println("rpc: failed to handle call:", err) - return capnp.ErrorAnswer(err) +// While the code below looks repetitive, resist the urge to refactor. +// Each operation is distinct in assumptions it can make about +// particular cases, and there isn't a convenient type signature that +// fits all cases. + +// lockedCall is used to make a call to an arbitrary client while +// holding onto c.mu. Since the client could point back to c, naively +// calling c.Call could deadlock. +func (c *Conn) lockedCall(client capnp.Client, cl *capnp.Call) capnp.Answer { +dig: + for client := client; ; { + switch curr := client.(type) { + case *importClient: + if curr.conn != c { + // This doesn't use our conn's lock, so it is safe to call. + return curr.Call(cl) + } + return curr.lockedCall(cl) + case *fulfiller.EmbargoClient: + if ans := curr.TryQueue(cl); ans != nil { + return ans + } + client = curr.Client() + case *refcount.Ref: + client = curr.Client() + case *embargoClient: + if ans := curr.tryQueue(cl); ans != nil { + return ans + } + client = curr.client + case *queueClient: + if ans := curr.tryQueue(cl); ans != nil { + return ans + } + client = curr.client + case *localAnswerClient: + curr.a.mu.Lock() + if curr.a.done { + obj, err := curr.a.obj, curr.a.err + curr.a.mu.Unlock() + client = clientFromResolution(curr.transform, obj, err) + } else { + f := new(fulfiller.Fulfiller) + err := curr.a.queueCallLocked(cl, pcall{ + transform: curr.transform, + qcall: qcall{f: f}, + }) + curr.a.mu.Unlock() + if err != nil { + return capnp.ErrorAnswer(err) + } + return f + } + case *capnp.PipelineClient: + p := (*capnp.Pipeline)(curr) + ans := p.Answer() + transform := p.Transform() + if capnp.IsFixedAnswer(ans) { + s, err := ans.Struct() + client = clientFromResolution(transform, s.ToPtr(), err) + continue + } + switch ans := ans.(type) { + case *fulfiller.Fulfiller: + ap := ans.Peek() + if ap == nil { + break dig + } + s, err := ap.Struct() + client = clientFromResolution(transform, s.ToPtr(), err) + case *question: + if ans.conn != c { + // This doesn't use our conn's lock, so it is safe to call. + return ans.PipelineCall(transform, cl) + } + return ans.lockedPipelineCall(transform, cl) + default: + break dig + } + default: + break dig } - return ans } + // TODO(light): Add a CallOption that signals to bypass sync. // The above hack works in *most* cases. // @@ -30,124 +100,177 @@ func (c *Conn) nestedCall(client capnp.Client, cl *capnp.Call) capnp.Answer { // 2) Arbitrary implementations of Client may exist // 3) Local E-order must be preserved // - // #3 is the one that creates a goroutine send cycle, since - // application code must synchronize with the coordinate goroutine - // to preserve order of delivery. You can't really overcome this - // without breaking one of the first two constraints. + // #3 is the one that creates a deadlock, since application code must + // acquire the connection mutex to preserve order of delivery. You + // can't really overcome this without breaking one of the first two + // constraints. // // To avoid #2 as much as possible, implementing Client is discouraged // by several docs. return client.Call(cl) } +// descriptorForClient fills desc for client, adding it to the export +// table if necessary. The caller must be holding onto c.mu. func (c *Conn) descriptorForClient(desc rpccapnp.CapDescriptor, client capnp.Client) error { - client = extractRPCClient(client) - if ic, ok := client.(*importClient); ok && isImportFromConn(ic, c) { - desc.SetReceiverHosted(uint32(ic.id)) - return nil - } - if pc, ok := client.(*capnp.PipelineClient); ok { - p := (*capnp.Pipeline)(pc) - if q, ok := p.Answer().(*question); ok && isQuestionFromConn(q, c) { - a, err := desc.NewReceiverAnswer() - if err != nil { - return err - } - a.SetQuestionId(uint32(q.id)) - err = transformToPromisedAnswer(desc.Segment(), a, p.Transform()) - if err != nil { - return err +dig: + for client := client; ; { + switch ct := client.(type) { + case *importClient: + if ct.conn != c { + break dig } + desc.SetReceiverHosted(uint32(ct.id)) return nil + case *fulfiller.EmbargoClient: + client = ct.Client() + if client == nil { + break dig + } + case *refcount.Ref: + client = ct.Client() + case *embargoClient: + ct.mu.RLock() + ok := ct.isPassthrough() + ct.mu.RUnlock() + if !ok { + break dig + } + client = ct.client + case *queueClient: + ct.mu.RLock() + ok := ct.isPassthrough() + ct.mu.RUnlock() + if !ok { + break dig + } + client = ct.client + case *localAnswerClient: + ct.a.mu.RLock() + obj, err, done := ct.a.obj, ct.a.err, ct.a.done + ct.a.mu.RUnlock() + if !done { + break dig + } + client = clientFromResolution(ct.transform, obj, err) + case *capnp.PipelineClient: + p := (*capnp.Pipeline)(ct) + ans := p.Answer() + transform := p.Transform() + if capnp.IsFixedAnswer(ans) { + s, err := ans.Struct() + client = clientFromResolution(transform, s.ToPtr(), err) + continue + } + switch ans := ans.(type) { + case *fulfiller.Fulfiller: + ap := ans.Peek() + if ap == nil { + break dig + } + s, err := ap.Struct() + client = clientFromResolution(transform, s.ToPtr(), err) + case *question: + ans.mu.RLock() + obj, err, state := ans.obj, ans.err, ans.state + ans.mu.RUnlock() + if state != questionInProgress { + client = clientFromResolution(transform, obj, err) + continue + } + if ans.conn != c { + break dig + } + a, err := desc.NewReceiverAnswer() + if err != nil { + return err + } + a.SetQuestionId(uint32(ans.id)) + err = transformToPromisedAnswer(desc.Segment(), a, p.Transform()) + if err != nil { + return err + } + return nil + default: + break dig + } + default: + break dig } } - id := c.exports.add(client) - desc.SetSenderHosted(uint32(id)) - return nil -} -func appCallFromClientCall(c *Conn, client capnp.Client, cl *capnp.Call) *appCall { - if ic, ok := client.(*importClient); ok && isImportFromConn(ic, c) { - ac, _ := newAppImportCall(ic.id, cl) - return ac - } - if pc, ok := client.(*capnp.PipelineClient); ok { - p := (*capnp.Pipeline)(pc) - if q, ok := p.Answer().(*question); ok && isQuestionFromConn(q, c) { - ac, _ := newAppPipelineCall(q, p.Transform(), cl) - return ac - } - } + id := c.addExport(client) + desc.SetSenderHosted(uint32(id)) return nil } -// extractRPCClient attempts to extract the client that is the most -// meaningful for further processing of RPCs. For example, instead of a -// PipelineClient on a resolved answer, the client's capability. -func extractRPCClient(client capnp.Client) capnp.Client { +// isImport returns the underlying import if client represents an import +// or nil otherwise. +func isImport(client capnp.Client) *importClient { for { - switch c := client.(type) { + switch curr := client.(type) { case *importClient: - return c + return curr + case *fulfiller.EmbargoClient: + client = curr.Client() + if client == nil { + return nil + } + case *refcount.Ref: + client = curr.Client() + case *embargoClient: + curr.mu.RLock() + ok := curr.isPassthrough() + curr.mu.RUnlock() + if !ok { + return nil + } + client = curr.client + case *queueClient: + curr.mu.RLock() + ok := curr.isPassthrough() + curr.mu.RUnlock() + if !ok { + return nil + } + client = curr.client + case *localAnswerClient: + curr.a.mu.RLock() + obj, err, done := curr.a.obj, curr.a.err, curr.a.done + curr.a.mu.RUnlock() + if !done { + return nil + } + client = clientFromResolution(curr.transform, obj, err) case *capnp.PipelineClient: - p := (*capnp.Pipeline)(c) - next := extractRPCClientFromPipeline(p.Answer(), p.Transform()) - if next == nil { - return client - } - client = next - case clientWrapper: - wc := c.WrappedClient() - if wc == nil { - return client - } - client = wc + p := (*capnp.Pipeline)(curr) + ans := p.Answer() + if capnp.IsFixedAnswer(ans) { + s, err := ans.Struct() + client = clientFromResolution(p.Transform(), s.ToPtr(), err) + continue + } + switch ans := ans.(type) { + case *fulfiller.Fulfiller: + ap := ans.Peek() + if ap == nil { + return nil + } + s, err := ap.Struct() + client = clientFromResolution(p.Transform(), s.ToPtr(), err) + case *question: + ans.mu.RLock() + obj, err, state := ans.obj, ans.err, ans.state + ans.mu.RUnlock() + if state != questionResolved { + return nil + } + client = clientFromResolution(p.Transform(), obj, err) + default: + return nil + } default: - return client - } - } -} - -func extractRPCClientFromPipeline(ans capnp.Answer, transform []capnp.PipelineOp) capnp.Client { - if capnp.IsFixedAnswer(ans) { - s, err := ans.Struct() - return clientFromResolution(transform, s.ToPtr(), err) - } - switch a := ans.(type) { - case *fulfiller.Fulfiller: - ap := a.Peek() - if ap == nil { - // This can race, see TODO in nestedCall. - return nil - } - s, err := ap.Struct() - return clientFromResolution(transform, s.ToPtr(), err) - case *question: - _, obj, err, ok := a.peek() - if !ok { - // This can race, see TODO in nestedCall. return nil } - return clientFromResolution(transform, obj, err) - default: - return nil } } - -// clientWrapper is an interface for types that wrap clients. -// If WrappedClient returns a non-nil value, that means that a Call to -// the wrapper passes through to the returned client. -// TODO(light): this should probably be exported at some point. -type clientWrapper interface { - WrappedClient() capnp.Client -} - -func isQuestionFromConn(q *question, c *Conn) bool { - // TODO(light): ideally there would be better ways to check. - return q.manager == &c.manager -} - -func isImportFromConn(ic *importClient, c *Conn) bool { - // TODO(light): ideally there would be better ways to check. - return ic.manager == &c.manager -} diff --git a/rpc/issue3_test.go b/rpc/issue3_test.go index b49a6630..6036b9f3 100644 --- a/rpc/issue3_test.go +++ b/rpc/issue3_test.go @@ -17,9 +17,10 @@ func TestIssue3(t *testing.T) { if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p) + log := testLogger{t} + c := rpc.NewConn(p, rpc.ConnLog(log)) echoSrv := testcapnp.Echoer_ServerToClient(new(SideEffectEchoer)) - d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client)) + d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client), rpc.ConnLog(log)) defer d.Wait() defer c.Close() client := testcapnp.Echoer{Client: c.Bootstrap(ctx)} diff --git a/rpc/log.go b/rpc/log.go new file mode 100644 index 00000000..4df33325 --- /dev/null +++ b/rpc/log.go @@ -0,0 +1,49 @@ +package rpc + +import ( + "log" + + "golang.org/x/net/context" +) + +// A Logger records diagnostic information and errors that are not +// associated with a call. The arguments passed into a log call are +// interpreted like fmt.Printf. They should not be held onto past the +// call's return. +type Logger interface { + Infof(ctx context.Context, format string, args ...interface{}) + Errorf(ctx context.Context, format string, args ...interface{}) +} + +type defaultLogger struct{} + +func (defaultLogger) Infof(ctx context.Context, format string, args ...interface{}) { + log.Printf("rpc: "+format, args...) +} + +func (defaultLogger) Errorf(ctx context.Context, format string, args ...interface{}) { + log.Printf("rpc: "+format, args...) +} + +func (c *Conn) infof(format string, args ...interface{}) { + if c.log == nil { + return + } + c.log.Infof(c.bg, format, args...) +} + +func (c *Conn) errorf(format string, args ...interface{}) { + if c.log == nil { + return + } + c.log.Errorf(c.bg, format, args...) +} + +// ConnLog sets the connection's log to the given Logger, which may be +// nil to disable logging. By default, logs are sent to the standard +// log package. +func ConnLog(log Logger) ConnOption { + return ConnOption{func(c *connParams) { + c.log = log + }} +} diff --git a/rpc/manager.go b/rpc/manager.go deleted file mode 100644 index 3baad4a7..00000000 --- a/rpc/manager.go +++ /dev/null @@ -1,83 +0,0 @@ -package rpc - -import ( - "sync" - - "golang.org/x/net/context" -) - -// manager signals the running goroutines in a Conn. -// Since there is one manager per connection, it's also a way of -// identifying an object's origin. -type manager struct { - finish chan struct{} - wg sync.WaitGroup - ctx context.Context - - mu sync.RWMutex - done bool - e error -} - -func (m *manager) init() { - m.finish = make(chan struct{}) - var cancel context.CancelFunc - m.ctx, cancel = context.WithCancel(context.Background()) - go func() { - <-m.finish - cancel() - }() -} - -// context returns a context that is cancelled when the manager shuts down. -func (m *manager) context() context.Context { - return m.ctx -} - -// do starts a function in a new goroutine and will block shutdown -// until it has returned. If the manager has already started shutdown, -// then it is a no-op. -func (m *manager) do(f func()) { - m.mu.RLock() - done := m.done - if !done { - m.wg.Add(1) - } - m.mu.RUnlock() - if !done { - go func() { - defer m.wg.Done() - f() - }() - } -} - -// shutdown closes the finish channel and sets the error. The first -// call to shutdown returns true; subsequent calls are no-ops and return -// false. This will not wait for the manager's goroutines to finish. -func (m *manager) shutdown(e error) bool { - m.mu.Lock() - ok := !m.done - if ok { - close(m.finish) - m.done = true - m.e = e - } - m.mu.Unlock() - return ok -} - -// wait blocks until the manager is shut down and all of its goroutines -// are finished. -func (m *manager) wait() { - <-m.finish - m.wg.Wait() -} - -// err returns the error passed to shutdown. -func (m *manager) err() error { - m.mu.RLock() - e := m.e - m.mu.RUnlock() - return e -} diff --git a/rpc/promise_test.go b/rpc/promise_test.go index 73eee8c0..586630f5 100644 --- a/rpc/promise_test.go +++ b/rpc/promise_test.go @@ -18,10 +18,11 @@ func TestPromisedCapability(t *testing.T) { if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p) + log := testLogger{t} + c := rpc.NewConn(p, rpc.ConnLog(log)) delay := make(chan struct{}) echoSrv := testcapnp.Echoer_ServerToClient(&DelayEchoer{delay: delay}) - d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client)) + d := rpc.NewConn(q, rpc.MainInterface(echoSrv.Client), rpc.ConnLog(log)) defer d.Wait() defer c.Close() client := testcapnp.Echoer{Client: c.Bootstrap(ctx)} diff --git a/rpc/question.go b/rpc/question.go index c76cd4e2..0c720bea 100644 --- a/rpc/question.go +++ b/rpc/question.go @@ -10,70 +10,58 @@ import ( rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc" ) -type questionTable struct { - tab []*question - gen idgen - - manager *manager - calls chan<- *appCall - cancels chan<- *question -} - -// new creates a new question with an unassigned ID. -func (qt *questionTable) new(ctx context.Context, method *capnp.Method) *question { - id := questionID(qt.gen.next()) +// newQuestion creates a new question with an unassigned ID. +func (c *Conn) newQuestion(ctx context.Context, method *capnp.Method) *question { + id := questionID(c.questionID.next()) q := &question{ ctx: ctx, + conn: c, method: method, - manager: qt.manager, - calls: qt.calls, - cancels: qt.cancels, resolved: make(chan struct{}), id: id, } // TODO(light): populate paramCaps - if int(id) == len(qt.tab) { - qt.tab = append(qt.tab, q) + if int(id) == len(c.questions) { + c.questions = append(c.questions, q) } else { - qt.tab[id] = q + c.questions[id] = q } return q } -func (qt *questionTable) get(id questionID) *question { - var q *question - if int(id) < len(qt.tab) { - q = qt.tab[id] +func (c *Conn) findQuestion(id questionID) *question { + if int(id) >= len(c.questions) { + return nil } - return q + return c.questions[id] } -func (qt *questionTable) pop(id questionID) *question { - var q *question - if int(id) < len(qt.tab) { - q = qt.tab[id] - qt.tab[id] = nil - qt.gen.remove(uint32(id)) +func (c *Conn) popQuestion(id questionID) *question { + q := c.findQuestion(id) + if q == nil { + return nil } + c.questions[id] = nil + c.questionID.remove(uint32(id)) return q } type question struct { + id questionID ctx context.Context + conn *Conn method *capnp.Method // nil if this is bootstrap paramCaps []exportID - calls chan<- *appCall - cancels chan<- *question - manager *manager resolved chan struct{} - // Fields below are protected by mu. - mu sync.RWMutex - id questionID - obj capnp.Ptr - err error - state questionState + // Protected by conn.mu derived [][]capnp.PipelineOp + + // Fields below are protected by mu. + mu sync.RWMutex + obj capnp.Ptr + err error + state questionState } type questionState uint8 @@ -90,29 +78,28 @@ func (q *question) start() { go func() { select { case <-q.resolved: + // Resolved naturally, nothing to do. case <-q.ctx.Done(): select { - case q.cancels <- q: + case <-q.conn.mu: + if q.cancel(q.ctx.Err()) { + q.conn.sendMessage(newFinishMessage(nil, q.id, true /* release */)) + } + q.conn.mu.Unlock() case <-q.resolved: - case <-q.manager.finish: + case <-q.conn.bg.Done(): } - case <-q.manager.finish: + case <-q.conn.bg.Done(): // TODO(light): connection should reject all questions on shutdown. } }() } -// fulfill is called to resolve a question successfully and returns the disembargoes. -// It must be called from the coordinate goroutine. -func (q *question) fulfill(obj capnp.Ptr, makeDisembargo func() (embargoID, embargo)) []rpccapnp.Message { - q.mu.Lock() - if q.state != questionInProgress { - q.mu.Unlock() - panic("question.fulfill called more than once") - } +// fulfill is called to resolve a question successfully. +// The caller must be holding onto q.conn.mu. +func (q *question) fulfill(obj capnp.Ptr) { ctab := obj.Segment().Message().CapTable visited := make([]bool, len(ctab)) - msgs := make([]rpccapnp.Message, 0, len(q.derived)) for _, d := range q.derived { tgt, err := capnp.TransformPtr(obj, d) if err != nil { @@ -122,57 +109,78 @@ func (q *question) fulfill(obj capnp.Ptr, makeDisembargo func() (embargoID, emba if !in.IsValid() { continue } - client := extractRPCClient(in.Client()) - if ic, ok := client.(*importClient); ok && ic.manager == q.manager { + if ic := isImport(in.Client()); ic != nil && ic.conn == q.conn { // Imported from remote vat. Don't need to disembargo. continue } - if cn := in.Capability(); !visited[cn] { - id, e := makeDisembargo() - ctab[cn] = newEmbargoClient(q.manager, ctab[cn], e) - m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_senderLoopback, id) - dis, _ := m.Disembargo() - mt, _ := dis.NewTarget() - pa, _ := mt.NewPromisedAnswer() - pa.SetQuestionId(uint32(q.id)) - transformToPromisedAnswer(m.Segment(), pa, d) - mt.SetPromisedAnswer(pa) - msgs = append(msgs, m) - visited[cn] = true + cn := in.Capability() + if visited[cn] { + continue } + visited[cn] = true + id, e := q.conn.newEmbargo() + ctab[cn] = newEmbargoClient(ctab[cn], e, q.conn.bg.Done()) + m := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_senderLoopback, id) + dis, _ := m.Disembargo() + mt, _ := dis.NewTarget() + pa, _ := mt.NewPromisedAnswer() + pa.SetQuestionId(uint32(q.id)) + transformToPromisedAnswer(m.Segment(), pa, d) + mt.SetPromisedAnswer(pa) + + select { + case q.conn.out <- m: + case <-q.conn.bg.Done(): + // TODO(soon): perhaps just drop all embargoes in this case? + } + } + + q.mu.Lock() + if q.state != questionInProgress { + panic("question.fulfill called more than once") } q.obj, q.state = obj, questionResolved close(q.resolved) q.mu.Unlock() - return msgs } // reject is called to resolve a question with failure. -// It must be called from the coordinate goroutine. -func (q *question) reject(state questionState, err error) { +// The caller must be holding onto q.conn.mu. +func (q *question) reject(err error) { if err == nil { panic("question.reject called with nil") } q.mu.Lock() if q.state != questionInProgress { - q.mu.Unlock() panic("question.reject called more than once") } - q.err, q.state = err, state + q.err = err + q.state = questionResolved close(q.resolved) q.mu.Unlock() } -func (q *question) peek() (id questionID, obj capnp.Ptr, err error, ok bool) { - q.mu.RLock() - id, obj, err, ok = q.id, q.obj, q.err, q.state != questionInProgress - q.mu.RUnlock() - return +// cancel is called to resolve a question with cancellation. +// The caller must be holding onto q.conn.mu. +func (q *question) cancel(err error) bool { + if err == nil { + panic("question.cancel called with nil") + } + q.mu.Lock() + canceled := q.state == questionInProgress + if canceled { + q.err = err + q.state = questionCanceled + close(q.resolved) + } + q.mu.Unlock() + return canceled } +// addPromise records a returned capability as being used for a call. +// This is needed for determining embargoes upon resolution. The +// caller must be holding onto q.conn.mu. func (q *question) addPromise(transform []capnp.PipelineOp) { - q.mu.Lock() - defer q.mu.Unlock() for _, d := range q.derived { if transformsEqual(transform, d) { return @@ -194,33 +202,85 @@ func transformsEqual(t, u []capnp.PipelineOp) bool { } func (q *question) Struct() (capnp.Struct, error) { - <-q.resolved - _, obj, err, _ := q.peek() - return obj.Struct(), err + select { + case <-q.resolved: + case <-q.conn.bg.Done(): + return capnp.Struct{}, ErrConnClosed + } + q.mu.RLock() + s, err := q.obj.Struct(), q.err + q.mu.RUnlock() + return s, err } func (q *question) PipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer { - ac, achan := newAppPipelineCall(q, transform, ccall) select { - case q.calls <- ac: + case <-q.conn.mu: case <-ccall.Ctx.Done(): return capnp.ErrorAnswer(ccall.Ctx.Err()) - case <-q.manager.finish: - return capnp.ErrorAnswer(q.manager.err()) + case <-q.conn.bg.Done(): + return capnp.ErrorAnswer(ErrConnClosed) } + ans := q.lockedPipelineCall(transform, ccall) + q.conn.mu.Unlock() + return ans +} + +// lockedPipelineCall is equivalent to PipelineCall but assumes that the +// caller is already holding onto q.conn.mu. +func (q *question) lockedPipelineCall(transform []capnp.PipelineOp, ccall *capnp.Call) capnp.Answer { + if q.conn.findQuestion(q.id) != q { + // Question has been finished. The call should happen as if it is + // back in application code. + q.mu.RLock() + obj, err, state := q.obj, q.err, q.state + q.mu.RUnlock() + if state == questionInProgress { + panic("question popped but not done") + } + client := clientFromResolution(transform, obj, err) + return q.conn.lockedCall(client, ccall) + } + + pipeq := q.conn.newQuestion(ccall.Ctx, &ccall.Method) + msg := newMessage(nil) + msgCall, _ := msg.NewCall() + msgCall.SetQuestionId(uint32(pipeq.id)) + msgCall.SetInterfaceId(ccall.Method.InterfaceID) + msgCall.SetMethodId(ccall.Method.MethodID) + target, _ := msgCall.NewTarget() + a, _ := target.NewPromisedAnswer() + a.SetQuestionId(uint32(q.id)) + err := transformToPromisedAnswer(a.Segment(), a, transform) + if err != nil { + q.conn.popQuestion(pipeq.id) + return capnp.ErrorAnswer(err) + } + payload, _ := msgCall.NewParams() + if err := q.conn.fillParams(payload, ccall); err != nil { + q.conn.popQuestion(q.id) + return capnp.ErrorAnswer(err) + } + select { - case a := <-achan: - return a + case q.conn.out <- msg: case <-ccall.Ctx.Done(): + q.conn.popQuestion(pipeq.id) return capnp.ErrorAnswer(ccall.Ctx.Err()) - case <-q.manager.finish: - return capnp.ErrorAnswer(q.manager.err()) + case <-q.conn.bg.Done(): + q.conn.popQuestion(pipeq.id) + return capnp.ErrorAnswer(ErrConnClosed) } + q.addPromise(transform) + pipeq.start() + return pipeq } func (q *question) PipelineClose(transform []capnp.PipelineOp) error { <-q.resolved - _, obj, err, _ := q.peek() + q.mu.RLock() + obj, err := q.obj, q.err + q.mu.RUnlock() if err != nil { return err } @@ -238,21 +298,23 @@ func (q *question) PipelineClose(transform []capnp.PipelineOp) error { // embargoClient is a client that waits until an embargo signal is // received to deliver calls. type embargoClient struct { - manager *manager + cancel <-chan struct{} client capnp.Client embargo embargo - mu sync.RWMutex - q queue.Queue + mu sync.RWMutex + q queue.Queue + calls ecallList } -func newEmbargoClient(manager *manager, client capnp.Client, e embargo) *embargoClient { +func newEmbargoClient(client capnp.Client, e embargo, cancel <-chan struct{}) *embargoClient { ec := &embargoClient{ - manager: manager, client: client, embargo: e, + cancel: cancel, + calls: make(ecallList, callQueueSize), } - ec.q.Init(make(ecallList, callQueueSize), 0) + ec.q.Init(ec.calls, 0) go ec.flushQueue() return ec } @@ -263,26 +325,14 @@ func (ec *embargoClient) push(cl *capnp.Call) capnp.Answer { if err != nil { return capnp.ErrorAnswer(err) } - if ok := ec.q.Push(ecall{cl, f}); !ok { + i := ec.q.Push() + if i == -1 { return capnp.ErrorAnswer(errQueueFull) } + ec.calls[i] = ecall{cl, f} return f } -func (ec *embargoClient) peek() ecall { - if ec.q.Len() == 0 { - return ecall{} - } - return ec.q.Peek().(ecall) -} - -func (ec *embargoClient) pop() ecall { - if ec.q.Len() == 0 { - return ecall{} - } - return ec.q.Pop().(ecall) -} - func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer { // Fast path: queue is flushed. ec.mu.RLock() @@ -302,14 +352,15 @@ func (ec *embargoClient) Call(cl *capnp.Call) capnp.Answer { return ans } -func (ec *embargoClient) WrappedClient() capnp.Client { - ec.mu.RLock() - ok := ec.isPassthrough() - ec.mu.RUnlock() - if !ok { +func (ec *embargoClient) tryQueue(cl *capnp.Call) capnp.Answer { + ec.mu.Lock() + if ec.isPassthrough() { + ec.mu.Unlock() return nil } - return ec.client + ans := ec.push(cl) + ec.mu.Unlock() + return ans } func (ec *embargoClient) isPassthrough() bool { @@ -323,11 +374,8 @@ func (ec *embargoClient) isPassthrough() bool { func (ec *embargoClient) Close() error { ec.mu.Lock() - for { - c := ec.pop() - if c.call == nil { - break - } + for ; ec.q.Len() > 0; ec.q.Pop() { + c := ec.calls[ec.q.Front()] c.f.Reject(errQueueCallCancel) } ec.mu.Unlock() @@ -338,18 +386,31 @@ func (ec *embargoClient) Close() error { func (ec *embargoClient) flushQueue() { select { case <-ec.embargo: - case <-ec.manager.finish: + case <-ec.cancel: + ec.mu.Lock() + for ec.q.Len() > 0 { + ec.q.Pop() + } + ec.mu.Unlock() return } + var c ecall ec.mu.RLock() - c := ec.peek() + if i := ec.q.Front(); i != -1 { + c = ec.calls[i] + } ec.mu.RUnlock() for c.call != nil { ans := ec.client.Call(c.call) go joinFulfiller(c.f, ans) + ec.mu.Lock() - ec.pop() - c = ec.peek() + ec.q.Pop() + if i := ec.q.Front(); i != -1 { + c = ec.calls[i] + } else { + c = ecall{} + } ec.mu.Unlock() } } @@ -365,14 +426,6 @@ func (el ecallList) Len() int { return len(el) } -func (el ecallList) At(i int) interface{} { - return el[i] -} - -func (el ecallList) Set(i int, x interface{}) { - if x == nil { - el[i] = ecall{} - } else { - el[i] = x.(ecall) - } +func (el ecallList) Clear(i int) { + el[i] = ecall{} } diff --git a/rpc/release_test.go b/rpc/release_test.go index 17a74401..083f2ebb 100644 --- a/rpc/release_test.go +++ b/rpc/release_test.go @@ -20,9 +20,10 @@ func TestRelease(t *testing.T) { if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p) + log := testLogger{t} + c := rpc.NewConn(p, rpc.ConnLog(log)) hf := new(HandleFactory) - d := rpc.NewConn(q, rpc.MainInterface(testcapnp.HandleFactory_ServerToClient(hf).Client)) + d := rpc.NewConn(q, rpc.MainInterface(testcapnp.HandleFactory_ServerToClient(hf).Client), rpc.ConnLog(log)) defer d.Wait() defer c.Close() client := testcapnp.HandleFactory{Client: c.Bootstrap(ctx)} @@ -52,9 +53,10 @@ func TestReleaseAlias(t *testing.T) { if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p) + log := testLogger{t} + c := rpc.NewConn(p, rpc.ConnLog(log)) hf := singletonHandleFactory() - d := rpc.NewConn(q, rpc.MainInterface(testcapnp.HandleFactory_ServerToClient(hf).Client)) + d := rpc.NewConn(q, rpc.MainInterface(testcapnp.HandleFactory_ServerToClient(hf).Client), rpc.ConnLog(log)) defer d.Wait() defer c.Close() client := testcapnp.HandleFactory{Client: c.Bootstrap(ctx)} diff --git a/rpc/rpc.go b/rpc/rpc.go index 5130bdcf..e2f5d5fe 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -4,7 +4,7 @@ package rpc // import "zombiezen.com/go/capnproto2/rpc" import ( "fmt" "io" - "log" + "sync" "golang.org/x/net/context" "zombiezen.com/go/capnproto2" @@ -16,27 +16,36 @@ import ( // It is safe to use from multiple goroutines. type Conn struct { transport Transport + log Logger mainFunc func(context.Context) (capnp.Client, error) mainCloser io.Closer - manager manager - in <-chan rpccapnp.Message - out chan<- rpccapnp.Message - calls chan *appCall - cancels <-chan *question - releases chan *outgoingRelease - returns <-chan *outgoingReturn - queueCloses <-chan queueClientClose + out chan rpccapnp.Message - // Mutable state. Only accessed from coordinate goroutine. - questions questionTable - answers answerTable - imports importTable - exports exportTable - embargoes embargoTable + bg context.Context + bgCancel context.CancelFunc + workers sync.WaitGroup + + // Mutable state protected by stateMu + stateMu sync.RWMutex + stateCond sync.Cond // broadcasts when state changes + state connState + closeErr error + + // Mutable state protected by mu + mu chanMutex + questions []*question + questionID idgen + exports []*export + exportID idgen + embargoes []chan<- struct{} + embargoID idgen + answers map[answerID]*answer + imports map[importID]*impent } type connParams struct { + log Logger mainFunc func(context.Context) (capnp.Client, error) mainCloser io.Closer sendBufferSize int @@ -82,197 +91,284 @@ func SendBufferSize(numMsgs int) ConnOption { // NewConn creates a new connection that communicates on c. // Closing the connection will cause c to be closed. func NewConn(t Transport, options ...ConnOption) *Conn { - conn := &Conn{transport: t} p := &connParams{ + log: defaultLogger{}, sendBufferSize: 4, } - conn.manager.init() for _, o := range options { o.f(p) } - conn.mainFunc = p.mainFunc - conn.mainCloser = p.mainCloser - i := make(chan rpccapnp.Message) - o := make(chan rpccapnp.Message, p.sendBufferSize) - calls := make(chan *appCall) - cancels := make(chan *question) - rets := make(chan *outgoingReturn) - queueCloses := make(chan queueClientClose) - releases := make(chan *outgoingRelease) - conn.in, conn.out = i, o - conn.calls = calls - conn.cancels = cancels - conn.releases = releases - conn.returns = rets - conn.queueCloses = queueCloses - conn.questions.manager = &conn.manager - conn.questions.calls = calls - conn.questions.cancels = cancels - conn.answers.manager = &conn.manager - conn.answers.out = o - conn.answers.returns = rets - conn.answers.queueCloses = queueCloses - conn.imports.manager = &conn.manager - conn.imports.calls = calls - conn.imports.releases = releases - - conn.manager.do(conn.coordinate) - conn.manager.do(func() { - dispatchRecv(&conn.manager, t, i) - }) - conn.manager.do(func() { - dispatchSend(&conn.manager, t, o) - }) + + conn := &Conn{ + transport: t, + out: make(chan rpccapnp.Message, p.sendBufferSize), + mainFunc: p.mainFunc, + mainCloser: p.mainCloser, + log: p.log, + mu: newChanMutex(), + } + conn.stateCond.L = conn.stateMu.RLocker() + conn.bg, conn.bgCancel = context.WithCancel(context.Background()) + conn.workers.Add(2) + go conn.dispatchRecv() + go conn.dispatchSend() return conn } // Wait waits until the connection is closed or aborted by the remote vat. // Wait will always return an error, usually ErrConnClosed or of type Abort. func (c *Conn) Wait() error { - c.manager.wait() - return c.manager.err() + c.stateMu.RLock() + for c.state != connDead { + c.stateCond.Wait() + } + err := c.closeErr + c.stateMu.RUnlock() + return err } -// Close closes the connection. +// Close closes the connection and the underlying transport. func (c *Conn) Close() error { - // Stop helper goroutines. - if !c.manager.shutdown(ErrConnClosed) { + c.stateMu.Lock() + alive := c.state == connAlive + if alive { + c.bgCancel() + c.closeErr = ErrConnClosed + c.state = connDying + c.stateCond.Broadcast() + } + c.stateMu.Unlock() + if !alive { return ErrConnClosed } - c.manager.wait() - // Hang up. - // TODO(light): add timeout to write. - ctx := context.Background() - n := newAbortMessage(nil, errShutdown) - werr := c.transport.SendMessage(ctx, n) - cerr := c.transport.Close() - if werr != nil { - return werr - } - if cerr != nil { - return cerr + c.teardown(newAbortMessage(nil, errShutdown)) + c.stateMu.RLock() + err := c.closeErr + c.stateMu.RUnlock() + if err != ErrConnClosed { + return err } return nil } -// coordinate runs in its own goroutine. -// It manages dispatching received messages and calls. -func (c *Conn) coordinate() { - for { - select { - case m := <-c.in: - c.handleMessage(m) - case ac := <-c.calls: - ans, err := c.handleCall(ac) - if err == nil { - ac.achan <- ans - } else { - log.Println("rpc: failed to handle call:", err) - ac.achan <- capnp.ErrorAnswer(err) - } - case q := <-c.cancels: - c.handleCancel(q) - case r := <-c.releases: - r.echan <- c.handleRelease(r.id) - case r := <-c.returns: - c.handleReturn(r) - case qcc := <-c.queueCloses: - c.handleQueueClose(qcc) - case <-c.manager.finish: - c.exports.releaseAll() - if c.mainCloser != nil { - if err := c.mainCloser.Close(); err != nil { - log.Println("rpc: closing main interface:", err) - } - } - return +// shutdown cancels the background context and sets closeErr to e. +// No abort message will be sent on the transport. After shutdown +// returns, the Conn will be in the dying or dead state. Calling +// shutdown on a dying or dead Conn is a no-op. +func (c *Conn) shutdown(e error) { + c.stateMu.Lock() + if c.state == connAlive { + c.bgCancel() + c.closeErr = e + c.state = connDying + c.stateCond.Broadcast() + go c.teardown(rpccapnp.Message{}) + } + c.stateMu.Unlock() +} + +// abort cancels the background context, sets closeErr to e, and queues +// an abort message to be sent on the transport before the Conn goes +// into the dead state. After abort returns, the Conn will be in the +// dying or dead state. Calling abort on a dying or dead Conn is a +// no-op. +func (c *Conn) abort(e error) { + c.stateMu.Lock() + if c.state == connAlive { + c.bgCancel() + c.closeErr = e + c.state = connDying + c.stateCond.Broadcast() + go c.teardown(newAbortMessage(nil, e)) + } + c.stateMu.Unlock() +} + +// teardown moves the connection from the dying to the dead state. +func (c *Conn) teardown(abort rpccapnp.Message) { + c.workers.Wait() + + c.mu.Lock() + for _, q := range c.questions { + if q != nil { + q.cancel(ErrConnClosed) + } + } + c.questions = nil + exps := c.exports + c.exports = nil + c.embargoes = nil + for _, a := range c.answers { + a.cancel() + } + c.answers = nil + c.imports = nil + c.mainFunc = nil + c.mu.Unlock() + + if c.mainCloser != nil { + if err := c.mainCloser.Close(); err != nil { + c.errorf("closing main interface: %v", err) } + c.mainCloser = nil } + // Closing an export may try to lock the Conn, so run it outside + // critical section. + for id, e := range exps { + if e == nil { + continue + } + if err := e.client.Close(); err != nil { + c.errorf("export %v close: %v", id, err) + } + } + + var werr error + if abort.IsValid() { + werr = c.transport.SendMessage(context.Background(), abort) + } + cerr := c.transport.Close() + + c.stateMu.Lock() + if c.closeErr == ErrConnClosed { + if cerr != nil { + c.closeErr = cerr + } else if werr != nil { + c.closeErr = werr + } + } + c.state = connDead + c.stateCond.Broadcast() + c.stateMu.Unlock() } // Bootstrap returns the receiver's main interface. func (c *Conn) Bootstrap(ctx context.Context) capnp.Client { // TODO(light): Create a client that returns immediately. - ac, achan := newAppBootstrapCall(ctx) select { - case c.calls <- ac: - select { - case a := <-achan: - return capnp.NewPipeline(a).Client() - case <-ctx.Done(): - return capnp.ErrorClient(ctx.Err()) - case <-c.manager.finish: - return capnp.ErrorClient(c.manager.err()) - } + case <-c.mu: + // Locked. + defer c.mu.Unlock() + case <-ctx.Done(): + return capnp.ErrorClient(ctx.Err()) + case <-c.bg.Done(): + return capnp.ErrorClient(ErrConnClosed) + } + + q := c.newQuestion(ctx, nil /* method */) + msg := newMessage(nil) + boot, _ := msg.NewBootstrap() + boot.SetQuestionId(uint32(q.id)) + // The mutex must be held while sending so that call order is preserved. + // Worst case, this blocks until a message is sent on the transport. + // Common case, this just adds to the channel queue. + select { + case c.out <- msg: + q.start() + return capnp.NewPipeline(q).Client() case <-ctx.Done(): + c.popQuestion(q.id) return capnp.ErrorClient(ctx.Err()) - case <-c.manager.finish: - return capnp.ErrorClient(c.manager.err()) + case <-c.bg.Done(): + c.popQuestion(q.id) + return capnp.ErrorClient(ErrConnClosed) } } -// handleMessage is run in the coordinate goroutine. +// handleMessage is run from the receive goroutine to process a single +// message. m cannot be held onto past the return of handleMessage, and +// c.mu is not held at the start of handleMessage. func (c *Conn) handleMessage(m rpccapnp.Message) { switch m.Which() { case rpccapnp.Message_Which_unimplemented: // no-op for now to avoid feedback loop case rpccapnp.Message_Which_abort: - ma, err := m.Abort() + a, err := copyAbort(m) if err != nil { - log.Println("rpc: decode abort:", err) + c.errorf("decode abort: %v", err) // Keep going, since we're trying to abort anyway. } - a := Abort{ma} - log.Print(a) - c.manager.shutdown(a) + c.infof("abort: %v", a) + c.shutdown(a) case rpccapnp.Message_Which_return: - if err := c.handleReturnMessage(m); err != nil { - log.Println("rpc: handle return:", err) + m = copyRPCMessage(m) + c.mu.Lock() + err := c.handleReturnMessage(m) + c.mu.Unlock() + + if err != nil { + c.errorf("handle return: %v", err) } case rpccapnp.Message_Which_finish: - // TODO(light): what if answers never had this ID? - // TODO(light): return if cancelled mfin, err := m.Finish() if err != nil { - log.Println("rpc: decode finish:", err) + c.errorf("decode finish: %v", err) return } id := answerID(mfin.QuestionId()) - a := c.answers.pop(id) + + c.mu.Lock() + a := c.popAnswer(id) + if a == nil { + c.mu.Unlock() + c.errorf("finish called for unknown answer %d", id) + return + } a.cancel() if mfin.ReleaseResultCaps() { - c.exports.releaseList(a.resultCaps) + for _, id := range a.resultCaps { + c.releaseExport(id, 1) + } } + c.mu.Unlock() case rpccapnp.Message_Which_bootstrap: boot, err := m.Bootstrap() if err != nil { - log.Println("rpc: decode bootstrap:", err) + c.errorf("decode bootstrap: %v", err) return } id := answerID(boot.QuestionId()) - if err := c.handleBootstrapMessage(id); err != nil { - log.Println("rpc: handle bootstrap:", err) + + c.mu.Lock() + err = c.handleBootstrapMessage(id) + c.mu.Unlock() + + if err != nil { + c.errorf("handle bootstrap: %v", err) } case rpccapnp.Message_Which_call: - if err := c.handleCallMessage(m); err != nil { - log.Println("rpc: handle call:", err) + m = copyRPCMessage(m) + c.mu.Lock() + err := c.handleCallMessage(m) + c.mu.Unlock() + + if err != nil { + c.errorf("handle call: %v", err) } case rpccapnp.Message_Which_release: rel, err := m.Release() if err != nil { - log.Println("rpc: decode release:", err) + c.errorf("decode release: %v", err) return } id := exportID(rel.Id()) refs := int(rel.ReferenceCount()) - c.exports.release(id, refs) + + c.mu.Lock() + c.releaseExport(id, refs) + c.mu.Unlock() case rpccapnp.Message_Which_disembargo: - if err := c.handleDisembargoMessage(m); err != nil { + m = copyRPCMessage(m) + c.mu.Lock() + err := c.handleDisembargoMessage(m) + c.mu.Unlock() + + if err != nil { // Any failure in a disembargo is a protocol violation. c.abort(err) } default: - log.Printf("rpc: received unimplemented message, which = %v", m.Which()) + c.infof("received unimplemented message, which = %v", m.Which()) um := newUnimplementedMessage(nil, m) c.sendMessage(um) } @@ -284,89 +380,22 @@ func newUnimplementedMessage(buf []byte, m rpccapnp.Message) rpccapnp.Message { return n } -// handleCall is run from the coordinate goroutine to send a question to a remote vat. -func (c *Conn) handleCall(ac *appCall) (capnp.Answer, error) { - if ac.kind == appPipelineCall && c.questions.get(ac.question.id) != ac.question { - // Question has been finished. The call should happen as if it is - // back in application code. - _, obj, err, done := ac.question.peek() - if !done { - panic("question popped but not done") - } - client := clientFromResolution(ac.transform, obj, err) - return c.nestedCall(client, ac.Call), nil - } - q := c.questions.new(ac.Ctx, &ac.Method) - if ac.kind == appPipelineCall { - pq := c.questions.get(ac.question.id) - pq.addPromise(ac.transform) - } - msg, err := c.newCallMessage(nil, q.id, ac) +func (c *Conn) fillParams(payload rpccapnp.Payload, cl *capnp.Call) error { + params, err := cl.PlaceParams(payload.Segment()) if err != nil { - return nil, err - } - select { - case c.out <- msg: - q.start() - return q, nil - case <-ac.Ctx.Done(): - c.questions.pop(q.id) - return nil, ac.Ctx.Err() - case <-c.manager.finish: - c.questions.pop(q.id) - return nil, c.manager.err() - } -} - -func (c *Conn) newCallMessage(buf []byte, id questionID, ac *appCall) (rpccapnp.Message, error) { - msg := newMessage(buf) - - if ac.kind == appBootstrapCall { - boot, _ := msg.NewBootstrap() - boot.SetQuestionId(uint32(id)) - return msg, nil - } - - msgCall, _ := msg.NewCall() - msgCall.SetQuestionId(uint32(id)) - msgCall.SetInterfaceId(ac.Method.InterfaceID) - msgCall.SetMethodId(ac.Method.MethodID) - - target, _ := msgCall.NewTarget() - switch ac.kind { - case appImportCall: - target.SetImportedCap(uint32(ac.importID)) - case appPipelineCall: - a, err := target.NewPromisedAnswer() - if err != nil { - return rpccapnp.Message{}, err - } - a.SetQuestionId(uint32(ac.question.id)) - err = transformToPromisedAnswer(a.Segment(), a, ac.transform) - if err != nil { - return rpccapnp.Message{}, err - } - default: - panic("unknown call type") - } - - payload, _ := msgCall.NewParams() - params, err := ac.PlaceParams(payload.Segment()) - if err != nil { - return rpccapnp.Message{}, err + return err } if err := payload.SetContent(params); err != nil { - return rpccapnp.Message{}, err + return err } ctab, err := c.makeCapTable(payload.Segment()) if err != nil { - return rpccapnp.Message{}, err + return err } if err := payload.SetCapTable(ctab); err != nil { - return rpccapnp.Message{}, err + return err } - - return msg, nil + return nil } func transformToPromisedAnswer(s *capnp.Segment, answer rpccapnp.PromisedAnswer, transform []capnp.PipelineOp) error { @@ -381,49 +410,28 @@ func transformToPromisedAnswer(s *capnp.Segment, answer rpccapnp.PromisedAnswer, return err } -// handleCancel is called from the coordinate goroutine to handle a question's cancelation. -func (c *Conn) handleCancel(q *question) { - q.reject(questionCanceled, q.ctx.Err()) - // TODO(light): timeout? - msg := newFinishMessage(nil, q.id, true /* release */) - c.sendMessage(msg) -} - -// handleRelease is run in the coordinate goroutine to handle an import -// client's release request. It sends a release message for an import ID. -func (c *Conn) handleRelease(id importID) error { - i := c.imports.pop(id) - if i == 0 { - return nil - } - // TODO(light): deadline to close? - msg := newMessage(nil) - mr, err := msg.NewRelease() - if err != nil { - return err - } - mr.SetId(uint32(id)) - mr.SetReferenceCount(uint32(i)) - return c.sendMessage(msg) -} - -// handleReturnMessage is run in the coordinate goroutine. +// handleReturnMessage is to handle a received return message. +// The caller is holding onto c.mu. func (c *Conn) handleReturnMessage(m rpccapnp.Message) error { ret, err := m.Return() if err != nil { return err } id := questionID(ret.AnswerId()) - q := c.questions.pop(id) + q := c.popQuestion(id) if q == nil { return fmt.Errorf("received return for unknown question id=%d", id) } if ret.ReleaseParamCaps() { - c.exports.releaseList(q.paramCaps) + for _, id := range q.paramCaps { + c.releaseExport(id, 1) + } } - if _, _, _, resolved := q.peek(); resolved { - // If the question was already resolved, that means it was canceled, - // in which case we already sent the finish message. + q.mu.RLock() + qstate := q.state + q.mu.RUnlock() + if qstate == questionCanceled { + // We already sent the finish message. return nil } releaseResultCaps := true @@ -446,13 +454,7 @@ func (c *Conn) handleReturnMessage(m rpccapnp.Message) error { if err != nil { return err } - disembargoes := q.fulfill(content, c.embargoes.new) - for _, d := range disembargoes { - if err := c.sendMessage(d); err != nil { - // shutdown - return nil - } - } + q.fulfill(content) case rpccapnp.Return_Which_exception: exc, err := ret.Exception() if err != nil { @@ -467,15 +469,15 @@ func (c *Conn) handleReturnMessage(m rpccapnp.Message) error { } else { e = bootstrapError{e} } - q.reject(questionResolved, e) + q.reject(e) case rpccapnp.Return_Which_canceled: err := &questionError{ id: id, method: q.method, err: fmt.Errorf("receiver reported canceled"), } - log.Println(err) - q.reject(questionResolved, err) + c.errorf("%v", err) + q.reject(err) return nil default: um := newUnimplementedMessage(nil, m) @@ -510,7 +512,7 @@ func (c *Conn) populateMessageCapTable(payload rpccapnp.Payload) error { msg.AddCap(nil) case rpccapnp.CapDescriptor_Which_senderHosted: id := importID(desc.SenderHosted()) - client := c.imports.addRef(id) + client := c.addImport(id) msg.AddCap(client) case rpccapnp.CapDescriptor_Which_senderPromise: // We do the same thing as senderHosted, above. @kentonv suggested this on @@ -524,11 +526,11 @@ func (c *Conn) populateMessageCapTable(payload rpccapnp.Payload) error { // > messages sent to it will uselessly round-trip over the network // > rather than being delivered locally. id := importID(desc.SenderPromise()) - client := c.imports.addRef(id) + client := c.addImport(id) msg.AddCap(client) case rpccapnp.CapDescriptor_Which_receiverHosted: id := exportID(desc.ReceiverHosted()) - e := c.exports.get(id) + e := c.findExport(id) if e == nil { return fmt.Errorf("rpc: capability table references unknown export ID %d", id) } @@ -539,7 +541,7 @@ func (c *Conn) populateMessageCapTable(payload rpccapnp.Payload) error { return err } id := answerID(recvAns.QuestionId()) - a := c.answers.get(id) + a := c.answers[id] if a == nil { return fmt.Errorf("rpc: capability table references unknown answer ID %d", id) } @@ -550,7 +552,7 @@ func (c *Conn) populateMessageCapTable(payload rpccapnp.Payload) error { transform := promisedAnswerOpsToTransform(recvTransform) msg.AddCap(a.pipelineClient(transform)) default: - log.Println("rpc: unknown capability type", desc.Which()) + c.errorf("unknown capability type %v", desc.Which()) return errUnimplemented } } @@ -575,12 +577,12 @@ func (c *Conn) makeCapTable(s *capnp.Segment) (rpccapnp.CapDescriptor_List, erro return t, nil } -// handleBootstrapMessage is run in the coordinate goroutine to handle -// a received bootstrap message. +// handleBootstrapMessage handles a received bootstrap message. +// The caller holds onto c.mu. func (c *Conn) handleBootstrapMessage(id answerID) error { ctx, cancel := c.newContext() defer cancel() - a := c.answers.insert(id, cancel) + a := c.insertAnswer(id, cancel) if a == nil { // Question ID reused, error out. retmsg := newReturnMessage(nil, id) @@ -588,25 +590,12 @@ func (c *Conn) handleBootstrapMessage(id answerID) error { setReturnException(r, errQuestionReused) return c.sendMessage(retmsg) } - msgs := make([]rpccapnp.Message, 0, 1) if c.mainFunc == nil { - msgs = a.reject(msgs, errNoMainInterface) - for _, m := range msgs { - if err := c.sendMessage(m); err != nil { - return err - } - } - return nil + return a.reject(errNoMainInterface) } main, err := c.mainFunc(ctx) if err != nil { - msgs = a.reject(msgs, errNoMainInterface) - for _, m := range msgs { - if err := c.sendMessage(m); err != nil { - return err - } - } - return nil + return a.reject(errNoMainInterface) } m := &capnp.Message{ Arena: capnp.SingleSegment(make([]byte, 0)), @@ -614,18 +603,11 @@ func (c *Conn) handleBootstrapMessage(id answerID) error { } s, _ := m.Segment(0) in := capnp.NewInterface(s, 0) - msgs = a.fulfill(msgs, in.ToPtr(), c.makeCapTable) - for _, m := range msgs { - if err := c.sendMessage(m); err != nil { - return err - } - } - return nil + return a.fulfill(in.ToPtr()) } -// handleCallMessage is run in the coordinate goroutine to handle a -// received call message. It mutates the capability table of its -// parameter. +// handleCallMessage handles a received call message. It mutates the +// capability table of its parameter. The caller holds onto c.mu. func (c *Conn) handleCallMessage(m rpccapnp.Message) error { mcall, err := m.Call() if err != nil { @@ -652,7 +634,7 @@ func (c *Conn) handleCallMessage(m rpccapnp.Message) error { } ctx, cancel := c.newContext() id := answerID(mcall.QuestionId()) - a := c.answers.insert(id, cancel) + a := c.insertAnswer(id, cancel) if a == nil { // Question ID reused, error out. c.abort(errQuestionReused) @@ -672,13 +654,7 @@ func (c *Conn) handleCallMessage(m rpccapnp.Message) error { Params: paramContent.Struct(), } if err := c.routeCallMessage(a, mt, cl); err != nil { - msgs := a.reject(nil, err) - for _, m := range msgs { - if err := c.sendMessage(m); err != nil { - return err - } - } - return nil + return a.reject(err) } return nil } @@ -687,11 +663,11 @@ func (c *Conn) routeCallMessage(result *answer, mt rpccapnp.MessageTarget, cl *c switch mt.Which() { case rpccapnp.MessageTarget_Which_importedCap: id := exportID(mt.ImportedCap()) - e := c.exports.get(id) + e := c.findExport(id) if e == nil { return errBadTarget } - answer := c.nestedCall(e.client, cl) + answer := c.lockedCall(e.client, cl) go joinAnswer(result, answer) case rpccapnp.MessageTarget_Which_promisedAnswer: mpromise, err := mt.PromisedAnswer() @@ -703,7 +679,7 @@ func (c *Conn) routeCallMessage(result *answer, mt rpccapnp.MessageTarget, cl *c // Grandfather paradox. return errBadTarget } - pa := c.answers.get(id) + pa := c.answers[id] if pa == nil { return errBadTarget } @@ -712,15 +688,18 @@ func (c *Conn) routeCallMessage(result *answer, mt rpccapnp.MessageTarget, cl *c return err } transform := promisedAnswerOpsToTransform(mtrans) - if obj, err, done := pa.peek(); done { + pa.mu.Lock() + if pa.done { + obj, err := pa.obj, pa.err + pa.mu.Unlock() client := clientFromResolution(transform, obj, err) - answer := c.nestedCall(client, cl) + answer := c.lockedCall(client, cl) go joinAnswer(result, answer) - return nil - } - if err := pa.queueCall(result, transform, cl); err != nil { - return err + } else { + err = pa.queueCallLocked(cl, pcall{transform: transform, qcall: qcall{a: result}}) + pa.mu.Unlock() } + return err default: panic("unreachable") } @@ -747,7 +726,7 @@ func (c *Conn) handleDisembargoMessage(msg rpccapnp.Message) error { return err } aid := answerID(dpa.QuestionId()) - a := c.answers.get(aid) + a := c.answers[aid] if a == nil { return errDisembargoMissingAnswer } @@ -771,7 +750,7 @@ func (c *Conn) handleDisembargoMessage(msg rpccapnp.Message) error { } case rpccapnp.Disembargo_context_Which_receiverLoopback: id := embargoID(d.Context().ReceiverLoopback()) - c.embargoes.disembargo(id) + c.disembargo(id) default: um := newUnimplementedMessage(nil, msg) c.sendMessage(um) @@ -796,7 +775,7 @@ func newDisembargoMessage(buf []byte, which rpccapnp.Disembargo_context_Which, i // newContext creates a new context for a local call. func (c *Conn) newContext() (context.Context, context.CancelFunc) { - return context.WithCancel(c.manager.context()) + return context.WithCancel(c.bg) } func promisedAnswerOpsToTransform(list rpccapnp.PromisedAnswer_Op_List) []capnp.PipelineOp { @@ -816,44 +795,6 @@ func promisedAnswerOpsToTransform(list rpccapnp.PromisedAnswer_Op_List) []capnp. return transform } -// handleReturn is called from the coordinate goroutine to send an -// answer's return value over the transport. -func (c *Conn) handleReturn(r *outgoingReturn) { - msgs := make([]rpccapnp.Message, 0, 32) - if r.err == nil { - msgs = r.a.fulfill(msgs, r.obj, c.makeCapTable) - } else { - msgs = r.a.reject(msgs, r.err) - } - for _, m := range msgs { - if err := c.sendMessage(m); err != nil { - return - } - } -} - -func (c *Conn) handleQueueClose(qcc queueClientClose) { - msgs := make([]rpccapnp.Message, 0, 32) - msgs = qcc.qc.rejectQueue(msgs) - close(qcc.done) - for _, m := range msgs { - if err := c.sendMessage(m); err != nil { - return - } - } -} - -func (c *Conn) sendMessage(msg rpccapnp.Message) error { - return sendMessage(&c.manager, c.out, msg) -} - -func (c *Conn) abort(err error) { - // TODO(light): ensure that the message is sent before shutting down? - am := newAbortMessage(nil, err) - c.sendMessage(am) - c.manager.shutdown(err) -} - func newAbortMessage(buf []byte, err error) rpccapnp.Message { n := newMessage(buf) e, _ := n.NewAbort() @@ -905,54 +846,37 @@ func newMessage(buf []byte) rpccapnp.Message { return m } -// An appCall is a message sent to the coordinate goroutine to indicate -// that the application code wants to initiate an outgoing call. -type appCall struct { - *capnp.Call - kind int - achan chan<- capnp.Answer +// chanMutex is a mutex backed by a channel so that it can be used in a select. +// A receive is a lock and a send is an unlock. +type chanMutex chan struct{} - // Import calls - importID importID +type connState int - // Pipeline calls - question *question - transform []capnp.PipelineOp -} +const ( + connAlive connState = iota + connDying + connDead +) -func newAppImportCall(id importID, cl *capnp.Call) (*appCall, <-chan capnp.Answer) { - achan := make(chan capnp.Answer, 1) - return &appCall{ - Call: cl, - kind: appImportCall, - achan: achan, - importID: id, - }, achan +func newChanMutex() chanMutex { + mu := make(chanMutex, 1) + mu <- struct{}{} + return mu } -func newAppPipelineCall(q *question, transform []capnp.PipelineOp, cl *capnp.Call) (*appCall, <-chan capnp.Answer) { - achan := make(chan capnp.Answer, 1) - return &appCall{ - Call: cl, - kind: appPipelineCall, - achan: achan, - question: q, - transform: transform, - }, achan +func (mu chanMutex) Lock() { + <-mu } -func newAppBootstrapCall(ctx context.Context) (*appCall, <-chan capnp.Answer) { - achan := make(chan capnp.Answer, 1) - return &appCall{ - Call: &capnp.Call{Ctx: ctx}, - kind: appBootstrapCall, - achan: achan, - }, achan +func (mu chanMutex) TryLock(ctx context.Context) error { + select { + case <-mu: + return nil + case <-ctx.Done(): + return ctx.Err() + } } -// Kinds of application calls. -const ( - appImportCall = iota - appPipelineCall - appBootstrapCall -) +func (mu chanMutex) Unlock() { + mu <- struct{}{} +} diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go index 10de9d05..ce99f8b0 100644 --- a/rpc/rpc_test.go +++ b/rpc/rpc_test.go @@ -22,18 +22,35 @@ const ( var logMessages = flag.Bool("logmessages", false, "whether to log the transport in tests. Messages are always from client to server.") -func newTestConn(t *testing.T, options ...rpc.ConnOption) (*rpc.Conn, rpc.Transport) { +type testLogger struct { + t interface { + Logf(format string, args ...interface{}) + } +} + +func (l testLogger) Infof(ctx context.Context, format string, args ...interface{}) { + l.t.Logf("conn log: "+format, args...) +} + +func (l testLogger) Errorf(ctx context.Context, format string, args ...interface{}) { + l.t.Logf("conn log: "+format, args...) +} + +func newUnpairedConn(t *testing.T, options ...rpc.ConnOption) (*rpc.Conn, rpc.Transport) { p, q := pipetransport.New() if *logMessages { p = logtransport.New(nil, p) } - c := rpc.NewConn(p, options...) + newopts := make([]rpc.ConnOption, len(options), len(options)+1) + copy(newopts, options) + newopts = append(newopts, rpc.ConnLog(testLogger{t})) + c := rpc.NewConn(p, newopts...) return c, q } func TestBootstrap(t *testing.T) { ctx := context.Background() - conn, p := newTestConn(t) + conn, p := newUnpairedConn(t) defer conn.Close() defer p.Close() @@ -78,7 +95,7 @@ func TestBootstrapFulfilledSenderPromise(t *testing.T) { func testBootstrapFulfilled(t *testing.T, resultIsPromise bool) { ctx := context.Background() - conn, p := newTestConn(t) + conn, p := newUnpairedConn(t) defer conn.Close() defer p.Close() @@ -170,7 +187,7 @@ func bootstrapAndFulfill(t *testing.T, ctx context.Context, conn *rpc.Conn, p rp func TestCallOnPromisedAnswer(t *testing.T) { ctx := context.Background() - conn, p := newTestConn(t) + conn, p := newUnpairedConn(t) defer conn.Close() defer p.Close() client, bootstrapID := readBootstrap(t, ctx, conn, p) @@ -250,7 +267,7 @@ func TestCallOnExportId_BootstrapIsHosted(t *testing.T) { func testCallOnExportId(t *testing.T, bootstrapIsPromise bool) { ctx := context.Background() - conn, p := newTestConn(t) + conn, p := newUnpairedConn(t) defer conn.Close() defer p.Close() client := bootstrapAndFulfill(t, ctx, conn, p, bootstrapIsPromise) @@ -311,7 +328,7 @@ func testCallOnExportId(t *testing.T, bootstrapIsPromise bool) { func TestMainInterface(t *testing.T) { main := mockClient() - conn, p := newTestConn(t, rpc.MainInterface(main)) + conn, p := newUnpairedConn(t, rpc.MainInterface(main)) defer conn.Close() defer p.Close() @@ -393,7 +410,7 @@ func TestReceiveCallOnPromisedAnswer(t *testing.T) { } return result, nil }) - conn, p := newTestConn(t, rpc.MainInterface(main)) + conn, p := newUnpairedConn(t, rpc.MainInterface(main)) defer conn.Close() defer p.Close() _, bootqID := bootstrapRoundtrip(t, p) @@ -476,7 +493,7 @@ func TestReceiveCallOnExport(t *testing.T) { } return result, nil }) - conn, p := newTestConn(t, rpc.MainInterface(main)) + conn, p := newUnpairedConn(t, rpc.MainInterface(main)) defer conn.Close() defer p.Close() importID := sendBootstrapAndFinish(t, p) diff --git a/rpc/tables.go b/rpc/tables.go index 187713ae..8f93a49d 100644 --- a/rpc/tables.go +++ b/rpc/tables.go @@ -1,7 +1,7 @@ package rpc import ( - "log" + "errors" "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/rpc/internal/refcount" @@ -22,97 +22,119 @@ type impent struct { refs int } -type importTable struct { - tab map[importID]*impent - manager *manager - calls chan<- *appCall - releases chan<- *outgoingRelease -} - -// addRef increases the counter of the times the import ID was sent to this vat. -func (it *importTable) addRef(id importID) capnp.Client { - if it.tab == nil { - it.tab = make(map[importID]*impent) - } - ent := it.tab[id] - var ref capnp.Client - if ent == nil { - client := &importClient{ - id: id, - manager: it.manager, - calls: it.calls, - releases: it.releases, - } - var rc *refcount.RefCount - rc, ref = refcount.New(client) - ent = &impent{rc: rc, refs: 0} - it.tab[id] = ent +// addImport increases the counter of the times the import ID was sent to this vat. +func (c *Conn) addImport(id importID) capnp.Client { + if c.imports == nil { + c.imports = make(map[importID]*impent) + } else if ent := c.imports[id]; ent != nil { + ent.refs++ + return ent.rc.Ref() } - if ref == nil { - ref = ent.rc.Ref() + client := &importClient{ + id: id, + conn: c, } - ent.refs++ + rc, ref := refcount.New(client) + c.imports[id] = &impent{rc: rc, refs: 1} return ref } -// pop removes the import ID and returns the number of times the import ID was sent to this vat. -func (it *importTable) pop(id importID) (refs int) { - if it.tab != nil { - if ent := it.tab[id]; ent != nil { - refs = ent.refs - } - delete(it.tab, id) +// popImport removes the import ID and returns the number of times the import ID was sent to this vat. +func (c *Conn) popImport(id importID) (refs int) { + if c.imports == nil { + return 0 } - return -} - -// An outgoingRelease is a message sent to the coordinate goroutine to -// indicate that an import should be released. -type outgoingRelease struct { - id importID - echan chan<- error + ent := c.imports[id] + if ent == nil { + return 0 + } + refs = ent.refs + delete(c.imports, id) + return refs } // An importClient implements capnp.Client for a remote capability. type importClient struct { - id importID - manager *manager - calls chan<- *appCall - releases chan<- *outgoingRelease + id importID + conn *Conn + closed bool // protected by conn.mu } func (ic *importClient) Call(cl *capnp.Call) capnp.Answer { - // TODO(light): don't send if closed. - ac, achan := newAppImportCall(ic.id, cl) select { - case ic.calls <- ac: - select { - case a := <-achan: - return a - case <-ic.manager.finish: - return capnp.ErrorAnswer(ic.manager.err()) - } - case <-ic.manager.finish: - return capnp.ErrorAnswer(ic.manager.err()) + case <-ic.conn.mu: + case <-cl.Ctx.Done(): + return capnp.ErrorAnswer(cl.Ctx.Err()) + case <-ic.conn.bg.Done(): + return capnp.ErrorAnswer(ErrConnClosed) + } + ans := ic.lockedCall(cl) + ic.conn.mu.Unlock() + return ans +} + +// lockedCall is equivalent to Call but assumes that the caller is +// already holding onto ic.conn.mu. +func (ic *importClient) lockedCall(cl *capnp.Call) capnp.Answer { + if ic.closed { + return capnp.ErrorAnswer(errImportClosed) + } + + q := ic.conn.newQuestion(cl.Ctx, &cl.Method) + msg := newMessage(nil) + msgCall, _ := msg.NewCall() + msgCall.SetQuestionId(uint32(q.id)) + msgCall.SetInterfaceId(cl.Method.InterfaceID) + msgCall.SetMethodId(cl.Method.MethodID) + target, _ := msgCall.NewTarget() + target.SetImportedCap(uint32(ic.id)) + payload, _ := msgCall.NewParams() + if err := ic.conn.fillParams(payload, cl); err != nil { + ic.conn.popQuestion(q.id) + return capnp.ErrorAnswer(err) + } + + select { + case ic.conn.out <- msg: + case <-cl.Ctx.Done(): + ic.conn.popQuestion(q.id) + return capnp.ErrorAnswer(cl.Ctx.Err()) + case <-ic.conn.bg.Done(): + ic.conn.popQuestion(q.id) + return capnp.ErrorAnswer(ErrConnClosed) } + q.start() + return q } func (ic *importClient) Close() error { - echan := make(chan error, 1) - r := &outgoingRelease{ - id: ic.id, - echan: echan, + ic.conn.mu.Lock() + closed := ic.closed + var i int + if !closed { + i = ic.conn.popImport(ic.id) + ic.closed = true + } + ic.conn.mu.Unlock() + + if closed { + return errImportClosed + } + if i == 0 { + return nil } + msg := newMessage(nil) + mr, err := msg.NewRelease() + if err != nil { + return err + } + mr.SetId(uint32(ic.id)) + mr.SetReferenceCount(uint32(i)) select { - case ic.releases <- r: - select { - case err := <-echan: - return err - case <-ic.manager.finish: - return ic.manager.err() - } - case <-ic.manager.finish: - return ic.manager.err() + case ic.conn.out <- msg: + return nil + case <-ic.conn.bg.Done(): + return ErrConnClosed } } @@ -124,47 +146,38 @@ type export struct { refs int } -type exportTable struct { - tab []*export - gen idgen -} - -func (et *exportTable) get(id exportID) *export { - var e *export - if int(id) < len(et.tab) { - e = et.tab[id] +func (c *Conn) findExport(id exportID) *export { + if int(id) >= len(c.exports) { + return nil } - return e + return c.exports[id] } -// add ensures that the client is present in the table, returning its ID. +// addExport ensures that the client is present in the table, returning its ID. // If the client is already in the table, the previous ID is returned. -func (et *exportTable) add(client capnp.Client) exportID { - for i, e := range et.tab { +func (c *Conn) addExport(client capnp.Client) exportID { + for i, e := range c.exports { if e != nil && e.client == client { e.refs++ return exportID(i) } } - id := exportID(et.gen.next()) + id := exportID(c.exportID.next()) export := &export{ id: id, client: client, refs: 1, } - if int(id) == len(et.tab) { - et.tab = append(et.tab, export) + if int(id) == len(c.exports) { + c.exports = append(c.exports, export) } else { - et.tab[id] = export + c.exports[id] = export } return id } -func (et *exportTable) release(id exportID, refs int) { - if int(id) >= len(et.tab) { - return - } - e := et.tab[id] +func (c *Conn) releaseExport(id exportID, refs int) { + e := c.findExport(id) if e == nil { return } @@ -173,64 +186,39 @@ func (et *exportTable) release(id exportID, refs int) { return } if e.refs < 0 { - log.Printf("rpc: warning: export %v has negative refcount (%d)", id, e.refs) + c.errorf("warning: export %v has negative refcount (%d)", id, e.refs) } if err := e.client.Close(); err != nil { - log.Printf("rpc: export %v close: %v", id, err) + c.errorf("export %v close: %v", id, err) } - et.tab[id] = nil - et.gen.remove(uint32(id)) -} - -func (et *exportTable) releaseAll() { - for id, e := range et.tab { - if e == nil { - continue - } - if err := e.client.Close(); err != nil { - log.Printf("rpc: export %v close: %v", id, err) - } - et.tab[id] = nil - et.gen.remove(uint32(id)) - } -} - -// releaseList decrements the reference count of each of the given exports by 1. -func (et *exportTable) releaseList(ids []exportID) { - for _, id := range ids { - et.release(id, 1) - } -} - -type embargoTable struct { - tab []chan<- struct{} - gen idgen + c.exports[id] = nil + c.exportID.remove(uint32(id)) } type embargo <-chan struct{} -func (et *embargoTable) new() (embargoID, embargo) { - id := embargoID(et.gen.next()) +func (c *Conn) newEmbargo() (embargoID, embargo) { + id := embargoID(c.embargoID.next()) e := make(chan struct{}) - if int(id) == len(et.tab) { - et.tab = append(et.tab, e) + if int(id) == len(c.embargoes) { + c.embargoes = append(c.embargoes, e) } else { - et.tab[id] = e + c.embargoes[id] = e } return id, e } -func (et *embargoTable) disembargo(id embargoID) { - if int(id) >= len(et.tab) { +func (c *Conn) disembargo(id embargoID) { + if int(id) >= len(c.embargoes) { return } - e := et.tab[id] + e := c.embargoes[id] if e == nil { return } close(e) - et.tab[id] = nil - et.gen.remove(uint32(id)) + c.embargoes[id] = nil + c.embargoID.remove(uint32(id)) } // idgen returns a sequence of monotonically increasing IDs with @@ -255,3 +243,5 @@ func (gen *idgen) next() uint32 { func (gen *idgen) remove(i uint32) { gen.free = append(gen.free, i) } + +var errImportClosed = errors.New("rpc: call on closed import") diff --git a/rpc/transport.go b/rpc/transport.go index cd15a720..c318703a 100644 --- a/rpc/transport.go +++ b/rpc/transport.go @@ -3,7 +3,6 @@ package rpc import ( "bytes" "io" - "log" "time" "golang.org/x/net/context" @@ -97,46 +96,44 @@ type writeDeadlineSetter interface { } // dispatchSend runs in its own goroutine and sends messages on a transport. -func dispatchSend(m *manager, transport Transport, msgs <-chan rpccapnp.Message) { +func (c *Conn) dispatchSend() { + defer c.workers.Done() for { select { - case msg := <-msgs: - err := transport.SendMessage(m.context(), msg) + case msg := <-c.out: + err := c.transport.SendMessage(c.bg, msg) if err != nil { - log.Printf("rpc: writing %v: %v", msg.Which(), err) + c.errorf("writing %v: %v", msg.Which(), err) } - case <-m.finish: + case <-c.bg.Done(): return } } } -// sendMessage sends a message to out to be sent. It returns an error -// if the manager finished. -func sendMessage(m *manager, out chan<- rpccapnp.Message, msg rpccapnp.Message) error { +// sendMessage enqueues a message to be sent or returns an error if the +// connection is shut down before the message is queued. It is safe to +// call from multiple goroutines and does not require holding c.mu. +func (c *Conn) sendMessage(msg rpccapnp.Message) error { select { - case out <- msg: + case c.out <- msg: return nil - case <-m.finish: - return m.err() + case <-c.bg.Done(): + return ErrConnClosed } } // dispatchRecv runs in its own goroutine and receives messages from a transport. -func dispatchRecv(m *manager, transport Transport, msgs chan<- rpccapnp.Message) { +func (c *Conn) dispatchRecv() { + defer c.workers.Done() for { - msg, err := transport.RecvMessage(m.context()) - if err != nil { - if isTemporaryError(err) { - log.Println("rpc: read temporary error:", err) - continue - } - m.shutdown(err) - return - } - select { - case msgs <- copyRPCMessage(msg): - case <-m.finish: + msg, err := c.transport.RecvMessage(c.bg) + if err == nil { + c.handleMessage(msg) + } else if isTemporaryError(err) { + c.errorf("read temporary error: %v", err) + } else { + c.shutdown(err) return } }