/
request.go
78 lines (65 loc) · 1.95 KB
/
request.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package main
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"net/http"
"github.com/mailgun/multibuf"
"github.com/pkg/errors"
"github.com/vulcand/oxy/utils"
)
const (
unlimiedSize = -1
)
type Request struct {
httpRequest http.Request
buffer multibuf.MultiReader
}
func NewRequest(inRequest *http.Request, memoryBufferSize int64, maxSize int64) (*Request, error) {
if inRequest.ContentLength > maxSize && maxSize > unlimiedSize {
return nil, errors.Errorf("request exceeded size limit (%d > %d)",
inRequest.ContentLength, maxSize)
}
body, err := multibuf.New(inRequest.Body, multibuf.MemBytes(memoryBufferSize))
if err != nil {
return nil, errors.Wrap(err, "cannot copy request body")
}
request := &Request{buffer: body}
request.copyRequest(inRequest)
return request, nil
}
func LoadRequest(reader *bufio.Reader, memoryBufferSize int64) (*Request, error) {
request, err := http.ReadRequest(reader)
if err != nil {
return nil, errors.Wrap(err, "cannot load request")
}
return NewRequest(request, memoryBufferSize, unlimiedSize)
}
func (r *Request) Close() {
r.buffer.Close()
}
func (r *Request) Save(file io.Writer) (int, error) {
buffer := bytes.NewBuffer([]byte{})
if err := r.httpRequest.Write(buffer); err != nil {
return 0, err
}
// TODO: проверять что успешно записан весь буфер.
if _, err := r.buffer.WriteTo(buffer); err != nil {
return 0, err
}
return file.Write(buffer.Bytes())
}
func (r *Request) copyRequest(req *http.Request) {
copyRequest(&r.httpRequest, req, r.buffer)
}
// Helpers
func copyRequest(dstRequest *http.Request, srcRequest *http.Request, buffer io.Reader) {
*(dstRequest) = *(srcRequest)
dstRequest.URL = utils.CopyURL(srcRequest.URL)
dstRequest.Header = make(http.Header)
utils.CopyHeaders(dstRequest.Header, srcRequest.Header)
dstRequest.ContentLength = srcRequest.ContentLength
dstRequest.TransferEncoding = []string{}
dstRequest.Body = ioutil.NopCloser(buffer)
}