Skip to content

Commit

Permalink
Debug-only runtime tracking of funcs running on correct goroutines.
Browse files Browse the repository at this point in the history
  • Loading branch information
bradfitz committed Nov 11, 2014
1 parent d43f8f3 commit 6fe7631
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 2 deletions.
160 changes: 160 additions & 0 deletions gotrack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE

// Defensive debug-only utility to track that functions run on the
// goroutine that they're supposed to.

package http2

import (
"bytes"
"errors"
"fmt"
"runtime"
"strconv"
"sync"
)

var DebugGoroutines = false

type goroutineLock uint64

func newGoroutineLock() goroutineLock {
return goroutineLock(curGoroutineID())
}

func (g goroutineLock) check() {
if !DebugGoroutines {
return
}
if curGoroutineID() != uint64(g) {
panic("running on the wrong goroutine")
}
}

var goroutineSpace = []byte("goroutine ")

func curGoroutineID() uint64 {
bp := littleBuf.Get().(*[]byte)
defer littleBuf.Put(bp)
b := *bp
b = b[:runtime.Stack(b, false)]
// Parse the 4707 otu of "goroutine 4707 ["
b = bytes.TrimPrefix(b, goroutineSpace)
i := bytes.IndexByte(b, ' ')
if i < 0 {
panic(fmt.Sprintf("No space found in %q", b))
}
b = b[:i]
n, err := parseUintBytes(b, 10, 64)
if err != nil {
panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
}
return n
}

var littleBuf = sync.Pool{
New: func() interface{} {
buf := make([]byte, 64)
return &buf
},
}

// parseUintBytes is like strconv.ParseUint, but using a []byte.
func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
var cutoff, maxVal uint64

if bitSize == 0 {
bitSize = int(strconv.IntSize)
}

s0 := s
switch {
case len(s) < 1:
err = strconv.ErrSyntax
goto Error

case 2 <= base && base <= 36:
// valid base; nothing to do

case base == 0:
// Look for octal, hex prefix.
switch {
case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
base = 16
s = s[2:]
if len(s) < 1 {
err = strconv.ErrSyntax
goto Error
}
case s[0] == '0':
base = 8
default:
base = 10
}

default:
err = errors.New("invalid base " + strconv.Itoa(base))
goto Error
}

n = 0
cutoff = cutoff64(base)
maxVal = 1<<uint(bitSize) - 1

for i := 0; i < len(s); i++ {
var v byte
d := s[i]
switch {
case '0' <= d && d <= '9':
v = d - '0'
case 'a' <= d && d <= 'z':
v = d - 'a' + 10
case 'A' <= d && d <= 'Z':
v = d - 'A' + 10
default:
n = 0
err = strconv.ErrSyntax
goto Error
}
if int(v) >= base {
n = 0
err = strconv.ErrSyntax
goto Error
}

if n >= cutoff {
// n*base overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n *= uint64(base)

n1 := n + uint64(v)
if n1 < n || n1 > maxVal {
// n+v overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n = n1
}

return n, nil

Error:
return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
}

// Return the first number n such that n*base >= 1<<64.
func cutoff64(base int) uint64 {
if base < 2 {
return 0
}
return (1<<64-1)/uint64(base) + 1
}
33 changes: 33 additions & 0 deletions gotrack_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE

package http2

import (
"fmt"
"strings"
"testing"
)

func TestGoroutineLock(t *testing.T) {
DebugGoroutines = true
g := newGoroutineLock()
g.check()

sawPanic := make(chan interface{})
go func() {
defer func() { sawPanic <- recover() }()
g.check() // should panic
}()
e := <-sawPanic
if e == nil {
t.Fatal("did not see panic from check in other goroutine")
}
if !strings.Contains(fmt.Sprint(e), "wrong goroutine") {
t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e)
}
}
21 changes: 20 additions & 1 deletion http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
writeHeaderCh: make(chan headerWriteReq), // must not be buffered
doneServing: make(chan struct{}),
maxWriteFrameSize: initialMaxFrameSize,
serveG: newGoroutineLock(),
}
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
Expand All @@ -89,6 +90,7 @@ type frameAndProcessed struct {
}

