Skip to content

Commit

Permalink
route: add Headers method for matching request headers (#133)
Browse files Browse the repository at this point in the history
Co-authored-by: E99p1ant <i@github.red>
  • Loading branch information
unknwon and wuhan005 committed Jun 11, 2022
1 parent 7e355eb commit 92b5d14
Show file tree
Hide file tree
Showing 9 changed files with 510 additions and 91 deletions.
4 changes: 4 additions & 0 deletions codecov.yml
Expand Up @@ -5,6 +5,10 @@ coverage:
default:
threshold: 1%
informational: true
patch:
default:
only_pulls: true
informational: true

comment:
layout: 'diff'
Expand Down
37 changes: 37 additions & 0 deletions internal/route/header_matcher.go
@@ -0,0 +1,37 @@
// Copyright 2022 Flamego. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package route

import (
"net/http"
"regexp"
)

// HeaderMatcher stores matchers for request headers.
type HeaderMatcher struct {
matches map[string]*regexp.Regexp // Key is the header name
}

// NewHeaderMatcher creates a new HeaderMatcher using given matches, where keys
// are header names.
func NewHeaderMatcher(matches map[string]*regexp.Regexp) *HeaderMatcher {
return &HeaderMatcher{
matches: matches,
}
}

// Match returns true if all matches are successfully in the given header.
func (m *HeaderMatcher) Match(header http.Header) bool {
for name, re := range m.matches {
v := header.Get(name)
if v == "" {
return false
}
if !re.MatchString(v) {
return false
}
}
return true
}
81 changes: 81 additions & 0 deletions internal/route/header_matcher_test.go
@@ -0,0 +1,81 @@
// Copyright 2022 Flamego. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package route

import (
"net/http"
"regexp"
"testing"

"github.com/stretchr/testify/assert"
)

func TestHeaderMatcher(t *testing.T) {
header := make(http.Header)
header.Set("Server", "Caddy")
header.Set("Status", "200 OK")

tests := []struct {
name string
matches map[string]*regexp.Regexp
want bool
}{
{
name: "loose matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("Caddy"),
"Status": regexp.MustCompile("200"),
},
want: true,
},
{
name: "loose matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("Caddy"),
"Status": regexp.MustCompile("404"),
},
want: false,
},

{
name: "exact matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("^Caddy$"),
"Status": regexp.MustCompile("^200 OK$"),
},
want: true,
},
{
name: "exact matches",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile("^Caddy$"),
"Status": regexp.MustCompile("^200$"),
},
want: false,
},

{
name: "presence match",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile(""),
},
want: true,
},
{
name: "presence match",
matches: map[string]*regexp.Regexp{
"Server": regexp.MustCompile(""),
"Cache-Control": regexp.MustCompile(""),
},
want: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := NewHeaderMatcher(test.matches).Match(header)
assert.Equal(t, test.want, got)
})
}
}
48 changes: 37 additions & 11 deletions internal/route/leaf.go
Expand Up @@ -6,6 +6,7 @@ package route

