forked from cloudflare/cfssl
/
lookup.go
144 lines (123 loc) · 3.55 KB
/
lookup.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package whitelist
import (
"errors"
"log"
"net"
"net/http"
)
// NetConnLookup extracts an IP from the remote address in the
// net.Conn. A single net.Conn should be passed to Address.
func NetConnLookup(conn net.Conn) (net.IP, error) {
if conn == nil {
return nil, errors.New("whitelist: no connection")
}
netAddr := conn.RemoteAddr()
if netAddr == nil {
return nil, errors.New("whitelist: no address returned")
}
addr, _, err := net.SplitHostPort(netAddr.String())
if err != nil {
return nil, err
}
ip := net.ParseIP(addr)
return ip, nil
}
// HTTPRequestLookup extracts an IP from the remote address in a
// *http.Request. A single *http.Request should be passed to Address.
func HTTPRequestLookup(req *http.Request) (net.IP, error) {
if req == nil {
return nil, errors.New("whitelist: no request")
}
addr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
return nil, err
}
ip := net.ParseIP(addr)
return ip, nil
}
// Handler wraps an HTTP handler with IP whitelisting.
type Handler struct {
allowHandler http.Handler
denyHandler http.Handler
whitelist ACL
}
// NewHandler returns a new whitelisting-wrapped HTTP handler. The
// allow handler should contain a handler that will be called if the
// request is whitelisted; the deny handler should contain a handler
// that will be called in the request is not whitelisted.
func NewHandler(allow, deny http.Handler, acl ACL) (http.Handler, error) {
if allow == nil {
return nil, errors.New("whitelist: allow cannot be nil")
}
if acl == nil {
return nil, errors.New("whitelist: ACL cannot be nil")
}
return &Handler{
allowHandler: allow,
denyHandler: deny,
whitelist: acl,
}, nil
}
// ServeHTTP wraps the request in a whitelist check.
func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ip, err := HTTPRequestLookup(req)
if err != nil {
log.Printf("failed to lookup request address: %v", err)
status := http.StatusInternalServerError
http.Error(w, http.StatusText(status), status)
return
}
if h.whitelist.Permitted(ip) {
h.allowHandler.ServeHTTP(w, req)
} else {
if h.denyHandler == nil {
status := http.StatusUnauthorized
http.Error(w, http.StatusText(status), status)
} else {
h.denyHandler.ServeHTTP(w, req)
}
}
}
// A HandlerFunc contains a pair of http.HandleFunc-handler functions
// that will be called depending on whether a request is allowed or
// denied.
type HandlerFunc struct {
allow func(http.ResponseWriter, *http.Request)
deny func(http.ResponseWriter, *http.Request)
whitelist ACL
}
// NewHandlerFunc returns a new basic whitelisting handler.
func NewHandlerFunc(allow, deny func(http.ResponseWriter, *http.Request), acl ACL) (*HandlerFunc, error) {
if allow == nil {
return nil, errors.New("whitelist: allow cannot be nil")
}
if acl == nil {
return nil, errors.New("whitelist: ACL cannot be nil")
}
return &HandlerFunc{
allow: allow,
deny: deny,
whitelist: acl,
}, nil
}
// ServeHTTP checks the incoming request to see whether it is permitted,
// and calls the appropriate handle function.
func (h *HandlerFunc) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ip, err := HTTPRequestLookup(req)
if err != nil {
log.Printf("failed to lookup request address: %v", err)
status := http.StatusInternalServerError
http.Error(w, http.StatusText(status), status)
return
}
if h.whitelist.Permitted(ip) {
h.allow(w, req)
} else {
if h.deny == nil {
status := http.StatusUnauthorized
http.Error(w, http.StatusText(status), status)
} else {
h.deny(w, req)
}
}
}