/
serve.go
107 lines (94 loc) · 2.61 KB
/
serve.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package httpu
import (
"bufio"
"bytes"
"context"
"fmt"
"log"
"net/http"
"regexp"
"sync"
"golang.org/x/net/ipv4"
"golang.org/x/sync/errgroup"
)
const (
DefaultMaxMessageBytes = 2048
)
var (
trailingWhitespaceRx = regexp.MustCompile(" +\r\n")
crlf = []byte("\r\n")
)
// Handler is the interface by which received SSDP messages are passed to
// handling code.
type Handler interface {
// ServeMessage is called for each SSDP message received. peerAddr contains
// the address that the message was received from.
ServeMessage(r *http.Request) ([]*http.Response, error)
}
// HandlerFunc is a function-to-Handler adapter.
type HandlerFunc func(r *http.Request) ([]*http.Response, error)
func (f HandlerFunc) ServeMessage(r *http.Request) ([]*http.Response, error) {
return f(r)
}
type server struct {
Handler Handler
MaxMessageBytes int
}
// Serve messages received on the given packet listener to the given handler.
func Serve(ctx context.Context, conn *ipv4.PacketConn, handler Handler) error {
srv := server{
Handler: handler,
MaxMessageBytes: DefaultMaxMessageBytes,
}
return srv.Serve(ctx, conn)
}
func (srv *server) Serve(ctx context.Context, conn *ipv4.PacketConn) error {
maxMessageBytes := DefaultMaxMessageBytes
if srv.MaxMessageBytes != 0 {
maxMessageBytes = srv.MaxMessageBytes
}
bufPool := &sync.Pool{
New: func() interface{} {
return make([]byte, maxMessageBytes)
},
}
tasks, _ := errgroup.WithContext(ctx)
defer tasks.Wait()
for {
buf := bufPool.Get().([]byte)
n, _, peerAddr, err := conn.ReadFrom(buf)
if err != nil {
return err
}
tasks.Go(func() error {
defer bufPool.Put(buf)
// At least one router's UPnP implementation has added a trailing space
// after "HTTP/1.1" - trim it.
reqBuf := trailingWhitespaceRx.ReplaceAllLiteral(buf[:n], crlf)
req, err := http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqBuf)))
if err != nil {
log.Printf("httpu: Failed to parse request: %v", err)
return err
}
req.RemoteAddr = peerAddr.String()
responses, err := srv.Handler.ServeMessage(req)
// No need to call req.Body.Close - underlying reader is bytes.Buffer.
if err != nil {
log.Printf("httpu: Failed to handle request: %v", err)
return nil
}
wr := bytes.Buffer{}
for _, resp := range responses {
wr.Reset()
if err := WriteResponse(&wr, resp); err != nil {
fmt.Printf("Error while encoding response: %v\n", err)
}
if _, err := conn.WriteTo(wr.Bytes(), nil, peerAddr); err != nil {
fmt.Printf("Error writing response: %v\n", err)
return nil
}
}
return nil
})
}
}