From d5d81351774423b2477dad4062bf7649ccbff429 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Thu, 22 Mar 2018 15:38:26 -0500 Subject: [PATCH] deglobalize variables; protect stats map with sync.Mutex --- handler.go | 10 ++++++---- main.go | 32 +++++++++++++++++++------------- matcher.go | 15 ++++++++------- stats.go | 21 ++++++++++++--------- 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/handler.go b/handler.go index e95eb08..d57fa86 100644 --- a/handler.go +++ b/handler.go @@ -7,15 +7,18 @@ import ( "log" "net" "net/http" + "sync" "time" ) type myTransport struct { + matcher + stats map[string]MonitoringPath + statsMu sync.RWMutex } type ModifiedRequest struct { Path string - Method string RemoteAddr string } @@ -48,7 +51,6 @@ func parseRequest(r *http.Request) ModifiedRequest { } return ModifiedRequest{ Path: path, - Method: r.Method, RemoteAddr: ip, } } @@ -58,7 +60,7 @@ func (t *myTransport) RoundTrip(request *http.Request) (*http.Response, error) { start := time.Now() parsedRequest := parseRequest(request) - if !MatchAnyRule(parsedRequest) { + if !t.MatchAnyRule(parsedRequest) { log.Println("Not allowed:", parsedRequest.Path, " from IP: ", parsedRequest.RemoteAddr) return &http.Response{ Body: ioutil.NopCloser(bytes.NewBufferString("You are not authorized to make this request")), @@ -83,7 +85,7 @@ func (t *myTransport) RoundTrip(request *http.Request) (*http.Response, error) { }, err } elapsed := time.Since(start) - updateStats(parsedRequest, elapsed) + t.updateStats(parsedRequest, elapsed) log.Println("Response Time:", elapsed.Seconds(), " path: ", parsedRequest.Path, " from IP: ", parsedRequest.RemoteAddr) return response, err diff --git a/main.go b/main.go index 7071a6e..1edaef7 100644 --- a/main.go +++ b/main.go @@ -12,27 +12,38 @@ import ( type Prox struct { target *url.URL proxy *httputil.ReverseProxy + myTransport } -func NewProxy(target string) *Prox { +func NewProxy(target string, m matcher) *Prox { url, _ := url.Parse(target) - return &Prox{target: url, proxy: httputil.NewSingleHostReverseProxy(url)} + p := &Prox{target: url, proxy: httputil.NewSingleHostReverseProxy(url)} + p.stats = make(map[string]MonitoringPath) + p.matcher = m + p.proxy.Transport = &p.myTransport + return p } func (p *Prox) handle(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-rpc-proxy", "rpc-proxy") - p.proxy.Transport = &myTransport{} - p.proxy.ServeHTTP(w, r) +} +func (p *Prox) ServerStatus(w http.ResponseWriter, r *http.Request) { + stats, err := p.getStats() + if err != nil { + http.Error(w, "failed to get stats", http.StatusInternalServerError) + log.Println("Failed to get server stats:", err) + } else { + w.Write(stats) + } } var port *string var redirecturl *string var allowedPathes *string var requestsPerMinuteLimit *int -var globalMap = make(map[string]MonitoringPath) func main() { const ( @@ -60,21 +71,16 @@ func main() { log.Println("requests from IP per minute limited to :", *requestsPerMinuteLimit) // filling matcher rules - err := AddMatcherRules(strings.Split(*allowedPathes, ",")) + rules, err := newMatcher(strings.Split(*allowedPathes, ",")) if err != nil { log.Println("Cannot parse list of allowed pathes", err) } // proxy - proxy := NewProxy(*redirecturl) + proxy := NewProxy(*redirecturl, rules) - http.HandleFunc("/rpc-proxy-server-status", ServerStatus) + http.HandleFunc("/rpc-proxy-server-status", proxy.ServerStatus) // server redirection http.HandleFunc("/", proxy.handle) log.Fatal(http.ListenAndServe(":"+*port, nil)) } - -func ServerStatus(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(getStats())) - return -} diff --git a/matcher.go b/matcher.go index e1803cb..ab38b62 100644 --- a/matcher.go +++ b/matcher.go @@ -4,14 +4,14 @@ import ( "regexp" ) -var paths []*regexp.Regexp +type matcher []*regexp.Regexp -func MatchAnyRule(request ModifiedRequest) bool { +func (m matcher) MatchAnyRule(request ModifiedRequest) bool { if request.Path == "" { return false } - for _, matcher := range paths { + for _, matcher := range m { if matcher.MatchString(request.Path) { return true } @@ -19,13 +19,14 @@ func MatchAnyRule(request ModifiedRequest) bool { return false } -func AddMatcherRules(rules []string) error { +func newMatcher(rules []string) (matcher, error) { + var m matcher for _, p := range rules { compiled, err := regexp.Compile(p) if err != nil { - return err + return nil, err } - paths = append(paths, compiled) + m = append(m, compiled) } - return nil + return m, nil } diff --git a/stats.go b/stats.go index a436b17..17e6fa9 100644 --- a/stats.go +++ b/stats.go @@ -2,7 +2,6 @@ package main import ( "encoding/json" - "log" "time" ) @@ -13,27 +12,31 @@ type MonitoringPath struct { AverageTime float64 } -func updateStats(parsedRequest ModifiedRequest, elapsed time.Duration) { +func (t *myTransport) updateStats(parsedRequest ModifiedRequest, elapsed time.Duration) { key := parsedRequest.RemoteAddr + "-" + parsedRequest.Path - if val, ok := globalMap[key]; ok { + t.statsMu.Lock() + defer t.statsMu.Unlock() + if val, ok := t.stats[key]; ok { val.Count = val.Count + 1 val.TotalDuration += elapsed.Seconds() val.AverageTime = val.TotalDuration / val.Count - globalMap[key] = val + t.stats[key] = val } else { var m MonitoringPath m.Path = parsedRequest.Path m.Count = 1 m.TotalDuration = elapsed.Seconds() m.AverageTime = m.TotalDuration / m.Count - globalMap[key] = m + t.stats[key] = m } } -func getStats() string { - b, err := json.MarshalIndent(globalMap, "", " ") +func (t *myTransport) getStats() ([]byte, error) { + t.statsMu.RLock() + defer t.statsMu.RUnlock() + b, err := json.MarshalIndent(t.stats, "", " ") if err != nil { - log.Println("error:", err) + return nil, err } - return string(b) + return b, nil }