Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: safe copy ctx #927

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
17 changes: 13 additions & 4 deletions pkg/app/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ var defaultFormValue = func(ctx *RequestContext, key string) []byte {
}

type RequestContext struct {
// isCopy shows that whether it is a copy through ctx.Copy().
isCopy bool

conn network.Conn
Request protocol.Request
Response protocol.Response
Expand All @@ -188,7 +191,8 @@ type RequestContext struct {
// This mutex protect Keys map.
mu sync.RWMutex

// Keys is a key/value pair exclusively for the context of each request.
// Deprecated: DO NOT CALL IT DIRECTLY.
// Use .Get/.Set/.ForEachKey APIs to manipulate it.
Keys map[string]interface{}

hijackHandler HijackHandler
Expand Down Expand Up @@ -688,15 +692,19 @@ func getRedirectStatusCode(statusCode int) int {
// Copy returns a copy of the current context that can be safely used outside
// the request's scope.
//
// NOTE: If you want to pass requestContext to a goroutine, call this method
// NOTE1: If you want to pass requestContext to a goroutine, call this method
// to get a copy of requestContext.
// NOTE2: The copy of the ctx is READ-ONLY, any writing scenario should be passed
// back to the origin ctx, and process in the origin ctx.
// NOTE3: The copy of the ctx will be marked as copy, which means it is safe for
// concurrent read only not write.
func (ctx *RequestContext) Copy() *RequestContext {
cp := &RequestContext{
conn: ctx.conn,
Params: ctx.Params,
}
ctx.Request.CopyTo(&cp.Request)
ctx.Response.CopyTo(&cp.Response)
ctx.Request.CopyToAndMark(&cp.Request)
ctx.Response.CopyToAndMark(&cp.Response)
cp.index = rConsts.AbortIndex
cp.handlers = nil
cp.Keys = map[string]interface{}{}
Expand Down Expand Up @@ -749,6 +757,7 @@ func (ctx *RequestContext) ResetWithoutConn() {
ctx.index = -1
ctx.fullPath = ""
ctx.Keys = nil
ctx.isCopy = false

if ctx.finished != nil {
close(ctx.finished)
Expand Down
2 changes: 1 addition & 1 deletion pkg/common/test/assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func Nil(t testing.TB, data interface{}) {
func NotNil(t testing.TB, data interface{}) {
t.Helper()
if data == nil {
return
t.Fatalf("assertion failed, unexpected: %v, expected: not nil", data)
}

if reflect.ValueOf(data).IsNil() {
Expand Down
18 changes: 18 additions & 0 deletions pkg/common/test/mock/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,24 @@ type OneTimeConn struct {
*Conn
}

func (o *OneTimeConn) Read(b []byte) (int, error) {
length := len(b)
if o.Len() < length {
length = o.Len()
}
if length == 0 {
length = 1
}

buf, err := o.Peek(length)
if err != nil {
return 0, err
}
n := copy(b, buf)
o.Skip(n)
return n, err
}

func (o *OneTimeConn) Peek(n int) ([]byte, error) {
if o.isRead {
return nil, io.EOF
Expand Down
8 changes: 4 additions & 4 deletions pkg/common/utils/ioutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestIoutilCopyBufferWithIoReaderFrom(t *testing.T) {
assert.DeepEqual(t, true, ok)
written, err := CopyBuffer(ioReaderFrom, src, buf)
assert.DeepEqual(t, written, int64(0))
assert.NotNil(t, err)
assert.Nil(t, err)
assert.DeepEqual(t, []byte(nil), writeBuffer.Bytes())
}

Expand Down Expand Up @@ -165,7 +165,7 @@ func TestIoutilCopyBufferWithNilBuffer(t *testing.T) {
written, err := CopyBuffer(dst, src, nil)

assert.DeepEqual(t, written, srcLen)
assert.NotNil(t, err)
assert.Nil(t, err)
assert.DeepEqual(t, []byte(str), writeBuffer.Bytes())
}

Expand All @@ -179,7 +179,7 @@ func TestIoutilCopyBufferWithNilBufferAndIoLimitedReader(t *testing.T) {
written, err := CopyBuffer(dst, &reader, nil)

assert.DeepEqual(t, written, srcLen)
assert.NotNil(t, err)
assert.Nil(t, err)
assert.DeepEqual(t, []byte(str), writeBuffer.Bytes())

// test l.N < 1
Expand All @@ -192,7 +192,7 @@ func TestIoutilCopyBufferWithNilBufferAndIoLimitedReader(t *testing.T) {
written, err = CopyBuffer(dst, &reader, nil)

assert.DeepEqual(t, written, srcLen)
assert.NotNil(t, err)
assert.Nil(t, err)
assert.DeepEqual(t, []byte(str), writeBuffer.Bytes())
}

Expand Down
17 changes: 17 additions & 0 deletions pkg/protocol/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ type argsScanner struct {
type Args struct {
noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used

// isCopy shows that whether it is a copy through ctx.Copy().
// Other APIs such as CopyTo do not need to handle this.
isCopy bool

args []argsKV
buf []byte
}
Expand All @@ -75,6 +79,16 @@ func (a *Args) Set(key, value string) {
// Reset clears query args.
func (a *Args) Reset() {
a.args = a.args[:0]

// a.ParseBytes() will trigger reset, which is a process during lazy load(read scenario), so do not reset this flag.
// Args is not a recycle object so the risk of dirty data is relatively low even though we do not reset this field.
// a.isCopy = false
}

// CopyToAndMark copies all args to dst and mark the dst args as a copy.
func (a *Args) CopyToAndMark(dst *Args) {
dst.isCopy = true
a.CopyTo(dst)
}

// CopyTo copies all args to dst.
Expand Down Expand Up @@ -343,6 +357,9 @@ func peekArgStrExists(h []argsKV, k string) (string, bool) {
//
// The returned value is valid until the next call to Args methods.
func (a *Args) QueryString() []byte {
if a.isCopy {
return a.AppendBytes(nil)
}
a.buf = a.AppendBytes(a.buf[:0])
return a.buf
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/protocol/args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,16 @@ func TestArgsVisitAll(t *testing.T) {
})
assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s)
}

func TestCopyArgs_QueryString(t *testing.T) {
a := Args{}
a.Add("foo", "bar")
assert.DeepEqual(t, "foo=bar", string(a.QueryString()))
assert.DeepEqual(t, "foo=bar", string(a.buf))

a.buf = nil
a.isCopy = true

assert.DeepEqual(t, "foo=bar", string(a.QueryString()))
assert.Nil(t, a.buf)
}
53 changes: 49 additions & 4 deletions pkg/protocol/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ var (
type RequestHeader struct {
noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used

// isCopy shows that whether it is a copy through ctx.Copy().
// Other APIs such as CopyTo do not need to handle this.
isCopy bool

disableNormalizing bool
connectionClose bool
noDefaultContentType bool
Expand Down Expand Up @@ -110,6 +114,10 @@ func (h *RequestHeader) SetRawHeaders(r []byte) {
type ResponseHeader struct {
noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used

// isCopy shows that whether it is a copy through ctx.Copy().
// Other APIs such as CopyTo do not need to handle this.
isCopy bool

disableNormalizing bool
connectionClose bool
noDefaultContentType bool
Expand Down Expand Up @@ -221,11 +229,15 @@ func (h *ResponseHeader) GetHeaders() []argsKV {
func (h *ResponseHeader) Reset() {
h.disableNormalizing = false
h.Trailer().disableNormalizing = false
h.noDefaultContentType = false
h.noDefaultDate = false
h.ResetSkipNormalize()
}

// CopyToAndMark copies all the headers to dst and mark the dst header as a copy.
func (h *ResponseHeader) CopyToAndMark(dst *ResponseHeader) {
dst.isCopy = true
h.CopyTo(dst)
}

// CopyTo copies all the headers to dst.
func (h *ResponseHeader) CopyTo(dst *ResponseHeader) {
dst.Reset()
Expand Down Expand Up @@ -443,6 +455,9 @@ func (h *RequestHeader) AppendBytes(dst []byte) []byte {
//
// The returned representation is valid until the next call to RequestHeader methods.
func (h *RequestHeader) Header() []byte {
if h.isCopy {
return h.AppendBytes(nil)
}
h.bufKV.value = h.AppendBytes(h.bufKV.value[:0])
return h.bufKV.value
}
Expand Down Expand Up @@ -506,6 +521,10 @@ func checkWriteHeaderCode(code int) {

func (h *ResponseHeader) ResetSkipNormalize() {
h.protocol = ""

h.isCopy = false
h.noDefaultContentType = false
h.noDefaultDate = false
h.connectionClose = false

h.statusCode = 0
Expand Down Expand Up @@ -631,6 +650,9 @@ func (h *ResponseHeader) DelBytes(key []byte) {
//
// The returned value is valid until the next call to ResponseHeader methods.
func (h *ResponseHeader) Header() []byte {
if h.isCopy {
return h.AppendBytes(nil)
}
h.bufKV.value = h.AppendBytes(h.bufKV.value[:0])
return h.bufKV.value
}
Expand Down Expand Up @@ -684,6 +706,9 @@ func (h *ResponseHeader) DelClientCookieBytes(key []byte) {
// Returned value is valid until the next call to ResponseHeader.
// Do not store references to returned value. Make copies instead.
func (h *ResponseHeader) Peek(key string) []byte {
if h.isCopy {
return h.peek(getHeaderKeyBytes(&argsKV{}, key, h.disableNormalizing))
}
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peek(k)
}
Expand Down Expand Up @@ -730,6 +755,9 @@ func (h *ResponseHeader) peek(key []byte) []byte {
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Use ResponseHeader.GetAll(key) instead.
func (h *ResponseHeader) PeekAll(key string) [][]byte {
if h.isCopy {
return h.peekAll(getHeaderKeyBytes(&argsKV{}, key, h.disableNormalizing))
}
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peekAll(k)
}
Expand Down Expand Up @@ -772,6 +800,9 @@ func (h *ResponseHeader) peekAll(key []byte) [][]byte {
// Any future calls to the Peek* will modify the returned value.
// Do not store references to returned value. Use RequestHeader.GetAll(key) instead.
func (h *RequestHeader) PeekAll(key string) [][]byte {
if h.isCopy {
return h.peekAll(getHeaderKeyBytes(&argsKV{}, key, h.disableNormalizing))
}
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peekAll(k)
}
Expand Down Expand Up @@ -1089,6 +1120,12 @@ func (h *RequestHeader) del(key []byte) {
h.h = delAllArgsBytes(h.h, key)
}

// CopyToAndMark copies all the headers to dst and mark the dst header as a copy.
func (h *RequestHeader) CopyToAndMark(dst *RequestHeader) {
dst.isCopy = true
h.CopyTo(dst)
}

// CopyTo copies all the headers to dst.
func (h *RequestHeader) CopyTo(dst *RequestHeader) {
dst.Reset()
Expand Down Expand Up @@ -1117,6 +1154,9 @@ func (h *RequestHeader) CopyTo(dst *RequestHeader) {
// Returned value is valid until the next call to RequestHeader.
// Do not store references to returned value. Make copies instead.
func (h *RequestHeader) Peek(key string) []byte {
if h.isCopy {
return h.peek(getHeaderKeyBytes(&argsKV{}, key, h.disableNormalizing))
}
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
return h.peek(k)
}
Expand Down Expand Up @@ -1414,6 +1454,7 @@ func (h *RequestHeader) SetCanonical(key, value []byte) {
}

func (h *RequestHeader) ResetSkipNormalize() {
h.isCopy = false
h.connectionClose = false
h.protocol = ""
h.noDefaultContentType = false
Expand Down Expand Up @@ -1530,8 +1571,12 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) {

h.collectCookies()
if len(h.cookies) > 0 {
h.bufKV.value = appendRequestCookieBytes(h.bufKV.value[:0], h.cookies)
f(bytestr.StrCookie, h.bufKV.value)
if h.isCopy {
f(bytestr.StrCookie, appendRequestCookieBytes(nil, h.cookies))
} else {
h.bufKV.value = appendRequestCookieBytes(h.bufKV.value[:0], h.cookies)
f(bytestr.StrCookie, h.bufKV.value)
}
}
visitArgs(h.h, f)
if h.ConnectionClose() {
Expand Down
43 changes: 1 addition & 42 deletions pkg/protocol/http1/req/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,48 +441,7 @@ func TestTryRead(t *testing.T) {
s := "P"
zr := mock.NewZeroCopyReader(s)
err := tryRead(&rh, zr, 0)
assert.NotNil(t, err)
}

func TestParseFirstLine(t *testing.T) {
tests := []struct {
input []byte
method string
uri string
protocol string
err error
}{
// Test case 1: n < 0
{
input: []byte("GET /path/to/resource HTTP/1.0\r\n"),
method: "GET",
uri: "/path/to/resource",
protocol: "HTTP/1.0",
err: nil,
},
// Test case 2: n == 0
{
input: []byte(" /path/to/resource HTTP/1.1\r\n"),
method: "",
uri: "",
protocol: "",
err: fmt.Errorf("requestURI cannot be empty in"),
},
// Test case 3: !bytes.Equal(b[n+1:], bytestr.StrHTTP11)
{
input: []byte("POST /path/to/resource HTTP/1.2\r\n"),
method: "POST",
uri: "/path/to/resource",
protocol: "HTTP/1.0",
err: nil,
},
}

for _, tc := range tests {
header := &protocol.RequestHeader{}
_, err := parseFirstLine(header, tc.input)
assert.NotNil(t, err)
}
assert.Nil(t, err)
}

func TestParse(t *testing.T) {
Expand Down
Loading
Loading