type serverConn struct {
// Immutable:
hs *http.Server
conn net.Conn
handler http.Handler
Expand All @@ -97,6 +99,9 @@ type serverConn struct {
readFrameCh chan frameAndProcessed // written by serverConn.readFrames
readFrameErrCh chan error
writeHeaderCh chan headerWriteReq // must not be buffered
serveG goroutineLock // used to verify funcs are on serve()

// Everything following is owned by the serve loop; use serveG.check()

maxStreamID uint32 // max ever seen
streams map[uint32]*stream
Expand Down Expand Up @@ -139,6 +144,7 @@ type stream struct {
}

func (sc *serverConn) state(streamID uint32) streamState {
sc.serveG.check()
// http://http2.github.io/http2-spec/#rfc.section.5.1
if st, ok := sc.streams[streamID]; ok {
return st.state
Expand Down Expand Up @@ -170,6 +176,7 @@ func (sc *serverConn) logf(format string, args ...interface{}) {
}

func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
sc.serveG.check()
switch {
case !validHeader(f.Name):
sc.invalidHeader = true
Expand Down Expand Up @@ -199,6 +206,7 @@ func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) {
}

func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check()
// TODO: use a sync.Pool instead of putting the cache on *serverConn?
cv, ok := sc.canonHeader[v]
if !ok {
Expand All @@ -208,6 +216,8 @@ func (sc *serverConn) canonicalHeader(v string) string {
return cv
}

// readFrames is the loop that reads incoming frames.
// It's run on its own goroutine.
func (sc *serverConn) readFrames() {
processed := make(chan struct{}, 1)
for {
Expand All @@ -223,6 +233,7 @@ func (sc *serverConn) readFrames() {
}

func (sc *serverConn) serve() {
sc.serveG.check()
defer sc.conn.Close()
defer close(sc.doneServing)

Expand Down Expand Up @@ -316,6 +327,7 @@ func (sc *serverConn) serve() {
}

func (sc *serverConn) resetStreamInLoop(se StreamError) error {
sc.serveG.check()
if err := sc.framer.WriteRSTStream(se.streamID, uint32(se.code)); err != nil {
return err
}
Expand All @@ -324,6 +336,8 @@ func (sc *serverConn) resetStreamInLoop(se StreamError) error {
}

func (sc *serverConn) processFrame(f Frame) error {
sc.serveG.check()

if s := sc.curHeaderStreamID; s != 0 {
if cf, ok := f.(*ContinuationFrame); !ok {
return ConnectionError(ErrCodeProtocol)
Expand All @@ -346,13 +360,15 @@ func (sc *serverConn) processFrame(f Frame) error {
}

func (sc *serverConn) processSettings(f *SettingsFrame) error {
sc.serveG.check()
f.ForeachSetting(func(s Setting) {
log.Printf(" setting %s = %v", s.ID, s.Val)
})
return nil
}

func (sc *serverConn) processHeaders(f *HeadersFrame) error {
sc.serveG.check()
id := f.Header().StreamID

// http://http2.github.io/http2-spec/#rfc.section.5.1.1
Expand Down Expand Up @@ -386,6 +402,7 @@ func (sc *serverConn) processHeaders(f *HeadersFrame) error {
}

func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
sc.serveG.check()
id := f.Header().StreamID
if sc.curHeaderStreamID != id {
return ConnectionError(ErrCodeProtocol)
Expand All @@ -394,6 +411,7 @@ func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
}

func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, end bool) error {
sc.serveG.check()
if _, err := sc.hpackDecoder.Write(frag); err != nil {
// TODO: convert to stream error I assume?
return err
Expand Down Expand Up @@ -423,6 +441,7 @@ func (sc *serverConn) processHeaderBlockFragment(streamID uint32, frag []byte, e
return nil
}

// Run on its own goroutine.
func (sc *serverConn) startHandler(streamID uint32, bodyOpen bool, method, path, scheme, authority string, reqHeader http.Header) {
var tlsState *tls.ConnectionState // make this non-nil if https
if scheme == "https" {
Expand Down Expand Up @@ -486,8 +505,8 @@ func (sc *serverConn) writeHeader(req headerWriteReq) {
sc.writeHeaderCh <- req
}

// called from serverConn.serve loop.
func (sc *serverConn) writeHeaderInLoop(req headerWriteReq) error {
sc.serveG.check()
sc.headerWriteBuf.Reset()
// TODO: remove this strconv
sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(req.httpResCode)})
Expand Down
5 changes: 4 additions & 1 deletion http2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ import (
"github.com/bradfitz/http2/hpack"
)

func init() { VerboseLogs = true }
func init() {
VerboseLogs = true
DebugGoroutines = true
}

type serverTester struct {
cc net.Conn // client conn
Expand Down

0 comments on commit 6fe7631

Please sign in to comment.