From 0477d98db53ad98ce6d1d25d447f4e8c0ffc1bc7 Mon Sep 17 00:00:00 2001 From: marrow16 Date: Fri, 11 Nov 2022 18:49:00 +0000 Subject: [PATCH] Path vars Path vars are now flexible and accept other types apart from strings Plus added host, headers and query params for building full URLs --- headers.go | 81 ++++++++++++++++++ headers_test.go | 105 +++++++++++++++++++++++ host.go | 23 +++++ host_test.go | 15 ++++ options.go | 15 ++++ path_part.go | 41 +++++++-- path_vars.go | 39 +++++---- path_vars_test.go | 5 +- query_params.go | 141 +++++++++++++++++++++++++++++++ query_params_test.go | 197 +++++++++++++++++++++++++++++++++++++++++++ template.go | 81 +++++++++++++++--- template_test.go | 108 ++++++++++++++++++++++-- value.go | 72 ++++++++++++++++ value_test.go | 140 ++++++++++++++++++++++++++++++ 14 files changed, 1023 insertions(+), 40 deletions(-) create mode 100644 headers.go create mode 100644 headers_test.go create mode 100644 host.go create mode 100644 host_test.go create mode 100644 query_params.go create mode 100644 query_params_test.go create mode 100644 value.go create mode 100644 value_test.go diff --git a/headers.go b/headers.go new file mode 100644 index 0000000..b0d181b --- /dev/null +++ b/headers.go @@ -0,0 +1,81 @@ +package urit + +import ( + "errors" +) + +type HeadersOption interface { + GetHeaders() (map[string]string, error) +} + +type Headers interface { + HeadersOption + Set(key string, value interface{}) Headers + Get(key string) (interface{}, bool) + Has(key string) bool + Del(key string) Headers + Clone() Headers +} + +func NewHeaders(namesAndValues ...interface{}) (Headers, error) { + if len(namesAndValues)%2 != 0 { + return nil, errors.New("must be a value for each name") + } + result := &headers{ + entries: map[string]interface{}{}, + } + for i := 0; i < len(namesAndValues)-1; i += 2 { + if k, ok := namesAndValues[i].(string); ok { + result.entries[k] = namesAndValues[i+1] + } else { + return nil, errors.New("name must be a string") + } + } + return result, nil +} + +type headers struct { + entries map[string]interface{} +} + +func (h *headers) GetHeaders() (map[string]string, error) { + result := map[string]string{} + for k, v := range h.entries { + if str, err := getValue(v); err == nil { + result[k] = str + } else { + return result, err + } + } + return result, nil +} + +func (h *headers) Set(key string, value interface{}) Headers { + h.entries[key] = value + return h +} + +func (h *headers) Get(key string) (interface{}, bool) { + v, ok := h.entries[key] + return v, ok +} + +func (h *headers) Has(key string) bool { + _, ok := h.entries[key] + return ok +} + +func (h *headers) Del(key string) Headers { + delete(h.entries, key) + return h +} + +func (h *headers) Clone() Headers { + result := &headers{ + entries: map[string]interface{}{}, + } + for k, v := range h.entries { + result.entries[k] = v + } + return result +} diff --git a/headers_test.go b/headers_test.go new file mode 100644 index 0000000..2fd193f --- /dev/null +++ b/headers_test.go @@ -0,0 +1,105 @@ +package urit + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestNewHeaders(t *testing.T) { + h, err := NewHeaders() + require.NoError(t, err) + require.NotNil(t, h) + rh, ok := h.(*headers) + require.True(t, ok) + require.Equal(t, 0, len(rh.entries)) + + h, err = NewHeaders("foo", 1.23) + require.NoError(t, err) + require.NotNil(t, h) + rh, ok = h.(*headers) + require.True(t, ok) + require.Equal(t, 1, len(rh.entries)) + require.Equal(t, 1.23, rh.entries["foo"]) +} + +func TestNewHeadersErrors(t *testing.T) { + _, err := NewHeaders("foo") // must be an even number! + require.Error(t, err) + require.Equal(t, `must be a value for each name`, err.Error()) + + _, err = NewHeaders(true, false) // first must be a string! + require.Error(t, err) + require.Equal(t, `name must be a string`, err.Error()) +} + +func TestHeaders_GetHeaders(t *testing.T) { + h, err := NewHeaders() + require.NoError(t, err) + hds, err := h.GetHeaders() + require.NoError(t, err) + require.Equal(t, 0, len(hds)) + + h, err = NewHeaders("foo", 1.23) + require.NoError(t, err) + hds, err = h.GetHeaders() + require.NoError(t, err) + require.Equal(t, 1, len(hds)) + require.Equal(t, "1.23", hds["foo"]) + + h, err = NewHeaders("foo", nil) + require.NoError(t, err) + _, err = h.GetHeaders() + require.Error(t, err) + require.Equal(t, `unknown value type`, err.Error()) + + h, err = NewHeaders("foo", func() { + // this does not yield a string + }) + require.NoError(t, err) + _, err = h.GetHeaders() + require.Error(t, err) + require.Equal(t, `unknown value type`, err.Error()) +} + +func TestHeaders_GetSet(t *testing.T) { + h, err := NewHeaders("foo", 1) + require.NoError(t, err) + v, ok := h.Get("foo") + require.True(t, ok) + require.Equal(t, 1, v) + h.Set("foo", 2) + v, ok = h.Get("foo") + require.True(t, ok) + require.Equal(t, 2, v) + _, ok = h.Get("bar") + require.False(t, ok) +} + +func TestHeaders_HasDel(t *testing.T) { + h, err := NewHeaders("foo", 1) + require.NoError(t, err) + require.True(t, h.Has("foo")) + h.Del("foo") + require.False(t, h.Has("foo")) + require.False(t, h.Has("bar")) +} + +func TestHeaders_Clone(t *testing.T) { + h1, err := NewHeaders("foo", 1) + require.NoError(t, err) + require.True(t, h1.Has("foo")) + require.False(t, h1.Has("bar")) + h2 := h1.Clone() + require.True(t, h2.Has("foo")) + require.False(t, h2.Has("bar")) + h2.Del("foo") + require.True(t, h1.Has("foo")) + require.False(t, h1.Has("bar")) + require.False(t, h2.Has("foo")) + require.False(t, h2.Has("bar")) + h2.Set("bar", 2) + require.True(t, h1.Has("foo")) + require.False(t, h1.Has("bar")) + require.False(t, h2.Has("foo")) + require.True(t, h2.Has("bar")) +} diff --git a/host.go b/host.go new file mode 100644 index 0000000..688c2da --- /dev/null +++ b/host.go @@ -0,0 +1,23 @@ +package urit + +type HostOption interface { + GetAddress() string +} + +type Host interface { + HostOption +} + +func NewHost(address string) Host { + return &host{ + address: address, + } +} + +type host struct { + address string +} + +func (h *host) GetAddress() string { + return h.address +} diff --git a/host_test.go b/host_test.go new file mode 100644 index 0000000..e9fa40b --- /dev/null +++ b/host_test.go @@ -0,0 +1,15 @@ +package urit + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestNewHost(t *testing.T) { + h := NewHost(`www.example.com`) + require.NotNil(t, h) + rh, ok := h.(*host) + require.True(t, ok) + require.Equal(t, `www.example.com`, rh.address) + require.Equal(t, `www.example.com`, h.GetAddress()) +} diff --git a/options.go b/options.go index 44c0c6a..efe8793 100644 --- a/options.go +++ b/options.go @@ -23,9 +23,11 @@ type VarMatchOption interface { var ( _CaseInsensitiveFixed = &caseInsensitiveFixed{} + _PathRegexCheck = &pathRegexChecker{} ) var ( CaseInsensitiveFixed = _CaseInsensitiveFixed // is a FixedMatchOption that can be used with templates to allow case-insensitive fixed path parts + PathRegexCheck = _PathRegexCheck // is a VarMatchOption that can be used with Template.PathFrom or Template.RequestFrom to check that vars passed in match regexes for the path part ) type fixedMatchOptions []FixedMatchOption @@ -64,3 +66,16 @@ type caseInsensitiveFixed struct{} func (o *caseInsensitiveFixed) Match(value string, expected string, pathPos int, vars PathVars) bool { return value == expected || strings.EqualFold(value, expected) } + +type pathRegexChecker struct{} + +func (o *pathRegexChecker) Applicable(value string, position int, name string, rx *regexp.Regexp, rxs string, pathPos int, vars PathVars) bool { + return rx != nil +} + +func (o *pathRegexChecker) Match(value string, position int, name string, rx *regexp.Regexp, rxs string, pathPos int, vars PathVars) (string, bool) { + if rx != nil && !rx.MatchString(value) { + return value, false + } + return value, true +} diff --git a/path_part.go b/path_part.go index ddcf175..077e856 100644 --- a/path_part.go +++ b/path_part.go @@ -1,6 +1,7 @@ package urit import ( + "errors" "fmt" "regexp" "strconv" @@ -150,27 +151,53 @@ func (pt *pathPart) pathFrom(tracker *positionsTracker) (string, error) { type positionsTracker struct { vars PathVars - position int + varPosition int + pathPosition int namedPositions map[string]int + varMatches varMatchOptions } func (tr *positionsTracker) getVar(pt *pathPart) (string, error) { - if tr.vars.VarsType() == Positions { - if str, ok := tr.vars.GetPositional(tr.position); ok { - tr.position++ + useVars := tr.vars + if useVars == nil { + useVars = Positional() + } + var err error + if useVars.VarsType() == Positions { + if str, ok := useVars.GetPositional(tr.varPosition); ok { + tr.varPosition++ return str, nil } - return "", fmt.Errorf("no var for position %d", tr.position+1) + return "", fmt.Errorf("no var for varPosition %d", tr.varPosition+1) } else { np := tr.namedPositions[pt.name] - if str, ok := tr.vars.GetNamed(pt.name, np); ok { + if str, ok := useVars.GetNamed(pt.name, np); ok { + str, err = tr.checkVar(str, pt, tr.varPosition, tr.pathPosition) + if err != nil { + return "", err + } tr.namedPositions[pt.name] = np + 1 + tr.varPosition++ return str, nil } else if np == 0 { return "", fmt.Errorf("no var for '%s'", pt.name) } - return "", fmt.Errorf("no var for '%s' (position %d)", pt.name, np+1) + return "", fmt.Errorf("no var for '%s' (varPosition %d)", pt.name, np+1) + } +} + +func (tr *positionsTracker) checkVar(s string, pt *pathPart, pos int, pathPos int) (result string, err error) { + result = s + for _, ck := range tr.varMatches { + if ck.Applicable(result, pos, pt.name, pt.regexp, pt.orgRegexp, pathPos, tr.vars) { + if altS, ok := ck.Match(result, pos, pt.name, pt.regexp, pt.orgRegexp, pathPos, tr.vars); ok { + result = altS + } else { + err = errors.New("no match path var") + } + } } + return } func addRegexHeadAndTail(rx string) string { diff --git a/path_vars.go b/path_vars.go index 08b6bb3..7d331cd 100644 --- a/path_vars.go +++ b/path_vars.go @@ -6,7 +6,7 @@ type PathVar struct { Name string NamedPosition int Position int - Value string + Value interface{} } // PathVars is the interface used to pass path vars into a template and returned from a template after extracting @@ -23,8 +23,8 @@ type PathVars interface { Clear() // VarsType returns the path vars type (Positions or Names) VarsType() PathVarsType - AddNamedValue(name string, val string) error - AddPositionalValue(val string) error + AddNamedValue(name string, val interface{}) error + AddPositionalValue(val interface{}) error } type pathVars struct { @@ -43,9 +43,9 @@ func newPathVars(varsType PathVarsType) PathVars { func (pvs *pathVars) GetPositional(position int) (string, bool) { if position < 0 && (len(pvs.all)+position) >= 0 { - return pvs.all[len(pvs.all)+position].Value, true + return getValueIf(pvs.all[len(pvs.all)+position].Value) } else if position >= 0 && position < len(pvs.all) { - return pvs.all[position].Value, true + return getValueIf(pvs.all[position].Value) } return "", false } @@ -53,9 +53,9 @@ func (pvs *pathVars) GetPositional(position int) (string, bool) { func (pvs *pathVars) GetNamed(name string, position int) (string, bool) { if vs, ok := pvs.named[name]; ok { if position < 0 && (len(vs)+position) >= 0 { - return vs[len(vs)+position].Value, true + return getValueIf(vs[len(vs)+position].Value) } else if position >= 0 && position < len(vs) { - return vs[position].Value, true + return getValueIf(vs[position].Value) } } return "", false @@ -63,14 +63,14 @@ func (pvs *pathVars) GetNamed(name string, position int) (string, bool) { func (pvs *pathVars) GetNamedFirst(name string) (string, bool) { if vs, ok := pvs.named[name]; ok && len(vs) > 0 { - return vs[0].Value, true + return getValueIf(vs[0].Value) } return "", false } func (pvs *pathVars) GetNamedLast(name string) (string, bool) { if vs, ok := pvs.named[name]; ok && len(vs) > 0 { - return vs[len(vs)-1].Value, true + return getValueIf(vs[len(vs)-1].Value) } return "", false } @@ -114,7 +114,7 @@ func (pvs *pathVars) VarsType() PathVarsType { return pvs.varsType } -func (pvs *pathVars) AddNamedValue(name string, val string) error { +func (pvs *pathVars) AddNamedValue(name string, val interface{}) error { if pvs.varsType != Names { return errors.New("cannot add named var to non-names vars") } @@ -130,7 +130,7 @@ func (pvs *pathVars) AddNamedValue(name string, val string) error { return nil } -func (pvs *pathVars) AddPositionalValue(val string) error { +func (pvs *pathVars) AddPositionalValue(val interface{}) error { if pvs.varsType != Positions { return errors.New("cannot add positional var to non-positionals vars") } @@ -142,7 +142,7 @@ func (pvs *pathVars) AddPositionalValue(val string) error { } // Positional creates a positional PathVars from the values supplied -func Positional(values ...string) PathVars { +func Positional(values ...interface{}) PathVars { result := newPathVars(Positions) for _, val := range values { _ = result.AddPositionalValue(val) @@ -152,15 +152,22 @@ func Positional(values ...string) PathVars { // Named creates a named PathVars from the name and value pairs supplied // -// Note: If there is not a value for each name - this function panics. -// So ensure that the number of varargs passed is an even number -func Named(namesAndValues ...string) PathVars { +// Notes: +// +// * If there is not a value for each name - this function panics (so ensure that the number of varargs passed is an even number!) +// +// * If any of the name values are not a string - this function panics +func Named(namesAndValues ...interface{}) PathVars { if len(namesAndValues)%2 != 0 { panic("must be a value for each name") } result := newPathVars(Names) for i := 0; i < len(namesAndValues); i += 2 { - _ = result.AddNamedValue(namesAndValues[i], namesAndValues[i+1]) + if name, ok := namesAndValues[i].(string); ok { + _ = result.AddNamedValue(name, namesAndValues[i+1]) + } else { + panic("name must be a string") + } } return result } diff --git a/path_vars_test.go b/path_vars_test.go index 16dec09..0c1114a 100644 --- a/path_vars_test.go +++ b/path_vars_test.go @@ -91,7 +91,10 @@ func TestNamed(t *testing.T) { func TestNamedPanics(t *testing.T) { require.Panics(t, func() { - Named("a", "b", "c") + Named("a", "b", "c") // not an even number! + }) + require.Panics(t, func() { + Named(true, false) // first should be a string! }) } diff --git a/query_params.go b/query_params.go new file mode 100644 index 0000000..b668de9 --- /dev/null +++ b/query_params.go @@ -0,0 +1,141 @@ +package urit + +import ( + "errors" + "net/url" + "sort" + "strings" +) + +type QueryParamsOption interface { + GetQuery() (string, error) +} + +type QueryParams interface { + QueryParamsOption + Get(key string) (interface{}, bool) + GetIndex(key string, index int) (interface{}, bool) + Set(key string, value interface{}) QueryParams + Add(key string, value interface{}) QueryParams + Del(key string) QueryParams + Has(key string) bool + Sorted(on bool) QueryParams + Clone() QueryParams +} + +func NewQueryParams(namesAndValues ...interface{}) (QueryParams, error) { + if len(namesAndValues)%2 != 0 { + return nil, errors.New("must be a value for each name") + } + result := &queryParams{ + params: map[string][]interface{}{}, + sorted: true, + } + for i := 0; i < len(namesAndValues)-1; i += 2 { + if k, ok := namesAndValues[i].(string); ok { + result.params[k] = append(result.params[k], namesAndValues[i+1]) + } else { + return nil, errors.New("name must be a string") + } + } + return result, nil +} + +type queryParams struct { + params map[string][]interface{} + sorted bool +} + +func (qp *queryParams) GetQuery() (string, error) { + var qb strings.Builder + if len(qp.params) > 0 { + names := make([]string, 0, len(qp.params)) + for name := range qp.params { + names = append(names, name) + } + if qp.sorted { + sort.Strings(names) + } + for _, name := range names { + if v := qp.params[name]; len(v) == 0 || (len(v) == 1 && v[0] == nil) { + qb.WriteString(ampersandOrQuestionMark(qb.Len() == 0)) + qb.WriteString(url.QueryEscape(name)) + } else { + for _, qv := range v { + qb.WriteString(ampersandOrQuestionMark(qb.Len() == 0)) + qb.WriteString(url.QueryEscape(name)) + if qv != nil { + if str, err := getValue(qv); err == nil { + qb.WriteString("=") + qb.WriteString(url.QueryEscape(str)) + } else { + return "", err + } + } + } + } + } + } + return qb.String(), nil +} + +func (qp *queryParams) Get(key string) (interface{}, bool) { + if vs, ok := qp.params[key]; ok && len(vs) > 0 { + return vs[0], true + } + return nil, false +} + +func (qp *queryParams) GetIndex(key string, index int) (interface{}, bool) { + if vs, ok := qp.params[key]; ok && len(vs) > 0 { + if index >= 0 && index < len(vs) { + return vs[index], true + } else if index < 0 && (len(vs)+index) >= 0 { + return vs[len(vs)+index], true + } + } + return nil, false +} + +func (qp *queryParams) Set(key string, value interface{}) QueryParams { + qp.params[key] = []interface{}{value} + return qp +} + +func (qp *queryParams) Add(key string, value interface{}) QueryParams { + qp.params[key] = append(qp.params[key], value) + return qp +} + +func (qp *queryParams) Del(key string) QueryParams { + delete(qp.params, key) + return qp +} + +func (qp *queryParams) Has(key string) bool { + _, ok := qp.params[key] + return ok +} + +func (qp *queryParams) Sorted(on bool) QueryParams { + qp.sorted = on + return qp +} + +func (qp *queryParams) Clone() QueryParams { + result := &queryParams{ + params: map[string][]interface{}{}, + sorted: qp.sorted, + } + for k, v := range qp.params { + result.params[k] = append(v) + } + return result +} + +func ampersandOrQuestionMark(first bool) string { + if first { + return "?" + } + return "&" +} diff --git a/query_params_test.go b/query_params_test.go new file mode 100644 index 0000000..efe2a63 --- /dev/null +++ b/query_params_test.go @@ -0,0 +1,197 @@ +package urit + +import ( + "github.com/stretchr/testify/require" + "strings" + "testing" +) + +func TestNewQueryParams(t *testing.T) { + p, err := NewQueryParams() + require.NoError(t, err) + require.NotNil(t, p) + rp, ok := p.(*queryParams) + require.True(t, ok) + require.Equal(t, 0, len(rp.params)) + q, err := p.GetQuery() + require.NoError(t, err) + require.Equal(t, ``, q) + + p, err = NewQueryParams("foo", 1.23) + require.NoError(t, err) + require.NotNil(t, p) + rp, ok = p.(*queryParams) + require.True(t, ok) + require.Equal(t, 1, len(rp.params)) + require.Equal(t, 1, len(rp.params["foo"])) + require.Equal(t, 1.23, rp.params["foo"][0]) +} + +func TestNewQueryParamsErrors(t *testing.T) { + _, err := NewQueryParams("foo") // must be an even number! + require.Error(t, err) + require.Equal(t, `must be a value for each name`, err.Error()) + + _, err = NewQueryParams(true, false) // first must be a string! + require.Error(t, err) + require.Equal(t, `name must be a string`, err.Error()) +} + +func TestQueryParams_GetQuery(t *testing.T) { + p, err := NewQueryParams() + require.NoError(t, err) + q, err := p.GetQuery() + require.NoError(t, err) + require.Equal(t, ``, q) + + p, err = NewQueryParams("foo", nil) + require.NoError(t, err) + q, err = p.GetQuery() + require.NoError(t, err) + require.Equal(t, `?foo`, q) + + p, err = NewQueryParams("foo", nil, "foo", true) + require.NoError(t, err) + q, err = p.GetQuery() + require.NoError(t, err) + require.Contains(t, q, `foo=true`) + require.True(t, strings.HasPrefix(q, "?")) + require.Equal(t, strings.Index(q, "&"), strings.LastIndex(q, "&")) + + p, err = NewQueryParams("foo", func() { + // this does not yield a string + }) + require.NoError(t, err) + _, err = p.GetQuery() + require.Error(t, err) + require.Equal(t, `unknown value type`, err.Error()) +} + +func TestQueryParams_Get(t *testing.T) { + p, err := NewQueryParams() + require.NoError(t, err) + _, ok := p.Get("foo") + require.False(t, ok) + p.Add("foo", nil) + v, ok := p.Get("foo") + require.True(t, ok) + require.Nil(t, v) +} + +func TestQueryParams_GetIndex(t *testing.T) { + p, err := NewQueryParams("foo", 1, "foo", 2, "foo", 3) + require.NoError(t, err) + v, ok := p.GetIndex("foo", 0) + require.True(t, ok) + require.Equal(t, 1, v) + v, ok = p.GetIndex("foo", 1) + require.True(t, ok) + require.Equal(t, 2, v) + v, ok = p.GetIndex("foo", 2) + require.True(t, ok) + require.Equal(t, 3, v) + _, ok = p.GetIndex("foo", 3) + require.False(t, ok) + v, ok = p.GetIndex("foo", -1) + require.True(t, ok) + require.Equal(t, 3, v) + v, ok = p.GetIndex("foo", -2) + require.True(t, ok) + require.Equal(t, 2, v) + v, ok = p.GetIndex("foo", -3) + require.True(t, ok) + require.Equal(t, 1, v) + _, ok = p.GetIndex("foo", -4) + require.False(t, ok) +} + +func TestQueryParams_Set(t *testing.T) { + p, err := NewQueryParams("foo", 1, "foo", 2) + require.NoError(t, err) + v, ok := p.GetIndex("foo", 0) + require.True(t, ok) + require.Equal(t, 1, v) + _, ok = p.GetIndex("foo", 1) + require.True(t, ok) + p.Set("foo", 3) + v, ok = p.Get("foo") + require.True(t, ok) + require.Equal(t, 3, v) + _, ok = p.GetIndex("foo", 1) + require.False(t, ok) +} + +func TestQueryParams_Add(t *testing.T) { + p, err := NewQueryParams("foo", 1) + require.NoError(t, err) + v, ok := p.GetIndex("foo", 0) + require.True(t, ok) + require.Equal(t, 1, v) + _, ok = p.GetIndex("foo", 1) + require.False(t, ok) + p.Add("foo", 2) + v, ok = p.GetIndex("foo", 1) + require.True(t, ok) + require.Equal(t, 2, v) +} + +func TestQueryParams_Del(t *testing.T) { + p, err := NewQueryParams("foo", 1, "foo", 2) + require.NoError(t, err) + v, ok := p.GetIndex("foo", 1) + require.True(t, ok) + require.Equal(t, 2, v) + p.Del("foo") + _, ok = p.Get("foo") + require.False(t, ok) +} + +func TestQueryParams_Has(t *testing.T) { + p, err := NewQueryParams("foo", 1, "foo", 2) + require.NoError(t, err) + require.True(t, p.Has("foo")) + p.Del("foo") + require.False(t, p.Has("foo")) +} + +func TestQueryParams_Sorted(t *testing.T) { + p, err := NewQueryParams("foo", 1, "baz", 2, "bar", 3) + require.NoError(t, err) + rp, ok := p.(*queryParams) + require.True(t, ok) + require.True(t, rp.sorted) + q, err := p.GetQuery() + require.NoError(t, err) + require.Equal(t, `?bar=3&baz=2&foo=1`, q) + p.Sorted(false) + q, err = p.GetQuery() + require.NoError(t, err) + require.Contains(t, q, `foo=1`) + require.Contains(t, q, `baz=2`) + require.Contains(t, q, `bar=3`) +} + +func TestQueryParams_Clone(t *testing.T) { + p1, err := NewQueryParams("foo", 1) + require.NoError(t, err) + rp1, ok := p1.(*queryParams) + require.True(t, ok) + require.True(t, rp1.sorted) + p2 := p1.Clone() + rp2, ok := p2.(*queryParams) + require.True(t, ok) + require.True(t, rp2.sorted) + p2.Sorted(false) + require.False(t, rp2.sorted) + require.True(t, rp1.sorted) + + _, ok = p1.Get("foo") + require.True(t, ok) + _, ok = p2.Get("foo") + require.True(t, ok) + p2.Del("foo") + _, ok = p1.Get("foo") + require.True(t, ok) + _, ok = p2.Get("foo") + require.False(t, ok) +} diff --git a/template.go b/template.go index 49f130b..a2c9690 100644 --- a/template.go +++ b/template.go @@ -3,6 +3,7 @@ package urit import ( "errors" "github.com/go-andiamo/splitter" + "io" "net/http" "strings" ) @@ -21,7 +22,7 @@ const ( // The options can be any FixedMatchOption or VarMatchOption - which can be used // to extend or check fixed or variable path parts func NewTemplate(path string, options ...interface{}) (Template, error) { - fs, vs, so := separateOptions(options) + fs, vs, so := separateParseOptions(options) return (&template{ originalTemplate: slashPrefix(path), pathParts: make([]pathPart, 0), @@ -44,7 +45,9 @@ func MustCreateTemplate(path string, options ...interface{}) Template { // Template is the interface for a URI template type Template interface { // PathFrom generates a path from the template given the specified path vars - PathFrom(vars PathVars) (string, error) + PathFrom(vars PathVars, options ...interface{}) (string, error) + // RequestFrom generates a http.Request from the template given the specified path vars + RequestFrom(method string, vars PathVars, body io.Reader, options ...interface{}) (*http.Request, error) // Matches checks whether the specified path matches the template - // and if a successful match, returns the extracted path vars Matches(path string, options ...interface{}) (PathVars, bool) @@ -73,12 +76,22 @@ type template struct { } // PathFrom generates a path from the template given the specified path vars -func (t *template) PathFrom(vars PathVars) (string, error) { +func (t *template) PathFrom(vars PathVars, options ...interface{}) (string, error) { + hostOption, queryOption, _, varMatches := separatePathOptions(options) + return t.buildPath(vars, hostOption, queryOption, varMatches) +} + +func (t *template) buildPath(vars PathVars, hostOption HostOption, queryOption QueryParamsOption, varMatches varMatchOptions) (string, error) { var pb strings.Builder + if hostOption != nil { + pb.WriteString(hostOption.GetAddress()) + } tracker := &positionsTracker{ vars: vars, - position: 0, + varPosition: 0, + pathPosition: 0, namedPositions: map[string]int{}, + varMatches: varMatches, } for _, pt := range t.pathParts { if str, err := pt.pathFrom(tracker); err == nil { @@ -86,10 +99,41 @@ func (t *template) PathFrom(vars PathVars) (string, error) { } else { return "", err } + tracker.pathPosition++ + } + if queryOption != nil { + if q, err := queryOption.GetQuery(); err == nil { + pb.WriteString(q) + } else { + return "", err + } } return pb.String(), nil } +// RequestFrom generates a http.Request from the template given the specified path vars +func (t *template) RequestFrom(method string, vars PathVars, body io.Reader, options ...interface{}) (*http.Request, error) { + hostOption, queryOption, headerOption, varMatches := separatePathOptions(options) + url, err := t.buildPath(vars, hostOption, queryOption, varMatches) + if err != nil { + return nil, err + } + result, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + if headerOption != nil { + hds, err := headerOption.GetHeaders() + if err != nil { + return nil, err + } + for k, v := range hds { + result.Header.Set(k, v) + } + } + return result, nil +} + // Matches checks whether the specified path matches the template - // and if a successful match, returns the extracted path vars func (t *template) Matches(path string, options ...interface{}) (PathVars, bool) { @@ -98,7 +142,7 @@ func (t *template) Matches(path string, options ...interface{}) (PathVars, bool) return nil, false } result := newPathVars(t.varsType) - fixedOpts, varOpts := t.mergeOptions(options) + fixedOpts, varOpts := t.mergeParseOptions(options) ok := true for i, pt := range t.pathParts { ok = pt.match(pts[i], i, result, fixedOpts, varOpts) @@ -143,7 +187,7 @@ func (t *template) Sub(path string, options ...interface{}) (Template, error) { func (t *template) ResolveTo(vars PathVars) (Template, error) { tracker := &positionsTracker{ vars: vars, - position: 0, + varPosition: 0, namedPositions: map[string]int{}, } result := &template{ @@ -242,7 +286,22 @@ func (t *template) OriginalTemplate() string { return t.originalTemplate } -func separateOptions(options []interface{}) (fixedMatchOptions, varMatchOptions, []splitter.Option) { +func separatePathOptions(options []interface{}) (host HostOption, params QueryParamsOption, headers HeadersOption, varMatches varMatchOptions) { + for _, intf := range options { + if h, ok := intf.(HostOption); ok { + host = h + } else if q, ok := intf.(QueryParamsOption); ok { + params = q + } else if hd, ok := intf.(HeadersOption); ok { + headers = hd + } else if v, ok := intf.(VarMatchOption); ok { + varMatches = append(varMatches, v) + } + } + return +} + +func separateParseOptions(options []interface{}) (fixedMatchOptions, varMatchOptions, []splitter.Option) { seenFixed := map[FixedMatchOption]bool{} seenVar := map[VarMatchOption]bool{} fixeds := make(fixedMatchOptions, 0) @@ -262,11 +321,11 @@ func separateOptions(options []interface{}) (fixedMatchOptions, varMatchOptions, return fixeds, vars, splitOps } -func (t *template) mergeOptions(options []interface{}) (fixedMatchOptions, varMatchOptions) { +func (t *template) mergeParseOptions(options []interface{}) (fixedMatchOptions, varMatchOptions) { if len(options) == 0 { return t.fixedMatchOpts, t.varMatchOpts } else if len(t.fixedMatchOpts) == 0 && len(t.varMatchOpts) == 0 { - fs, vs, _ := separateOptions(options) + fs, vs, _ := separateParseOptions(options) return fs, vs } fixed := make(fixedMatchOptions, 0) @@ -362,7 +421,7 @@ func (t *template) newUriPathPart(pt string, pos int, subParts []splitter.SubPar }, nil } } - return t.newVarPathPart(pt, pos, subParts) + return t.newVarPathPart(subParts) } func (t *template) addVar(pt pathPart) { @@ -375,7 +434,7 @@ func (t *template) addVar(pt pathPart) { } } -func (t *template) newVarPathPart(pt string, pos int, subParts []splitter.SubPart) (pathPart, error) { +func (t *template) newVarPathPart(subParts []splitter.SubPart) (pathPart, error) { result := pathPart{ fixed: false, subParts: make([]pathPart, 0), diff --git a/template_test.go b/template_test.go index fb5b4b8..9a518fb 100644 --- a/template_test.go +++ b/template_test.go @@ -118,7 +118,7 @@ func TestTemplate_PathFrom(t *testing.T) { "foo", "fooey", "bar", "abc")) require.Error(t, err) - require.Equal(t, `no var for 'bar' (position 2)`, err.Error()) + require.Equal(t, `no var for 'bar' (varPosition 2)`, err.Error()) } func TestTemplate_PathFrom_Positional(t *testing.T) { @@ -131,15 +131,15 @@ func TestTemplate_PathFrom_Positional(t *testing.T) { _, err = tmp.PathFrom(Positional("fooey", "barey")) require.Error(t, err) - require.Equal(t, `no var for position 3`, err.Error()) + require.Equal(t, `no var for varPosition 3`, err.Error()) _, err = tmp.PathFrom(Positional("fooey")) require.Error(t, err) - require.Equal(t, `no var for position 2`, err.Error()) + require.Equal(t, `no var for varPosition 2`, err.Error()) _, err = tmp.PathFrom(Positional()) require.Error(t, err) - require.Equal(t, `no var for position 1`, err.Error()) + require.Equal(t, `no var for varPosition 1`, err.Error()) } func TestTemplate_ResolveTo(t *testing.T) { @@ -259,6 +259,104 @@ func TestTemplate_Matches_WithVarOption(t *testing.T) { require.True(t, ok) } +func TestTemplate_PathFrom_WithHost(t *testing.T) { + tmp, err := NewTemplate(`/foo/{foo-id}/bar/{bar-id}`) + require.NoError(t, err) + + h := NewHost(`https://www.example.com`) + pth, err := tmp.PathFrom(Named("foo-id", "1", "bar-id", "2"), h) + require.NoError(t, err) + require.Equal(t, `https://www.example.com/foo/1/bar/2`, pth) +} + +func TestTemplate_PathFrom_WithQuery(t *testing.T) { + tmp, err := NewTemplate(`/foo/{foo-id}/bar/{bar-id}`) + require.NoError(t, err) + + q, err := NewQueryParams("fooq", true, "barq", 1.23) + require.NoError(t, err) + pth, err := tmp.PathFrom(Named("foo-id", "1", "bar-id", "2"), q) + require.NoError(t, err) + require.Contains(t, pth, `/foo/1/bar/2`) + require.Contains(t, pth, `fooq=true`) + require.Contains(t, pth, `barq=1.23`) + + q, err = NewQueryParams() + require.NoError(t, err) + pth, err = tmp.PathFrom(Named("foo-id", "1", "bar-id", "2"), q) + require.NoError(t, err) + require.Equal(t, `/foo/1/bar/2`, pth) +} + +func TestTemplate_PathFrom_WithQueryErrors(t *testing.T) { + tmp, err := NewTemplate(`/foo/{foo-id}/bar/{bar-id}`) + require.NoError(t, err) + + q, err := NewQueryParams("fooq", func() bool { + // this does not return a string! + return false + }) + require.NoError(t, err) + _, err = tmp.PathFrom(Named("foo-id", "1", "bar-id", "2"), q) + require.Error(t, err) + require.Equal(t, `unknown value type`, err.Error()) +} + +func TestTemplate_PathFrom_WithRegexCheck(t *testing.T) { + tmp, err := NewTemplate(`/foo/{foo-id:[a-z]{3}}/bar/{bar-id:[0-9]{3}}`) + require.NoError(t, err) + + _, err = tmp.PathFrom(Named("foo-id", "1", "bar-id", "2"), PathRegexCheck) + require.Error(t, err) + require.Equal(t, `no match path var`, err.Error()) + + _, err = tmp.PathFrom(Named("foo-id", "abc", "bar-id", "123"), PathRegexCheck) + require.NoError(t, err) +} + +func TestTemplate_RequestFrom(t *testing.T) { + tmp, err := NewTemplate(`/foo/{foo-id}/bar/{bar-id}`) + require.NoError(t, err) + + h := NewHost(`https://www.example.com`) + q, err := NewQueryParams("fooq", true, "barq", 1.23) + require.NoError(t, err) + hds, err := NewHeaders("Accept", "application/json") + require.NoError(t, err) + + req, err := tmp.RequestFrom("GET", Named("foo-id", "1", "bar-id", "2"), nil, h, q, hds) + require.NoError(t, err) + + require.Equal(t, "GET", req.Method) + require.Equal(t, `www.example.com`, req.Host) + require.Equal(t, `/foo/1/bar/2`, req.URL.Path) + require.Contains(t, req.URL.RawQuery, `fooq=true`) + require.Contains(t, req.URL.RawQuery, `barq=1.23`) + require.Equal(t, `application/json`, req.Header.Get("Accept")) + require.Equal(t, 1, len(req.Header)) +} + +func TestTemplate_RequestFrom_Errors(t *testing.T) { + tmp, err := NewTemplate(`/foo/{foo-id}/bar/{bar-id}`) + require.NoError(t, err) + + _, err = tmp.RequestFrom("GET", nil, nil) + require.Error(t, err) + require.Equal(t, `no var for varPosition 1`, err.Error()) + + _, err = tmp.RequestFrom("£££", Named("foo-id", "1", "bar-id", "2"), nil) + require.Error(t, err) + require.Equal(t, `net/http: invalid method "£££"`, err.Error()) + + hds, err := NewHeaders("Accept", func() bool { + // this does not return a string! + return false + }) + _, err = tmp.RequestFrom("GET", Named("foo-id", "1", "bar-id", "2"), nil, hds) + require.Error(t, err) + require.Equal(t, `unknown value type`, err.Error()) +} + func TestTemplate_MergeOptions(t *testing.T) { testCases := []struct { initOptions []interface{} @@ -351,7 +449,7 @@ func TestTemplate_MergeOptions(t *testing.T) { require.NoError(t, err) rt, ok := tmp.(*template) require.True(t, ok) - fs, vs := rt.mergeOptions(tc.addOptions) + fs, vs := rt.mergeParseOptions(tc.addOptions) require.Equal(t, tc.expectFixeds, len(fs)) require.Equal(t, tc.expectVars, len(vs)) }) diff --git a/value.go b/value.go new file mode 100644 index 0000000..e29b165 --- /dev/null +++ b/value.go @@ -0,0 +1,72 @@ +package urit + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +type Stringable interface { + String() string +} + +func getValueIf(v interface{}) (string, bool) { + switch av := v.(type) { + case string: + return av, true + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%d", av), true + case float32, float64: + return fmt.Sprintf("%v", av), true + case bool: + if av { + return "true", true + } + return "false", true + case time.Time: + return av.Format(time.RFC3339), true + case *time.Time: + return av.Format(time.RFC3339), true + default: + if str, ok := stringableValue(v); ok { + return str, true + } + } + return "", false +} + +func getValue(v interface{}) (string, error) { + if str, ok := getValueIf(v); ok { + return str, nil + } + return "", errors.New("unknown value type") +} + +func stringableValue(v interface{}) (string, bool) { + if v == nil { + return "", false + } + if sa, ok := v.(Stringable); ok { + return sa.String(), true + } + rt := reflect.TypeOf(v) + if rt.Kind() == reflect.Func { + if rt.NumIn() == 0 && rt.NumOut() == 1 && rt.Out(0).Kind() == reflect.String { + rv := reflect.ValueOf(v) + rvs := rv.Call(nil) + return rvs[0].String(), true + } + return "", false + } + if data, err := json.Marshal(v); err == nil { + str := string(data[:]) + if strings.HasPrefix(str, `"`) && strings.HasSuffix(str, `"`) { + return str[1 : len(str)-1], true + } + return str, true + } + return "", false +} diff --git a/value_test.go b/value_test.go new file mode 100644 index 0000000..5bd4979 --- /dev/null +++ b/value_test.go @@ -0,0 +1,140 @@ +package urit + +import ( + "errors" + "fmt" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestGetValueIf(t *testing.T) { + dt := time.Date(2022, 11, 8, 12, 13, 14, 0, time.UTC) + testCases := []struct { + value interface{} + expectOk bool + expectStr string + }{ + { + "foo", + true, + "foo", + }, + { + 1, + true, + "1", + }, + { + 1.23, + true, + "1.23", + }, + { + true, + true, + "true", + }, + { + false, + true, + "false", + }, + { + dt, + true, + "2022-11-08T12:13:14Z", + }, + { + &dt, + true, + "2022-11-08T12:13:14Z", + }, + { + func() string { + return "foo" + }, + true, + "foo", + }, + { + &valueStruct{Value: "foo"}, + true, + "foo", + }, + { + &marshallable{Value: "foo"}, + true, + "foo", + }, + { + &marshallable2{Value: true}, + true, + "true", + }, + { + &marshallable2{Value: false}, + false, + "", + }, + { + nil, + false, + "", + }, + } + for i, tc := range testCases { + t.Run(fmt.Sprintf("[%d]", i+1), func(t *testing.T) { + str, ok := getValueIf(tc.value) + if tc.expectOk { + require.True(t, ok) + require.Equal(t, tc.expectStr, str) + } else { + require.False(t, ok) + } + }) + } +} + +type valueStruct struct { + Value string +} + +func (v *valueStruct) String() string { + return v.Value +} + +type marshallable struct { + Value string +} + +func (m *marshallable) MarshalJSON() ([]byte, error) { + return []byte(`"` + m.Value + `"`), nil +} + +type marshallable2 struct { + Value bool +} + +func (m *marshallable2) MarshalJSON() ([]byte, error) { + if m.Value { + return []byte("true"), nil + } + return nil, errors.New("whoops") +} + +func TestGetValue(t *testing.T) { + _, err := getValue(nil) + require.Error(t, err) + require.Equal(t, `unknown value type`, err.Error()) + + _, err = getValue(func() { + // this does not yield a string + }) + require.Error(t, err) + require.Equal(t, `unknown value type`, err.Error()) + + str, err := getValue("foo") + require.NoError(t, err) + require.Equal(t, "foo", str) +}