import (
"bytes"
"net/http"
"regexp"
"strconv"
"strings"
Expand All @@ -27,6 +28,9 @@ const (

// Leaf is a leaf derived from a segment.
type Leaf interface {
// SetHeaderMatcher sets the HeaderMatcher for the leaf.
SetHeaderMatcher(m *HeaderMatcher)

// URLPath fills in bind parameters with given values to build the "path"
// portion of the URL. If `withOptional` is true, the path will include the
// current leaf when it is optional; otherwise, the current leaf is excluded.
Expand All @@ -46,15 +50,16 @@ type Leaf interface {
getMatchStyle() MatchStyle
// match returns true if the leaf matches the segment, values of bind parameters
// are stored in the `Params`.
match(segment string, params Params) bool
match(segment string, params Params, header http.Header) bool
}

// baseLeaf contains common fields for any leaf.
type baseLeaf struct {
parent Tree // The parent tree this leaf belongs to.
route *Route // The route that the segment belongs to.
segment *Segment // The segment that the leaf is derived from.
handler Handler // The handler bound to the leaf.
parent Tree // The parent tree this leaf belongs to.
route *Route // The route that the segment belongs to.
segment *Segment // The segment that the leaf is derived from.
handler Handler // The handler bound to the leaf.
headerMatcher *HeaderMatcher // The matcher for header values.
}

func (l *baseLeaf) getParent() Tree {
Expand All @@ -65,6 +70,14 @@ func (l *baseLeaf) getSegment() *Segment {
return l.segment
}

func (l *baseLeaf) SetHeaderMatcher(m *HeaderMatcher) {
l.headerMatcher = m
}

func (l *baseLeaf) matchHeader(header http.Header) bool {
return l.headerMatcher == nil || l.headerMatcher.Match(header)
}

func (l *baseLeaf) URLPath(vals map[string]string, withOptional bool) string {
var buf bytes.Buffer
for _, s := range l.route.Segments {
Expand Down Expand Up @@ -123,8 +136,8 @@ func (*staticLeaf) getMatchStyle() MatchStyle {
return matchStyleStatic
}

func (l *staticLeaf) match(segment string, _ Params) bool {
return l.literals == segment
func (l *staticLeaf) match(segment string, _ Params, header http.Header) bool {
return l.literals == segment && l.matchHeader(header)
}

func (l *staticLeaf) Static() bool {
Expand All @@ -149,12 +162,16 @@ func (*regexLeaf) getMatchStyle() MatchStyle {
return matchStyleRegex
}

func (l *regexLeaf) match(segment string, params Params) bool {
func (l *regexLeaf) match(segment string, params Params, header http.Header) bool {
submatches := l.regexp.FindStringSubmatch(segment)
if len(submatches) < len(l.binds)+1 {
return false
}

if !l.matchHeader(header) {
return false
}

for i, bind := range l.binds {
params[bind] = submatches[i+1]
}
Expand All @@ -171,7 +188,10 @@ func (*placeholderLeaf) getMatchStyle() MatchStyle {
return matchStylePlaceholder
}

func (l *placeholderLeaf) match(segment string, params Params) bool {
func (l *placeholderLeaf) match(segment string, params Params, header http.Header) bool {
if !l.matchHeader(header) {
return false
}
params[l.bind] = segment
return true
}
Expand All @@ -187,7 +207,10 @@ func (*matchAllLeaf) getMatchStyle() MatchStyle {
return matchStyleAll
}

func (l *matchAllLeaf) match(segment string, params Params) bool {
func (l *matchAllLeaf) match(segment string, params Params, header http.Header) bool {
if !l.matchHeader(header) {
return false
}
params[l.bind] = segment
return true
}
Expand All @@ -196,13 +219,16 @@ func (l *matchAllLeaf) match(segment string, params Params) bool {
// defined). The `path` should be original request path, `segment` should NOT be
// unescaped by the caller. It returns true if segments are captured within the
// limit, and the capture result is stored in `params`.
func (l *matchAllLeaf) matchAll(path, segment string, next int, params Params) bool {
func (l *matchAllLeaf) matchAll(path, segment string, next int, params Params, header http.Header) bool {
// Do `next-1` because "next" starts at the next character of preceding "/"; do
// `strings.Count()+1` because the segment itself also counts. E.g. "webapi" +
// "users/events" => 3
if l.capture > 0 && l.capture < strings.Count(path[next-1:], "/")+1 {
return false
}
if !l.matchHeader(header) {
return false
}

params[l.bind] = segment + "/" + path[next:]
return true
Expand Down
19 changes: 10 additions & 9 deletions internal/route/leaf_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewLeaf(t *testing.T) {
Expand All @@ -21,7 +22,7 @@ func TestNewLeaf(t *testing.T) {
})

parser, err := NewParser()
assert.Nil(t, err)
require.NoError(t, err)

tests := []struct {
route string
Expand Down Expand Up @@ -60,12 +61,12 @@ func TestNewLeaf(t *testing.T) {
for _, test := range tests {
t.Run(test.route, func(t *testing.T) {
route, err := parser.Parse(test.route)
assert.Nil(t, err)
require.NoError(t, err)
assert.Len(t, route.Segments, 1)

segment := route.Segments[0]
got, err := newLeaf(nil, route, segment, nil)
assert.Nil(t, err)
require.NoError(t, err)

switch test.style {
case matchStyleStatic:
Expand All @@ -86,7 +87,7 @@ func TestNewLeaf(t *testing.T) {

func TestNewLeaf_Regex(t *testing.T) {
parser, err := NewParser()
assert.Nil(t, err)
require.NoError(t, err)

tests := []struct {
route string
Expand Down Expand Up @@ -122,12 +123,12 @@ func TestNewLeaf_Regex(t *testing.T) {
for _, test := range tests {
t.Run(test.route, func(t *testing.T) {
route, err := parser.Parse(test.route)
assert.Nil(t, err)
require.NoError(t, err)
assert.Len(t, route.Segments, 1)

segment := route.Segments[0]
got, err := newLeaf(nil, route, segment, nil)
assert.Nil(t, err)
require.NoError(t, err)

leaf := got.(*regexLeaf)
assert.Equal(t, test.wantRegexp, leaf.regexp.String())
Expand All @@ -138,7 +139,7 @@ func TestNewLeaf_Regex(t *testing.T) {

func TestLeaf_URLPath(t *testing.T) {
parser, err := NewParser()
assert.Nil(t, err)
require.NoError(t, err)

tests := []struct {
route string
Expand Down Expand Up @@ -245,11 +246,11 @@ func TestLeaf_URLPath(t *testing.T) {
for _, test := range tests {
t.Run(test.route, func(t *testing.T) {
route, err := parser.Parse(test.route)
assert.Nil(t, err)
require.NoError(t, err)

segment := route.Segments[len(route.Segments)-1]
leaf, err := newLeaf(nil, route, segment, nil)
assert.Nil(t, err)
require.NoError(t, err)

got := leaf.URLPath(test.vals, test.withOptional)
assert.Equal(t, test.want, got)
Expand Down

0 comments on commit 92b5d14

Please sign in to comment.