From 6cb3b7c046b9fe73f0a1341cd2eea8071a60cc04 Mon Sep 17 00:00:00 2001 From: eric <65116642+nonbutAworker@users.noreply.github.com> Date: Wed, 23 Feb 2022 15:22:20 +0800 Subject: [PATCH 01/28] remove redundant 0 in make chan (#2101) * remove 0 in make(chan,0) to fix go-staticcheck problem --- echo_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/echo_test.go b/echo_test.go index f175d765b..d31e7b604 100644 --- a/echo_test.go +++ b/echo_test.go @@ -961,7 +961,7 @@ func TestEchoStartTLSByteString(t *testing.T) { e := New() e.HideBanner = true - errChan := make(chan error, 0) + errChan := make(chan error) go func() { errChan <- e.StartTLS(":0", test.cert, test.key) @@ -999,7 +999,7 @@ func TestEcho_StartAutoTLS(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - errChan := make(chan error, 0) + errChan := make(chan error) go func() { errChan <- e.StartAutoTLS(tc.addr) From 27b404bbc5290de56044a906c9f1692a08b64e29 Mon Sep 17 00:00:00 2001 From: eric <65116642+nonbutAworker@users.noreply.github.com> Date: Wed, 23 Feb 2022 19:28:20 +0800 Subject: [PATCH 02/28] remove unused notFoundHandler in echo struct (#2102) * remove unused notFoundHandler in echo --- echo.go | 1 - 1 file changed, 1 deletion(-) diff --git a/echo.go b/echo.go index 2e63cc6b1..88a732ee7 100644 --- a/echo.go +++ b/echo.go @@ -75,7 +75,6 @@ type ( maxParam *int router *Router routers map[string]*Router - notFoundHandler HandlerFunc pool sync.Pool Server *http.Server TLSServer *http.Server From 124825ee629f32aade886f1aeb76e0c6f70c7faa Mon Sep 17 00:00:00 2001 From: Yusuf Eyisan Date: Tue, 1 Mar 2022 10:56:46 +0300 Subject: [PATCH 03/28] Bugfix/1834 Fix X-Real-IP bug (#2007) * Fix incorrect return ip value for RealIpHeader * Improve test file to compare correct real IPs to each other and have better comments * Refactor ip extractor tests to be more readable (longer but readable) Co-authored-by: toimtoimtoim --- echo.go | 6 +- ip.go | 142 +++++++++- ip_test.go | 777 +++++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 711 insertions(+), 214 deletions(-) diff --git a/echo.go b/echo.go index 88a732ee7..b658de4d7 100644 --- a/echo.go +++ b/echo.go @@ -214,9 +214,9 @@ const ( HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" - HeaderXRealIP = "X-Real-IP" - HeaderXRequestID = "X-Request-ID" - HeaderXCorrelationID = "X-Correlation-ID" + HeaderXRealIP = "X-Real-Ip" + HeaderXRequestID = "X-Request-Id" + HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" diff --git a/ip.go b/ip.go index 39cb421fd..46d464cf9 100644 --- a/ip.go +++ b/ip.go @@ -6,6 +6,130 @@ import ( "strings" ) +/** +By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 ) +Source: https://echo.labstack.com/guide/ip-address/ + +IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more. +Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that. + +However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application. +In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally. +Otherwise, you might give someone a chance of deceiving you. **A security risk!** + +To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure. +In Echo, this can be done by configuring `Echo#IPExtractor` appropriately. +This guides show you why and how. + +> Note: if you dont' set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. + +Let's start from two questions to know the right direction: + +1. Do you put any HTTP (L7) proxy in front of the application? + - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway). +2. If yes, what HTTP header do your proxies use to pass client IP to the application? + +## Case 1. With no proxy + +If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer. +Any HTTP header is untrustable because the clients have full control what headers to be set. + +In this case, use `echo.ExtractIPDirect()`. + +```go +e.IPExtractor = echo.ExtractIPDirect() +``` + +## Case 2. With proxies using `X-Forwarded-For` header + +[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header +to relay clients' IP addresses. +At each hop on the proxies, they append the request IP address at the end of the header. + +Following example diagram illustrates this behavior. + +```text +┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │ +│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │ +└──────────┘ └──────────┘ └──────────┘ └──────────┘ + +Case 1. +XFF: "" "a" "a, b" + ~~~~~~ +Case 2. +XFF: "x" "x, a" "x, a, b" + ~~~~~~~~~ + ↑ What your app will see +``` + +In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is +configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre". +In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. + +In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader() +``` + +By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +E.g.: + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader( + TrustLinkLocal(false), + TrustIPRanges(lbIPRange), +) +``` + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +## Case 3. With proxies using `X-Real-IP` header + +`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF. + +If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromRealIPHeader() +``` + +Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**. +> Otherwise there is a chance of fraud, as it is what clients can control. + +## About default behavior + +In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer. + +As you might already notice, after reading this article, this is not good. +Sole reason this is default is just backward compatibility. + +## Private IP ranges + +See: https://en.wikipedia.org/wiki/Private_network + +Private IPv4 address ranges (RFC 1918): +* 10.0.0.0 – 10.255.255.255 (24-bit block) +* 172.16.0.0 – 172.31.255.255 (20-bit block) +* 192.168.0.0 – 192.168.255.255 (16-bit block) + +Private IPv6 address ranges: +* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + +*/ + type ipChecker struct { trustLoopback bool trustLinkLocal bool @@ -52,6 +176,7 @@ func newIPChecker(configs []TrustOption) *ipChecker { return checker } +// Go1.16+ added `ip.IsPrivate()` but until that use this implementation func isPrivateIPRange(ip net.IP) bool { if ip4 := ip.To4(); ip4 != nil { return ip4[0] == 10 || @@ -87,10 +212,12 @@ type IPExtractor func(*http.Request) string // ExtractIPDirect extracts IP address using actual IP address. // Use this if your server faces to internet directory (i.e.: uses no proxy). func ExtractIPDirect() IPExtractor { - return func(req *http.Request) string { - ra, _, _ := net.SplitHostPort(req.RemoteAddr) - return ra - } + return extractIP +} + +func extractIP(req *http.Request) string { + ra, _, _ := net.SplitHostPort(req.RemoteAddr) + return ra } // ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. @@ -98,14 +225,13 @@ func ExtractIPDirect() IPExtractor { func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { - if ip := net.ParseIP(directIP); ip != nil && checker.trust(ip) { + if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } } - return directIP + return extractIP(req) } } @@ -115,7 +241,7 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) + directIP := extractIP(req) xffs := req.Header[HeaderXForwardedFor] if len(xffs) == 0 { return directIP diff --git a/ip_test.go b/ip_test.go index 5acc11798..755900d3d 100644 --- a/ip_test.go +++ b/ip_test.go @@ -1,235 +1,606 @@ package echo import ( + "github.com/stretchr/testify/assert" "net" "net/http" - "strings" "testing" - - testify "github.com/stretchr/testify/assert" ) -const ( - // For RemoteAddr - ipForRemoteAddrLoopback = "127.0.0.1" // From 127.0.0.0/8 - sampleRemoteAddrLoopback = ipForRemoteAddrLoopback + ":8080" - ipForRemoteAddrExternal = "203.0.113.1" - sampleRemoteAddrExternal = ipForRemoteAddrExternal + ":8080" - // For x-real-ip - ipForRealIP = "203.0.113.10" - // For XFF - ipForXFF1LinkLocal = "169.254.0.101" // From 169.254.0.0/16 - ipForXFF2Private = "192.168.0.102" // From 192.168.0.0/16 - ipForXFF3External = "2001:db8::103" - ipForXFF4Private = "fc00::104" // From fc00::/7 - ipForXFF5External = "198.51.100.105" - ipForXFF6External = "192.0.2.106" - ipForXFFBroken = "this.is.broken.lol" - // keys for test cases - ipTestReqKeyNoHeader = "no header" - ipTestReqKeyRealIPExternal = "x-real-ip; remote addr external" - ipTestReqKeyRealIPInternal = "x-real-ip; remote addr internal" - ipTestReqKeyRealIPAndXFFExternal = "x-real-ip and xff; remote addr external" - ipTestReqKeyRealIPAndXFFInternal = "x-real-ip and xff; remote addr internal" - ipTestReqKeyXFFExternal = "xff; remote addr external" - ipTestReqKeyXFFInternal = "xff; remote addr internal" - ipTestReqKeyBrokenXFF = "broken xff" -) +func mustParseCIDR(s string) *net.IPNet { + _, IPNet, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return IPNet +} + +func TestIPChecker_TrustOption(t *testing.T) { + var testCases = []struct { + name string + givenOptions []TrustOption + whenIP string + expect bool + }{ + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustLoopback(false), + TrustLinkLocal(false), + TrustPrivateNet(false), + // this is private IPv6 ip + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + TrustIPRange(mustParseCIDR("2001:db8::103/48")), + }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustIPRange(mustParseCIDR("2001:db8::103/48")), + }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker(tc.givenOptions) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustIPRange(t *testing.T) { + var testCases = []struct { + name string + givenRange string + whenIP string + expect bool + }{ + { + name: "ip is within trust range, IPV6 network range", + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff", + expect: false, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.9.0", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.7.255", + expect: false, + }, + { + name: "public ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "internal ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "127.0.10.1", + expect: true, + }, + { + name: "public ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "2a00:1450:4026:805::200e", + expect: true, + }, + { + name: "internal ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "0:0:0:0:0:0:0:1", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cidr := mustParseCIDR(tc.givenRange) + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustIPRange(cidr), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} -var ( - sampleXFF = strings.Join([]string{ - ipForXFF6External, ipForXFF5External, ipForXFF4Private, ipForXFF3External, ipForXFF2Private, ipForXFF1LinkLocal, - }, ", ") +func TestTrustPrivateNet(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "do not trust public IPv4 address", + whenIP: "8.8.8.8", + expect: false, + }, + { + name: "do not trust public IPv6 address", + whenIP: "2a00:1450:4026:805::200e", + expect: false, + }, - requests = map[string]*http.Request{ - ipTestReqKeyNoHeader: &http.Request{ - RemoteAddr: sampleRemoteAddrExternal, + { // Class A: 10.0.0.0 — 10.255.255.255 + name: "do not trust IPv4 just outside of class A (lower bounds)", + whenIP: "9.255.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class A (upper bounds)", + whenIP: "11.0.0.0", + expect: false, + }, + { + name: "trust IPv4 of class A (lower bounds)", + whenIP: "10.0.0.0", + expect: true, + }, + { + name: "trust IPv4 of class A (upper bounds)", + whenIP: "10.255.255.255", + expect: true, + }, + + { // Class B: 172.16.0.0 — 172.31.255.255 + name: "do not trust IPv4 just outside of class B (lower bounds)", + whenIP: "172.15.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class B (upper bounds)", + whenIP: "172.32.0.0", + expect: false, + }, + { + name: "trust IPv4 of class B (lower bounds)", + whenIP: "172.16.0.0", + expect: true, + }, + { + name: "trust IPv4 of class B (upper bounds)", + whenIP: "172.31.255.255", + expect: true, + }, + + { // Class C: 192.168.0.0 — 192.168.255.255 + name: "do not trust IPv4 just outside of class C (lower bounds)", + whenIP: "192.167.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class C (upper bounds)", + whenIP: "192.169.0.0", + expect: false, + }, + { + name: "trust IPv4 of class C (lower bounds)", + whenIP: "192.168.0.0", + expect: true, + }, + { + name: "trust IPv4 of class C (upper bounds)", + whenIP: "192.168.255.255", + expect: true, + }, + + { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + // splits the address block in two equally sized halves, fc00::/8 and fd00::/8. + // https://en.wikipedia.org/wiki/Unique_local_address + name: "trust IPv6 private address", + whenIP: "fdfc:3514:2cb3:4bd5::", + expect: true, + }, + { + name: "do not trust IPv6 just out of /fd (upper bounds)", + whenIP: "/fe00:0000:0000:0000:0000", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + + TrustPrivateNet(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLinkLocal(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust link local IPv4 address (lower bounds)", + whenIP: "169.254.0.0", + expect: true, + }, + { + name: "trust link local IPv4 address (upper bounds)", + whenIP: "169.254.255.255", + expect: true, }, - ipTestReqKeyRealIPExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, + { + name: "do not trust link local IPv4 address (outside of lower bounds)", + whenIP: "169.253.255.255", + expect: false, + }, + { + name: "do not trust link local IPv4 address (outside of upper bounds)", + whenIP: "169.255.0.0", + expect: false, + }, + { + name: "trust link local IPv6 address ", + whenIP: "fe80::1", + expect: true, + }, + { + name: "do not trust link local IPv6 address ", + whenIP: "fec0::1", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLinkLocal(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLoopback(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust IPv4 as localhost", + whenIP: "127.0.0.1", + expect: true, + }, + { + name: "trust IPv6 as localhost", + whenIP: "::1", + expect: true, + }, + { + name: "do not trust public ip as localhost", + whenIP: "8.8.8.8", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLoopback(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestExtractIPDirect(t *testing.T) { + var testCases = []struct { + name string + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrExternal, + expectIP: "203.0.113.1", }, - ipTestReqKeyRealIPInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, + { + name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "203.0.113.1", }, - ipTestReqKeyRealIPAndXFFExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "127.0.0.1:8080", }, - RemoteAddr: sampleRemoteAddrExternal, + expectIP: "127.0.0.1", }, - ipTestReqKeyRealIPAndXFFInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "203.0.113.1", }, - ipTestReqKeyXFFExternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"127.0.0.1"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", }, - RemoteAddr: sampleRemoteAddrExternal, + expectIP: "127.0.0.1", }, - ipTestReqKeyXFFInternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "203.0.113.1", }, - ipTestReqKeyBrokenXFF: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{ipForXFFBroken + ", " + ipForXFF1LinkLocal}, + { + name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "127.0.0.1", + }, + { + name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", }, } -) -func TestExtractIP(t *testing.T) { - _, ipv4AllRange, _ := net.ParseCIDR("0.0.0.0/0") - _, ipv6AllRange, _ := net.ParseCIDR("::/0") - _, ipForXFF3ExternalRange, _ := net.ParseCIDR(ipForXFF3External + "/48") - _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR(ipForRemoteAddrExternal + "/24") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPDirect()(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromRealIPHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") - tests := map[string]*struct { - extractor IPExtractor - expectedIPs map[string]string + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string }{ - "ExtractIPDirect": { - ExtractIPDirect(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(default)": { - ExtractIPFromRealIPHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust only direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(default)": { - ExtractIPFromXFFHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust only direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF3External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForXFF3External, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust everything)": { - // This is similar to legacy behavior, but ignores x-real-ip header. - ExtractIPFromXFFHeader(TrustIPRange(ipv4AllRange), TrustIPRange(ipv6AllRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF6External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF6External, - ipTestReqKeyXFFExternal: ipForXFF6External, - ipTestReqKeyXFFInternal: ipForXFF6External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust ipForXFF3External)": { - // This trusts private network also after "additional" trust ranges unlike `TrustNProxies(1)` doesn't - ExtractIPFromXFFHeader(TrustIPRange(ipForXFF3ExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF5External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF5External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - assert := testify.New(t) - for key, req := range requests { - actual := test.extractor(req) - expected := test.expectedIPs[key] - assert.Equal(expected, actual, "Request: %s", key) - } + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromXFFHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request has INVALID external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.3", + }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed) + // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs) + // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office) + // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"}, + }, + RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP + }, + expectIP: "203.0.100.100", // this is first trusted IP in XFF chain + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) }) } } From 7e719b46e290993a7f396819808998e3ae0becf4 Mon Sep 17 00:00:00 2001 From: Wagner Souza Date: Tue, 1 Mar 2022 20:11:28 -0300 Subject: [PATCH 04/28] Add cache-control and connection headers (#2103) Co-authored-by: Wagner Souza --- echo.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/echo.go b/echo.go index b658de4d7..6143403f3 100644 --- a/echo.go +++ b/echo.go @@ -220,6 +220,8 @@ const ( HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" From d66712b252b09751742243aaae56fdd5628ce4d2 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 2 Mar 2022 22:59:19 +0100 Subject: [PATCH 05/28] Update direct golang deps --- go.mod | 10 +++++----- go.sum | 23 +++++++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 4de2bdde1..f09e32cf9 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,18 @@ require ( github.com/labstack/gommon v0.3.1 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 + golang.org/x/net v0.0.0-20220225172249-27dd8689420f + golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.11 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/text v0.3.7 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index f66734243..7b86ace06 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,9 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -18,25 +19,27 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From da85d23d685ce31105f1a88682edaeb284223c53 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 2 Mar 2022 23:11:46 +0100 Subject: [PATCH 06/28] Revert "Update direct golang deps" This reverts commit d66712b252b09751742243aaae56fdd5628ce4d2. --- go.mod | 10 +++++----- go.sum | 23 ++++++++++------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index f09e32cf9..4de2bdde1 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,18 @@ require ( github.com/labstack/gommon v0.3.1 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20220214200702-86341886e292 - golang.org/x/net v0.0.0-20220225172249-27dd8689420f - golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 + golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 + golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-colorable v0.1.11 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect + golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect golang.org/x/text v0.3.7 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 7b86ace06..f66734243 100644 --- a/go.sum +++ b/go.sum @@ -5,9 +5,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -19,27 +18,25 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= -golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= -golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= -golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 5ebed440aeec1abf7f08cca41cb02f6aaf0d7f6a Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 2 Mar 2022 23:16:19 +0100 Subject: [PATCH 07/28] Update version to v4.7.0 --- CHANGELOG.md | 21 +++++++++++++++++++++ echo.go | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 372ed13c5..461ac89c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## v4.7.0 - 2022-03-01 + +**Enhancements** + +* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060) +* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072) +* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027) +* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064) + +**Fixes** + +* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007) +* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102) + +**General** + +* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103) +* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078) +* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049) +* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README + ## v4.6.3 - 2022-01-10 **Fixes** diff --git a/echo.go b/echo.go index 6143403f3..143f9ffe3 100644 --- a/echo.go +++ b/echo.go @@ -246,7 +246,7 @@ const ( const ( // Version of Echo - Version = "4.6.3" + Version = "4.7.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 3f5b733425617138573e3768381278f619561f7e Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 13 Mar 2022 15:05:12 +0200 Subject: [PATCH 08/28] Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) (#2123) * Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) --- echo_fs_go1.16.go | 40 ++++++++++++++++++++++++------- echo_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 88 insertions(+), 13 deletions(-) diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go index 435459de2..eb17768ab 100644 --- a/echo_fs_go1.16.go +++ b/echo_fs_go1.16.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path/filepath" + "runtime" "strings" ) @@ -94,10 +95,12 @@ func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { } } -// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` -// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` -// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break -// all old applications that rely on being able to traverse up from current executable run path. +// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. +// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. +// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` +// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not +// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to +// traverse up from current executable run path. // NB: private because you really should use fs.FS implementation instances type defaultFS struct { prefix string @@ -108,20 +111,26 @@ func newDefaultFS() *defaultFS { dir, _ := os.Getwd() return &defaultFS{ prefix: dir, - fs: os.DirFS(dir), + fs: nil, } } func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) + } return fs.fs.Open(name) } func subFS(currentFs fs.FS, root string) (fs.FS, error) { root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to - // allow cases when root is given as `../somepath` which is not valid for fs.FS - root = filepath.Join(dFS.prefix, root) + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if isRelativePath(root) { + root = filepath.Join(dFS.prefix, root) + } return &defaultFS{ prefix: root, fs: os.DirFS(root), @@ -130,6 +139,21 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { return fs.Sub(currentFs, root) } +func isRelativePath(path string) bool { + if path == "" { + return true + } + if path[0] == '/' { + return false + } + if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { + // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names + // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats + return false + } + return true +} + // MustSubFS creates sub FS from current filesystem or panic on failure. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // diff --git a/echo_test.go b/echo_test.go index d31e7b604..0e1e42be0 100644 --- a/echo_test.go +++ b/echo_test.go @@ -84,6 +84,14 @@ func TestEchoStatic(t *testing.T) { expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, + { + name: "ok with relative path for root points to directory", + givenPrefix: "/images", + givenRoot: "./_fixture/images", + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, { name: "No file", givenPrefix: "/images", @@ -246,11 +254,54 @@ func TestEchoStaticRedirectIndex(t *testing.T) { } func TestEchoFile(t *testing.T) { - e := New() - e.File("/walle", "_fixture/images/walle.png") - c, b := request(http.MethodGet, "/walle", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) + var testCases = []struct { + name string + givenPath string + givenFile string + whenPath string + expectCode int + expectStartsWith string + }{ + { + name: "ok", + givenPath: "/walle", + givenFile: "_fixture/images/walle.png", + whenPath: "/walle", + expectCode: http.StatusOK, + expectStartsWith: string([]byte{0x89, 0x50, 0x4e}), + }, + { + name: "ok with relative path", + givenPath: "/", + givenFile: "./go.mod", + whenPath: "/", + expectCode: http.StatusOK, + expectStartsWith: "module github.com/labstack/echo/v", + }, + { + name: "nok file does not exist", + givenPath: "/", + givenFile: "./this-file-does-not-exist", + whenPath: "/", + expectCode: http.StatusNotFound, + expectStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() // we are using echo.defaultFS instance + e.File(tc.givenPath, tc.givenFile) + + c, b := request(http.MethodGet, tc.whenPath, e) + assert.Equal(t, tc.expectCode, c) + + if len(b) > len(tc.expectStartsWith) { + b = b[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, b) + }) + } } func TestEchoMiddleware(t *testing.T) { From 54efc3850dd205bbffe650763533310cae170f4d Mon Sep 17 00:00:00 2001 From: eric <65116642+nonbutAworker@users.noreply.github.com> Date: Sun, 13 Mar 2022 21:31:39 +0800 Subject: [PATCH 09/28] remove some unused code (#2116) * remove unused code --- binder_test.go | 2 +- router_test.go | 31 ------------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/binder_test.go b/binder_test.go index 946906a96..034967793 100644 --- a/binder_test.go +++ b/binder_test.go @@ -54,7 +54,7 @@ func TestBindingError_Error(t *testing.T) { func TestBindingError_ErrorJSON(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) - resp, err := json.Marshal(err) + resp, _ := json.Marshal(err) assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) } diff --git a/router_test.go b/router_test.go index 5cbb8d9b8..457566b90 100644 --- a/router_test.go +++ b/router_test.go @@ -1,7 +1,6 @@ package echo import ( - "fmt" "net/http" "net/http/httptest" "strings" @@ -2446,33 +2445,3 @@ func BenchmarkRouterGooglePlusAPIMisses(b *testing.B) { func BenchmarkRouterParamsAndAnyAPI(b *testing.B) { benchmarkRouterRoutes(b, paramAndAnyAPI, paramAndAnyAPIToFind) } - -func (n *node) printTree(pfx string, tail bool) { - p := prefix(tail, pfx, "└── ", "├── ") - fmt.Printf("%s%s, %p: type=%d, parent=%p, handler=%v, pnames=%v\n", p, n.prefix, n, n.kind, n.parent, n.methodHandler, n.pnames) - - p = prefix(tail, pfx, " ", "│ ") - - children := n.staticChildren - l := len(children) - - if n.paramChild != nil { - n.paramChild.printTree(p, n.anyChild == nil && l == 0) - } - if n.anyChild != nil { - n.anyChild.printTree(p, l == 0) - } - for i := 0; i < l-1; i++ { - children[i].printTree(p, false) - } - if l > 0 { - children[l-1].printTree(p, true) - } -} - -func prefix(tail bool, p, on, off string) string { - if tail { - return fmt.Sprintf("%s%s", p, on) - } - return fmt.Sprintf("%s%s", p, off) -} From b445958c3ce4cf34997a67ef73a30cd870170945 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 13 Mar 2022 18:20:30 +0200 Subject: [PATCH 10/28] Update version and changelog for 4.7.1 --- CHANGELOG.md | 11 +++++++++++ echo.go | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 461ac89c9..7d1d9086a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v4.7.1 - 2022-03-13 + +**Fixes** + +* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123) + +**Enhancements** + +* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116) + + ## v4.7.0 - 2022-03-01 **Enhancements** diff --git a/echo.go b/echo.go index 143f9ffe3..5b3087269 100644 --- a/echo.go +++ b/echo.go @@ -246,7 +246,7 @@ const ( const ( // Version of Echo - Version = "4.7.0" + Version = "4.7.1" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 05df10c62f8a753e342623afb7dec8dbf4ef3f59 Mon Sep 17 00:00:00 2001 From: Gabriel Nelle Date: Mon, 14 Mar 2022 10:44:07 +0100 Subject: [PATCH 11/28] fix nil pointer exception when calling Start again after address binding error --- echo.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/echo.go b/echo.go index 5b3087269..c2f23c194 100644 --- a/echo.go +++ b/echo.go @@ -732,7 +732,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { return s.Serve(e.Listener) } -func (e *Echo) configureServer(s *http.Server) (err error) { +func (e *Echo) configureServer(s *http.Server) error { // Setup e.colorer.SetOutput(e.Logger.Output()) s.ErrorLog = e.StdLogger @@ -747,10 +747,11 @@ func (e *Echo) configureServer(s *http.Server) (err error) { if s.TLSConfig == nil { if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) + l, err := newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } + e.Listener = l } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) @@ -791,7 +792,7 @@ func (e *Echo) TLSListenerAddr() net.Addr { } // StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { +func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error { e.startupMutex.Lock() // Setup s := e.Server @@ -808,11 +809,12 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { } if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) + l, err := newListener(s.Addr, e.ListenerNetwork) if err != nil { e.startupMutex.Unlock() return err } + e.Listener = l } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) From 5c38c3b770c2e477f17266e09fe77ee07ab70dfe Mon Sep 17 00:00:00 2001 From: Becir Basic Date: Wed, 16 Mar 2022 00:29:42 +0100 Subject: [PATCH 12/28] Recover middleware should not log panic for aborted handler (#2134, fixes #2133) Co-authored-by: Becir Basic --- middleware/recover.go | 4 ++++ middleware/recover_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/middleware/recover.go b/middleware/recover.go index a621a9efe..7b6128533 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "net/http" "runtime" "github.com/labstack/echo/v4" @@ -77,6 +78,9 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { defer func() { if r := recover(); r != nil { + if r == http.ErrAbortHandler { + panic(r) + } err, ok := r.(error) if !ok { err = fmt.Errorf("%v", r) diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 9ac4feedc..b27f3b41c 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -28,6 +28,35 @@ func TestRecover(t *testing.T) { assert.Contains(t, buf.String(), "PANIC RECOVER") } +func TestRecoverErrAbortHandler(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + panic(http.ErrAbortHandler) + })) + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`") + } else { + if err, ok := r.(error); ok { + assert.ErrorIs(t, err, http.ErrAbortHandler) + } else { + assert.Fail(t, "not of error type") + } + } + }() + + h(c) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.NotContains(t, buf.String(), "PANIC RECOVER") +} + func TestRecoverWithConfig_LogLevel(t *testing.T) { tests := []struct { logLevel log.Lvl From 01d7d01bbc1948cd308b2ae93a131654e6dba195 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 16 Mar 2022 01:43:20 +0200 Subject: [PATCH 13/28] Fix CSRF middleware not being able to extract token from `multipart/form-data` form (#2136, fixes #2135) --- middleware/extractor.go | 4 ++-- middleware/extractor_test.go | 39 +++++++++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/middleware/extractor.go b/middleware/extractor.go index a57ed4e13..afdfd8195 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -168,8 +168,8 @@ func valuesFromCookie(name string) ValuesExtractor { // valuesFromForm returns a function that extracts values from the form field. func valuesFromForm(name string) ValuesExtractor { return func(c echo.Context) ([]string, error) { - if parseErr := c.Request().ParseForm(); parseErr != nil { - return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr) + if c.Request().Form == nil { + _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does } values := c.Request().Form[name] if len(values) == 0 { diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index ae4b30a8a..2e898f541 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -1,9 +1,11 @@ package middleware import ( + "bytes" "fmt" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" + "mime/multipart" "net/http" "net/http/httptest" "net/url" @@ -499,6 +501,25 @@ func TestValuesFromForm(t *testing.T) { return req } + exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { + var b bytes.Buffer + w := multipart.NewWriter(&b) + w.WriteField("name", "Jon Snow") + w.WriteField("emails[]", "jon@labstack.com") + if mod != nil { + mod(w) + } + + fw, _ := w.CreateFormFile("upload", "my.file") + fw.Write([]byte(`
hi
`)) + w.Close() + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String())) + req.Header.Add(echo.HeaderContentType, w.FormDataContentType()) + + return req + } + var testCases = []struct { name string givenRequest *http.Request @@ -520,6 +541,14 @@ func TestValuesFromForm(t *testing.T) { whenName: "emails[]", expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, + { + name: "ok, POST multipart/form, multiple value", + givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) { + w.WriteField("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, { name: "ok, GET form, single value", givenRequest: exampleGetFormRequest(nil), @@ -540,16 +569,6 @@ func TestValuesFromForm(t *testing.T) { whenName: "nope", expectError: errFormExtractorValueMissing.Error(), }, - { - name: "nok, POST form, form parsing error", - givenRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodPost, "/", nil) - req.Body = nil - return req - }(), - whenName: "name", - expectError: "valuesFromForm parse form failed: missing form body", - }, { name: "ok, cut values over extractorLimit", givenRequest: examplePostFormRequest(func(v *url.Values) { From 1919cf4491fa46624a34eb1fb2dd13d414343b64 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 13 Mar 2022 16:00:02 +0200 Subject: [PATCH 14/28] Timeout middleware write race --- middleware/timeout.go | 130 ++++++++++++++++++++++++------------- middleware/timeout_test.go | 9 ++- 2 files changed, 91 insertions(+), 48 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 768ef8d70..4e8836c85 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -2,10 +2,10 @@ package middleware import ( "context" + "github.com/labstack/echo/v4" "net/http" + "sync" "time" - - "github.com/labstack/echo/v4" ) // --------------------------------------------------------------------------------------------------------------- @@ -55,29 +55,27 @@ import ( // }) // -type ( - // TimeoutConfig defines the config for Timeout middleware. - TimeoutConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code - // It can be used to define a custom timeout error message - ErrorMessage string - - // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after - // request timeouted and we already had sent the error code (503) and message response to the client. - // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer - // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` - OnTimeoutRouteErrorHandler func(err error, c echo.Context) - - // Timeout configures a timeout for the middleware, defaults to 0 for no timeout - // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) - // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output - // difference over 500microseconds (0.5millisecond) response seems to be reliable - Timeout time.Duration - } -) +// TimeoutConfig defines the config for Timeout middleware. +type TimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code + // It can be used to define a custom timeout error message + ErrorMessage string + + // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after + // request timeouted and we already had sent the error code (503) and message response to the client. + // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer + // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` + OnTimeoutRouteErrorHandler func(err error, c echo.Context) + + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + Timeout time.Duration +} var ( // DefaultTimeoutConfig is the default Timeout middleware config. @@ -94,10 +92,17 @@ func Timeout() echo.MiddlewareFunc { return TimeoutWithConfig(DefaultTimeoutConfig) } -// TimeoutWithConfig returns a Timeout middleware with config. -// See: `Timeout()`. +// TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration. func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { - // Defaults + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + +// ToMiddleware converts Config to middleware or returns an error for invalid configuration +func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultTimeoutConfig.Skipper } @@ -108,26 +113,29 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { return next(c) } + errChan := make(chan error, 1) handlerWrapper := echoHandlerFuncWrapper{ + writer: &ignorableWriter{ResponseWriter: c.Response().Writer}, ctx: c, handler: next, - errChan: make(chan error, 1), + errChan: errChan, errHandler: config.OnTimeoutRouteErrorHandler, } handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) - handler.ServeHTTP(c.Response().Writer, c.Request()) + handler.ServeHTTP(handlerWrapper.writer, c.Request()) select { - case err := <-handlerWrapper.errChan: + case err := <-errChan: return err default: return nil } } - } + }, nil } type echoHandlerFuncWrapper struct { + writer *ignorableWriter ctx echo.Context handler echo.HandlerFunc errHandler func(err error, c echo.Context) @@ -160,23 +168,53 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques } return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers } - // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client - // and should not anymore send additional headers/data - // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body if err != nil { - // Error must be written into Writer created in `http.TimeoutHandler` so to get Response into `commited` state. - // So call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send - // status code by itself and after that our tries to write status code will not work anymore and/or create errors in - // log about `superfluous response.WriteHeader call from` - t.ctx.Error(err) - // we pass error from handler to middlewares up in handler chain to act on it if needed. But this means that - // global error handler is probably be called twice as `t.ctx.Error` already does that. - - // NB: later call of the global error handler or middlewares will not take any effect, as echo.Response will be - // already marked as `committed` because we called global error handler above. - t.ctx.Response().Writer = originalWriter // make sure we restore before we signal original coroutine about the error + // This is needed as `http.TimeoutHandler` will write status code by itself on error and after that our tries to write + // status code will not work anymore as Echo.Response thinks it has been already "committed" and further writes + // create errors in log about `superfluous response.WriteHeader call from` + t.writer.Ignore(true) + t.ctx.Response().Writer = originalWriter // make sure we restore writer before we signal original coroutine about the error + // we pass error from handler to middlewares up in handler chain to act on it if needed. t.errChan <- err return } + // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client + // and should not anymore send additional headers/data + // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body t.ctx.Response().Writer = originalWriter } + +// ignorableWriter is ResponseWriter implementations that allows us to mark writer to ignore further write calls. This +// is handy in cases when you do not have direct control of code being executed (3rd party middleware) but want to make +// sure that external code will not be able to write response to the client. +// Writer is coroutine safe for writes. +type ignorableWriter struct { + http.ResponseWriter + + lock sync.Mutex + ignoreWrites bool +} + +func (w *ignorableWriter) Ignore(ignore bool) { + w.lock.Lock() + w.ignoreWrites = ignore + w.lock.Unlock() +} + +func (w *ignorableWriter) WriteHeader(code int) { + w.lock.Lock() + defer w.lock.Unlock() + if w.ignoreWrites { + return + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *ignorableWriter) Write(b []byte) (int, error) { + w.lock.Lock() + defer w.lock.Unlock() + if w.ignoreWrites { + return len(b), nil + } + return w.ResponseWriter.Write(b) +} diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index aa6402b8d..7fb802a9a 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -74,13 +74,18 @@ func TestTimeoutErrorOutInHandler(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) + rec.Code = 1 // we want to be sure that even 200 will not be sent err := m(func(c echo.Context) error { + // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able + // to handle returned error and this can be done only then handler has not yet committed (written status code) + // the response. return echo.NewHTTPError(http.StatusTeapot, "err") })(c) assert.Error(t, err) - assert.Equal(t, http.StatusTeapot, rec.Code) - assert.Equal(t, "{\"message\":\"err\"}\n", rec.Body.String()) + assert.EqualError(t, err, "code=418, message=err") + assert.Equal(t, 1, rec.Code) + assert.Equal(t, "", rec.Body.String()) } func TestTimeoutSuccessfulRequest(t *testing.T) { From ec92fedf21e817d2d52004a4178292404beb9eaa Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 16 Mar 2022 08:43:59 +0200 Subject: [PATCH 15/28] Update version and changelog for 4.7.2 --- CHANGELOG.md | 13 +++++++++++++ echo.go | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d1d9086a..ba75d71f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v4.7.2 - 2022-03-16 + +**Fixes** + +* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131) +* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136) +* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126) + +**Enhancements** + +* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134) + + ## v4.7.1 - 2022-03-13 **Fixes** diff --git a/echo.go b/echo.go index c2f23c194..8829619c7 100644 --- a/echo.go +++ b/echo.go @@ -246,7 +246,7 @@ const ( const ( // Version of Echo - Version = "4.7.1" + Version = "4.7.2" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 59d2eaa4ac35c4dca41b6545bd410b95f60fe354 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 13 Mar 2022 17:30:02 +0200 Subject: [PATCH 16/28] Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to ValueBinder --- binder.go | 121 ++++++++++++++++++++-- binder_test.go | 274 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 385 insertions(+), 10 deletions(-) diff --git a/binder.go b/binder.go index 0900ce8dc..4409c174a 100644 --- a/binder.go +++ b/binder.go @@ -1,6 +1,8 @@ package echo import ( + "encoding" + "encoding/json" "fmt" "net/http" "strconv" @@ -52,8 +54,11 @@ import ( * time * duration * BindUnmarshaler() interface + * TextUnmarshaler() interface + * JSONUnmarshaler() interface * UnixTime() - converts unix time (integer) to time.Time - * UnixTimeNano() - converts unix time with nano second precision (integer) to time.Time + * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time + * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error` */ @@ -321,6 +326,78 @@ func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshal return b } +// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface +func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// MustJSONUnmarshaler requires parameter value to exist to be bind to destination implementing json.Unmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface +func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + +// MustTextUnmarshaler requires parameter value to exist to be bind to destination implementing encoding.TextUnmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + // BindWithDelimiter binds parameter to destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { @@ -1161,7 +1238,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, false) + return b.unixTime(sourceParam, dest, false, time.Second) } // MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding @@ -1172,10 +1249,31 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, false) + return b.unixTime(sourceParam, dest, true, time.Second) +} + +// UnixTimeMilli binds parameter to time.Time variable (in local Time corresponding to the given Unix time in millisecond precision). +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, time.Millisecond) } -// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nano second precision). +// MustUnixTimeMilli requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// to the given Unix time in millisecond precision). Returns error when value does not exist. +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, time.Millisecond) +} + +// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nanosecond precision). // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 @@ -1185,7 +1283,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, true) + return b.unixTime(sourceParam, dest, false, time.Nanosecond) } // MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding @@ -1199,10 +1297,10 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, true) + return b.unixTime(sourceParam, dest, true, time.Nanosecond) } -func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, isNano bool) *ValueBinder { +func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -1221,10 +1319,13 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi return b } - if isNano { - *dest = time.Unix(0, n) - } else { + switch precision { + case time.Second: *dest = time.Unix(n, 0) + case time.Millisecond: + *dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows + case time.Nanosecond: + *dest = time.Unix(0, n) } return b } diff --git a/binder_test.go b/binder_test.go index 034967793..910bbfc50 100644 --- a/binder_test.go +++ b/binder_test.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "io" + "math/big" "net/http" "net/http/httptest" "strconv" @@ -2187,6 +2188,188 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { } } +func TestValueBinder_JSONUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustJSONUnmarshaler("param", &dest).BindError() + } else { + err = b.JSONUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TextUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustTextUnmarshaler("param", &dest).BindError() + } else { + err = b.TextUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_BindWithDelimiter_types(t *testing.T) { var testCases = []struct { name string @@ -2529,6 +2712,97 @@ func TestValueBinder_UnixTime(t *testing.T) { } } +func TestValueBinder_UnixTimeMilli(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Time + expectError string + }{ + { + name: "ok, binds value, unix time in milliseconds", + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTimeMilli("param", &dest).BindError() + } else { + err = b.UnixTimeMilli("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_UnixTimeNano(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603 exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 From 63c62bcbe521dd060e38392f60a1437764d0794c Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 16 Mar 2022 00:56:50 +0100 Subject: [PATCH 17/28] Tidy up comments for value binders --- binder.go | 86 +++++++++++++++++++++++++++---------------------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/binder.go b/binder.go index 4409c174a..5a6cf9d9b 100644 --- a/binder.go +++ b/binder.go @@ -209,7 +209,7 @@ func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []st return b.customFunc(sourceParam, customFunc, false) } -// MustCustomFunc requires parameter values to exist to be bind with Func. Returns error when value does not exist. +// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist. func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { return b.customFunc(sourceParam, customFunc, true) } @@ -246,7 +246,7 @@ func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder { return b } -// MustString requires parameter value to exist to be bind to string variable. Returns error when value does not exist +// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -275,7 +275,7 @@ func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder { return b } -// MustStrings requires parameter values to exist to be bind to slice of string variables. Returns error when value does not exist +// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -307,7 +307,7 @@ func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) return b } -// MustBindUnmarshaler requires parameter value to exist to be bind to destination implementing BindUnmarshaler interface. +// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -343,7 +343,7 @@ func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) return b } -// MustJSONUnmarshaler requires parameter value to exist to be bind to destination implementing json.Unmarshaler interface. +// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -379,7 +379,7 @@ func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnma return b } -// MustTextUnmarshaler requires parameter value to exist to be bind to destination implementing encoding.TextUnmarshaler interface. +// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -404,7 +404,7 @@ func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, de return b.bindWithDelimiter(sourceParam, dest, delimiter, false) } -// MustBindWithDelimiter requires parameter value to exist to be bind destination by suitable conversion function. +// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, true) @@ -453,7 +453,7 @@ func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, false) } -// MustInt64 requires parameter value to exist to be bind to int64 variable. Returns error when value does not exist +// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, true) } @@ -463,7 +463,7 @@ func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, false) } -// MustInt32 requires parameter value to exist to be bind to int32 variable. Returns error when value does not exist +// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, true) } @@ -473,7 +473,7 @@ func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, false) } -// MustInt16 requires parameter value to exist to be bind to int16 variable. Returns error when value does not exist +// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, true) } @@ -483,7 +483,7 @@ func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, false) } -// MustInt8 requires parameter value to exist to be bind to int8 variable. Returns error when value does not exist +// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, true) } @@ -493,7 +493,7 @@ func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, false) } -// MustInt requires parameter value to exist to be bind to int variable. Returns error when value does not exist +// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, true) } @@ -621,7 +621,7 @@ func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt64s requires parameter value to exist to be bind to int64 slice variable. Returns error when value does not exist +// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -631,7 +631,7 @@ func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt32s requires parameter value to exist to be bind to int32 slice variable. Returns error when value does not exist +// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -641,7 +641,7 @@ func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt16s requires parameter value to exist to be bind to int16 slice variable. Returns error when value does not exist +// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -651,7 +651,7 @@ func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt8s requires parameter value to exist to be bind to int8 slice variable. Returns error when value does not exist +// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -661,7 +661,7 @@ func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInts requires parameter value to exist to be bind to int slice variable. Returns error when value does not exist +// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -671,7 +671,7 @@ func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, false) } -// MustUint64 requires parameter value to exist to be bind to uint64 variable. Returns error when value does not exist +// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, true) } @@ -681,7 +681,7 @@ func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, false) } -// MustUint32 requires parameter value to exist to be bind to uint32 variable. Returns error when value does not exist +// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, true) } @@ -691,7 +691,7 @@ func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, false) } -// MustUint16 requires parameter value to exist to be bind to uint16 variable. Returns error when value does not exist +// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, true) } @@ -701,7 +701,7 @@ func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustUint8 requires parameter value to exist to be bind to uint8 variable. Returns error when value does not exist +// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -711,7 +711,7 @@ func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustByte requires parameter value to exist to be bind to byte variable. Returns error when value does not exist +// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -721,7 +721,7 @@ func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, false) } -// MustUint requires parameter value to exist to be bind to uint variable. Returns error when value does not exist +// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, true) } @@ -849,7 +849,7 @@ func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint64s requires parameter value to exist to be bind to uint64 slice variable. Returns error when value does not exist +// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -859,7 +859,7 @@ func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint32s requires parameter value to exist to be bind to uint32 slice variable. Returns error when value does not exist +// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -869,7 +869,7 @@ func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint16s requires parameter value to exist to be bind to uint16 slice variable. Returns error when value does not exist +// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -879,7 +879,7 @@ func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint8s requires parameter value to exist to be bind to uint8 slice variable. Returns error when value does not exist +// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -889,7 +889,7 @@ func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUints requires parameter value to exist to be bind to uint slice variable. Returns error when value does not exist +// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -899,7 +899,7 @@ func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, false) } -// MustBool requires parameter value to exist to be bind to bool variable. Returns error when value does not exist +// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, true) } @@ -964,7 +964,7 @@ func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, false) } -// MustBools requires parameter values to exist to be bind to slice of bool variables. Returns error when values does not exist +// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, true) } @@ -974,7 +974,7 @@ func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, false) } -// MustFloat64 requires parameter value to exist to be bind to float64 variable. Returns error when value does not exist +// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, true) } @@ -984,7 +984,7 @@ func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, false) } -// MustFloat32 requires parameter value to exist to be bind to float32 variable. Returns error when value does not exist +// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, true) } @@ -1069,7 +1069,7 @@ func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat64s requires parameter values to exist to be bind to slice of float64 variables. Returns error when values does not exist +// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1079,7 +1079,7 @@ func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat32s requires parameter values to exist to be bind to slice of float32 variables. Returns error when values does not exist +// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1089,7 +1089,7 @@ func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) * return b.time(sourceParam, dest, layout, false) } -// MustTime requires parameter value to exist to be bind to time.Time variable. Returns error when value does not exist +// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder { return b.time(sourceParam, dest, layout, true) } @@ -1120,7 +1120,7 @@ func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string return b.times(sourceParam, dest, layout, false) } -// MustTimes requires parameter values to exist to be bind to slice of time.Time variables. Returns error when values does not exist +// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { return b.times(sourceParam, dest, layout, true) } @@ -1161,7 +1161,7 @@ func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBi return b.duration(sourceParam, dest, false) } -// MustDuration requires parameter value to exist to be bind to time.Duration variable. Returns error when value does not exist +// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder { return b.duration(sourceParam, dest, true) } @@ -1192,7 +1192,7 @@ func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *Valu return b.durationsValue(sourceParam, dest, false) } -// MustDurations requires parameter values to exist to be bind to slice of time.Duration variables. Returns error when values does not exist +// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder { return b.durationsValue(sourceParam, dest, true) } @@ -1241,7 +1241,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder return b.unixTime(sourceParam, dest, false, time.Second) } -// MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time). Returns error when value does not exist. // // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 @@ -1252,7 +1252,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi return b.unixTime(sourceParam, dest, true, time.Second) } -// UnixTimeMilli binds parameter to time.Time variable (in local Time corresponding to the given Unix time in millisecond precision). +// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision). // // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // @@ -1262,7 +1262,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB return b.unixTime(sourceParam, dest, false, time.Millisecond) } -// MustUnixTimeMilli requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time in millisecond precision). Returns error when value does not exist. // // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 @@ -1273,7 +1273,7 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va return b.unixTime(sourceParam, dest, true, time.Millisecond) } -// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nanosecond precision). +// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision). // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 @@ -1286,7 +1286,7 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi return b.unixTime(sourceParam, dest, false, time.Nanosecond) } -// MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding // to the given Unix time value in nano second precision). Returns error when value does not exist. // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 From 572466d92889a5c946885ec90d5a94d7ad25b0a3 Mon Sep 17 00:00:00 2001 From: gemaizi <864321211@qq.com> Date: Mon, 21 Mar 2022 23:45:06 +0800 Subject: [PATCH 18/28] Fix body_limit middleware unit test --- middleware/body_limit_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 0e8642a06..8981534d4 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -33,12 +33,13 @@ func TestBodyLimit(t *testing.T) { assert.Equal(hw, rec.Body.Bytes()) } - // Based on content read (overlimit) + // Based on content length (overlimit) he := BodyLimit("2B")(h)(c).(*echo.HTTPError) assert.Equal(http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) if assert.NoError(BodyLimit("2M")(h)(c)) { @@ -48,6 +49,7 @@ func TestBodyLimit(t *testing.T) { // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) he = BodyLimit("2B")(h)(c).(*echo.HTTPError) From a987b6577c5ade3d4cd3ece29db43487e975b597 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 29 Apr 2022 21:57:14 +0300 Subject: [PATCH 19/28] Update Github CI flow to use Go 1.18, bump actions versions --- .github/workflows/echo.yml | 59 ++++++++++++-------------------------- Makefile | 6 ++-- 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 266406664..69535f09c 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -19,6 +19,7 @@ on: - '_fixture/**' - '.github/**' - 'codecov.yml' + workflow_dispatch: jobs: test: @@ -27,33 +28,22 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases - go: [1.14, 1.15, 1.16, 1.17] + go: [1.16, 1.17, 1.18] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 - with: - go-version: ${{ matrix.go }} - - - name: Set GOPATH and PATH - run: | - echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV - echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH - shell: bash - - - name: Set build variables - run: | - echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV - echo "GO111MODULE=on" >> $GITHUB_ENV - - name: Checkout Code - uses: actions/checkout@v1 + uses: actions/checkout@v3 with: ref: ${{ github.ref }} + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/lint/golint + run: go install golang.org/x/lint/golint@latest - name: Run Tests run: | @@ -61,7 +51,7 @@ jobs: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov - if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest' + if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v1 with: token: @@ -71,39 +61,28 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go: [1.17] + go: [1.18] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 - with: - go-version: ${{ matrix.go }} - - - name: Set GOPATH and PATH - run: | - echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV - echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH - shell: bash - - - name: Set build variables - run: | - echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV - echo "GO111MODULE=on" >> $GITHUB_ENV - - name: Checkout Code (Previous) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: new + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/perf/cmd/benchstat + run: go install golang.org/x/perf/cmd/benchstat@latest - name: Run Benchmark (Previous) run: | diff --git a/Makefile b/Makefile index 48061f7e2..a6c4aaa90 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ tag: check: lint vet race ## Check project init: - @go get -u golang.org/x/lint/golint + @go install golang.org/x/lint/golint@latest lint: ## Lint the files @golint -set_exit_status ${PKG_LIST} @@ -29,6 +29,6 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.15" -test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15 +goversion ?= "1.16" +test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.16 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" From 2e02ce3dd88f4404c87e8a3a410ae8676fad6521 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 21 May 2022 19:27:22 +0300 Subject: [PATCH 20/28] Timeout mw: fix datarace in tests when we are getting data from buffer. Run each test in their own server so multiple tests cases will not cause datarace getting data out of logger buffer. --- middleware/timeout_test.go | 62 ++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 7fb802a9a..bba48a80f 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -362,40 +362,38 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { }, } - e := echo.New() - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - - // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first - // FIXME: I have no idea how to fix this without adding mutexes. - e.Use(TimeoutWithConfig(TimeoutConfig{ - Timeout: 15 * time.Millisecond, - })) - e.Use(Logger()) - e.Use(Recover()) - - e.GET("/", func(c echo.Context) error { - var delay time.Duration - if err := echo.QueryParamsBinder(c).Duration("delay", &delay).BindError(); err != nil { - return err - } - if delay > 0 { - time.Sleep(delay) - } - return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"}) - }) - - server, addr, err := startServer(e) - if err != nil { - assert.NoError(t, err) - return - } - defer server.Close() - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - buf.Reset() // this is design this can not be run in parallel + e := echo.New() + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first + // FIXME: I have no idea how to fix this without adding mutexes. + e.Use(TimeoutWithConfig(TimeoutConfig{ + Timeout: 15 * time.Millisecond, + })) + e.Use(Logger()) + e.Use(Recover()) + + e.GET("/", func(c echo.Context) error { + var delay time.Duration + if err := echo.QueryParamsBinder(c).Duration("delay", &delay).BindError(); err != nil { + return err + } + if delay > 0 { + time.Sleep(delay) + } + return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"}) + }) + + server, addr, err := startServer(e) + if err != nil { + assert.NoError(t, err) + return + } + defer server.Close() res, err := http.Get(fmt.Sprintf("http://%v%v", addr, tc.whenPath)) if err != nil { From 28797c761df73cef962bbe92395089b60275680a Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 21 May 2022 20:58:15 +0300 Subject: [PATCH 21/28] Timeout mw: fix datarace in tests when we are getting data from buffer (in test) and writing to logger at the same time. --- middleware/timeout_test.go | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index bba48a80f..6da6a3866 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -12,6 +12,7 @@ import ( "net/url" "reflect" "strings" + "sync" "testing" "time" @@ -366,7 +367,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := echo.New() - buf := new(bytes.Buffer) + buf := new(coroutineSafeBuffer) e.Logger.SetOutput(buf) // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first @@ -419,6 +420,36 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { } } +// as we are spawning multiple coroutines - one for http server, one for request, one by timeout middleware, one by testcase +// we are accessing logger (writing/reading) from multiple coroutines and causing dataraces (most often reported on macos) +// we could be writing to logger in logger middleware and at the same time our tests is getting logger buffer contents +// in testcase coroutine. +type coroutineSafeBuffer struct { + bytes.Buffer + lock sync.RWMutex +} + +func (b *coroutineSafeBuffer) Write(p []byte) (n int, err error) { + b.lock.Lock() + defer b.lock.Unlock() + + return b.Buffer.Write(p) +} + +func (b *coroutineSafeBuffer) Bytes() []byte { + b.lock.RLock() + defer b.lock.RUnlock() + + return b.Buffer.Bytes() +} + +func (b *coroutineSafeBuffer) String() string { + b.lock.RLock() + defer b.lock.RUnlock() + + return b.Buffer.String() +} + func startServer(e *echo.Echo) (*http.Server, string, error) { l, err := net.Listen("tcp", ":0") if err != nil { From d5f883707bc2cce801e261959c7a8dd5f111f702 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 22 May 2022 00:21:50 +0300 Subject: [PATCH 22/28] =?UTF-8?q?Timeout=20mw:=20rework=20how=20test=20wai?= =?UTF-8?q?ts=20for=20timeout.=20Using=20sleep=20as=20delay=20i=E2=80=A6?= =?UTF-8?q?=20(#2187)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Timeout mw: rework how test waits for timeout. Using sleep as delay is problematic when CI worker is slower than usual. --- middleware/timeout_test.go | 49 ++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 6da6a3866..56eb7bc74 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -328,12 +329,13 @@ func TestTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { func TestTimeoutWithFullEchoStack(t *testing.T) { // test timeout with full http server stack running, do see what http.Server.ErrorLog contains var testCases = []struct { - name string - whenPath string - expectStatusCode int - expectResponse string - expectLogContains []string - expectLogNotContains []string + name string + whenPath string + whenForceHandlerTimeout bool + expectStatusCode int + expectResponse string + expectLogContains []string + expectLogNotContains []string }{ { name: "404 - write response in global error handler", @@ -352,14 +354,15 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { expectLogContains: []string{`"status":418,"error":"",`}, }, { - name: "503 - handler timeouts, write response in timeout middleware", - whenPath: "/?delay=50ms", - expectResponse: "Timeout

Timeout

", - expectStatusCode: http.StatusServiceUnavailable, + name: "503 - handler timeouts, write response in timeout middleware", + whenForceHandlerTimeout: true, + whenPath: "/", + expectResponse: "Timeout

Timeout

", + expectStatusCode: http.StatusServiceUnavailable, expectLogNotContains: []string{ "echo:http: superfluous response.WriteHeader call from", - "{", // means that logger was not called. }, + expectLogContains: []string{"http: Handler timeout"}, }, } @@ -371,21 +374,18 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { e.Logger.SetOutput(buf) // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first - // FIXME: I have no idea how to fix this without adding mutexes. e.Use(TimeoutWithConfig(TimeoutConfig{ Timeout: 15 * time.Millisecond, })) e.Use(Logger()) e.Use(Recover()) + wg := sync.WaitGroup{} + if tc.whenForceHandlerTimeout { + wg.Add(1) // make `wg.Wait()` block until we release it with `wg.Done()` + } e.GET("/", func(c echo.Context) error { - var delay time.Duration - if err := echo.QueryParamsBinder(c).Duration("delay", &delay).BindError(); err != nil { - return err - } - if delay > 0 { - time.Sleep(delay) - } + wg.Wait() return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"}) }) @@ -401,6 +401,13 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { assert.NoError(t, err) return } + if tc.whenForceHandlerTimeout { + wg.Done() + // shutdown waits for server to shutdown. this way we wait logger mw to be executed + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + server.Shutdown(ctx) + } assert.Equal(t, tc.expectStatusCode, res.StatusCode) if body, err := ioutil.ReadAll(res.Body); err == nil { @@ -411,10 +418,10 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { logged := buf.String() for _, subStr := range tc.expectLogContains { - assert.True(t, strings.Contains(logged, subStr)) + assert.True(t, strings.Contains(logged, subStr), "expected logs to contain: %v, logged: '%v'", subStr, logged) } for _, subStr := range tc.expectLogNotContains { - assert.False(t, strings.Contains(logged, subStr)) + assert.False(t, strings.Contains(logged, subStr), "expected logs not to contain: %v, logged: '%v'", subStr, logged) } }) } From 829ddef710f029b5a6e494353dcc7cc0d2141d17 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 15 Jul 2021 23:34:01 +0300 Subject: [PATCH 23/28] V5.0.0-alpha --- .github/workflows/echo.yml | 7 +- .travis.yml | 21 - LICENSE | 2 +- Makefile | 6 +- README.md | 43 +- bind.go | 104 +- bind_test.go | 240 ++++- binder.go | 4 +- binder_external_test.go | 2 +- binder_go1.15_test.go | 265 ----- binder_test.go | 238 +++- context.go | 640 ++++++----- context_fs.go | 33 - context_fs_go1.16.go | 52 - context_fs_go1.16_test.go | 135 --- context_test.go | 725 +++++++------ echo.go | 1054 +++++++----------- echo_fs.go | 62 -- echo_fs_go1.16.go | 169 --- echo_fs_go1.16_test.go | 265 ----- echo_test.go | 1180 +++++++------------- go.mod | 15 +- go.sum | 34 +- group.go | 180 +++- group_fs.go | 9 - group_fs_go1.16.go | 33 - group_fs_go1.16_test.go | 106 -- group_test.go | 613 ++++++++++- httperror.go | 74 ++ httperror_test.go | 52 + json.go | 11 +- json_test.go | 10 +- log.go | 175 ++- log_test.go | 87 ++ middleware/DEVELOPMENT.md | 11 + middleware/basic_auth.go | 97 +- middleware/basic_auth_test.go | 175 ++- middleware/body_dump.go | 66 +- middleware/body_dump_test.go | 88 +- middleware/body_limit.go | 97 +- middleware/body_limit_test.go | 120 ++- middleware/compress.go | 66 +- middleware/compress_test.go | 154 +-- middleware/cors.go | 123 +-- middleware/cors_test.go | 34 +- middleware/csrf.go | 169 ++- middleware/csrf_test.go | 74 +- middleware/decompress.go | 46 +- middleware/decompress_test.go | 96 +- middleware/extractor.go | 43 +- middleware/extractor_test.go | 48 +- middleware/jwt.go | 325 ++---- middleware/jwt_external_test.go | 76 ++ middleware/jwt_test.go | 665 +++++------- middleware/key_auth.go | 198 ++-- middleware/key_auth_test.go | 193 ++-- middleware/logger.go | 187 ++-- middleware/logger_test.go | 4 +- middleware/method_override.go | 46 +- middleware/method_override_test.go | 68 +- middleware/middleware.go | 21 +- middleware/proxy.go | 167 +-- middleware/proxy_test.go | 21 +- middleware/rate_limiter.go | 106 +- middleware/rate_limiter_test.go | 135 ++- middleware/recover.go | 111 +- middleware/recover_test.go | 186 ++-- middleware/redirect.go | 140 ++- middleware/redirect_test.go | 2 +- middleware/request_id.go | 58 +- middleware/request_id_test.go | 101 +- middleware/request_logger.go | 15 +- middleware/request_logger_test.go | 4 +- middleware/rewrite.go | 74 +- middleware/rewrite_test.go | 77 +- middleware/secure.go | 162 +-- middleware/secure_test.go | 84 +- middleware/slash.go | 82 +- middleware/slash_test.go | 6 +- middleware/static.go | 223 ++-- middleware/static_1_16_test.go | 106 -- middleware/static_test.go | 263 ++++- middleware/timeout.go | 220 ---- middleware/timeout_test.go | 484 --------- middleware/util.go | 39 + middleware/util_test.go | 40 +- response.go | 31 +- route.go | 182 ++++ route_test.go | 423 ++++++++ router.go | 845 +++++++++++---- router_test.go | 1618 +++++++++++++++++++++------- server.go | 213 ++++ server_test.go | 815 ++++++++++++++ 93 files changed, 9672 insertions(+), 7297 deletions(-) delete mode 100644 .travis.yml delete mode 100644 binder_go1.15_test.go delete mode 100644 context_fs.go delete mode 100644 context_fs_go1.16.go delete mode 100644 context_fs_go1.16_test.go delete mode 100644 echo_fs.go delete mode 100644 echo_fs_go1.16.go delete mode 100644 echo_fs_go1.16_test.go delete mode 100644 group_fs.go delete mode 100644 group_fs_go1.16.go delete mode 100644 group_fs_go1.16_test.go create mode 100644 httperror.go create mode 100644 httperror_test.go create mode 100644 log_test.go create mode 100644 middleware/DEVELOPMENT.md create mode 100644 middleware/jwt_external_test.go delete mode 100644 middleware/static_1_16_test.go delete mode 100644 middleware/timeout.go delete mode 100644 middleware/timeout_test.go create mode 100644 route.go create mode 100644 route_test.go create mode 100644 server.go create mode 100644 server_test.go diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 69535f09c..6830b2b39 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -19,7 +19,7 @@ on: - '_fixture/**' - '.github/**' - 'codecov.yml' - workflow_dispatch: + workflow_dispatch: # to be able to run workflow manually jobs: test: @@ -28,7 +28,8 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases - go: [1.16, 1.17, 1.18] + # except v5 starts from 1.17 until there is last four major releases after that + go: [1.17, 1.18] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -52,7 +53,7 @@ jobs: - name: Upload coverage to Codecov if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2 with: token: fail_ci_if_error: false diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 67d45ad78..000000000 --- a/.travis.yml +++ /dev/null @@ -1,21 +0,0 @@ -arch: - - amd64 - - ppc64le - -language: go -go: - - 1.14.x - - 1.15.x - - tip -env: - - GO111MODULE=on -install: - - go get -v golang.org/x/lint/golint -script: - - golint -set_exit_status ./... - - go test -race -coverprofile=coverage.txt -covermode=atomic ./... -after_success: - - bash <(curl -s https://codecov.io/bash) -matrix: - allow_failures: - - go: tip diff --git a/LICENSE b/LICENSE index c46d0105f..2f18411bd 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2021 LabStack +Copyright (c) 2022 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index a6c4aaa90..8093e481c 100644 --- a/Makefile +++ b/Makefile @@ -24,11 +24,11 @@ race: ## Run tests with data race detector @go test -race ${PKG_LIST} benchmark: ## Run benchmarks - @go test -run="-" -bench=".*" ${PKG_LIST} + @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.16" -test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.16 +goversion ?= "1.17" +test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/README.md b/README.md index 8b2321f05..b9cb69e33 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,14 @@ ## Supported Go versions +Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that. + As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: - 1.9.7+ - 1.10.3+ -- 1.14+ +- 1.16+ Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. @@ -39,24 +41,13 @@ For older versions, please use the latest v3 tag. - Automatic TLS via Let’s Encrypt - HTTP/2 support -## Benchmarks - -Date: 2020/11/11
-Source: https://github.com/vishr/web-framework-benchmark
-Lower is better! - - - - -The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz - ## [Guide](https://echo.labstack.com/guide) ### Installation ```sh // go get github.com/labstack/echo/{version} -go get github.com/labstack/echo/v4 +go get github.com/labstack/echo/v5 ``` ### Example @@ -65,8 +56,8 @@ go get github.com/labstack/echo/v4 package main import ( - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" "net/http" ) @@ -82,7 +73,9 @@ func main() { e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":1323"); err != http.ErrServerClosed { + log.Fatal(err) + } } // Handler @@ -93,15 +86,15 @@ func hello(c echo.Context) error { # Third-party middlewares -| Repository | Description | -|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | -| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | -| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | -| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | -| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | -| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | -| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. +| Repository | Description | +|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | +| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | Please send a PR to add your own library here. diff --git a/bind.go b/bind.go index c841ca010..0a7eb8b42 100644 --- a/bind.go +++ b/bind.go @@ -11,42 +11,38 @@ import ( "strings" ) -type ( - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(i interface{}, c Context) error - } +// Binder is the interface that wraps the Bind method. +type Binder interface { + Bind(c Context, i interface{}) error +} - // DefaultBinder is the default implementation of the Binder interface. - DefaultBinder struct{} +// DefaultBinder is the default implementation of the Binder interface. +type DefaultBinder struct{} - // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. - // Types that don't implement this, but do implement encoding.TextUnmarshaler - // will use that interface instead. - BindUnmarshaler interface { - // UnmarshalParam decodes and assigns a value from an form or query param. - UnmarshalParam(param string) error - } -) +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +// Types that don't implement this, but do implement encoding.TextUnmarshaler +// will use that interface instead. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} // BindPathParams binds path params to bindable object -func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { - names := c.ParamNames() - values := c.ParamValues() +func BindPathParams(c Context, i interface{}) error { params := map[string][]string{} - for i, name := range names { - params[name] = []string{values[i]} + for _, param := range c.PathParams() { + params[param.Name] = []string{param.Value} } - if err := b.bindData(i, params, "param"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err := bindData(i, params, "param"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // BindQueryParams binds query params to bindable object -func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindQueryParams(c Context, i interface{}) error { + if err := bindData(i, c.QueryParams(), "query"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } @@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm -func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { +func BindBody(c Context, i interface{}) (err error) { req := c.Request() if req.ContentLength == 0 { return @@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { case *HTTPError: return err default: - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } } case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) } else if se, ok := err.(*xml.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())) } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - params, err := c.FormParams() + values, err := c.FormValues() if err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } - if err = b.bindData(i, params, "form"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = bindData(i, values, "form"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } default: return ErrUnsupportedMediaType @@ -97,34 +93,34 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } // BindHeaders binds HTTP headers to a bindable object -func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindHeaders(c Context, i interface{}) error { + if err := bindData(i, c.Request().Header, "header"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous -// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. -func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { - if err := b.BindPathParams(c, i); err != nil { +// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. +func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) { + if err := BindPathParams(c, i); err != nil { return err } // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) - method := c.Request().Method + method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { - if err = b.BindQueryParams(c, i); err != nil { + if err = BindQueryParams(c, i); err != nil { return err } } - return b.BindBody(c, i) + return BindBody(c, i) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { +func bindData(destination interface{}, data map[string][]string, tag string) error { if destination == nil || len(data) == 0 { return nil } @@ -167,10 +163,10 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri } if inputFieldName == "" { - // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). - // structs that implement BindUnmarshaler are binded only when they have explicit tag + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). + // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + if err := bindData(structField.Addr().Interface(), data, tag); err != nil { return err } } @@ -180,10 +176,8 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri inputValue, exists := data[inputFieldName] if !exists { - // Go json.Unmarshal supports case insensitive binding. However the - // url params are bound case sensitive which is inconsistent. To - // fix this we must check all of the map values in a - // case-insensitive search. + // Go json.Unmarshal supports case-insensitive binding. However, the url params are bound case-sensitive which + // is inconsistent. To fix this we must check all the map values in a case-insensitive search. for k, v := range data { if strings.EqualFold(k, inputFieldName) { inputValue = v @@ -297,7 +291,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { func setIntField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } intVal, err := strconv.ParseInt(value, 10, bitSize) if err == nil { @@ -308,7 +302,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error { func setUintField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } uintVal, err := strconv.ParseUint(value, 10, bitSize) if err == nil { @@ -319,7 +313,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error { func setBoolField(value string, field reflect.Value) error { if value == "" { - value = "false" + return nil } boolVal, err := strconv.ParseBool(value) if err == nil { @@ -330,7 +324,7 @@ func setBoolField(value string, field reflect.Value) error { func setFloatField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0.0" + return nil } floatVal, err := strconv.ParseFloat(value, bitSize) if err == nil { diff --git a/bind_test.go b/bind_test.go index 4ed8dbb50..8f711c4f8 100644 --- a/bind_test.go +++ b/bind_test.go @@ -277,7 +277,7 @@ func TestBindHeaderParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) if assert.NoError(t, err) { assert.Equal(t, 2, u.ID) assert.Equal(t, "Jon Doe", u.Name) @@ -291,7 +291,7 @@ func TestBindHeaderParamBadType(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) assert.Error(t, err) httpErr, ok := err.(*HTTPError) @@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) { } } +func TestBind_CombineQueryWithHeaderParam(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil) + req.Header.Set("language", "de") + req.Header.Set("length", "99") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPathParams(PathParams{{ + Name: "id", + Value: "999", + }}) + + type SearchOpts struct { + ID int `param:"id"` + Length int `query:"length"` + Page int `query:"page"` + Search string `query:"search"` + Language string `query:"language" header:"language"` + } + + opts := SearchOpts{ + Length: 100, + Page: 0, + Search: "default value", + Language: "en", + } + err := c.Bind(&opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // bind from query + assert.Equal(t, 10, opts.Page) // bind from query + assert.Equal(t, 999, opts.ID) // bind from path param + assert.Equal(t, "et", opts.Language) // bind from query + assert.Equal(t, "default value", opts.Search) // default value stays + + // make sure another bind will not mess already set values unless there are new values + err = BindHeaders(c, &opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists + assert.Equal(t, 10, opts.Page) + assert.Equal(t, 999, opts.ID) + assert.Equal(t, "de", opts.Language) // header overwrites now this value + assert.Equal(t, "default value", opts.Search) +} + func TestBindUnmarshalParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) @@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalText(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { func TestBindUnmarshalTextPtr(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) { func TestBindbindData(t *testing.T) { a := assert.New(t) ts := new(bindTestStruct) - b := new(DefaultBinder) - err := b.bindData(ts, values, "form") + err := bindData(ts, values, "form") a.NoError(err) a.Equal(0, ts.I) @@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) { func TestBindParam(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - c.SetPath("/users/:id/:name") - c.SetParamNames("id", "name") - c.SetParamValues("1", "Jon Snow") + cc := c.(RoutableContext) + cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"}) + cc.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + {Name: "name", Value: "Jon Snow"}, + }) u := new(user) err := c.Bind(u) @@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) { // Second test for the absence of a param c2 := e.NewContext(req, rec) - c2.SetPath("/users/:id") - c2.SetParamNames("id") - c2.SetParamValues("1") + cc2 := c2.(RoutableContext) + cc2.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc2.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c2.Bind(u) @@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - req2 := httptest.NewRequest(POST, "/", body) + req2 := httptest.NewRequest(http.MethodPost, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) - c3.SetPath("/users/:id") - c3.SetParamNames("id") - c3.SetParamValues("1") + cc3 := c3.(RoutableContext) + cc3.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc3.SetRawPathParams(&PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c3.Bind(u) @@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) { assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } -func TestBindSetFields(t *testing.T) { - assert := assert.New(t) +func TestSetIntField(t *testing.T) { + ts := new(bindTestStruct) + ts.I = 100 + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 100, ts.I) + + // second set with value sets the value + err = setIntField("5", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) +} +func TestSetUintField(t *testing.T) { ts := new(bindTestStruct) + ts.UI = 100 + val := reflect.ValueOf(ts).Elem() - // Int - if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(5, ts.I) - } - if assert.NoError(setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(0, ts.I) - } - // Uint - if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(uint(10), ts.UI) - } - if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(uint(0), ts.UI) - } + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(100), ts.UI) + + // second set with value sets the value + err = setUintField("5", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) +} - // Float - if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(float32(15.5), ts.F32) - } - if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(float32(0.0), ts.F32) - } +func TestSetFloatField(t *testing.T) { + ts := new(bindTestStruct) + ts.F32 = 100 - // Bool - if assert.NoError(setBoolField("true", val.FieldByName("B"))) { - assert.Equal(true, ts.B) - } - if assert.NoError(setBoolField("", val.FieldByName("B"))) { - assert.Equal(false, ts.B) - } + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(100), ts.F32) + + // second set with value sets the value + err = setFloatField("15.5", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) +} + +func TestSetBoolField(t *testing.T) { + ts := new(bindTestStruct) + ts.B = true + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // second set with value sets the value + err = setBoolField("true", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // fourth set to false + err = setBoolField("false", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, false, ts.B) +} + +func TestUnmarshalFieldNonPtr(t *testing.T) { + ts := new(bindTestStruct) + val := reflect.ValueOf(ts).Elem() ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(err) { - assert.Equal(ok, true) - assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) + if assert.NoError(t, err) { + assert.True(t, ok) + assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) } } @@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() assert := assert.New(b) ts := new(bindTestStructWithTags) - binder := new(DefaultBinder) var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") + err = bindData(ts, values, "form") } assert.NoError(err) assertBindTestStruct(assert, (*bindTestStruct)(ts)) @@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("node_from_path") + cc := c.(RoutableContext) + cc.SetRawPathParams(&PathParams{ + {Name: "node", Value: "node_from_path"}, + }) } var bindTarget interface{} @@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { } b := new(DefaultBinder) - err := b.Bind(bindTarget, c) + err := b.Bind(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("real_node") + cc := c.(RoutableContext) + cc.SetRawPathParams(&PathParams{ + {Name: "node", Value: "real_node"}, + }) } var bindTarget interface{} @@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) { } else { bindTarget = &Node{} } - b := new(DefaultBinder) - err := b.BindBody(c, bindTarget) + err := BindBody(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/binder.go b/binder.go index 5a6cf9d9b..e022a7e8d 100644 --- a/binder.go +++ b/binder.go @@ -123,10 +123,10 @@ func QueryParamsBinder(c Context) *ValueBinder { func PathParamsBinder(c Context) *ValueBinder { return &ValueBinder{ failFast: true, - ValueFunc: c.Param, + ValueFunc: c.PathParam, ValuesFunc: func(sourceParam string) []string { // path parameter should not have multiple values so getting values does not make sense but lets not error out here - value := c.Param(sourceParam) + value := c.PathParam(sourceParam) if value == "" { return nil } diff --git a/binder_external_test.go b/binder_external_test.go index f1aecb52b..585ade816 100644 --- a/binder_external_test.go +++ b/binder_external_test.go @@ -4,7 +4,7 @@ package echo_test import ( "encoding/base64" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "log" "net/http" "net/http/httptest" diff --git a/binder_go1.15_test.go b/binder_go1.15_test.go deleted file mode 100644 index 018628c3a..000000000 --- a/binder_go1.15_test.go +++ /dev/null @@ -1,265 +0,0 @@ -// +build go1.15 - -package echo - -/** - Since version 1.15 time.Time and time.Duration error message pattern has changed (values are wrapped now in \"\") - So pre 1.15 these tests fail with similar error: - - expected: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param" - actual : "code=400, message=failed to bind field value to Duration, internal=time: invalid duration nope, field=param" -*/ - -import ( - "errors" - "github.com/stretchr/testify/assert" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func createTestContext15(URL string, body io.Reader, pathParams map[string]string) Context { - e := New() - req := httptest.NewRequest(http.MethodGet, URL, body) - if body != nil { - req.Header.Set(HeaderContentType, MIMEApplicationJSON) - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) - for name, value := range pathParams { - names = append(names, name) - values = append(values, value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) - } - - return c -} - -func TestValueBinder_TimeError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue time.Time - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - dest := time.Time{} - var err error - if tc.whenMust { - err = b.MustTime("param", &dest, tc.whenLayout).BindError() - } else { - err = b.Time("param", &dest, tc.whenLayout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_TimesError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue []time.Time - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - layout := time.RFC3339 - if tc.whenLayout != "" { - layout = tc.whenLayout - } - - var dest []time.Time - var err error - if tc.whenMust { - err = b.MustTimes("param", &dest, layout).BindError() - } else { - err = b.Times("param", &dest, layout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue time.Duration - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - var dest time.Duration - var err error - if tc.whenMust { - err = b.MustDuration("param", &dest).BindError() - } else { - err = b.Duration("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationsError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue []time.Duration - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - var dest []time.Duration - var err error - if tc.whenMust { - err = b.MustDurations("param", &dest).BindError() - } else { - err = b.Durations("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/binder_test.go b/binder_test.go index 910bbfc50..3c17057c0 100644 --- a/binder_test.go +++ b/binder_test.go @@ -26,14 +26,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string) c := e.NewContext(req, rec) if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) + params := make(PathParams, 0) for name, value := range pathParams { - names = append(names, name) - values = append(values, value) + params = append(params, PathParam{ + Name: name, + Value: value, + }) } - c.SetParamNames(names...) - c.SetParamValues(values...) + cc := c.(RoutableContext) + cc.SetRawPathParams(¶ms) } return c @@ -2917,7 +2918,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) } } @@ -2984,7 +2985,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) if dest.Int64 != 1 { b.Fatalf("int64!=1") } @@ -3029,3 +3030,224 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { } } } + +func TestValueBinder_TimeError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue time.Time + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TimesError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue []time.Time + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Duration + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationsError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []time.Duration + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/context.go b/context.go index a4ecfadfc..a397fba70 100644 --- a/context.go +++ b/context.go @@ -3,210 +3,193 @@ package echo import ( "bytes" "encoding/xml" + "errors" "fmt" "io" + "io/fs" "mime/multipart" "net" "net/http" "net/url" + "path/filepath" "strings" "sync" ) -type ( - // Context represents the context of the current HTTP request. It holds request and - // response objects, path, path parameters, data and registered handler. - Context interface { - // Request returns `*http.Request`. - Request() *http.Request +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context interface { + // Request returns `*http.Request`. + Request() *http.Request - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) + // SetRequest sets `*http.Request`. + SetRequest(r *http.Request) - // SetResponse sets `*Response`. - SetResponse(r *Response) + // SetResponse sets `*Response`. + SetResponse(r *Response) - // Response returns `*Response`. - Response() *Response + // Response returns `*Response`. + Response() *Response - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - // The behavior can be configured using `Echo#IPExtractor`. - RealIP() string + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. + RealIP() string - // Path returns the registered path for the handler. - Path() string + // RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. + // In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases. + RouteInfo() RouteInfo - // SetPath sets the registered path for the handler. - SetPath(p string) + // Path returns the registered path for the handler. + Path() string - // Param returns path parameter by name. - Param(name string) string + // PathParam returns path parameter by name. + PathParam(name string) string - // ParamNames returns path parameter names. - ParamNames() []string + // PathParams returns path parameter values. + PathParams() PathParams - // SetParamNames sets path parameter names. - SetParamNames(names ...string) + // SetPathParams sets path parameters for current request. + SetPathParams(params PathParams) - // ParamValues returns path parameter values. - ParamValues() []string + // QueryParam returns the query param for the provided name. + QueryParam(name string) string - // SetParamValues sets path parameter values. - SetParamValues(values ...string) + // QueryParamDefault returns the query param or default value for the provided name. + QueryParamDefault(name, defaultValue string) string - // QueryParam returns the query param for the provided name. - QueryParam(name string) string + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values + // QueryString returns the URL query string. + QueryString() string - // QueryString returns the URL query string. - QueryString() string + // FormValue returns the form field value for the provided name. + FormValue(name string) string - // FormValue returns the form field value for the provided name. - FormValue(name string) string + // FormValueDefault returns the form field value or default value for the provided name. + FormValueDefault(name, defaultValue string) string - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) + // FormValues returns the form field values as `url.Values`. + FormValues() (url.Values, error) - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) + // FormFile returns the multipart form file for the provided name. + FormFile(name string) (*multipart.FileHeader, error) - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) + // MultipartForm returns the multipart form. + MultipartForm() (*multipart.Form, error) - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) + // Cookie returns the named cookie provided in the request. + Cookie(name string) (*http.Cookie, error) - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) + // SetCookie adds a `Set-Cookie` header in HTTP response. + SetCookie(cookie *http.Cookie) - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie + // Cookies returns the HTTP cookies sent with the request. + Cookies() []*http.Cookie - // Get retrieves data from the context. - Get(key string) interface{} + // Get retrieves data from the context. + Get(key string) interface{} - // Set saves data in the context. - Set(key string, val interface{}) + // Set saves data in the context. + Set(key string, val interface{}) - // Bind binds the request body into provided type `i`. The default binder - // does it based on Content-Type header. - Bind(i interface{}) error + // Bind binds the request body into provided type `i`. The default binder + // does it based on Content-Type header. + Bind(i interface{}) error - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i interface{}) error + // Validate validates provided `i`. It is usually called after `Context#Bind()`. + // Validator must be registered using `Echo#Validator`. + Validate(i interface{}) error - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data interface{}) error + // Render renders a template with data and sends a text/html response with status + // code. Renderer must be registered using `Echo.Renderer`. + Render(code int, name string, data interface{}) error - // HTML sends an HTTP response with status code. - HTML(code int, html string) error + // HTML sends an HTTP response with status code. + HTML(code int, html string) error - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error + // HTMLBlob sends an HTTP blob response with status code. + HTMLBlob(code int, b []byte) error - // String sends a string response with status code. - String(code int, s string) error + // String sends a string response with status code. + String(code int, s string) error - // JSON sends a JSON response with status code. - JSON(code int, i interface{}) error + // JSON sends a JSON response with status code. + JSON(code int, i interface{}) error - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i interface{}, indent string) error + // JSONPretty sends a pretty-print JSON with status code. + JSONPretty(code int, i interface{}, indent string) error - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error + // JSONBlob sends a JSON blob response with status code. + JSONBlob(code int, b []byte) error - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i interface{}) error + // JSONP sends a JSONP response with status code. It uses `callback` to construct + // the JSONP payload. + JSONP(code int, callback string, i interface{}) error - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error + // JSONPBlob sends a JSONP blob response with status code. It uses `callback` + // to construct the JSONP payload. + JSONPBlob(code int, callback string, b []byte) error - // XML sends an XML response with status code. - XML(code int, i interface{}) error + // XML sends an XML response with status code. + XML(code int, i interface{}) error - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i interface{}, indent string) error + // XMLPretty sends a pretty-print XML with status code. + XMLPretty(code int, i interface{}, indent string) error - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error + // XMLBlob sends an XML blob response with status code. + XMLBlob(code int, b []byte) error - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error + // Blob sends a blob response with status code and content type. + Blob(code int, contentType string, b []byte) error - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error + // Stream sends a streaming response with status code and content type. + Stream(code int, contentType string, r io.Reader) error - // File sends a response with the content of the file. - File(file string) error + // File sends a response with the content of the file. + File(file string) error - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error + // FileFS sends a response with the content of the file from given filesystem. + FileFS(file string, filesystem fs.FS) error - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error + // Attachment sends a response as attachment, prompting client to save the + // file. + Attachment(file string, name string) error - // NoContent sends a response with no body and a status code. - NoContent(code int) error + // Inline sends a response as inline, opening the file in the browser. + Inline(file string, name string) error - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error + // NoContent sends a response with no body and a status code. + NoContent(code int) error - // Error invokes the registered HTTP error handler. Generally used by middleware. - Error(err error) + // Redirect redirects the request to a provided URL with status code. + Redirect(code int, url string) error - // Handler returns the matched handler by router. - Handler() HandlerFunc - - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) - - // Logger returns the `Logger` instance. - Logger() Logger - - // Set the logger - SetLogger(l Logger) - - // Echo returns the `Echo` instance. - Echo() *Echo + // Echo returns the `Echo` instance. + Echo() *Echo +} - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) - } +// ServableContext is interface that Echo context implementation must implement to be usable in middleware/handlers and +// be able to be routed by Router. +type ServableContext interface { + Context // minimal set of methods for middlewares and handler + RoutableContext // minimal set for routing. These methods should not be accessed in middlewares/handlers - context struct { - request *http.Request - response *Response - path string - pnames []string - pvalues []string - query url.Values - handler HandlerFunc - store Map - echo *Echo - logger Logger - lock sync.RWMutex - } -) + // Reset resets the context after request completes. It must be called along + // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. + // See `Echo#ServeHTTP()` + Reset(r *http.Request, w http.ResponseWriter) +} const ( // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. @@ -221,39 +204,95 @@ const ( defaultIndent = " " ) -func (c *context) writeContentType(value string) { +// DefaultContext is default implementation of Context interface and can be embedded into structs to compose +// new Contexts with extended/modified behaviour. +type DefaultContext struct { + request *http.Request + response *Response + + route RouteInfo + path string + + // pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations. + pathParams *PathParams + // currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request. + // Lifecycle is not handle by Echo and could have excess allocations per served Request + currentParams PathParams + + query url.Values + store Map + echo *Echo + lock sync.RWMutex +} + +// NewDefaultContext creates new instance of DefaultContext. +// Argument pathParamAllocSize must be value that is stored in Echo.contextPathParamAllocSize field and is used +// to preallocate PathParams slice. +func NewDefaultContext(e *Echo, pathParamAllocSize int) *DefaultContext { + p := make(PathParams, pathParamAllocSize) + return &DefaultContext{ + pathParams: &p, + store: make(Map), + echo: e, + } +} + +// Reset resets the context after request completes. It must be called along +// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. +// See `Echo#ServeHTTP()` +func (c *DefaultContext) Reset(r *http.Request, w http.ResponseWriter) { + c.request = r + c.response.reset(w) + c.query = nil + c.store = nil + + c.route = nil + c.path = "" + // NOTE: Don't reset because it has to have length of c.echo.contextPathParamAllocSize at all times + *c.pathParams = (*c.pathParams)[:0] + c.currentParams = nil +} + +func (c *DefaultContext) writeContentType(value string) { header := c.Response().Header() if header.Get(HeaderContentType) == "" { header.Set(HeaderContentType, value) } } -func (c *context) Request() *http.Request { +// Request returns `*http.Request`. +func (c *DefaultContext) Request() *http.Request { return c.request } -func (c *context) SetRequest(r *http.Request) { +// SetRequest sets `*http.Request`. +func (c *DefaultContext) SetRequest(r *http.Request) { c.request = r } -func (c *context) Response() *Response { +// Response returns `*Response`. +func (c *DefaultContext) Response() *Response { return c.response } -func (c *context) SetResponse(r *Response) { +// SetResponse sets `*Response`. +func (c *DefaultContext) SetResponse(r *Response) { c.response = r } -func (c *context) IsTLS() bool { +// IsTLS returns true if HTTP connection is TLS otherwise false. +func (c *DefaultContext) IsTLS() bool { return c.request.TLS != nil } -func (c *context) IsWebSocket() bool { +// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. +func (c *DefaultContext) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) return strings.EqualFold(upgrade, "websocket") } -func (c *context) Scheme() string { +// Scheme returns the HTTP protocol scheme, `http` or `https`. +func (c *DefaultContext) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if c.IsTLS() { @@ -274,7 +313,10 @@ func (c *context) Scheme() string { return "http" } -func (c *context) RealIP() string { +// RealIP returns the client's network address based on `X-Forwarded-For` +// or `X-Real-IP` request header. +// The behavior can be configured using `Echo#IPExtractor`. +func (c *DefaultContext) RealIP() string { if c.echo != nil && c.echo.IPExtractor != nil { return c.echo.IPExtractor(c.request) } @@ -293,85 +335,116 @@ func (c *context) RealIP() string { return ra } -func (c *context) Path() string { +// Path returns the registered path for the handler. +func (c *DefaultContext) Path() string { return c.path } -func (c *context) SetPath(p string) { +// SetPath sets the registered path for the handler. +func (c *DefaultContext) SetPath(p string) { c.path = p } -func (c *context) Param(name string) string { - for i, n := range c.pnames { - if i < len(c.pvalues) { - if n == name { - return c.pvalues[i] - } - } - } - return "" +// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. +// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases. +func (c *DefaultContext) RouteInfo() RouteInfo { + return c.route } -func (c *context) ParamNames() []string { - return c.pnames +// SetRouteInfo sets the route info of this request to the context. +func (c *DefaultContext) SetRouteInfo(ri RouteInfo) { + c.route = ri } -func (c *context) SetParamNames(names ...string) { - c.pnames = names +// RawPathParams returns raw path pathParams value. Allocation of PathParams is handled by Context. +func (c *DefaultContext) RawPathParams() *PathParams { + return c.pathParams +} - l := len(names) - if *c.echo.maxParam < l { - *c.echo.maxParam = l - } +// SetRawPathParams replaces any existing param values with new values for this context lifetime (request). +// +// DO NOT USE! +// Do not set any other value than what you got from RawPathParams as allocation of PathParams is handled by Context. +// If you mess up size of pathParams size your application will panic/crash during routing +func (c *DefaultContext) SetRawPathParams(params *PathParams) { + c.pathParams = params +} - if len(c.pvalues) < l { - // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, - // probably those values will be overriden in a Context#SetParamValues - newPvalues := make([]string, l) - copy(newPvalues, c.pvalues) - c.pvalues = newPvalues +// PathParam returns path parameter by name. +func (c *DefaultContext) PathParam(name string) string { + if c.currentParams != nil { + return c.currentParams.Get(name, "") } -} -func (c *context) ParamValues() []string { - return c.pvalues[:len(c.pnames)] + return c.pathParams.Get(name, "") } -func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times - // It will brake the Router#Find code - limit := len(values) - if limit > *c.echo.maxParam { - limit = *c.echo.maxParam - } - for i := 0; i < limit; i++ { - c.pvalues[i] = values[i] +// PathParamDefault does not exist as expecting empty path param makes no sense + +// PathParams returns path parameter values. +func (c *DefaultContext) PathParams() PathParams { + if c.currentParams != nil { + return c.currentParams } + + result := make(PathParams, len(*c.pathParams)) + copy(result, *c.pathParams) + return result } -func (c *context) QueryParam(name string) string { +// SetPathParams sets path parameters for current request. +func (c *DefaultContext) SetPathParams(params PathParams) { + c.currentParams = params +} + +// QueryParam returns the query param for the provided name. +func (c *DefaultContext) QueryParam(name string) string { if c.query == nil { c.query = c.request.URL.Query() } return c.query.Get(name) } -func (c *context) QueryParams() url.Values { +// QueryParamDefault returns the query param or default value for the provided name. +// Note: QueryParamDefault does not distinguish if form had no value by that name or value was empty string +func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string { + value := c.QueryParam(name) + if value == "" { + value = defaultValue + } + return value +} + +// QueryParams returns the query parameters as `url.Values`. +func (c *DefaultContext) QueryParams() url.Values { if c.query == nil { c.query = c.request.URL.Query() } return c.query } -func (c *context) QueryString() string { +// QueryString returns the URL query string. +func (c *DefaultContext) QueryString() string { return c.request.URL.RawQuery } -func (c *context) FormValue(name string) string { +// FormValue returns the form field value for the provided name. +func (c *DefaultContext) FormValue(name string) string { return c.request.FormValue(name) } -func (c *context) FormParams() (url.Values, error) { +// FormValueDefault returns the form field value or default value for the provided name. +// Note: FormValueDefault does not distinguish if form had no value by that name or value was empty string +func (c *DefaultContext) FormValueDefault(name, defaultValue string) string { + value := c.FormValue(name) + if value == "" { + value = defaultValue + } + return value +} + +// FormValues returns the form field values as `url.Values`. +func (c *DefaultContext) FormValues() (url.Values, error) { if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { if err := c.request.ParseMultipartForm(defaultMemory); err != nil { return nil, err @@ -384,7 +457,8 @@ func (c *context) FormParams() (url.Values, error) { return c.request.Form, nil } -func (c *context) FormFile(name string) (*multipart.FileHeader, error) { +// FormFile returns the multipart form file for the provided name. +func (c *DefaultContext) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) if err != nil { return nil, err @@ -393,30 +467,36 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) { return fh, nil } -func (c *context) MultipartForm() (*multipart.Form, error) { +// MultipartForm returns the multipart form. +func (c *DefaultContext) MultipartForm() (*multipart.Form, error) { err := c.request.ParseMultipartForm(defaultMemory) return c.request.MultipartForm, err } -func (c *context) Cookie(name string) (*http.Cookie, error) { +// Cookie returns the named cookie provided in the request. +func (c *DefaultContext) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } -func (c *context) SetCookie(cookie *http.Cookie) { +// SetCookie adds a `Set-Cookie` header in HTTP response. +func (c *DefaultContext) SetCookie(cookie *http.Cookie) { http.SetCookie(c.Response(), cookie) } -func (c *context) Cookies() []*http.Cookie { +// Cookies returns the HTTP cookies sent with the request. +func (c *DefaultContext) Cookies() []*http.Cookie { return c.request.Cookies() } -func (c *context) Get(key string) interface{} { +// Get retrieves data from the context. +func (c *DefaultContext) Get(key string) interface{} { c.lock.RLock() defer c.lock.RUnlock() return c.store[key] } -func (c *context) Set(key string, val interface{}) { +// Set saves data in the context. +func (c *DefaultContext) Set(key string, val interface{}) { c.lock.Lock() defer c.lock.Unlock() @@ -426,18 +506,24 @@ func (c *context) Set(key string, val interface{}) { c.store[key] = val } -func (c *context) Bind(i interface{}) error { - return c.echo.Binder.Bind(i, c) +// Bind binds the request body into provided type `i`. The default binder +// does it based on Content-Type header. +func (c *DefaultContext) Bind(i interface{}) error { + return c.echo.Binder.Bind(c, i) } -func (c *context) Validate(i interface{}) error { +// Validate validates provided `i`. It is usually called after `Context#Bind()`. +// Validator must be registered using `Echo#Validator`. +func (c *DefaultContext) Validate(i interface{}) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } -func (c *context) Render(code int, name string, data interface{}) (err error) { +// Render renders a template with data and sends a text/html response with status +// code. Renderer must be registered using `Echo.Renderer`. +func (c *DefaultContext) Render(code int, name string, data interface{}) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } @@ -448,19 +534,22 @@ func (c *context) Render(code int, name string, data interface{}) (err error) { return c.HTMLBlob(code, buf.Bytes()) } -func (c *context) HTML(code int, html string) (err error) { +// HTML sends an HTTP response with status code. +func (c *DefaultContext) HTML(code int, html string) (err error) { return c.HTMLBlob(code, []byte(html)) } -func (c *context) HTMLBlob(code int, b []byte) (err error) { +// HTMLBlob sends an HTTP blob response with status code. +func (c *DefaultContext) HTMLBlob(code int, b []byte) (err error) { return c.Blob(code, MIMETextHTMLCharsetUTF8, b) } -func (c *context) String(code int, s string) (err error) { +// String sends a string response with status code. +func (c *DefaultContext) String(code int, s string) (err error) { return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } -func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) { +func (c *DefaultContext) jsonPBlob(code int, callback string, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -479,13 +568,14 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error return } -func (c *context) json(code int, i interface{}, indent string) error { +func (c *DefaultContext) json(code int, i interface{}, indent string) error { c.writeContentType(MIMEApplicationJSONCharsetUTF8) c.response.Status = code return c.echo.JSONSerializer.Serialize(c, i, indent) } -func (c *context) JSON(code int, i interface{}) (err error) { +// JSON sends a JSON response with status code. +func (c *DefaultContext) JSON(code int, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -493,19 +583,25 @@ func (c *context) JSON(code int, i interface{}) (err error) { return c.json(code, i, indent) } -func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) { +// JSONPretty sends a pretty-print JSON with status code. +func (c *DefaultContext) JSONPretty(code int, i interface{}, indent string) (err error) { return c.json(code, i, indent) } -func (c *context) JSONBlob(code int, b []byte) (err error) { +// JSONBlob sends a JSON blob response with status code. +func (c *DefaultContext) JSONBlob(code int, b []byte) (err error) { return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b) } -func (c *context) JSONP(code int, callback string, i interface{}) (err error) { +// JSONP sends a JSONP response with status code. It uses `callback` to construct +// the JSONP payload. +func (c *DefaultContext) JSONP(code int, callback string, i interface{}) (err error) { return c.jsonPBlob(code, callback, i) } -func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { +// JSONPBlob sends a JSONP blob response with status code. It uses `callback` +// to construct the JSONP payload. +func (c *DefaultContext) JSONPBlob(code int, callback string, b []byte) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { @@ -518,7 +614,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { return } -func (c *context) xml(code int, i interface{}, indent string) (err error) { +func (c *DefaultContext) xml(code int, i interface{}, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) @@ -531,7 +627,8 @@ func (c *context) xml(code int, i interface{}, indent string) (err error) { return enc.Encode(i) } -func (c *context) XML(code int, i interface{}) (err error) { +// XML sends an XML response with status code. +func (c *DefaultContext) XML(code int, i interface{}) (err error) { indent := "" if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { indent = defaultIndent @@ -539,11 +636,13 @@ func (c *context) XML(code int, i interface{}) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLPretty(code int, i interface{}, indent string) (err error) { +// XMLPretty sends a pretty-print XML with status code. +func (c *DefaultContext) XMLPretty(code int, i interface{}, indent string) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLBlob(code int, b []byte) (err error) { +// XMLBlob sends an XML blob response with status code. +func (c *DefaultContext) XMLBlob(code int, b []byte) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(xml.Header)); err != nil { @@ -553,39 +652,86 @@ func (c *context) XMLBlob(code int, b []byte) (err error) { return } -func (c *context) Blob(code int, contentType string, b []byte) (err error) { +// Blob sends a blob response with status code and content type. +func (c *DefaultContext) Blob(code int, contentType string, b []byte) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = c.response.Write(b) return } -func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { +// Stream sends a streaming response with status code and content type. +func (c *DefaultContext) Stream(code int, contentType string, r io.Reader) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = io.Copy(c.response, r) return } -func (c *context) Attachment(file, name string) error { +// File sends a response with the content of the file. +func (c *DefaultContext) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (c *DefaultContext) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c Context, file string, filesystem fs.FS) error { + f, err := filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + + fi, _ := f.Stat() + if fi.IsDir() { + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. + f, err = filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + if fi, err = f.Stat(); err != nil { + return err + } + } + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil +} + +// Attachment sends a response as attachment, prompting client to save the file. +func (c *DefaultContext) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } -func (c *context) Inline(file, name string) error { +// Inline sends a response as inline, opening the file in the browser. +func (c *DefaultContext) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } -func (c *context) contentDisposition(file, name, dispositionType string) error { +func (c *DefaultContext) contentDisposition(file, name, dispositionType string) error { c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) return c.File(file) } -func (c *context) NoContent(code int) error { +// NoContent sends a response with no body and a status code. +func (c *DefaultContext) NoContent(code int) error { c.response.WriteHeader(code) return nil } -func (c *context) Redirect(code int, url string) error { +// Redirect redirects the request to a provided URL with status code. +func (c *DefaultContext) Redirect(code int, url string) error { if code < 300 || code > 308 { return ErrInvalidRedirectCode } @@ -594,45 +740,7 @@ func (c *context) Redirect(code int, url string) error { return nil } -func (c *context) Error(err error) { - c.echo.HTTPErrorHandler(err, c) -} - -func (c *context) Echo() *Echo { +// Echo returns the `Echo` instance. +func (c *DefaultContext) Echo() *Echo { return c.echo } - -func (c *context) Handler() HandlerFunc { - return c.handler -} - -func (c *context) SetHandler(h HandlerFunc) { - c.handler = h -} - -func (c *context) Logger() Logger { - res := c.logger - if res != nil { - return res - } - return c.echo.Logger -} - -func (c *context) SetLogger(l Logger) { - c.logger = l -} - -func (c *context) Reset(r *http.Request, w http.ResponseWriter) { - c.request = r - c.response.reset(w) - c.query = nil - c.handler = NotFoundHandler - c.store = nil - c.path = "" - c.pnames = nil - c.logger = nil - // NOTE: Don't reset because it has to have length c.echo.maxParam at all times - for i := 0; i < *c.echo.maxParam; i++ { - c.pvalues[i] = "" - } -} diff --git a/context_fs.go b/context_fs.go deleted file mode 100644 index 11ee84bcd..000000000 --- a/context_fs.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -import ( - "net/http" - "os" - "path/filepath" -) - -func (c *context) File(file string) (err error) { - f, err := os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.Join(file, indexPage) - f, err = os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return - } - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) - return -} diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go deleted file mode 100644 index c1c724afd..000000000 --- a/context_fs_go1.16.go +++ /dev/null @@ -1,52 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "errors" - "io" - "io/fs" - "net/http" - "path/filepath" -) - -func (c *context) File(file string) error { - return fsFile(c, file, c.echo.Filesystem) -} - -// FileFS serves file from given file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (c *context) FileFS(file string, filesystem fs.FS) error { - return fsFile(c, file, filesystem) -} - -func fsFile(c Context, file string, filesystem fs.FS) error { - f, err := filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. - f, err = filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return err - } - } - ff, ok := f.(io.ReadSeeker) - if !ok { - return errors.New("file does not implement io.ReadSeeker") - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) - return nil -} diff --git a/context_fs_go1.16_test.go b/context_fs_go1.16_test.go deleted file mode 100644 index 027d1c483..000000000 --- a/context_fs_go1.16_test.go +++ /dev/null @@ -1,135 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestContext_File(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok, from default file system", - whenFile: "_fixture/images/walle.png", - whenFS: nil, - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "ok, from custom file system", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - if tc.whenFS != nil { - e.Filesystem = tc.whenFS - } - - handler := func(ec Context) error { - return ec.(*context).File(tc.whenFile) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestContext_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - handler := func(ec Context) error { - return ec.(*context).FileFS(tc.whenFile, tc.whenFS) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} diff --git a/context_test.go b/context_test.go index a8b9a9946..dca680f9b 100644 --- a/context_test.go +++ b/context_test.go @@ -8,33 +8,33 @@ import ( "errors" "fmt" "io" + "io/fs" "math" "mime/multipart" "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" "text/template" "time" - "github.com/labstack/gommon/log" - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -type ( - Template struct { - templates *template.Template - } -) +type Template struct { + templates *template.Template +} var testUser = user{1, "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -46,9 +46,10 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -60,9 +61,10 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &noOpLogger{} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) b.ResetTimer() b.ReportAllocs() @@ -73,7 +75,7 @@ func BenchmarkAllocXML(b *testing.B) { } func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { - c := context{request: &http.Request{ + c := DefaultContext{request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }} for i := 0; i < b.N; i++ { @@ -104,18 +106,16 @@ func TestContext(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Render @@ -126,106 +126,106 @@ func TestContext(t *testing.T) { } c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, Jon Snow!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } c.echo.Renderer = nil err = c.Render(http.StatusOK, "hello", "Jon Snow") - assert.Error(err) + assert.Error(t, err) // JSON rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } // JSON with "?pretty" req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) // reset // JSONPretty rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } // JSON (error) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.JSON(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // JSONP rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) callback := "callback" err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) } // XML rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // XML with "?pretty" req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) // XML (error) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XML(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // XML response write error - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) c.response.Writer = responseWriterErr{} err = c.XML(0, 0) - testify.Error(t, err) + assert.Error(t, err) // XMLPretty rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } t.Run("empty indent", func(t *testing.T) { @@ -237,166 +237,157 @@ func TestContext(t *testing.T) { t.Run("json", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New JSONBlob with empty indent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) enc := json.NewEncoder(buf) enc.SetIndent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, buf.String(), rec.Body.String()) } }) t.Run("xml", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New XMLBlob with empty indent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) enc := xml.NewEncoder(buf) enc.Indent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) } }) }) // Legacy JSONBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) data, err := json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON, rec.Body.String()) } // Legacy JSONPBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) callback = "callback" data, err = json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+");", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) } // Legacy XMLBlob rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) data, err = xml.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // String rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.String(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // HTML rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // Stream rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) r := strings.NewReader("response from a stream") err = c.Stream(http.StatusOK, "application/octet-stream", r) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal("response from a stream", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "response from a stream", rec.Body.String()) } // Attachment rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // Inline rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // NoContent rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) c.NoContent(http.StatusOK) - assert.Equal(http.StatusOK, rec.Code) - - // Error - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - c.Error(errors.New("error")) - assert.Equal(http.StatusInternalServerError, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) // Reset - c.SetParamNames("foo") - c.SetParamValues("bar") + c.pathParams = &PathParams{ + {Name: "foo", Value: "bar"}, + } c.Set("foe", "ban") c.query = url.Values(map[string][]string{"fon": {"baz"}}) c.Reset(req, httptest.NewRecorder()) - assert.Equal(0, len(c.ParamValues())) - assert.Equal(0, len(c.ParamNames())) - assert.Equal(0, len(c.store)) - assert.Equal("", c.Path()) - assert.Equal(0, len(c.QueryParams())) + assert.Equal(t, 0, len(c.PathParams())) + assert.Equal(t, 0, len(c.store)) + assert.Equal(t, nil, c.RouteInfo()) + assert.Equal(t, 0, len(c.QueryParams())) } func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) - assert := testify.New(t) - if assert.NoError(err) { - assert.Equal(http.StatusCreated, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -404,12 +395,11 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) - assert := testify.New(t) - if assert.Error(err) { - assert.False(c.response.Committed) + if assert.Error(t, err) { + assert.False(t, c.response.Committed) } } @@ -421,24 +411,22 @@ func TestContextCookie(t *testing.T) { req.Header.Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, user) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec).(*DefaultContext) // Read single cookie, err := c.Cookie("theme") - if assert.NoError(err) { - assert.Equal("theme", cookie.Name) - assert.Equal("light", cookie.Value) + if assert.NoError(t, err) { + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": - assert.Equal("light", cookie.Value) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal("Jon Snow", cookie.Value) + assert.Equal(t, "Jon Snow", cookie.Value) } } @@ -453,104 +441,95 @@ func TestContextCookie(t *testing.T) { HttpOnly: true, } c.SetCookie(cookie) - assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") - assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") - assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") -} - -func TestContextPath(t *testing.T) { - e := New() - r := e.Router() - - handler := func(c Context) error { return c.String(http.StatusOK, "OK") } - - r.Add(http.MethodGet, "/users/:id", handler) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1", c) - - assert := testify.New(t) - - assert.Equal("/users/:id", c.Path()) - - r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal("/users/:uid/files/:fid", c.Path()) + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } func TestContextPathParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) + c := e.NewContext(req, nil).(*DefaultContext) + params := &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + } // ParamNames - c.SetParamNames("uid", "fid") - testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) - - // ParamValues - c.SetParamValues("101", "501") - testify.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + c.pathParams = params + assert.EqualValues(t, *params, c.PathParams()) // Param - testify.Equal(t, "501", c.Param("fid")) - testify.Equal(t, "", c.Param("undefined")) + assert.Equal(t, "501", c.PathParam("fid")) + assert.Equal(t, "", c.PathParam("undefined")) } func TestContextGetAndSetParam(t *testing.T) { e := New() r := e.Router() - r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) + _, err := r.Add(Route{ + Method: http.MethodGet, + Path: "/:foo", + Name: "", + Handler: func(Context) error { return nil }, + Middlewares: nil, + }) + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) - c.SetParamNames("foo") + + params := &PathParams{{Name: "foo", Value: "101"}} + // ParamNames + c.(*DefaultContext).pathParams = params // round-trip param values with modification - paramVals := c.ParamValues() - testify.EqualValues(t, []string{""}, c.ParamValues()) - paramVals[0] = "bar" - c.SetParamValues(paramVals...) - testify.EqualValues(t, []string{"bar"}, c.ParamValues()) + paramVals := c.PathParams() + assert.Equal(t, *params, c.PathParams()) + + paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context + assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams()) + + pathParams := PathParams{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + c.SetPathParams(pathParams) + assert.Equal(t, pathParams, c.PathParams()) // shouldn't explode during Reset() afterwards! - testify.NotPanics(t, func() { - c.Reset(nil, nil) + assert.NotPanics(t, func() { + c.(ServableContext).Reset(nil, nil) }) + assert.Equal(t, PathParams{}, c.PathParams()) + assert.Len(t, *c.(*DefaultContext).pathParams, 0) + assert.Equal(t, cap(*c.(*DefaultContext).pathParams), 1) } // Issue #1655 -func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { - assert := testify.New(t) - +func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) { e := New() - assert.Equal(0, *e.maxParam) - - expectedOneParam := []string{"one"} - expectedTwoParams := []string{"one", "two"} - expectedThreeParams := []string{"one", "two", ""} - expectedABCParams := []string{"A", "B", "C"} + c := e.NewContext(nil, nil).(*DefaultContext) - c := e.NewContext(nil, nil) - c.SetParamNames("1", "2") - c.SetParamValues(expectedTwoParams...) - assert.Equal(2, *e.maxParam) - assert.EqualValues(expectedTwoParams, c.ParamValues()) - - c.SetParamNames("1") - assert.Equal(2, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are - assert.EqualValues(expectedOneParam, c.ParamValues()) - - c.SetParamNames("1", "2", "3") - assert.Equal(3, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam - assert.EqualValues(expectedThreeParams, c.ParamValues()) + assert.Equal(t, 0, e.contextPathParamAllocSize) + expectedTwoParams := &PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + } + c.SetRawPathParams(expectedTwoParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, *expectedTwoParams, c.PathParams()) - c.SetParamValues("A", "B", "C", "D") - assert.Equal(3, *e.maxParam) - // Here D shouldn't be returned - assert.EqualValues(expectedABCParams, c.ParamValues()) + expectedThreeParams := PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + {Name: "3", Value: "three"}, + } + c.SetPathParams(expectedThreeParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, expectedThreeParams, c.PathParams()) } func TestContextFormValue(t *testing.T) { @@ -564,25 +543,29 @@ func TestContextFormValue(t *testing.T) { c := e.NewContext(req, nil) // FormValue - testify.Equal(t, "Jon Snow", c.FormValue("name")) - testify.Equal(t, "jon@labstack.com", c.FormValue("email")) + assert.Equal(t, "Jon Snow", c.FormValue("name")) + assert.Equal(t, "jon@labstack.com", c.FormValue("email")) + + // FormValueDefault + assert.Equal(t, "Jon Snow", c.FormValueDefault("name", "nope")) + assert.Equal(t, "default", c.FormValueDefault("missing", "default")) - // FormParams - params, err := c.FormParams() - if testify.NoError(t, err) { - testify.Equal(t, url.Values{ + // FormValues + values, err := c.FormValues() + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, - }, params) + }, values) } // Multipart FormParams error req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) - params, err = c.FormParams() - testify.Nil(t, params) - testify.Error(t, err) + values, err = c.FormValues() + assert.Nil(t, values) + assert.Error(t, err) } func TestContextQueryParam(t *testing.T) { @@ -594,11 +577,15 @@ func TestContextQueryParam(t *testing.T) { c := e.NewContext(req, nil) // QueryParam - testify.Equal(t, "Jon Snow", c.QueryParam("name")) - testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) + assert.Equal(t, "Jon Snow", c.QueryParam("name")) + assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) + + // QueryParamDefault + assert.Equal(t, "Jon Snow", c.QueryParamDefault("name", "nope")) + assert.Equal(t, "default", c.QueryParamDefault("missing", "default")) // QueryParams - testify.Equal(t, url.Values{ + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, c.QueryParams()) @@ -609,7 +596,7 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") - if testify.NoError(t, err) { + if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() @@ -618,8 +605,8 @@ func TestContextFormFile(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") - if testify.NoError(t, err) { - testify.Equal(t, "test", f.Filename) + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) } } @@ -634,8 +621,8 @@ func TestContextMultipartForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() - if testify.NoError(t, err) { - testify.NotNil(t, f) + if assert.NoError(t, err) { + assert.NotNil(t, f) } } @@ -644,22 +631,22 @@ func TestContextRedirect(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - testify.Equal(t, http.StatusMovedPermanently, rec.Code) - testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) - testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) + assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) + assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } func TestContextStore(t *testing.T) { - var c Context = new(context) + var c Context = new(DefaultContext) c.Set("name", "Jon Snow") - testify.Equal(t, "Jon Snow", c.Get("name")) + assert.Equal(t, "Jon Snow", c.Get("name")) } func BenchmarkContext_Store(b *testing.B) { e := &Echo{} - c := &context{ + c := &DefaultContext{ echo: e, } @@ -671,42 +658,6 @@ func BenchmarkContext_Store(b *testing.B) { } } -func TestContextHandler(t *testing.T) { - e := New() - r := e.Router() - b := new(bytes.Buffer) - - r.Add(http.MethodGet, "/handler", func(Context) error { - _, err := b.Write([]byte("handler")) - return err - }) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/handler", c) - err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) - testify.NoError(t, err) -} - -func TestContext_SetHandler(t *testing.T) { - var c Context = new(context) - - testify.Nil(t, c.Handler()) - - c.SetHandler(func(c Context) error { - return nil - }) - testify.NotNil(t, c.Handler()) -} - -func TestContext_Path(t *testing.T) { - path := "/pa/th" - - var c Context = new(context) - - c.SetPath(path) - testify.Equal(t, path, c.Path()) -} - type validator struct{} func (*validator) Validate(i interface{}) error { @@ -717,10 +668,10 @@ func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) - testify.Error(t, c.Validate(struct{}{})) + assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} - testify.NoError(t, c.Validate(struct{}{})) + assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { @@ -728,21 +679,21 @@ func TestContext_QueryString(t *testing.T) { queryString := "query=string&var=val" - req := httptest.NewRequest(GET, "/?"+queryString, nil) + req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) - testify.Equal(t, queryString, c.QueryString()) + assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { - var c Context = new(context) + var c Context = new(DefaultContext) - testify.Nil(t, c.Request()) + assert.Nil(t, c.Request()) - req := httptest.NewRequest(GET, "/path", nil) + req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) - testify.Equal(t, req, c.Request()) + assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { @@ -751,7 +702,7 @@ func TestContext_Scheme(t *testing.T) { s string }{ { - &context{ + &DefaultContext{ request: &http.Request{ TLS: &tls.ConnectionState{}, }, @@ -759,7 +710,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedProto: []string{"https"}}, }, @@ -767,7 +718,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedProtocol: []string{"http"}}, }, @@ -775,7 +726,7 @@ func TestContext_Scheme(t *testing.T) { "http", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedSsl: []string{"on"}}, }, @@ -783,7 +734,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXUrlScheme: []string{"https"}}, }, @@ -791,7 +742,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &DefaultContext{ request: &http.Request{}, }, "http", @@ -799,44 +750,44 @@ func TestContext_Scheme(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.Scheme()) + assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c Context - ws testify.BoolAssertionFunc + ws assert.BoolAssertionFunc }{ { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, }, - testify.True, + assert.True, }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, }, - testify.True, + assert.True, }, { - &context{ + &DefaultContext{ request: &http.Request{}, }, - testify.False, + assert.False, }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, - testify.False, + assert.False, }, } @@ -849,30 +800,14 @@ func TestContext_IsWebSocket(t *testing.T) { func TestContext_Bind(t *testing.T) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, nil) u := new(user) req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) - testify.NoError(t, err) - testify.Equal(t, &user{1, "Jon Snow"}, u) -} - -func TestContext_Logger(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - - log1 := c.Logger() - testify.NotNil(t, log1) - - log2 := log.New("echo2") - c.SetLogger(log2) - testify.Equal(t, log2, c.Logger()) - - // Resetting the context returns the initial logger - c.Reset(nil, nil) - testify.Equal(t, log1, c.Logger()) + assert.NoError(t, err) + assert.Equal(t, &user{1, "Jon Snow"}, u) } func TestContext_RealIP(t *testing.T) { @@ -881,7 +816,7 @@ func TestContext_RealIP(t *testing.T) { s string }{ { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }, @@ -889,7 +824,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, }, @@ -897,7 +832,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, }, @@ -905,7 +840,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"192.168.0.1"}, @@ -915,7 +850,7 @@ func TestContext_RealIP(t *testing.T) { "192.168.0.1", }, { - &context{ + &DefaultContext{ request: &http.Request{ RemoteAddr: "89.89.89.89:1654", }, @@ -925,6 +860,128 @@ func TestContext_RealIP(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.RealIP()) + assert.Equal(t, tt.s, tt.c.RealIP()) + } +} + +func TestContext_File(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok, from default file system", + whenFile: "_fixture/images/walle.png", + whenFS: nil, + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "ok, from custom file system", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + if tc.whenFS != nil { + e.Filesystem = tc.whenFS + } + + handler := func(ec Context) error { + return ec.(*DefaultContext).File(tc.whenFile) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestContext_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + handler := func(ec Context) error { + return ec.(*DefaultContext).FileFS(tc.whenFile, tc.whenFS) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) } } diff --git a/echo.go b/echo.go index 8829619c7..b0dad550e 100644 --- a/echo.go +++ b/echo.go @@ -5,12 +5,12 @@ Example: package main - import ( - "net/http" - - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - ) + import ( + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + "log" + "net/http" + ) // Handler func hello(c echo.Context) error { @@ -29,7 +29,9 @@ Example: e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":8080"); err != http.ErrServerClosed { + log.Fatal(err) + } } Learn more at https://echo.labstack.com @@ -37,124 +39,89 @@ Learn more at https://echo.labstack.com package echo import ( - "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" "io" - "io/ioutil" - stdLog "log" - "net" + "io/fs" "net/http" - "reflect" - "runtime" + "net/url" + "os" + "os/signal" + "path/filepath" + "strings" "sync" - "time" - - "github.com/labstack/gommon/color" - "github.com/labstack/gommon/log" - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) -type ( - // Echo is the top-level framework instance. - Echo struct { - filesystem - common - // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener binded) without having data races. - startupMutex sync.RWMutex - StdLogger *stdLog.Logger - colorer *color.Color - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - pool sync.Pool - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool - HTTPErrorHandler HTTPErrorHandler - Binder Binder - JSONSerializer JSONSerializer - Validator Validator - Renderer Renderer - Logger Logger - IPExtractor IPExtractor - ListenerNetwork string - } +// Echo is the top-level framework instance. +// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. +type Echo struct { + // premiddleware are middlewares that are run for every request before routing is done + premiddleware []MiddlewareFunc + // middleware are middlewares that are run after router found a matching route (not found and method not found are also matches) + middleware []MiddlewareFunc - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } + router Router + routers map[string]Router + routerCreator func(e *Echo) Router - // HTTPError represents an error that occurred while handling a request. - HTTPError struct { - Code int `json:"-"` - Message interface{} `json:"message"` - Internal error `json:"-"` // Stores the error returned by an external dependency - } + contextPool sync.Pool + // contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info for context + // creation time so we can allocate path parameter values slice. + contextPathParamAllocSize int - // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(next HandlerFunc) HandlerFunc + // NewContextFunc allows using custom context implementations, instead of default *echo.context + NewContextFunc func(e *Echo, pathParamAllocSize int) ServableContext + Debug bool + HTTPErrorHandler HTTPErrorHandler + Binder Binder + JSONSerializer JSONSerializer + Validator Validator + Renderer Renderer + Logger Logger + IPExtractor IPExtractor - // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(c Context) error + // Filesystem is file system used by Static and File handlers to access files. + // Defaults to os.DirFS(".") + // + // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary + // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths + // including `assets/images` as their prefix. + Filesystem fs.FS +} - // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(error, Context) +// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} - // Validator is the interface that wraps the Validate function. - Validator interface { - Validate(i interface{}) error - } +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(c Context, err error) - // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. - JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error - } +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c Context) error - // Renderer is the interface that wraps the Render function. - Renderer interface { - Render(io.Writer, string, interface{}, Context) error - } +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc - // Map defines a generic map of type `map[string]interface{}`. - Map map[string]interface{} +// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} - // Common struct for Echo & Group. - common struct{} -) +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i interface{}) error +} -// HTTP methods -// NOTE: Deprecated, please use the stdlib constants directly instead. -const ( - CONNECT = http.MethodConnect - DELETE = http.MethodDelete - GET = http.MethodGet - HEAD = http.MethodHead - OPTIONS = http.MethodOptions - PATCH = http.MethodPatch - POST = http.MethodPost - // PROPFIND = "PROPFIND" - PUT = http.MethodPut - TRACE = http.MethodTrace -) +// Renderer is the interface that wraps the Render function. +type Renderer interface { + Render(io.Writer, string, interface{}, Context) error +} + +// Map defines a generic map of type `map[string]interface{}`. +type Map map[string]interface{} // MIME types const ( @@ -246,297 +213,360 @@ const ( const ( // Version of Echo - Version = "4.7.2" - website = "https://echo.labstack.com" - // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo - banner = ` - ____ __ - / __/___/ / ___ - / _// __/ _ \/ _ \ -/___/\__/_//_/\___/ %s -High performance, minimalist Go web framework -%s -____________________________________O/_______ - O\ -` -) - -var ( - methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, - } -) - -// Errors -var ( - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - ErrNotFound = NewHTTPError(http.StatusNotFound) - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) - ErrForbidden = NewHTTPError(http.StatusForbidden) - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) - ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) - ErrBadRequest = NewHTTPError(http.StatusBadRequest) - ErrBadGateway = NewHTTPError(http.StatusBadGateway) - ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) - ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) - ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") - ErrInvalidListenerNetwork = errors.New("invalid listener network") + Version = "5.0.0-alpha" ) -// Error handlers -var ( - NotFoundHandler = func(c Context) error { - return ErrNotFound - } - - MethodNotAllowedHandler = func(c Context) error { - // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) - // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned - routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) - if ok && routerAllowMethods != "" { - c.Response().Header().Set(HeaderAllow, routerAllowMethods) - } - return ErrMethodNotAllowed - } -) +var methods = [...]string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + PROPFIND, + http.MethodPut, + http.MethodTrace, + REPORT, +} // New creates an instance of Echo. -func New() (e *Echo) { - e = &Echo{ - filesystem: createFilesystem(), - Server: new(http.Server), - TLSServer: new(http.Server), - AutoTLSManager: autocert.Manager{ - Prompt: autocert.AcceptTOS, +func New() *Echo { + logger := newJSONLogger(os.Stdout) + e := &Echo{ + Logger: logger, + Filesystem: newDefaultFS(), + Binder: &DefaultBinder{}, + JSONSerializer: &DefaultJSONSerializer{}, + + routers: make(map[string]Router), + routerCreator: func(ec *Echo) Router { + return NewRouter(RouterConfig{}) }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), - ListenerNetwork: "tcp", } - e.Server.Handler = e - e.TLSServer.Handler = e - e.HTTPErrorHandler = e.DefaultHTTPErrorHandler - e.Binder = &DefaultBinder{} - e.JSONSerializer = &DefaultJSONSerializer{} - e.Logger.SetLevel(log.ERROR) - e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) - e.pool.New = func() interface{} { + + e.router = NewRouter(RouterConfig{}) + e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) + e.contextPool.New = func() interface{} { return e.NewContext(nil, nil) } - e.router = NewRouter(e) - e.routers = map[string]*Router{} - return + return e } -// NewContext returns a Context instance. +// NewContext returns a new Context instance. +// +// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway +// these arguments are useful when creating context for tests and cases like that. func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { - return &context{ - request: r, - response: NewResponse(w, e), - store: make(Map), - echo: e, - pvalues: make([]string, *e.maxParam), - handler: NotFoundHandler, + var c Context + if e.NewContextFunc != nil { + c = e.NewContextFunc(e, e.contextPathParamAllocSize) + } else { + c = NewDefaultContext(e, e.contextPathParamAllocSize) } + c.SetRequest(r) + c.SetResponse(NewResponse(w, e)) + return c } // Router returns the default router. -func (e *Echo) Router() *Router { +func (e *Echo) Router() Router { return e.router } // Routers returns the map of host => router. -func (e *Echo) Routers() map[string]*Router { +func (e *Echo) Routers() map[string]Router { return e.routers } -// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response -// with status code. +// RouterFor returns Router for given host. +func (e *Echo) RouterFor(host string) Router { + return e.routers[host] +} + +// ResetRouterCreator resets callback for creating new router instances. +// Note: current (default) router is immediately replaced with router created with creator func and vhost routers are cleared. +func (e *Echo) ResetRouterCreator(creator func(e *Echo) Router) { + e.routerCreator = creator + e.router = creator(e) + e.routers = make(map[string]Router) +} + +// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response +// with status code. `exposeError` parameter decides if returned message will contain also error message or not // -// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) +// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // handler. Then the error that global error handler received will be ignored because we have already "commited" the // response and status code header has been sent to the client. -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - - if c.Response().Committed { - return - } - - he, ok := err.(*HTTPError) - if ok { - if he.Internal != nil { - if herr, ok := he.Internal.(*HTTPError); ok { - he = herr - } +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { + return func(c Context, err error) { + if c.Response().Committed { + return } - } else { - he = &HTTPError{ + + he := &HTTPError{ Code: http.StatusInternalServerError, Message: http.StatusText(http.StatusInternalServerError), } - } + if errors.As(err, &he) { + if he.Internal != nil { // max 2 levels of checks even if internal could have also internal + errors.As(he.Internal, &he) + } + } - // Issue #1426 - code := he.Code - message := he.Message - if m, ok := he.Message.(string); ok { - if e.Debug { - message = Map{"message": m, "error": err.Error()} - } else { - message = Map{"message": m} + // Issue #1426 + code := he.Code + message := he.Message + if m, ok := he.Message.(string); ok { + if exposeError { + message = Map{"message": m, "error": err.Error()} + } else { + message = Map{"message": m} + } } - } - // Send response - if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) - } else { - err = c.JSON(code, message) - } - if err != nil { - e.Logger.Error(err) + // Send response + var cErr error + if c.Request().Method == http.MethodHead { // Issue #608 + cErr = c.NoContent(he.Code) + } else { + cErr = c.JSON(code, message) + } + if cErr != nil { + c.Echo().Logger.Error(err) // truly rare case. ala client already disconnected + } } } -// Pre adds middleware to the chain which is run before router. +// Pre adds middleware to the chain which is run before router tries to find matching route. +// Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) } -// Use adds middleware to the chain which is run after router. +// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) } // CONNECT registers a new CONNECT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodTrace, path, h, m...) } -// Any registers a new route for all HTTP methods and path with matching handler -// in the router with optional route-level middleware. -func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// Any registers a new route for all supported HTTP methods and path with matching handler +// in the router with optional route-level middleware. Panics on error. +func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } // Match registers a new route for multiple HTTP methods and path with matching -// handler in the router with optional route-level middleware. -func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// handler in the router with optional route-level middleware. Panics on error. +func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) + } + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris +} + +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, fsRoot string) RouteInfo { + subFs := MustSubFS(e.Filesystem, fsRoot) + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + ) +} + +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) RouteInfo { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c Context) error { + p := c.PathParam("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath + } + + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid + name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) + fi, err := fs.Stat(fileSystem, name) + if err != nil { + return ErrNotFound + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, p+"/") + } + return fsFile(c, name, fileSystem) + } +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c Context) error { + return fsFile(c, file, filesystem) } - return routes } -func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, - m ...MiddlewareFunc) *Route { - return get(path, func(c Context) error { +// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { return c.File(file) - }, m...) + } + return e.Add(http.MethodGet, path, handler, middleware...) } -// File registers a new route with path to serve a static file with optional route-level middleware. -func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { - return e.file(path, file, e.GET, m...) +// AddRoute registers a new Route with default host Router +func (e *Echo) AddRoute(route Routable) (RouteInfo, error) { + return e.add("", route) } -func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) +func (e *Echo) add(host string, route Routable) (RouteInfo, error) { router := e.findRouter(host) - router.Add(method, path, func(c Context) error { - h := applyMiddleware(handler, middleware...) - return h(c) - }) - r := &Route{ - Method: method, - Path: path, - Name: name, + ri, err := router.Add(route) + if err != nil { + return nil, err + } + + paramsCount := len(ri.Params()) + if paramsCount > e.contextPathParamAllocSize { + e.contextPathParamAllocSize = paramsCount } - e.router.routes[method+path] = r - return r + return ri, nil } // Add registers a new route for an HTTP method and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - return e.add("", method, path, handler, middleware...) +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := e.add( + "", + Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + Name: "", + }, + ) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri } // Host creates a new router group for the provided host and optional host-level middleware. func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { - e.routers[name] = NewRouter(e) + e.routers[name] = e.routerCreator(e) g = &Group{host: name, echo: e} g.Use(m...) return @@ -549,328 +579,82 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { return } -// URI generates a URI from handler. -func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string { - name := handlerName(handler) - return e.Reverse(name, params...) -} - -// URL is an alias for `URI` function. -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { - return e.URI(h, params...) -} - -// Reverse generates an URL from route name and provided parameters. -func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() -} - -// Routes returns the registered routes. -func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes -} - // AcquireContext returns an empty `Context` instance from the pool. // You must return the context by calling `ReleaseContext()`. func (e *Echo) AcquireContext() Context { - return e.pool.Get().(Context) + return e.contextPool.Get().(Context) } // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. func (e *Echo) ReleaseContext(c Context) { - e.pool.Put(c) + e.contextPool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Acquire context - c := e.pool.Get().(*context) + var c ServableContext + if e.NewContextFunc != nil { + // NOTE: we are not casting always context to RoutableContext because casting to interface vs pointer to struct is + // "significantly" slower. Echo Context interface has way to many methods so these checks take time. + // These are benchmarks with 1.16: + // * interface extending another interface = +24% slower (3233 ns/op vs 2605 ns/op) + // * interface (not extending any, just methods)= +14% slower + // + // Quote from https://stackoverflow.com/a/31584377 + // "it's even worse with interface-to-interface assertion, because you also need to ensure that the type implements the interface." + // + // So most of the time we do not need custom context type and simple IF + cast to pointer to struct is fast enough. + c = e.contextPool.Get().(ServableContext) + } else { + c = e.contextPool.Get().(*DefaultContext) + } c.Reset(r, w) var h func(Context) error if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h = c.Handler() - h = applyMiddleware(h, e.middleware...) + h = applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) } else { - h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h := c.Handler() - h = applyMiddleware(h, e.middleware...) - return h(c) + h = func(cc Context) error { + // NOTE: router will be executed after pre middlewares have been run. We assume here that context we receive after pre middlewares + // is the same we began with. If not - this is use-case we do not support and is probably abuse from developer. + h1 := applyMiddleware(e.findRouter(r.Host).Route(c), e.middleware...) + return h1(cc) } h = applyMiddleware(h, e.premiddleware...) } // Execute chain if err := h(c); err != nil { - e.HTTPErrorHandler(err, c) + e.HTTPErrorHandler(c, err) } - // Release context - e.pool.Put(c) + e.contextPool.Put(c) } -// Start starts an HTTP server. +// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by +// sending os.Interrupt signal with `ctrl+c`. +// +// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration +// options. +// +// In need of customization use: +// sc := echo.StartConfig{Address: ":8080"} +// if err := sc.Start(e); err != http.ErrServerClosed { +// log.Fatal(err) +// } +// // or standard library `http.Server` +// s := http.Server{Addr: ":8080", Handler: e} +// if err := s.ListenAndServe(); err != http.ErrServerClosed { +// log.Fatal(err) +// } func (e *Echo) Start(address string) error { - e.startupMutex.Lock() - e.Server.Addr = address - if err := e.configureServer(e.Server); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return e.Server.Serve(e.Listener) -} - -// StartTLS starts an HTTPS server. -// If `certFile` or `keyFile` is `string` the values are treated as file paths. -// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. -func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { - e.startupMutex.Lock() - var cert []byte - if cert, err = filepathOrContent(certFile); err != nil { - e.startupMutex.Unlock() - return - } + sc := StartConfig{Address: address} + ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt) // start shutdown process on ctrl+c + defer cancel() + sc.GracefulContext = ctx - var key []byte - if key, err = filepathOrContent(keyFile); err != nil { - e.startupMutex.Unlock() - return - } - - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.Certificates = make([]tls.Certificate, 1) - if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { - e.startupMutex.Unlock() - return - } - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { - switch v := fileOrContent.(type) { - case string: - return ioutil.ReadFile(v) - case []byte: - return v, nil - default: - return nil, ErrInvalidCertOrKeyType - } -} - -// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. -func (e *Echo) StartAutoTLS(address string) error { - e.startupMutex.Lock() - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func (e *Echo) configureTLS(address string) { - s := e.TLSServer - s.Addr = address - if !e.DisableHTTP2 { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") - } -} - -// StartServer starts a custom http server. -func (e *Echo) StartServer(s *http.Server) (err error) { - e.startupMutex.Lock() - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - if s.TLSConfig != nil { - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -func (e *Echo) configureServer(s *http.Server) error { - // Setup - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = e - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if s.TLSConfig == nil { - if e.Listener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.Listener = l - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - return nil - } - if e.TLSListener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.TLSListener = tls.NewListener(l, s.TLSConfig) - } - if !e.HidePort { - e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) - } - return nil -} - -// ListenerAddr returns net.Addr for Listener -func (e *Echo) ListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.Listener == nil { - return nil - } - return e.Listener.Addr() -} - -// TLSListenerAddr returns net.Addr for TLSListener -func (e *Echo) TLSListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.TLSListener == nil { - return nil - } - return e.TLSListener.Addr() -} - -// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error { - e.startupMutex.Lock() - // Setup - s := e.Server - s.Addr = address - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = h2c.NewHandler(e, h2s) - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if e.Listener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - e.startupMutex.Unlock() - return err - } - e.Listener = l - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -// Close immediately stops the server. -// It internally calls `http.Server#Close()`. -func (e *Echo) Close() error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Close(); err != nil { - return err - } - return e.Server.Close() -} - -// Shutdown stops the server gracefully. -// It internally calls `http.Server#Shutdown()`. -func (e *Echo) Shutdown(ctx stdContext.Context) error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Shutdown(ctx); err != nil { - return err - } - return e.Server.Shutdown(ctx) -} - -// NewHTTPError creates a new HTTPError instance. -func NewHTTPError(code int, message ...interface{}) *HTTPError { - he := &HTTPError{Code: code, Message: http.StatusText(code)} - if len(message) > 0 { - he.Message = message[0] - } - return he -} - -// Error makes it compatible with `error` interface. -func (he *HTTPError) Error() string { - if he.Internal == nil { - return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) - } - return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) -} - -// SetInternal sets error to HTTPError.Internal -func (he *HTTPError) SetInternal(err error) *HTTPError { - he.Internal = err - return he -} - -// Unwrap satisfies the Go 1.13 error wrapper interface. -func (he *HTTPError) Unwrap() error { - return he.Internal + return sc.Start(e) } // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. @@ -895,19 +679,7 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -// GetPath returns RawPath, if it's empty returns Path from URL -// Difference between RawPath and Path is: -// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. -// * RawPath is an optional field which only gets set if the default encoding is different from Path. -func GetPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path - } - return path -} - -func (e *Echo) findRouter(host string) *Router { +func (e *Echo) findRouter(host string) Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { return r @@ -916,53 +688,59 @@ func (e *Echo) findRouter(host string) *Router { return e.router } -func handlerName(h HandlerFunc) string { - t := reflect.ValueOf(h).Type() - if t.Kind() == reflect.Func { - return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() +func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h) } - return t.String() + return h } -// // PathUnescape is wraps `url.PathUnescape` -// func PathUnescape(s string) (string, error) { -// return url.PathUnescape(s) -// } - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener +// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` +// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` +// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break +// all old applications that rely on being able to traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + prefix string + fs fs.FS } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - if c, err = ln.AcceptTCP(); err != nil { - return - } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { - return +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: os.DirFS(dir), } - // Ignore error from setting the KeepAlivePeriod as some systems, such as - // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP - _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute) - return } -func newListener(address, network string) (*tcpKeepAliveListener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, ErrInvalidListenerNetwork - } - l, err := net.Listen(network, address) - if err != nil { - return nil, err +func (fs defaultFS) Open(name string) (fs.File, error) { + return fs.fs.Open(name) +} + +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to + // allow cases when root is given as `../somepath` which is not valid for fs.FS + root = filepath.Join(dFS.prefix, root) + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil } - return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil + return fs.Sub(currentFs, root) } -func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) + if err != nil { + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) } - return h + return subFs } diff --git a/echo_fs.go b/echo_fs.go deleted file mode 100644 index c3790545a..000000000 --- a/echo_fs.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -import ( - "net/http" - "net/url" - "os" - "path/filepath" -) - -type filesystem struct { -} - -func createFilesystem() filesystem { - return filesystem{} -} - -// Static registers a new route with path prefix to serve static files from the -// provided root directory. -func (e *Echo) Static(prefix, root string) *Route { - if root == "" { - root = "." // For security we want to restrict to CWD. - } - return e.static(prefix, root, e.GET) -} - -func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { - h := func(c Context) error { - p, err := url.PathUnescape(c.Param("*")) - if err != nil { - return err - } - - name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security - fi, err := os.Stat(name) - if err != nil { - // The access path does not exist - return NotFoundHandler(c) - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return c.File(name) - } - // Handle added routes based on trailing slash: - // /prefix => exact route "/prefix" + any route "/prefix/*" - // /prefix/ => only any route "/prefix/*" - if prefix != "" { - if prefix[len(prefix)-1] == '/' { - // Only add any route for intentional trailing slash - return get(prefix+"*", h) - } - get(prefix, h) - } - return get(prefix+"/*", h) -} diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go deleted file mode 100644 index eb17768ab..000000000 --- a/echo_fs_go1.16.go +++ /dev/null @@ -1,169 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "fmt" - "io/fs" - "net/http" - "net/url" - "os" - "path/filepath" - "runtime" - "strings" -) - -type filesystem struct { - // Filesystem is file system used by Static and File handlers to access files. - // Defaults to os.DirFS(".") - // - // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary - // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths - // including `assets/images` as their prefix. - Filesystem fs.FS -} - -func createFilesystem() filesystem { - return filesystem{ - Filesystem: newDefaultFS(), - } -} - -// Static registers a new route with path prefix to serve static files from the provided root directory. -func (e *Echo) Static(pathPrefix, fsRoot string) *Route { - subFs := MustSubFS(e.Filesystem, fsRoot) - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(subFs, false), - ) -} - -// StaticFS registers a new route with path prefix to serve static files from the provided file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// StaticDirectoryHandler creates handler function to serve files from provided file system -// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. -func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { - return func(c Context) error { - p := c.Param("*") - if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice - tmpPath, err := url.PathUnescape(p) - if err != nil { - return fmt.Errorf("failed to unescape path variable: %w", err) - } - p = tmpPath - } - - // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) - fi, err := fs.Stat(fileSystem, name) - if err != nil { - return ErrNotFound - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return fsFile(c, name, fileSystem) - } -} - -// FileFS registers a new route with path to serve file from the provided file system. -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return e.GET(path, StaticFileHandler(file, filesystem), m...) -} - -// StaticFileHandler creates handler function to serve file from provided file system -func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { - return func(c Context) error { - return fsFile(c, file, filesystem) - } -} - -// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. -// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. -// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` -// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not -// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to -// traverse up from current executable run path. -// NB: private because you really should use fs.FS implementation instances -type defaultFS struct { - prefix string - fs fs.FS -} - -func newDefaultFS() *defaultFS { - dir, _ := os.Getwd() - return &defaultFS{ - prefix: dir, - fs: nil, - } -} - -func (fs defaultFS) Open(name string) (fs.File, error) { - if fs.fs == nil { - return os.Open(name) - } - return fs.fs.Open(name) -} - -func subFS(currentFs fs.FS, root string) (fs.FS, error) { - root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows - if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. - // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we - // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs - if isRelativePath(root) { - root = filepath.Join(dFS.prefix, root) - } - return &defaultFS{ - prefix: root, - fs: os.DirFS(root), - }, nil - } - return fs.Sub(currentFs, root) -} - -func isRelativePath(path string) bool { - if path == "" { - return true - } - if path[0] == '/' { - return false - } - if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { - // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names - // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats - return false - } - return true -} - -// MustSubFS creates sub FS from current filesystem or panic on failure. -// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. -// -// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with -// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to -// create sub fs which uses necessary prefix for directory path. -func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { - subFs, err := subFS(currentFs, fsRoot) - if err != nil { - panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) - } - return subFs -} diff --git a/echo_fs_go1.16_test.go b/echo_fs_go1.16_test.go deleted file mode 100644 index 07e516555..000000000 --- a/echo_fs_go1.16_test.go +++ /dev/null @@ -1,265 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" -) - -func TestEcho_StaticFS(t *testing.T) { - var testCases = []struct { - name string - givenPrefix string - givenFs fs.FS - givenFsRoot string - whenURL string - expectStatus int - expectHeaderLocation string - expectBodyStartsWith string - }{ - { - name: "ok", - givenPrefix: "/images", - givenFs: os.DirFS("./_fixture/images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "ok, from sub fs", - givenPrefix: "/images", - givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "No file", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/scripts"), - whenURL: "/images/bolt.png", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/images"), - whenURL: "/images/", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory Redirect", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: "/folder", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory Redirect with non-root path", - givenPrefix: "/static", - givenFs: os.DirFS("_fixture"), - whenURL: "/static", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/static/", - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory 404 (request URL without slash)", - givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Prefixed directory redirect (without slash redirect to slash)", - givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending with slash)", - givenPrefix: "/assets/", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending without slash)", - givenPrefix: "/assets", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Sub-directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/folder/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "do not allow directory traversal (backslash - windows separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/..\\middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "do not allow directory traversal (slash - unix separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/../middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - tmpFs := tc.givenFs - if tc.givenFsRoot != "" { - tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) - } - e.StaticFS(tc.givenPrefix, tmpFs) - - req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectStatus, rec.Code) - body := rec.Body.String() - if tc.expectBodyStartsWith != "" { - assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) - } else { - assert.Equal(t, "", body) - } - - if tc.expectHeaderLocation != "" { - assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) - } else { - _, ok := rec.Result().Header["Location"] - assert.False(t, ok) - } - }) - } -} - -func TestEcho_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestEcho_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - expectError string - }{ - { - name: "panics for ../", - givenRoot: "../assets", - expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", - }, - { - name: "panics for /", - givenRoot: "/assets", - expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - assert.PanicsWithError(t, tc.expectError, func() { - e.Static("../assets", tc.givenRoot) - }) - }) - } -} diff --git a/echo_test.go b/echo_test.go index 0e1e42be0..aac92b924 100644 --- a/echo_test.go +++ b/echo_test.go @@ -3,31 +3,25 @@ package echo import ( "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" - "io/ioutil" + "io/fs" "net" "net/http" "net/http/httptest" "net/url" "os" - "reflect" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/http2" ) -type ( - user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` - } -) +type user struct { + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` +} const ( userJSON = `{"id":1,"name":"Jon Snow"}` @@ -61,16 +55,17 @@ func TestEcho(t *testing.T) { // Router assert.NotNil(t, e.Router()) - // DefaultHTTPErrorHandler - e.DefaultHTTPErrorHandler(errors.New("error"), c) + e.HTTPErrorHandler(c, errors.New("error")) + assert.Equal(t, http.StatusInternalServerError, rec.Code) } -func TestEchoStatic(t *testing.T) { +func TestEcho_StaticFS(t *testing.T) { var testCases = []struct { name string givenPrefix string - givenRoot string + givenFs fs.FS + givenFsRoot string whenURL string expectStatus int expectHeaderLocation string @@ -79,15 +74,15 @@ func TestEchoStatic(t *testing.T) { { name: "ok", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("./_fixture/images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { - name: "ok with relative path for root points to directory", + name: "ok, from sub fs", givenPrefix: "/images", - givenRoot: "./_fixture/images", + givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), @@ -95,7 +90,7 @@ func TestEchoStatic(t *testing.T) { { name: "No file", givenPrefix: "/images", - givenRoot: "_fixture/scripts", + givenFs: os.DirFS("_fixture/scripts"), whenURL: "/images/bolt.png", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -103,7 +98,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("_fixture/images"), whenURL: "/images/", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -111,7 +106,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture/"), whenURL: "/folder", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -120,7 +115,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect with non-root path", givenPrefix: "/static", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/static", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/static/", @@ -129,7 +124,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory 404 (request URL without slash)", givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -137,7 +132,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory redirect (without slash redirect to slash)", givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -146,7 +141,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -154,7 +149,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending with slash)", givenPrefix: "/assets/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -162,7 +157,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending without slash)", givenPrefix: "/assets", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -170,7 +165,7 @@ func TestEchoStatic(t *testing.T) { { name: "Sub-directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -178,7 +173,7 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (backslash - windows separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/..\\middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -186,7 +181,7 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (slash - unix separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/../middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -196,10 +191,18 @@ func TestEchoStatic(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - e.Static(tc.givenPrefix, tc.givenRoot) + + tmpFs := tc.givenFs + if tc.givenFsRoot != "" { + tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) + } + e.StaticFS(tc.givenPrefix, tmpFs) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) body := rec.Body.String() if tc.expectBodyStartsWith != "" { @@ -218,39 +221,117 @@ func TestEchoStatic(t *testing.T) { } } -func TestEchoStaticRedirectIndex(t *testing.T) { - e := New() +func TestEcho_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } - // HandlerFunc - e.Static("/static", "_fixture") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - errCh := make(chan error) + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() - go func() { - errCh <- e.Start(":0") - }() + e.ServeHTTP(rec, req) - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) + assert.Equal(t, tc.expectCode, rec.Code) - addr := e.ListenerAddr().String() - if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(t, true, strings.HasPrefix(string(body), "")) - } else { - assert.Fail(t, err.Error()) - } +func TestEcho_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../assets", + expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", + }, + { + name: "panics for /", + givenRoot: "/assets", + expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", + }, + } - } else { - assert.NoError(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + assert.PanicsWithError(t, tc.expectError, func() { + e.Static("../assets", tc.givenRoot) + }) + }) } +} - if err := e.Close(); err != nil { - t.Fatal(err) +func TestEchoStaticRedirectIndex(t *testing.T) { + e := New() + + // HandlerFunc + ri := e.Static("/static", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/static*", ri.Path()) + assert.Equal(t, "GET:/static*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) + + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + addr, err := startOnRandomPort(ctx, e) + if err != nil { + assert.Fail(t, err.Error()) } + + code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(body, "")) + assert.Equal(t, http.StatusOK, code) } func TestEchoFile(t *testing.T) { @@ -310,7 +391,8 @@ func TestEchoMiddleware(t *testing.T) { e.Pre(func(next HandlerFunc) HandlerFunc { return func(c Context) error { - assert.Empty(t, c.Path()) + // before route match is found RouteInfo does not exist + assert.Equal(t, nil, c.RouteInfo()) buf.WriteString("-1") return next(c) } @@ -355,7 +437,7 @@ func TestEchoMiddlewareError(t *testing.T) { return errors.New("error") } }) - e.GET("/", NotFoundHandler) + e.GET("/", notFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -410,128 +492,202 @@ func TestEchoWrapMiddleware(t *testing.T) { } } +func TestEchoGet_routeInfoIsImmutable(t *testing.T) { + e := New() + ri := e.GET("/test", handlerFunc) + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err := e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) + + rInfo := ri.(routeInfo) + rInfo.name = "changed" // this change should not change other returned values + + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err = e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) +} + func TestEchoConnect(t *testing.T) { e := New() - testMethod(t, http.MethodConnect, "/", e) + + ri := e.CONNECT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodConnect+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoDelete(t *testing.T) { e := New() - testMethod(t, http.MethodDelete, "/", e) + + ri := e.DELETE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodDelete+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoGet(t *testing.T) { e := New() - testMethod(t, http.MethodGet, "/", e) + + ri := e.GET("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodGet+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodGet, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoHead(t *testing.T) { e := New() - testMethod(t, http.MethodHead, "/", e) + + ri := e.HEAD("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodHead+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoOptions(t *testing.T) { e := New() - testMethod(t, http.MethodOptions, "/", e) + + ri := e.OPTIONS("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodOptions+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPatch(t *testing.T) { e := New() - testMethod(t, http.MethodPatch, "/", e) + + ri := e.PATCH("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPatch+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPost(t *testing.T) { e := New() - testMethod(t, http.MethodPost, "/", e) + + ri := e.POST("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPost+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPut(t *testing.T) { e := New() - testMethod(t, http.MethodPut, "/", e) + + ri := e.PUT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPut+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoTrace(t *testing.T) { e := New() - testMethod(t, http.MethodTrace, "/", e) + + ri := e.TRACE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodTrace+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoAny(t *testing.T) { // JFC e := New() - e.Any("/", func(c Context) error { + ris := e.Any("/", func(c Context) error { return c.String(http.StatusOK, "Any") }) + assert.Len(t, ris, 11) } func TestEchoMatch(t *testing.T) { // JFC e := New() - e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { + ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { return c.String(http.StatusOK, "Match") }) + assert.Len(t, ris, 2) } -func TestEchoURL(t *testing.T) { - e := New() - static := func(Context) error { return nil } - getUser := func(Context) error { return nil } - getAny := func(Context) error { return nil } - getFile := func(Context) error { return nil } - - e.GET("/static/file", static) - e.GET("/users/:id", getUser) - e.GET("/documents/*", getAny) - g := e.Group("/group") - g.GET("/users/:uid/files/:fid", getFile) - - assert := assert.New(t) - - assert.Equal("/static/file", e.URL(static)) - assert.Equal("/users/:id", e.URL(getUser)) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) - assert.Equal("/documents/*", e.URL(getAny)) - assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) - assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) -} - -func TestEchoRoutes(t *testing.T) { - e := New() - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } - } -} - -func TestEchoRoutesHandleHostsProperly(t *testing.T) { +func TestEcho_Routers_HandleHostsProperly(t *testing.T) { e := New() h := e.Host("route.com") routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + {Method: http.MethodGet, Path: "/users/:user/events"}, + {Method: http.MethodGet, Path: "/users/:user/events/public"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/refs"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/tags"}, } for _, r := range routes { h.Add(r.Method, r.Path, func(c Context) error { @@ -539,17 +695,22 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) { }) } - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { + routers := e.Routers() + + routeCom, ok := routers["route.com"] + assert.True(t, ok) + + if assert.Equal(t, len(routes), len(routeCom.Routes())) { + for _, r := range routeCom.Routes() { found := false for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { + if r.Method() == rr.Method && r.Path() == rr.Path { found = true break } } if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) + t.Errorf("Route %s %s not found", r.Method(), r.Path()) } } } @@ -561,7 +722,7 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { return c.String(http.StatusOK, "/with/slash") }) e.GET("/:id", func(c Context) error { - return c.String(http.StatusOK, c.Param("id")) + return c.String(http.StatusOK, c.PathParam("id")) }) var testCases = []struct { @@ -598,8 +759,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } func TestEchoHost(t *testing.T) { - assert := assert.New(t) - okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } @@ -694,8 +853,8 @@ func TestEchoHost(t *testing.T) { e.ServeHTTP(rec, req) - assert.Equal(tc.expectStatus, rec.Code) - assert.Equal(tc.expectBody, rec.Body.String()) + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) }) } } @@ -783,710 +942,157 @@ func TestEchoMethodNotAllowed(t *testing.T) { func TestEchoContext(t *testing.T) { e := New() c := e.AcquireContext() - assert.IsType(t, new(context), c) + assert.IsType(t, new(DefaultContext), c) e.ReleaseContext(c) } -func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) - defer cancel() - - ticker := time.NewTicker(5 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - var addr net.Addr - if isTLS { - addr = e.TLSListenerAddr() - } else { - addr = e.ListenerAddr() - } - if addr != nil && strings.Contains(addr.String(), ":") { - return nil // was started - } - case err := <-errChan: - if err == http.ErrServerClosed { - return nil - } - return err - } - } -} - -func TestEchoStart(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - err := e.Start(":0") - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - assert.NoError(t, err) - - assert.NoError(t, e.Close()) -} - -func TestEcho_StartTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - certFile string - keyFile string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid certFile", - addr: ":0", - certFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, invalid keyFile", - addr: ":0", - keyFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, failed to create cert out of certFile and keyFile", - addr: ":0", - keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key - expectError: "tls: found a certificate rather than a key in the PEM for the private key", - }, - { - name: "nok, invalid tls address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - certFile := "_fixture/certs/cert.pem" - if tc.certFile != "" { - certFile = tc.certFile - } - keyFile := "_fixture/certs/key.pem" - if tc.keyFile != "" { - keyFile = tc.keyFile - } - - err := e.StartTLS(tc.addr, certFile, keyFile) - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - if _, ok := err.(*os.PathError); ok { - assert.Error(t, err) // error messages for unix and windows are different. so test only error type here - } else { - assert.EqualError(t, err, tc.expectError) - } - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEchoStartTLSAndStart(t *testing.T) { - // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server +func TestEcho_Start(t *testing.T) { e := New() e.GET("/", func(c Context) error { - return c.String(http.StatusOK, "OK") + return c.String(http.StatusTeapot, "OK") }) - - errTLSChan := make(chan error) - go func() { - certFile := "_fixture/certs/cert.pem" - keyFile := "_fixture/certs/key.pem" - err := e.StartTLS("localhost:", certFile, keyFile) - if err != nil { - errTLSChan <- err - } - }() - - err := waitForServerStart(e, errTLSChan, true) - assert.NoError(t, err) - defer func() { - if err := e.Shutdown(stdContext.Background()); err != nil { - t.Error(err) - } - }() - - // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true) - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }} - res, err := client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - errChan := make(chan error) + rndPort, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer rndPort.Close() + errChan := make(chan error, 1) go func() { - err := e.Start("localhost:") - if err != nil { - errChan <- err - } + errChan <- e.Start(rndPort.Addr().String()) }() - err = waitForServerStart(e, errChan, false) - assert.NoError(t, err) - // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS - res, err = http.Get("http://" + e.ListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - // see if HTTPS works after HTTP listener is also added - res, err = client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + select { + case <-time.After(250 * time.Millisecond): + t.Fatal("start did not error out") + case err := <-errChan: + assert.Contains(t, err.Error(), "bind: address already in use") + } } -func TestEchoStartTLSByteString(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) +func request(method, path string, e *Echo) (int, string) { + req := httptest.NewRequest(method, path, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec.Code, rec.Body.String() +} - testCases := []struct { - cert interface{} - key interface{} - expectedErr error - name string +func TestDefaultHTTPErrorHandler(t *testing.T) { + var testCases = []struct { + name string + givenExposeError bool + givenLoggerFunc bool + whenMethod string + whenError error + expectBody string + expectStatus int + expectLogged string }{ { - cert: "_fixture/certs/cert.pem", - key: "_fixture/certs/key.pem", - expectedErr: nil, - name: `ValidCertAndKeyFilePath`, + name: "ok, expose error = true, HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error","message":"my_error"}` + "\n", }, { - cert: cert, - key: key, - expectedErr: nil, - name: `ValidCertAndKeyByteString`, + name: "ok, expose error = true, HTTPError + internal error", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(errors.New("internal_error")), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error, internal=internal_error","message":"my_error"}` + "\n", }, { - cert: cert, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidKeyType`, + name: "ok, expose error = true, HTTPError + internal HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"error":"code=418, message=my_error, internal=code=425, message=early_error","message":"early_error"}` + "\n", }, { - cert: 0, - key: key, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertType`, + name: "ok, expose error = false, HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", }, { - cert: 0, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertAndKeyTypes`, + name: "ok, expose error = false, HTTPError + internal HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"message":"early_error"}` + "\n", }, - } - - for _, test := range testCases { - test := test - t.Run(test.name, func(t *testing.T) { - e := New() - e.HideBanner = true - - errChan := make(chan error) - - go func() { - errChan <- e.StartTLS(":0", test.cert, test.key) - }() - - err := waitForServerStart(e, errChan, true) - if test.expectedErr != nil { - assert.EqualError(t, err, test.expectedErr.Error()) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEcho_StartAutoTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ { - name: "ok", - addr: ":0", + name: "ok, expose error = true, Error", + givenExposeError: true, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", }, { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - errChan <- e.StartAutoTLS(tc.addr) - }() - - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEcho_StartH2CServer(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ - { - name: "ok", - addr: ":0", + name: "ok, expose error = false, Error", + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"message":"Internal Server Error"}` + "\n", }, { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "ok, http.HEAD, expose error = true, Error", + givenExposeError: true, + whenMethod: http.MethodHead, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: ``, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) e := New() - e.Debug = true - h2s := &http2.Server{} - - errChan := make(chan error) - go func() { - err := e.StartH2CServer(tc.addr, h2s) - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func testMethod(t *testing.T, method, path string, e *Echo) { - p := reflect.ValueOf(path) - h := reflect.ValueOf(func(c Context) error { - return c.String(http.StatusOK, method) - }) - i := interface{}(e) - reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) - _, body := request(method, path, e) - assert.Equal(t, method, body) -} - -func request(method, path string, e *Echo) (int, string) { - req := httptest.NewRequest(method, path, nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - return rec.Code, rec.Body.String() -} - -func TestHTTPError(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Equal(t, "code=400, message=map[code:12]", err.Error()) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) - }) -} - -func TestHTTPError_Unwrap(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Nil(t, errors.Unwrap(err)) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "internal error", errors.Unwrap(err).Error()) - }) -} - -func TestDefaultHTTPErrorHandler(t *testing.T) { - e := New() - e.Debug = true - e.Any("/plain", func(c Context) error { - return errors.New("An error occurred") - }) - e.Any("/badrequest", func(c Context) error { - return NewHTTPError(http.StatusBadRequest, "Invalid request") - }) - e.Any("/servererror", func(c Context) error { - return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ - "code": 33, - "message": "Something bad happened", - "error": "stackinfo", - }) - }) - e.Any("/early-return", func(c Context) error { - c.String(http.StatusOK, "OK") - return errors.New("ERROR") - }) - e.GET("/internal-error", func(c Context) error { - err := errors.New("internal error message body") - return NewHTTPError(http.StatusBadRequest).SetInternal(err) - }) - - // With Debug=true plain response contains error message - c, b := request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) - // and special handling for HTTPError - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) - // complex errors are serialized to pretty JSON - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) - // if the body is already set HTTPErrorHandler should not add anything to response body - c, b = request(http.MethodGet, "/early-return", e) - assert.Equal(t, http.StatusOK, c) - assert.Equal(t, "OK", b) - // internal error should be reflected in the message - c, b = request(http.MethodGet, "/internal-error", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b) - - e.Debug = false - // With Debug=false the error response is shortened - c, b = request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) - // No difference for error response with non plain string errors - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) -} - -func TestEchoClose(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - assert.NoError(t, e.Close()) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -func TestEchoShutdown(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second) - defer cancel() - assert.NoError(t, e.Shutdown(ctx)) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -var listenerNetworkTests = []struct { - test string - network string - address string -}{ - {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, - {"tcp ipv6 address", "tcp", "[::1]:1323"}, - {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, - {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, -} - -func supportsIPv6() bool { - addrs, _ := net.InterfaceAddrs() - for _, addr := range addrs { - // Check if any interface has local IPv6 assigned - if strings.Contains(addr.String(), "::1") { - return true - } - } - return false -} - -func TestEchoListenerNetwork(t *testing.T) { - hasIPv6 := supportsIPv6() - for _, tt := range listenerNetworkTests { - if !hasIPv6 && strings.Contains(tt.address, "::") { - t.Skip("Skipping testing IPv6 for " + tt.address + ", not available") - continue - } - t.Run(tt.test, func(t *testing.T) { - e := New() - e.ListenerNetwork = tt.network - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") + e.Logger = &jsonLogger{writer: buf} + e.Any("/path", func(c Context) error { + return tc.whenError }) - errCh := make(chan error) + e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) - go func() { - errCh <- e.Start(tt.address) - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(t, "OK", string(body)) - } else { - assert.Fail(t, err.Error()) - } - - } else { - assert.Fail(t, err.Error()) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod } + c, b := request(method, "/path", e) - if err := e.Close(); err != nil { - t.Fatal(err) - } + assert.Equal(t, tc.expectStatus, c) + assert.Equal(t, tc.expectBody, b) + assert.Equal(t, tc.expectLogged, buf.String()) }) } } -func TestEchoListenerNetworkInvalid(t *testing.T) { - e := New() - e.ListenerNetwork = "unix" - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) +type myCustomContext struct { + DefaultContext } -func TestEchoReverse(t *testing.T) { - assert := assert.New(t) - - e := New() - dummyHandler := func(Context) error { return nil } - - e.GET("/static", dummyHandler).Name = "/static" - e.GET("/static/*", dummyHandler).Name = "/static/*" - e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" - e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" - e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) - - assert.Equal("/params/:foo", e.Reverse("/params/:foo")) - assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) - assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) - assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) - assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) - assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) -} - -func TestEchoReverseHandleHostProperly(t *testing.T) { - assert := assert.New(t) - - dummyHandler := func(Context) error { return nil } - - e := New() - h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "/static" - h.GET("/static/*", dummyHandler).Name = "/static/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) -} - -func TestEcho_ListenerAddr(t *testing.T) { - e := New() - - addr := e.ListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) +func (c *myCustomContext) QueryParam(name string) string { + return "prefix_" + c.DefaultContext.QueryParam(name) } -func TestEcho_TLSListenerAddr(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - +func TestEcho_customContext(t *testing.T) { e := New() - - addr := e.TLSListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.StartTLS(":0", cert, key) - }() - - err = waitForServerStart(e, errCh, true) - assert.NoError(t, err) -} - -func TestEcho_StartServer(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - certs, err := tls.X509KeyPair(cert, key) - require.NoError(t, err) - - var testCases = []struct { - name string - addr string - TLSConfig *tls.Config - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "ok, start with TLS", - addr: ":0", - TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - { - name: "nok, invalid tls address", - addr: "nope", - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - expectError: "listen tcp: address nope: missing port in address", - }, + e.NewContextFunc = func(ec *Echo, pathParamAllocSize int) ServableContext { + return &myCustomContext{ + DefaultContext: *NewDefaultContext(ec, pathParamAllocSize), + } } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Debug = true - - server := new(http.Server) - server.Addr = tc.addr - if tc.TLSConfig != nil { - server.TLSConfig = tc.TLSConfig - } - - errCh := make(chan error) - go func() { - errCh <- e.StartServer(server) - }() + e.GET("/info/:id/:file", func(c Context) error { + return c.String(http.StatusTeapot, c.QueryParam("param")) + }) - err := waitForServerStart(e, errCh, tc.TLSConfig != nil) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - assert.NoError(t, e.Close()) - }) - } + status, body := request(http.MethodGet, "/info/1/a.csv?param=123", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "prefix_123", body) } -func benchmarkEchoRoutes(b *testing.B, routes []*Route) { +func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() req := httptest.NewRequest("GET", "/", nil) u := req.URL diff --git a/go.mod b/go.mod index 4de2bdde1..339d24ef2 100644 --- a/go.mod +++ b/go.mod @@ -1,24 +1,19 @@ -module github.com/labstack/echo/v4 +module github.com/labstack/echo/v5 go 1.17 require ( - github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.3.1 + github.com/golang-jwt/jwt/v4 v4.2.0 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 + golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.11 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/text v0.3.3 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index f66734243..3290b99b8 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= -github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= -github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= +github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -18,25 +14,15 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/group.go b/group.go index bba470ce8..4f04a73a2 100644 --- a/group.go +++ b/group.go @@ -1,98 +1,121 @@ package echo import ( + "io/fs" "net/http" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent echo instance while still inheriting from it. - Group struct { - common - host string - prefix string - middleware []MiddlewareFunc - echo *Echo - } -) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent echo instance while still inheriting from it. +type Group struct { + host string + prefix string + middleware []MiddlewareFunc + echo *Echo +} // Use implements `Echo#Use()` for sub-routes within the Group. +// Group middlewares are not executed on request when there is no matching route found. func (g *Group) Use(middleware ...MiddlewareFunc) { g.middleware = append(g.middleware, middleware...) - if len(g.middleware) == 0 { - return - } - // Allow all requests to reach the group as they might get dropped if router - // doesn't find a match, making none of the group middleware process. - g.Any("", NotFoundHandler) - g.Any("/*", NotFoundHandler) } -// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. -func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. +func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodConnect, path, h, m...) } -// DELETE implements `Echo#DELETE()` for sub-routes within the Group. -func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. +func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodDelete, path, h, m...) } -// GET implements `Echo#GET()` for sub-routes within the Group. -func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. +func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodGet, path, h, m...) } -// HEAD implements `Echo#HEAD()` for sub-routes within the Group. -func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. +func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodHead, path, h, m...) } -// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. -func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. +func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodOptions, path, h, m...) } -// PATCH implements `Echo#PATCH()` for sub-routes within the Group. -func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. +func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPatch, path, h, m...) } -// POST implements `Echo#POST()` for sub-routes within the Group. -func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. +func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPost, path, h, m...) } -// PUT implements `Echo#PUT()` for sub-routes within the Group. -func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. +func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPut, path, h, m...) } -// TRACE implements `Echo#TRACE()` for sub-routes within the Group. -func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. +func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodTrace, path, h, m...) } -// Any implements `Echo#Any()` for sub-routes within the Group. -func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. +func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) + } + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return routes + return ris } -// Match implements `Echo#Match()` for sub-routes within the Group. -func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. +func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } // Group creates a new sub-group with prefix and optional sub-group-level middleware. +// Important! Group middlewares are only executed in case there was exact route match and not +// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add +// a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) @@ -102,18 +125,57 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { return } -// File implements `Echo#File()` for sub-routes within the Group. -func (g *Group) File(path, file string) { - g.file(path, file, g.GET) +// Static implements `Echo#Static()` for sub-routes within the Group. +func (g *Group) Static(pathPrefix, fsRoot string) RouteInfo { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) + return g.StaticFS(pathPrefix, subFs) +} + +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) RouteInfo { + return g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return g.GET(path, StaticFileHandler(file, filesystem), m...) } -// Add implements `Echo#Add()` for sub-routes within the Group. -func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - // Combine into a new slice to avoid accidentally passing the same slice for +// File implements `Echo#File()` for sub-routes within the Group. Panics on error. +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { + return c.File(file) + } + return g.Add(http.MethodGet, path, handler, middleware...) +} + +// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. +func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := g.AddRoute(Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri +} + +// AddRoute registers a new Routable with Router +func (g *Group) AddRoute(route Routable) (RouteInfo, error) { + // Combine middleware into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) - m = append(m, g.middleware...) - m = append(m, middleware...) - return g.echo.add(g.host, method, g.prefix+path, handler, m...) + groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) + return g.echo.add(g.host, groupRoute) } diff --git a/group_fs.go b/group_fs.go deleted file mode 100644 index 0a1ce4a94..000000000 --- a/group_fs.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !go1.16 -// +build !go1.16 - -package echo - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(prefix, root string) { - g.static(prefix, root, g.GET) -} diff --git a/group_fs_go1.16.go b/group_fs_go1.16.go deleted file mode 100644 index 2ba52b5e2..000000000 --- a/group_fs_go1.16.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "io/fs" - "net/http" -) - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(pathPrefix, fsRoot string) { - subFs := MustSubFS(g.echo.Filesystem, fsRoot) - g.StaticFS(pathPrefix, subFs) -} - -// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { - g.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// FileFS implements `Echo#FileFS()` for sub-routes within the Group. -func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return g.GET(path, StaticFileHandler(file, filesystem), m...) -} diff --git a/group_fs_go1.16_test.go b/group_fs_go1.16_test.go deleted file mode 100644 index d0caa33db..000000000 --- a/group_fs_go1.16_test.go +++ /dev/null @@ -1,106 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestGroup_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/assets/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - g := e.Group("/assets") - g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestGroup_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - expectError string - }{ - { - name: "panics for ../", - givenRoot: "../images", - expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", - }, - { - name: "panics for /", - givenRoot: "/images", - expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - g := e.Group("/assets") - - assert.PanicsWithError(t, tc.expectError, func() { - g.Static("/images", tc.givenRoot) - }) - }) - } -} diff --git a/group_test.go b/group_test.go index c51fd91eb..3914c0bd8 100644 --- a/group_test.go +++ b/group_test.go @@ -1,31 +1,70 @@ package echo import ( + "github.com/stretchr/testify/assert" + "io/fs" "io/ioutil" "net/http" "net/http/httptest" + "os" + "strings" "testing" - - "github.com/stretchr/testify/assert" ) -// TODO: Fix me -func TestGroup(t *testing.T) { - g := New().Group("/group") - h := func(Context) error { return nil } - g.CONNECT("/", h) - g.DELETE("/", h) - g.GET("/", h) - g.HEAD("/", h) - g.OPTIONS("/", h) - g.PATCH("/", h) - g.POST("/", h) - g.PUT("/", h) - g.TRACE("/", h) - g.Any("/", h) - g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) - g.Static("/static", "/tmp") - g.File("/walle", "_fixture/images//walle.png") +func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware it will not be executed when there are no routes under that group + _ = e.Group("/group", mw) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware and routes when we have no match on some route the middlewares for that + // group will not be executed + g := e.Group("/group", mw) + g.GET("/yes", handlerFunc) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_multiLevelGroup(t *testing.T) { + e := New() + + api := e.Group("/api") + users := api.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + status, body := request(http.MethodGet, "/api/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) } func TestGroupFile(t *testing.T) { @@ -92,11 +131,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { } m2 := func(next HandlerFunc) HandlerFunc { return func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } } h := func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } g.Use(m1) g.GET("/help", h, m2) @@ -119,3 +158,535 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestGroup_CONNECT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.CONNECT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_DELETE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.DELETE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_HEAD(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.HEAD("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodHead+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_OPTIONS(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.OPTIONS("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PATCH(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PATCH("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_POST(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.POST("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPost+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PUT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PUT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPut+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_TRACE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.TRACE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_Any(t *testing.T) { + e := New() + + users := e.Group("/users") + ris := users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 11) + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_AnyWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Match(t *testing.T) { + e := New() + + myMethods := []string{http.MethodGet, http.MethodPost} + users := e.Group("/users") + ris := users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 2) + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_MatchWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + myMethods := []string{http.MethodGet, http.MethodPost} + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Static(t *testing.T) { + e := New() + + g := e.Group("/books") + ri := g.Static("/download", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/books/download*", ri.Path()) + assert.Equal(t, "GET:/books/download*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) + + req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + assert.True(t, strings.HasPrefix(body, "")) +} + +func TestGroup_StaticMultiTest(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, without prefix", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "nok, without prefix does not serve dir index", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/test/", // `/test` + `*` creates route `/test*` + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/test/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/test/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + g := e.Group("/test") + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} + +func TestGroup_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/assets/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/assets") + g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestGroup_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../images", + expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", + }, + { + name: "panics for /", + givenRoot: "/images", + expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + g := e.Group("/assets") + + assert.PanicsWithError(t, tc.expectError, func() { + g.Static("/images", tc.givenRoot) + }) + }) + } +} diff --git a/httperror.go b/httperror.go new file mode 100644 index 000000000..5c217dac1 --- /dev/null +++ b/httperror.go @@ -0,0 +1,74 @@ +package echo + +import ( + "errors" + "fmt" + "net/http" +) + +// Errors +var ( + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) + ErrNotFound = NewHTTPError(http.StatusNotFound) + ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) + ErrForbidden = NewHTTPError(http.StatusForbidden) + ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) + ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) + ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) + ErrBadRequest = NewHTTPError(http.StatusBadRequest) + ErrBadGateway = NewHTTPError(http.StatusBadGateway) + ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) + ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) + ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") +) + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + Code int `json:"-"` + Message interface{} `json:"message"` + Internal error `json:"-"` // Stores the error returned by an external dependency +} + +// NewHTTPError creates a new HTTPError instance. +func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used? + he := &HTTPError{Code: code, Message: http.StatusText(code)} + if len(message) > 0 { + he.Message = message[0] + } + return he +} + +// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set. +func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError { + he := NewHTTPError(code, message...) + he.Internal = internalError + return he +} + +// Error makes it compatible with `error` interface. +func (he *HTTPError) Error() string { + if he.Internal == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) + } + return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) +} + +// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field +func (he *HTTPError) WithInternal(err error) *HTTPError { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + Internal: err, + } +} + +// Unwrap satisfies the Go 1.13 error wrapper interface. +func (he *HTTPError) Unwrap() error { + return he.Internal +} diff --git a/httperror_test.go b/httperror_test.go new file mode 100644 index 000000000..f9d340f11 --- /dev/null +++ b/httperror_test.go @@ -0,0 +1,52 @@ +package echo + +import ( + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestHTTPError(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Equal(t, "code=400, message=map[code:12]", err.Error()) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) +} + +func TestNewHTTPErrorWithInternal(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message") + assert.Equal(t, "code=400, message=test message, internal=test", he.Error()) +} + +func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test")) + assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error()) +} + +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) +} diff --git a/json.go b/json.go index 16b2d0577..16074fa24 100644 --- a/json.go +++ b/json.go @@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { err := json.NewDecoder(c.Request().Body).Decode(i) if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) + return NewHTTPErrorWithInternal( + http.StatusBadRequest, + err, + fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset), + ) } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, + err, + fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()), + ) } return err } diff --git a/json_test.go b/json_test.go index 27ee43e73..ac64d2894 100644 --- a/json_test.go +++ b/json_test.go @@ -14,7 +14,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) assert := testify.New(t) @@ -40,7 +40,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Serialize(c, user{1, "Jon Snow"}, " ") if assert.NoError(err) { assert.Equal(userJSONPretty+"\n", rec.Body.String()) @@ -53,7 +53,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec).(*DefaultContext) assert := testify.New(t) @@ -81,7 +81,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var userUnmarshalSyntaxError = user{} req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalSyntaxError) assert.IsType(&HTTPError{}, err) assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") @@ -93,7 +93,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec).(*DefaultContext) err = enc.Deserialize(c, &userUnmarshalTypeError) assert.IsType(&HTTPError{}, err) assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") diff --git a/log.go b/log.go index 3f8de5904..ce351881c 100644 --- a/log.go +++ b/log.go @@ -1,41 +1,148 @@ package echo import ( + "bytes" "io" - - "github.com/labstack/gommon/log" + "strconv" + "sync" + "time" ) -type ( - // Logger defines the logging interface. - Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) +//----------------------------------------------------------------------------- +// Example for Zap (https://github.com/uber-go/zap) +//func main() { +// e := echo.New() +// logger, _ := zap.NewProduction() +// e.Logger = &ZapLogger{logger: logger} +//} +//type ZapLogger struct { +// logger *zap.Logger +//} +// +//func (l *ZapLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZapLogger) Error(err error) { +// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo")) +//} + +//----------------------------------------------------------------------------- +// Example for Zerolog (https://github.com/rs/zerolog) +//func main() { +// e := echo.New() +// logger := zerolog.New(os.Stdout) +// e.Logger = &ZeroLogger{logger: &logger} +//} +// +//type ZeroLogger struct { +// logger *zerolog.Logger +//} +// +//func (l *ZeroLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZeroLogger) Error(err error) { +// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error()) +//} + +//----------------------------------------------------------------------------- +// Example for Logrus (https://github.com/sirupsen/logrus) +//func main() { +// e := echo.New() +// e.Logger = &LogrusLogger{logger: logrus.New()} +//} +// +//type LogrusLogger struct { +// logger *logrus.Logger +//} +// +//func (l *LogrusLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *LogrusLogger) Error(err error) { +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err) +//} + +// Logger defines the logging interface that Echo uses internally in few places. +// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice. +type Logger interface { + // Write provides writer interface for http.Server `ErrorLog` and for logging startup messages. + // `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers, + // and underlying FileSystem errors. + // `logger` middleware will use this method to write its JSON payload. + Write(p []byte) (n int, err error) + // Error logs the error + Error(err error) +} + +// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only +// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting. +// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases. +type jsonLogger struct { + writer io.Writer + bufferPool sync.Pool + lock sync.Mutex + + timeNow func() time.Time +} + +func newJSONLogger(writer io.Writer) *jsonLogger { + return &jsonLogger{ + writer: writer, + bufferPool: sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 256)) + }, + }, + timeNow: time.Now, } -) +} + +func (l *jsonLogger) Write(p []byte) (n int, err error) { + pLen := len(p) + if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message + (p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') || + (p[0] == '{' && p[pLen-1] == '}') { + return l.write(p) + } + // we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is + // called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably + // deserves at least WARN level. + return l.printf("INFO", string(p)) +} + +func (l *jsonLogger) Error(err error) { + _, _ = l.printf("ERROR", err.Error()) +} + +func (l *jsonLogger) printf(level string, message string) (n int, err error) { + buf := l.bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer l.bufferPool.Put(buf) + + buf.WriteString(`{"time":"`) + buf.WriteString(l.timeNow().Format(time.RFC3339Nano)) + buf.WriteString(`","level":"`) + buf.WriteString(level) + buf.WriteString(`","prefix":"echo","message":`) + + buf.WriteString(strconv.Quote(message)) + buf.WriteString("}\n") + + return l.write(buf.Bytes()) +} + +func (l *jsonLogger) write(p []byte) (int, error) { + l.lock.Lock() + defer l.lock.Unlock() + return l.writer.Write(p) +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 000000000..c7b4674e9 --- /dev/null +++ b/log_test.go @@ -0,0 +1,87 @@ +package echo + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +type noOpLogger struct { +} + +func (l *noOpLogger) Write(p []byte) (n int, err error) { + return 0, err +} + +func (l *noOpLogger) Error(err error) { +} + +func TestJsonLogger_Write(t *testing.T) { + var testCases = []struct { + name string + when []byte + expect string + }{ + { + name: "ok, write non JSONlike message", + when: []byte("version: %v, build: %v"), + expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: %v, build: %v"}` + "\n", + }, + { + name: "ok, write quoted message", + when: []byte(`version: "%v"`), + expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: \"%v\""}` + "\n", + }, + { + name: "ok, write JSON", + when: []byte(`{"version": 123}` + "\n"), + expect: `{"version": 123}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0).UTC() + } + + _, err := logger.Write(tc.when) + + result := buf.String() + assert.Equal(t, tc.expect, result) + assert.NoError(t, err) + }) + } +} + +func TestJsonLogger_Error(t *testing.T) { + var testCases = []struct { + name string + whenError error + expect string + }{ + { + name: "ok", + whenError: ErrForbidden, + expect: `{"time":"2021-09-07T20:09:37Z","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0).UTC() + } + + logger.Error(tc.whenError) + + result := buf.String() + assert.Equal(t, tc.expect, result) + }) + } +} diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md new file mode 100644 index 000000000..77cb226dd --- /dev/null +++ b/middleware/DEVELOPMENT.md @@ -0,0 +1,11 @@ +# Development Guidelines for middlewares + +## Best practices: + +* Do not use `panic` in middleware creator functions in case of invalid configuration. +* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead + because previous middlewares up in call chain could have logic for dealing with returned errors. +* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they + want to create middleware with panics or with returning errors on configuration errors. +* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. + diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 8cf1ed9fc..82e2fbf7a 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,64 +1,59 @@ package middleware import ( + "bytes" "encoding/base64" + "errors" + "fmt" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BasicAuthConfig defines the config for BasicAuth middleware. - BasicAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BasicAuthConfig defines the config for BasicAuthWithConfig middleware. +type BasicAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Validator is a function to validate BasicAuth credentials. - // Required. - Validator BasicAuthValidator + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers + // this function would be called once for each header until first valid result is returned + // Required. + Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuth. - // Default value "Restricted". - Realm string - } + // Realm is a string to define realm attribute of BasicAuthWithConfig. + // Default value "Restricted". + Realm string +} - // BasicAuthValidator defines a function to validate BasicAuth credentials. - BasicAuthValidator func(string, string, echo.Context) (bool, error) -) +// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. +type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -var ( - // DefaultBasicAuthConfig is the default BasicAuth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, - } -) - // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.Validator = fn - return BasicAuthWithConfig(c) + return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper } if config.Realm == "" { config.Realm = defaultRealm @@ -70,29 +65,33 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return err + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = fmt.Errorf("invalid basic auth value: %w", errDecode) + continue } - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:])) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } } + if lastError != nil { + return lastError + } + realm := defaultRealm if config.Realm != defaultRealm { realm = strconv.Quote(config.Realm) @@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 76039db0a..9580dff0b 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -2,70 +2,157 @@ package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) - f := func(u, p string, c echo.Context) (bool, error) { + validatorFunc := func(c echo.Context, u, p string) (bool, error) { if u == "joe" && p == "secret" { return true, nil } + if u == "error" { + return false, errors.New(p) + } return false, nil } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + defaultConfig := BasicAuthConfig{Validator: validatorFunc} - assert := assert.New(t) + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string + }{ + { + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, multiple", + givenConfig: defaultConfig, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " NOT_BASE64", + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, + }, + { + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "invalid basic auth value: illegal base64 data at input byte 3", + }, + { + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + }, + } - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + config := tc.givenConfig - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + mw, err := config.ToMiddleware() + assert.NoError(t, err) - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + h := mw(func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) + } + } + err = h(c) + + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) + } else { + assert.Equal(t, http.StatusTeapot, res.Code) + assert.NoError(t, err) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + }) + } +} - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) +func TestBasicAuth_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuth(nil) + assert.NotNil(t, mw) + }) + + mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }) + assert.NotNil(t, mw) +} + +func TestBasicAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) + assert.NotNil(t, mw) + }) - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }}) + assert.NotNil(t, mw) } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index ebd0d0ab2..390c37d64 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -3,71 +3,65 @@ package middleware import ( "bufio" "bytes" + "errors" "io" "io/ioutil" "net" "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BodyDumpConfig defines the config for BodyDump middleware. - BodyDumpConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyDumpConfig defines the config for BodyDump middleware. +type BodyDumpConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Handler receives request and response payload. - // Required. - Handler BodyDumpHandler - } - - // BodyDumpHandler receives the request and response payload. - BodyDumpHandler func(echo.Context, []byte, []byte) + // Handler receives request and response payload. + // Required. + Handler BodyDumpHandler +} - bodyDumpResponseWriter struct { - io.Writer - http.ResponseWriter - } -) +// BodyDumpHandler receives the request and response payload. +type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte) -var ( - // DefaultBodyDumpConfig is the default BodyDump middleware config. - DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, - } -) +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { - c := DefaultBodyDumpConfig - c.Handler = handler - return BodyDumpWithConfig(c) + return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration +func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Handler == nil { - panic("echo: body-dump middleware requires a handler function") + return nil, errors.New("echo body-dump middleware requires a handler function") } if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } // Request reqBody := []byte{} - if c.Request().Body != nil { // Read + if c.Request().Body != nil { reqBody, _ = ioutil.ReadAll(c.Request().Body) } c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset @@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} c.Response().Writer = writer - if err = next(c); err != nil { - c.Error(err) - } + err := next(c) // Callback config.Handler(c, reqBody, resBody.Bytes()) - return + return err } - } + }, nil } func (w *bodyDumpResponseWriter) WriteHeader(code int) { diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index e6e00f726..323f46c15 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) { requestBody := "" responseBody := "" - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) { requestBody = string(reqBody) responseBody = string(resBody) - }) - - assert := assert.New(t) + }}.ToMiddleware() + assert.NoError(t, err) - if assert.NoError(mw(h)(c)) { - assert.Equal(requestBody, hw) - assert.Equal(responseBody, hw) - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.String()) + if assert.NoError(t, mw(h)(c)) { + assert.Equal(t, requestBody, hw) + assert.Equal(t, responseBody, hw) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.String()) } - // Must set default skipper - BodyDumpWithConfig(BodyDumpConfig{ - Skipper: nil, +} + +func TestBodyDump_skipper(t *testing.T) { + e := echo.New() + + isCalled := false + mw, err := BodyDumpConfig{ + Skipper: func(c echo.Context) bool { + return true + }, Handler: func(c echo.Context, reqBody, resBody []byte) { - requestBody = string(reqBody) - responseBody = string(resBody) + isCalled = true }, - }) + }.ToMiddleware() + assert.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + return errors.New("some error") + } + + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.False(t, isCalled) } -func TestBodyDumpFails(t *testing.T) { +func TestBodyDump_fails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) @@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) { return errors.New("some error") } - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware() + assert.NoError(t, err) - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ + mw := BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) + assert.NotNil(t, mw) }) assert.NotPanics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ - Skipper: func(c echo.Context) bool { - return true - }, - Handler: func(c echo.Context, reqBody, resBody []byte) { - }, - }) + mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}) + assert.NotNil(t, mw) + }) +} - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } +func TestBodyDump_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BodyDump(nil) + assert.NotNil(t, mw) + }) + + assert.NotPanics(t, func() { + BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) }) } diff --git a/middleware/body_limit.go b/middleware/body_limit.go index b436bd595..6b32f9d43 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -1,98 +1,83 @@ package middleware import ( - "fmt" "io" "sync" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/bytes" + "github.com/labstack/echo/v5" ) -type ( - // BodyLimitConfig defines the config for BodyLimit middleware. - BodyLimitConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyLimitConfig defines the config for BodyLimitWithConfig middleware. +type BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 - } - - limitedReader struct { - BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context - } -) + // LimitBytes is maximum allowed size in bytes for a request body + LimitBytes int64 +} -var ( - // DefaultBodyLimitConfig is the default BodyLimit middleware config. - DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, - } -) +type limitedReader struct { + BodyLimitConfig + reader io.ReadCloser + read int64 + context echo.Context +} // BodyLimit returns a BodyLimit middleware. // -// BodyLimit middleware sets the maximum allowed size for a request body, if the -// size exceeds the configured limit, it sends "413 - Request Entity Too Large" -// response. The BodyLimit is determined based on both `Content-Length` request +// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it +// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. -// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, -// G, T or P. -func BodyLimit(limit string) echo.MiddlewareFunc { - c := DefaultBodyLimitConfig - c.Limit = limit - return BodyLimitWithConfig(c) +func BodyLimit(limitBytes int64) echo.MiddlewareFunc { + return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) } -// BodyLimitWithConfig returns a BodyLimit middleware with config. -// See: `BodyLimit()`. +// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for +// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. +// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which +// makes it super secure. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration +func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyLimitConfig.Skipper + config.Skipper = DefaultSkipper } - - limit, err := bytes.Parse(config.Limit) - if err != nil { - panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) + pool := sync.Pool{ + New: func() interface{} { + return &limitedReader{BodyLimitConfig: config} + }, } - config.limit = limit - pool := limitedReaderPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } - req := c.Request() // Based on content length - if req.ContentLength > config.limit { + if req.ContentLength > config.LimitBytes { return echo.ErrStatusRequestEntityTooLarge } // Based on content read r := pool.Get().(*limitedReader) - r.Reset(req.Body, c) + r.Reset(c, req.Body) defer pool.Put(r) req.Body = r return next(c) } - } + }, nil } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) - if r.read > r.limit { + if r.read > r.LimitBytes { return n, echo.ErrStatusRequestEntityTooLarge } return @@ -102,16 +87,8 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { +func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) { r.reader = reader r.context = context r.read = 0 } - -func limitedReaderPool(c BodyLimitConfig) sync.Pool { - return sync.Pool{ - New: func() interface{} { - return &limitedReader{BodyLimitConfig: c} - }, - } -} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 8981534d4..6e7778eab 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -7,11 +7,11 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestBodyLimit(t *testing.T) { +func TestBodyLimitConfig_ToMiddleware(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) @@ -25,35 +25,44 @@ func TestBodyLimit(t *testing.T) { return c.String(http.StatusOK, string(body)) } - assert := assert.New(t) - // Based on content length (within limit) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.Bytes()) + mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } - // Based on content length (overlimit) - he := BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + // Based on content read (overlimit) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he := mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, World!", rec.Body.String()) - } + + mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - he = BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he = mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) } func TestBodyLimitReader(t *testing.T) { @@ -63,9 +72,8 @@ func TestBodyLimitReader(t *testing.T) { rec := httptest.NewRecorder() config := BodyLimitConfig{ - Skipper: DefaultSkipper, - Limit: "2B", - limit: 2, + Skipper: DefaultSkipper, + LimitBytes: 2, } reader := &limitedReader{ BodyLimitConfig: config, @@ -80,8 +88,80 @@ func TestBodyLimitReader(t *testing.T) { // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) + reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) } + +func TestBodyLimit_skipper(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw, err := BodyLimitConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + LimitBytes: 2, + }.ToMiddleware() + assert.NoError(t, err) + + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimitWithConfig(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimit(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimit(2 * MB) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} diff --git a/middleware/compress.go b/middleware/compress.go index ac6672e9d..d383cac63 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -3,6 +3,7 @@ package middleware import ( "bufio" "compress/gzip" + "errors" "io" "io/ioutil" "net" @@ -10,54 +11,49 @@ import ( "strings" "sync" - "github.com/labstack/echo/v4" -) - -type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Gzip compression level. - // Optional. Default value -1. - Level int `yaml:"level"` - } - - gzipResponseWriter struct { - io.Writer - http.ResponseWriter - wroteBody bool - } + "github.com/labstack/echo/v5" ) const ( gzipScheme = "gzip" ) -var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - } -) +// GzipConfig defines the config for Gzip middleware. +type GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -// Gzip returns a middleware which compresses HTTP response using gzip compression -// scheme. + // Gzip compression level. + // Optional. Default value -1. + Level int +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter + wroteBody bool +} + +// Gzip returns a middleware which compresses HTTP response using gzip compression scheme. func Gzip() echo.MiddlewareFunc { - return GzipWithConfig(DefaultGzipConfig) + return GzipWithConfig(GzipConfig{}) } -// GzipWithConfig return Gzip middleware with config. -// See: `Gzip()`. +// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration +func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression + return nil, errors.New("invalid gzip level") } if config.Level == 0 { - config.Level = DefaultGzipConfig.Level + config.Level = -1 } pool := gzipCompressPool(config) @@ -98,7 +94,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } func (w *gzipResponseWriter) WriteHeader(code int) { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index b62bffef5..d6b4f60ed 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -3,94 +3,128 @@ package middleware import ( "bytes" "compress/gzip" - "io" "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" + "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestGzip(t *testing.T) { +func TestGzip_NoAcceptEncodingHeader(t *testing.T) { + // Skip if no Accept-Encoding header + h := Gzip()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - // Skip if no Accept-Encoding header + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) +} + +func TestMustGzipWithConfig_panics(t *testing.T) { + assert.Panics(t, func() { + GzipWithConfig(GzipConfig{Level: 999}) + }) +} + +func TestGzip_AcceptEncodingHeader(t *testing.T) { h := Gzip()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - assert.Equal("test", rec.Body.String()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - // Gzip - req = httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - r, err := gzip.NewReader(rec.Body) - if assert.NoError(err) { - buf := new(bytes.Buffer) - defer r.Close() - buf.ReadFrom(r) - assert.Equal("test", buf.String()) - } + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - chunkBuf := make([]byte, 5) + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) +} - // Gzip chunked - req = httptest.NewRequest(http.MethodGet, "/", nil) +func TestGzip_chunked(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - c = e.NewContext(req, rec) - Gzip()(func(c echo.Context) error { + chunkChan := make(chan struct{}) + waitChan := make(chan struct{}) + h := Gzip()(func(c echo.Context) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("first\n")) c.Response().Flush() - // Read the first part of the data - assert.True(rec.Flushed) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - r.Reset(rec.Body) - - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write and flush the second part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("second\n")) c.Response().Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write the final part of the data and return - c.Response().Write([]byte("test")) + c.Response().Write([]byte("third")) + + chunkChan <- struct{}{} return nil - })(c) + }) + + go func() { + err := h(c) + chunkChan <- struct{}{} + assert.NoError(t, err) + }() + + <-chunkChan // wait for first write + waitChan <- struct{}{} + + <-chunkChan // wait for second write + waitChan <- struct{}{} + + <-chunkChan // wait for final write in handler + <-chunkChan // wait for return from handler + time.Sleep(5 * time.Millisecond) // to have time for flushing + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) buf := new(bytes.Buffer) - defer r.Close() buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "first\nsecond\nthird", buf.String()) } -func TestGzipNoContent(t *testing.T) { +func TestGzip_NoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) @@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) { } } -func TestGzipEmpty(t *testing.T) { +func TestGzip_Empty(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) @@ -127,7 +161,7 @@ func TestGzipEmpty(t *testing.T) { } } -func TestGzipErrorReturned(t *testing.T) { +func TestGzip_ErrorReturned(t *testing.T) { e := echo.New() e.Use(Gzip()) e.GET("/", func(c echo.Context) error { @@ -141,31 +175,25 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } -func TestGzipErrorReturnedInvalidConfig(t *testing.T) { - e := echo.New() - // Invalid level - e.Use(GzipWithConfig(GzipConfig{Level: 12})) - e.GET("/", func(c echo.Context) error { - c.Response().Write([]byte("test")) - return nil - }) - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, rec.Body.String(), "gzip") +func TestGzipWithConfig_invalidLevel(t *testing.T) { + mw, err := GzipConfig{Level: 12}.ToMiddleware() + assert.EqualError(t, err, "invalid gzip level") + assert.Nil(t, mw) } // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() + e.Filesystem = os.DirFS("../") + e.Use(Gzip()) - e.Static("/test", "../_fixture/images") + e.Static("/test", "_fixture/images") req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) // Data is written out in chunks when Content-Length == "", so only // validate the content length if it's not set. diff --git a/middleware/cors.go b/middleware/cors.go index 16259512a..78b44975d 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -6,67 +6,63 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // CORSConfig defines the config for CORS middleware. - CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // AllowOrigin defines a list of origins that may access the resource. - // Optional. Default value []string{"*"}. - AllowOrigins []string `yaml:"allow_origins"` - - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. - // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` - - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. - // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value - // from `Allow` header that echo.Router set into context. - AllowMethods []string `yaml:"allow_methods"` - - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This is in response to a preflight request. - // Optional. Default value []string{}. - AllowHeaders []string `yaml:"allow_headers"` - - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. - // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - AllowCredentials bool `yaml:"allow_credentials"` - - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. - ExposeHeaders []string `yaml:"expose_headers"` - - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. - MaxAge int `yaml:"max_age"` - } -) +// CORSConfig defines the config for CORS middleware. +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // AllowOrigin defines a list of origins that may access the resource. + // Optional. Default value []string{"*"}. + AllowOrigins []string + + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) + + // AllowMethods defines a list methods allowed when accessing the resource. + // This is used in response to a preflight request. + // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. + AllowMethods []string + + // AllowHeaders defines a list of request headers that can be used when + // making the actual request. This is in response to a preflight request. + // Optional. Default value []string{}. + AllowHeaders []string + + // AllowCredentials indicates whether or not the response to the request + // can be exposed when the credentials flag is true. When used as part of + // a response to a preflight request, this indicates whether or not the + // actual request can be made using credentials. + // Optional. Default value false. + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + AllowCredentials bool + + // ExposeHeaders defines a whitelist headers that clients are allowed to + // access. + // Optional. Default value []string{}. + ExposeHeaders []string + + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached. + // Optional. Default value 0. + MaxAge int +} -var ( - // DefaultCORSConfig is the default CORS middleware config. - DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - } -) +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, +} // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS @@ -74,9 +70,14 @@ func CORS() echo.MiddlewareFunc { return CORSWithConfig(DefaultCORSConfig) } -// CORSWithConfig returns a CORS middleware with config. +// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. // See: `CORS()`. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration +func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCORSConfig.Skipper @@ -172,7 +173,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { checkPatterns := false if allowOrigin == "" { // to avoid regex cost by invalid (long) domains (253 is domain name max limit) - if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { + if len(origin) <= (5+3+253) && strings.Contains(origin, "://") { checkPatterns = true } } @@ -230,5 +231,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } return c.NoContent(http.StatusNoContent) } - } + }, nil } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index daadbab6e..2299a885d 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -17,7 +17,7 @@ func TestCORS(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := CORS()(echo.NotFoundHandler) + h := CORS()(func(c echo.Context) error { return echo.ErrNotFound }) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -26,7 +26,7 @@ func TestCORS(t *testing.T) { req = httptest.NewRequest(http.MethodGet, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - h = CORS()(echo.NotFoundHandler) + h = CORS()(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) @@ -38,7 +38,7 @@ func TestCORS(t *testing.T) { AllowOrigins: []string{"localhost"}, AllowCredentials: true, MaxAge: 3600, - })(echo.NotFoundHandler) + })(func(c echo.Context) error { return echo.ErrNotFound }) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -55,7 +55,7 @@ func TestCORS(t *testing.T) { AllowCredentials: true, MaxAge: 3600, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) @@ -73,7 +73,7 @@ func TestCORS(t *testing.T) { AllowCredentials: true, MaxAge: 3600, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) @@ -90,7 +90,7 @@ func TestCORS(t *testing.T) { cors = CORSWithConfig(CORSConfig{ AllowOrigins: []string{"*"}, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) @@ -104,7 +104,7 @@ func TestCORS(t *testing.T) { cors = CORSWithConfig(CORSConfig{ AllowOrigins: []string{"http://*.example.com"}, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -324,7 +324,9 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) { c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey) } - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) h(c) assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) @@ -511,11 +513,11 @@ func Test_allowOriginFunc(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, origin) - cors := CORSWithConfig(CORSConfig{ - AllowOriginFunc: allowOriginFunc, - }) - h := cors(echo.NotFoundHandler) - err := h(c) + cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware() + assert.NoError(t, err) + + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) + err = h(c) expected, expectedErr := allowOriginFunc(origin) if expectedErr != nil { diff --git a/middleware/csrf.go b/middleware/csrf.go index 61299f5ca..acab8790b 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -5,91 +5,93 @@ import ( "net/http" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // CSRFConfig defines the config for CSRF middleware. - CSRFConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` - // Optional. Default value 32. - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" or "header::" - // - "query:" - // - "form:" - // Multiple sources example: - // - "header:X-CSRF-Token,query:csrf" - TokenLookup string `yaml:"token_lookup"` - - // Context key to store generated CSRF token into context. - // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` - - // Name of the CSRF cookie. This cookie will store CSRF token. - // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` - - // Domain of the CSRF cookie. - // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` - - // Path of the CSRF cookie. - // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` - - // Max age (in seconds) of the CSRF cookie. - // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` - - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` - - // Indicates if CSRF cookie is HTTP only. - // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` - - // Indicates SameSite mode of the CSRF cookie. - // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite `yaml:"cookie_same_site"` - } -) +// CSRFConfig defines the config for CSRF middleware. +type CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // TokenLength is the length of the generated token. + TokenLength uint8 + // Optional. Default value 32. + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" or "header::" + // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" + TokenLookup string `yaml:"token_lookup"` + + // Generator defines a function to generate token. + // Optional. Defaults tp randomString(TokenLength). + Generator func() string + + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string + + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite +} // ErrCSRFInvalid is returned when CSRF check fails var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") -var ( - // DefaultCSRFConfig is the default CSRF middleware config. - DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, - CookieSameSite: http.SameSiteDefaultMode, - } -) +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { - c := DefaultCSRFConfig - return CSRFWithConfig(c) + return CSRFWithConfig(DefaultCSRFConfig) } -// CSRFWithConfig returns a CSRF middleware with config. -// See `CSRF()`. +// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration +func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper @@ -97,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.Generator == nil { + config.Generator = createRandomStringGenerator(config.TokenLength) + } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -113,9 +118,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - extractors, err := createExtractors(config.TokenLookup, "") + extractors, err := createExtractors(config.TokenLookup) if err != nil { - panic(err) + return nil, err } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -126,7 +131,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = random.String(config.TokenLength) // Generate token + token = config.Generator() // Generate token } else { token = k.Value // Reuse token } @@ -157,17 +162,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if lastTokenErr != nil { return lastTokenErr } else if lastExtractorErr != nil { - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") - } else if lastExtractorErr == errFormExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") - } else { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) - } - return lastExtractorErr + return echo.ErrBadRequest.WithInternal(lastExtractorErr) } } @@ -197,7 +192,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } func validateCSRFToken(token, clientToken string) bool { diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 9aff82a98..f8af5e9cc 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -7,22 +7,22 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCSRF_tokenExtractors(t *testing.T) { var testCases = []struct { - name string - whenTokenLookup string - whenCookieName string - givenCSRFCookie string - givenMethod string - givenQueryTokens map[string][]string - givenFormTokens map[string][]string - givenHeaderTokens map[string][]string - expectError string + name string + whenTokenLookup string + whenCookieName string + givenCSRFCookie string + givenMethod string + givenQueryTokens map[string][]string + givenFormTokens map[string][]string + givenHeaderTokens map[string][]string + expectError string + expectToMiddlewareError string }{ { name: "ok, multiple token lookups sources, succeeds on last one", @@ -70,7 +70,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the form parameter", + expectError: "code=400, message=Bad Request, internal=missing value in the form", }, { name: "ok, token from POST header", @@ -106,7 +106,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in request header", + expectError: "code=400, message=Bad Request, internal=missing value in request header", }, { name: "ok, token from PUT query param", @@ -142,7 +142,15 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the query string", + expectError: "code=400, message=Bad Request, internal=missing value in the query string", + }, + { + name: "nok, invalid TokenLookup", + whenTokenLookup: "q", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", }, } @@ -186,16 +194,23 @@ func TestCSRF_tokenExtractors(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + config := CSRFConfig{ TokenLookup: tc.whenTokenLookup, CookieName: tc.whenCookieName, - }) + } + csrf, err := config.ToMiddleware() + if tc.expectToMiddlewareError != "" { + assert.EqualError(t, err, tc.expectToMiddlewareError) + return + } else if err != nil { + assert.NoError(t, err) + } h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) - err := h(c) + err = h(c) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -219,6 +234,24 @@ func TestCSRF(t *testing.T) { h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") +} + +func TestMustCSRFWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + csrf := CSRFWithConfig(CSRFConfig{ + TokenLength: 16, + }) + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Generate CSRF token + h(c) + assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") + // Without CSRF cookie req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() @@ -233,7 +266,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(32) + token := randomString(16) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { @@ -302,9 +335,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + csrf, err := CSRFConfig{ CookieSameSite: http.SameSiteNoneMode, - }) + }.ToMiddleware() + assert.NoError(t, err) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") diff --git a/middleware/decompress.go b/middleware/decompress.go index 88ec70982..dcf7172fa 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -6,21 +6,19 @@ import ( "net/http" "sync" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // DecompressConfig defines the config for Decompress middleware. - DecompressConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// DecompressConfig defines the config for Decompress middleware. +type DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers - GzipDecompressPool Decompressor - } -) + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor +} -//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers @@ -28,14 +26,6 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } -var ( - //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{ - Skipper: DefaultSkipper, - GzipDecompressPool: &DefaultGzipDecompressPool{}, - } -) - // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { } @@ -44,19 +34,23 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { return sync.Pool{New: func() interface{} { return new(gzip.Reader) }} } -//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +// Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { - return DecompressWithConfig(DefaultDecompressConfig) + return DecompressWithConfig(DecompressConfig{}) } -//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration +func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper } if config.GzipDecompressPool == nil { - config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + config.GzipDecompressPool = &DefaultGzipDecompressPool{} } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -95,5 +89,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 51fa6b0f1..c35ed6fa3 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -11,54 +11,82 @@ import ( "sync" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestDecompress(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - // Skip if no Content-Encoding header h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) - - // Decompress + // Decompress request body body := `{"name": "echo"}` gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } -func TestDecompressDefaultConfig(t *testing.T) { +func TestDecompress_skippedIfNoHeader(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + // Skip if no Content-Encoding header + h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + })(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) // Decompress body := `{"name": "echo"}` @@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { @@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) assert.NoError(t, err) @@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) { h := Decompress()(func(c echo.Context) error { return c.NoContent(http.StatusNoContent) }) - if assert.NoError(t, h(c)) { + + err := h(c) + + if assert.NoError(t, err) { assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) @@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } @@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) reqBody, err := ioutil.ReadAll(c.Request().Body) assert.NoError(t, err) @@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) reqBody, err := ioutil.ReadAll(c.Request().Body) assert.NoError(t, err) diff --git a/middleware/extractor.go b/middleware/extractor.go index afdfd8195..ce5e9f7c6 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -1,9 +1,8 @@ package middleware import ( - "errors" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/textproto" "strings" ) @@ -14,17 +13,27 @@ const ( extractorLimit = 20 ) -var errHeaderExtractorValueMissing = errors.New("missing value in request header") -var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") -var errQueryExtractorValueMissing = errors.New("missing value in the query string") -var errParamExtractorValueMissing = errors.New("missing value in path params") -var errCookieExtractorValueMissing = errors.New("missing value in cookies") -var errFormExtractorValueMissing = errors.New("missing value in the form") +// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups +type ValueExtractorError struct { + message string +} + +// Error returns errors text +func (e *ValueExtractorError) Error() string { + return e.message +} + +var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"} +var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"} +var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"} +var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"} +var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"} +var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. type ValuesExtractor func(c echo.Context) ([]string, error) -func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { +func createExtractors(lookups string) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil } @@ -49,15 +58,6 @@ func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, err prefix := "" if len(parts) > 2 { prefix = parts[2] - } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { - // backwards compatibility for JWT and KeyAuth: - // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc - // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that - // behaviour for default values and Authorization header. - prefix = authScheme - if !strings.HasSuffix(prefix, " ") { - prefix += " " - } } extractors = append(extractors, valuesFromHeader(parts[1], prefix)) } @@ -125,10 +125,9 @@ func valuesFromQuery(param string) ValuesExtractor { func valuesFromParam(param string) ValuesExtractor { return func(c echo.Context) ([]string, error) { result := make([]string, 0) - paramVales := c.ParamValues() - for i, p := range c.ParamNames() { - if param == p { - result = append(result, paramVales[i]) + for i, p := range c.PathParams() { + if param == p.Name { + result = append(result, p.Value) if i >= extractorLimit-1 { break } diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 2e898f541..59157d498 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -3,7 +3,7 @@ package middleware import ( "bytes" "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "mime/multipart" "net/http" @@ -13,27 +13,11 @@ import ( "testing" ) -type pathParam struct { - name string - value string -} - -func setPathParams(c echo.Context, params []pathParam) { - names := make([]string, 0, len(params)) - values := make([]string, 0, len(params)) - for _, pp := range params { - names = append(names, pp.name) - values = append(values, pp.value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) -} - func TestCreateExtractors(t *testing.T) { var testCases = []struct { name string givenRequest func() *http.Request - givenPathParams []pathParam + givenPathParams echo.PathParams whenLoopups string expectValues []string expectCreateError string @@ -74,8 +58,8 @@ func TestCreateExtractors(t *testing.T) { }, { name: "ok, param", - givenPathParams: []pathParam{ - {name: "id", value: "123"}, + givenPathParams: echo.PathParams{ + {Name: "id", Value: "123"}, }, whenLoopups: "param:id", expectValues: []string{"123"}, @@ -105,12 +89,12 @@ func TestCreateExtractors(t *testing.T) { req = tc.givenRequest() } rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + c.SetRawPathParams(&tc.givenPathParams) } - extractors, err := createExtractors(tc.whenLoopups, "") + extractors, err := createExtractors(tc.whenLoopups) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return @@ -317,19 +301,19 @@ func TestValuesFromQuery(t *testing.T) { } func TestValuesFromParam(t *testing.T) { - examplePathParams := []pathParam{ - {name: "id", value: "123"}, - {name: "gid", value: "456"}, - {name: "gid", value: "789"}, + examplePathParams := echo.PathParams{ + {Name: "id", Value: "123"}, + {Name: "gid", Value: "456"}, + {Name: "gid", Value: "789"}, } - examplePathParams20 := make([]pathParam, 0) + examplePathParams20 := make(echo.PathParams, 0) for i := 1; i < 25; i++ { - examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) + examplePathParams20 = append(examplePathParams20, echo.PathParam{Name: "id", Value: fmt.Sprintf("%v", i)}) } var testCases = []struct { name string - givenPathParams []pathParam + givenPathParams echo.PathParams whenName string expectValues []string expectError string @@ -377,9 +361,9 @@ func TestValuesFromParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + c.SetRawPathParams(&tc.givenPathParams) } extractor := valuesFromParam(tc.whenName) diff --git a/middleware/jwt.go b/middleware/jwt.go index bec5167e2..40b45e77e 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,150 +1,98 @@ -//go:build go1.15 -// +build go1.15 - package middleware import ( "errors" - "fmt" - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/http" - "reflect" ) -type ( - // JWTConfig defines the config for JWT middleware. - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc BeforeFunc - - // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next - // middleware or handler. - SuccessHandler JWTSuccessHandler - - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. - ContinueOnIgnoredError bool - - // Signing key to validate token. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKeys is provided. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKey is provided. - SigningKeys map[string]interface{} - - // Signing method used to check the token's signing algorithm. - // Optional. Default value HS256. - SigningMethod string - - // Context key to store user information from the token into context. - // Optional. Default value "user". - ContextKey string - - // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. - // Not used if custom ParseTokenFunc is set. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. - // If prefix is left empty the whole value is returned. - // - "query:" - // - "param:" - // - "cookie:" - // - "form:" - // Multiple sources example: - // - "header:Authorization,cookie:myowncookie" - TokenLookup string - - // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. - // This is one of the two options to provide a token extractor. - // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. - // You can also provide both if you want. - TokenLookupFuncs []ValuesExtractor - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // KeyFunc defines a user-defined function that supplies the public key for a token validation. - // The function shall take care of verifying the signing algorithm and selecting the proper key. - // A user-defined KeyFunc can be useful if tokens are issued by an external party. - // Used by default ParseTokenFunc implementation. - // - // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither SigningKeys nor SigningKey is provided. - // Not used if custom ParseTokenFunc is set. - // Default to an internal implementation verifying the signing algorithm and selecting the proper key. - KeyFunc jwt.Keyfunc - - // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token - // parsing fails or parsed token is invalid. - // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) - } +// JWTConfig defines the config for JWT middleware. +type JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeFunc defines a function which is executed just before the middleware. + BeforeFunc BeforeFunc + + // SuccessHandler defines a function which is executed for a valid token. + SuccessHandler JWTSuccessHandler + + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom JWT error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain. + ErrorHandler JWTErrorHandlerWithContext + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. + ContinueOnIgnoredError bool + + // Context key to store user information from the token into context. + // Optional. Default value "user". + ContextKey string + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. + // If prefix is left empty the whole value is returned. + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiple sources example: + // - "header:Authorization:Bearer ,cookie:myowncookie" + TokenLookup string + + // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. + // This is one of the two options to provide a token extractor. + // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. + // You can also provide both if you want. + TokenLookupFuncs []ValuesExtractor + + // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token + // parsing fails or parsed token is invalid. + // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library + ParseTokenFunc func(c echo.Context, auth string) (interface{}, error) +} - // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(c echo.Context) +// JWTSuccessHandler defines a function which is executed for a valid token. +type JWTSuccessHandler func(c echo.Context) - // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(err error) error +// JWTErrorHandler defines a function which is executed for an invalid token. +type JWTErrorHandler func(err error) error - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(err error, c echo.Context) error -) +// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. +type JWTErrorHandlerWithContext func(c echo.Context, err error) error -// Algorithms const ( + // AlgorithmHS256 is token signing algorithm AlgorithmHS256 = "HS256" ) -// Errors -var ( - ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") - ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") -) +// ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request +var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt") -var ( - // DefaultJWTConfig is the default JWT auth middleware config. - DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - TokenLookupFuncs: nil, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - KeyFunc: nil, - } -) +// ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired +var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") + +// DefaultJWTConfig is the default JWT auth middleware config. +var DefaultJWTConfig = JWTConfig{ + Skipper: DefaultSkipper, + ContextKey: "user", + TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", +} // JWT returns a JSON Web Token (JWT) auth middleware. // @@ -153,48 +101,40 @@ var ( // For missing token, it returns "400 - Bad Request" error. // // See: https://jwt.io/introduction -// See `JWTConfig.TokenLookup` -func JWT(key interface{}) echo.MiddlewareFunc { +func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc { c := DefaultJWTConfig - c.SigningKey = key + c.ParseTokenFunc = parseTokenFunc return JWTWithConfig(c) } -// JWTWithConfig returns a JWT auth middleware with config. -// See: `JWT()`. +// JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid. +// +// For valid token, it sets the user in context and calls next handler. +// For invalid token, it returns "401 - Unauthorized" error. +// For missing token, it returns "400 - Bad Request" error. +// +// See: https://jwt.io/introduction func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration +func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { - panic("echo: jwt middleware requires signing key") - } - if config.SigningMethod == "" { - config.SigningMethod = DefaultJWTConfig.SigningMethod + if config.ParseTokenFunc == nil { + return nil, errors.New("echo jwt middleware requires parse token function") } if config.ContextKey == "" { config.ContextKey = DefaultJWTConfig.ContextKey } - if config.Claims == nil { - config.Claims = DefaultJWTConfig.Claims - } if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 { config.TokenLookup = DefaultJWTConfig.TokenLookup } - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme - } - if config.KeyFunc == nil { - config.KeyFunc = config.defaultKeyFunc - } - if config.ParseTokenFunc == nil { - config.ParseTokenFunc = config.defaultParseToken - } - - extractors, err := createExtractors(config.TokenLookup, config.AuthScheme) + extractors, err := createExtractors(config.TokenLookup) if err != nil { - panic(err) + return nil, err } if len(config.TokenLookupFuncs) > 0 { extractors = append(config.TokenLookupFuncs, extractors...) @@ -209,17 +149,16 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.BeforeFunc != nil { config.BeforeFunc(c) } - var lastExtractorErr error var lastTokenErr error for _, extractor := range extractors { - auths, err := extractor(c) - if err != nil { - lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth) + auths, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr continue } for _, auth := range auths { - token, err := config.ParseTokenFunc(auth, c) + token, err := config.ParseTokenFunc(c, auth) if err != nil { lastTokenErr = err continue @@ -232,69 +171,23 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return next(c) } } - // we are here only when we did not successfully extract or parse any of the tokens + + // prioritize token errors over extracting errors err := lastTokenErr - if err == nil { // prioritize token errors over extracting errors + if err == nil { err = lastExtractorErr } if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - if config.ErrorHandlerWithContext != nil { - tmpErr := config.ErrorHandlerWithContext(err, c) + tmpErr := config.ErrorHandler(c, err) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } - - // backwards compatible errors codes - if lastTokenErr != nil { - return &echo.HTTPError{ - Code: ErrJWTInvalid.Code, - Message: ErrJWTInvalid.Message, - Internal: err, - } - } - return err // this is lastExtractorErr value - } - } -} - -func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { - token := new(jwt.Token) - var err error - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.KeyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) - } - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid token") - } - return token, nil -} - -// defaultKeyFunc returns a signing key of the given token. -func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil + if lastTokenErr == nil { + return ErrJWTMissing.WithInternal(err) } + return ErrJWTInvalid.WithInternal(err) } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil + }, nil } diff --git a/middleware/jwt_external_test.go b/middleware/jwt_external_test.go new file mode 100644 index 000000000..1b92f188f --- /dev/null +++ b/middleware/jwt_external_test.go @@ -0,0 +1,76 @@ +package middleware_test + +import ( + "errors" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + "net/http" + "net/http/httptest" +) + +// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc +// +// signingKey is signing key to validate token. +// This is one of the options to provide a token validation key. +// The order of precedence is a user-defined SigningKeys and SigningKey. +// Required if signingKeys is not provided. +// +// signingKeys is Map of signing keys to validate token with kid field usage. +// This is one of the options to provide a token validation key. +// The order of precedence is a user-defined SigningKeys and SigningKey. +// Required if signingKey is not provided +func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) { + // keyFunc defines a user-defined function that supplies the public key for a token validation. + // The function shall take care of verifying the signing algorithm and selecting the proper key. + // A user-defined KeyFunc can be useful if tokens are issued by an external party. + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != middleware.AlgorithmHS256 { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + if len(signingKeys) == 0 { + return signingKey, nil + } + + if kid, ok := t.Header["kid"].(string); ok { + if key, ok := signingKeys[kid]; ok { + return key, nil + } + } + return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) + } + + return func(c echo.Context, auth string) (interface{}, error) { + token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + } +} + +func ExampleJWTConfig_withJWTGoAsTokenParser() { + mw := middleware.JWTWithConfig(middleware.JWTConfig{ + ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil), + }) + + e := echo.New() + e.Use(mw) + + e.GET("/", func(c echo.Context) error { + user := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusTeapot, user.Claims) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + fmt.Printf("status: %v, body: %v", res.Code, res.Body.String()) + // Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"} +} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index eee9df966..5e5b99121 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,6 +1,3 @@ -//go:build go1.15 -// +build go1.15 - package middleware import ( @@ -12,11 +9,32 @@ import ( "strings" "testing" - "github.com/golang-jwt/jwt" - "github.com/labstack/echo/v4" + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) +func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) { + // This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != signingMethod { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + return signingKey, nil + } + + return func(c echo.Context, auth string) (interface{}, error) { + token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + } +} + // jwtCustomInfo defines some custom types we're going to use within our tokens. type jwtCustomInfo struct { Name string `json:"name"` @@ -25,7 +43,7 @@ type jwtCustomInfo struct { // jwtCustomClaims are custom claims expanding default ones. type jwtCustomClaims struct { - *jwt.StandardClaims + *jwt.RegisteredClaims jwtCustomInfo } @@ -37,7 +55,7 @@ func TestJWT(t *testing.T) { return c.JSON(http.StatusOK, token.Claims) }) - e.Use(JWT([]byte("secret"))) + e.Use(JWT(createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")))) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") @@ -49,247 +67,197 @@ func TestJWT(t *testing.T) { assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) } -func TestJWTRace(t *testing.T) { +func TestJWT_combinations(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss" - validKey := []byte("secret") - - h := JWTWithConfig(JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: validKey, - })(handler) - - makeReq := func(token string) echo.Context { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token) - c := e.NewContext(req, res) - assert.NoError(t, h(c)) - return c - } - - c := makeReq(initialToken) - user := c.Get("user").(*jwt.Token) - claims := user.Claims.(*jwtCustomClaims) - assert.Equal(t, claims.Name, "John Doe") - - makeReq(raceToken) - user = c.Get("user").(*jwt.Token) - claims = user.Claims.(*jwtCustomClaims) - // Initial context should still be "John Doe", not "Race Condition" - assert.Equal(t, claims.Name, "John Doe") - assert.Equal(t, claims.Admin, true) -} - -func TestJWTConfig(t *testing.T) { handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" validKey := []byte("secret") invalidKey := []byte("invalid-key") - validAuth := DefaultJWTConfig.AuthScheme + " " + token - - testCases := []struct { - name string - expPanic bool - expErrCode int // 0 for Success - config JWTConfig - reqURL string // "/" if empty - hdrAuth string - hdrCookie string // test.Request doesn't provide SetCookie(); use name=val - formValues map[string]string + validAuth := "Bearer " + token + + var testCases = []struct { + name string + config JWTConfig + reqURL string // "/" if empty + hdrAuth string + hdrCookie string // test.Request doesn't provide SetCookie(); use name=val + formValues map[string]string + expectPanic bool + expectToMiddlewareError string + expectError string }{ { - name: "No signing key provided", - expPanic: true, + name: "No signing key provided", + expectToMiddlewareError: "echo jwt middleware requires parse token function", }, { - name: "Unexpected signing method", - expErrCode: http.StatusBadRequest, + name: "invalid TokenLookup", config: JWTConfig{ - SigningKey: validKey, - SigningMethod: "RS256", + ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey), + TokenLookup: "q", }, + expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", }, { - name: "Invalid key", - expErrCode: http.StatusUnauthorized, - hdrAuth: validAuth, - config: JWTConfig{SigningKey: invalidKey}, + name: "Unexpected signing method", + hdrAuth: validAuth, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey), + }, + expectError: "code=401, message=invalid or expired jwt, internal=unexpected jwt signing method=HS256", + }, + { + name: "Invalid key", + hdrAuth: validAuth, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey), + }, + expectError: "code=401, message=invalid or expired jwt, internal=signature is invalid", }, { name: "Valid JWT", hdrAuth: validAuth, - config: JWTConfig{SigningKey: validKey}, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, }, { name: "Valid JWT with custom AuthScheme", hdrAuth: "Token" + " " + token, - config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, + config: JWTConfig{ + TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, }, { name: "Valid JWT with custom claims", hdrAuth: validAuth, config: JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), }, }, { - name: "Invalid Authorization header", - hdrAuth: "invalid-auth", - expErrCode: http.StatusBadRequest, - config: JWTConfig{SigningKey: validKey}, + name: "Invalid Authorization header", + hdrAuth: "invalid-auth", + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header", }, { - name: "Empty header auth field", - config: JWTConfig{SigningKey: validKey}, - expErrCode: http.StatusBadRequest, + name: "Empty header auth field", + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header", }, { name: "Valid query method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=" + token, }, { name: "Invalid query param name", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, - reqURL: "/?a=b&jwtxyz=" + token, - expErrCode: http.StatusBadRequest, + reqURL: "/?a=b&jwtxyz=" + token, + expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string", }, { name: "Invalid query param value", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, - reqURL: "/?a=b&jwt=invalid-token", - expErrCode: http.StatusUnauthorized, + reqURL: "/?a=b&jwt=invalid-token", + expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments", }, { name: "Empty query", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, - reqURL: "/?a=b", - expErrCode: http.StatusBadRequest, + reqURL: "/?a=b", + expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string", }, { - name: "Valid param method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "param:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "param:jwt", }, reqURL: "/" + token, + name: "Valid param method", }, { - name: "Valid cookie method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, hdrCookie: "jwt=" + token, + name: "Valid cookie method", }, { - name: "Multiple jwt lookuop", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt,cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt,cookie:jwt", }, hdrCookie: "jwt=" + token, + name: "Multiple jwt lookuop", }, { name: "Invalid token with cookie method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, - expErrCode: http.StatusUnauthorized, - hdrCookie: "jwt=invalid", + hdrCookie: "jwt=invalid", + expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments", }, { name: "Empty cookie", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, - expErrCode: http.StatusBadRequest, + expectError: "code=401, message=missing or malformed jwt, internal=missing value in cookies", }, { name: "Valid form method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, formValues: map[string]string{"jwt": token}, }, { name: "Invalid token with form method", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, - expErrCode: http.StatusUnauthorized, - formValues: map[string]string{"jwt": "invalid"}, + formValues: map[string]string{"jwt": "invalid"}, + expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments", }, { name: "Empty form field", config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - expErrCode: http.StatusBadRequest, - }, - { - name: "Valid JWT with a valid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return validKey, nil - }, + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, - }, - { - name: "Valid JWT with an invalid key using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return invalidKey, nil - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Token verification does not pass using a user-defined KeyFunc", - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return nil, errors.New("faulty KeyFunc") - }, - }, - expErrCode: http.StatusUnauthorized, - }, - { - name: "Valid JWT with lower case AuthScheme", - hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, - config: JWTConfig{SigningKey: validKey}, + expectError: "code=401, message=missing or malformed jwt, internal=missing value in the form", }, } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := echo.New() if tc.reqURL == "" { tc.reqURL = "/" } @@ -312,128 +280,36 @@ func TestJWTConfig(t *testing.T) { c := e.NewContext(req, res) if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) + cc := c.(echo.ServableContext) + cc.SetPathParams(echo.PathParams{ + {Name: "jwt", Value: token}, + }) } - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.name) + mw, err := tc.config.ToMiddleware() + if tc.expectToMiddlewareError != "" { + assert.EqualError(t, err, tc.expectToMiddlewareError) return } - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.name) + hErr := mw(handler)(c) + if tc.expectError != "" { + assert.EqualError(t, hErr, tc.expectError) return } + assert.NoError(t, hErr) - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.name) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.name) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.name) - assert.Equal(t, claims.Admin, true, tc.name) - default: - panic("unexpected type of claims") - } - } - }) - } -} - -func TestJWTwithKID(t *testing.T) { - test := assert.New(t) - - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk" - secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM" - wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90" - staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o" - validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")} - invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")} - staticSecret := []byte("static_secret") - invalidStaticSecret := []byte("invalid_secret") - - for _, tc := range []struct { - expErrCode int // 0 for Success - config JWTConfig - hdrAuth string - info string - }{ - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "First token valid", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Second token valid", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Wrong key id token", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: staticSecret}, - info: "Valid static secret token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: invalidStaticSecret}, - info: "Invalid static secret", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys first token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys second token", - }, - } { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - c := e.NewContext(req, res) - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - test.Equal(tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if test.NoError(h(c), tc.info) { user := c.Get("user").(*jwt.Token) switch claims := user.Claims.(type) { case jwt.MapClaims: - test.Equal(claims["name"], "John Doe", tc.info) + assert.Equal(t, claims["name"], "John Doe") case *jwtCustomClaims: - test.Equal(claims.Name, "John Doe", tc.info) - test.Equal(claims.Admin, true, tc.info) + assert.Equal(t, claims.Name, "John Doe") + assert.Equal(t, claims.Admin, true) default: panic("unexpected type of claims") } - } + }) } } @@ -444,7 +320,7 @@ func TestJWTConfig_skipper(t *testing.T) { Skipper: func(context echo.Context) bool { return true // skip everything }, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) isCalled := false @@ -472,11 +348,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) { BeforeFunc: func(context echo.Context) { isCalled = true }, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) @@ -493,18 +369,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "custom_error") - }, - }, - expectStatusCode: http.StatusTeapot, - }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "custom_error") }, }, @@ -539,23 +405,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) }, }, expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error()) - }, - }, - expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n", - }, } for _, tc := range testCases { @@ -568,14 +424,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { config := tc.given parseTokenCalled := false - config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { + config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) { parseTokenCalled = true return nil, errors.New("parsing failed") } e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) @@ -598,7 +454,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { signingKey := []byte("secret") config := JWTConfig{ - ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != "HS256" { return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) @@ -621,125 +477,130 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) } -func TestJWTConfig_TokenLookupFuncs(t *testing.T) { +func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) + success := c.Get("success").(string) + user := c.Get("user").(string) + return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user)) }) - e.Use(JWTWithConfig(JWTConfig{ - TokenLookupFuncs: []ValuesExtractor{ - func(c echo.Context) ([]string, error) { - return []string{c.Request().Header.Get("X-API-Key")}, nil - }, + mw, err := JWTConfig{ + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + return auth, nil }, - SigningKey: []byte("secret"), - })) + SuccessHandler: func(c echo.Context) { + c.Set("success", "yes") + }, + }.ToMiddleware() + assert.NoError(t, err) + e.Use(mw) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64") res := httptest.NewRecorder() e.ServeHTTP(res, req) - assert.Equal(t, http.StatusOK, res.Code) - assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) + assert.Equal(t, "yes:valid_token_base64", res.Body.String()) + assert.Equal(t, http.StatusTeapot, res.Code) } -func TestJWTConfig_SuccessHandler(t *testing.T) { +func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { var testCases = []struct { - name string - givenToken string - expectCalled bool - expectStatus int + name string + givenContinueOnIgnoredError bool + givenErrorHandler JWTErrorHandlerWithContext + givenTokenLookup string + whenAuthHeaders []string + whenCookies []string + whenParseReturn string + whenParseError error + expectHandlerCalled bool + expect string + expectCode int }{ { - name: "ok, success handler is called", - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectCalled: true, - expectStatus: http.StatusOK, + name: "ok, with valid JWT from auth header", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + return nil + }, + whenAuthHeaders: []string{"Bearer valid_token_base64"}, + whenParseReturn: "valid_token", + expectCode: http.StatusTeapot, + expect: "valid_token", }, { - name: "nok, success handler is not called", - givenToken: "x.x.x", - expectCalled: false, - expectStatus: http.StatusUnauthorized, + name: "ok, missing header, callNext and set public_token from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + if errors.Is(err, &ValueExtractorError{}) { + panic("must get ErrJWTMissing") + } + c.Set("user", "public_token") + return nil + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusTeapot, + expect: "public_token", }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - e.GET("/", func(c echo.Context) error { - token := c.Get("user").(*jwt.Token) - return c.JSON(http.StatusOK, token.Claims) - }) - - wasCalled := false - e.Use(JWTWithConfig(JWTConfig{ - SuccessHandler: func(c echo.Context) { - wasCalled = true - }, - SigningKey: []byte("secret"), - })) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) - res := httptest.NewRecorder() - - e.ServeHTTP(res, req) - - assert.Equal(t, tc.expectCalled, wasCalled) - assert.Equal(t, tc.expectStatus, res.Code) - }) - } -} - -func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { - var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenToken string - expectStatus int - expectBody string - }{ { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", - expectStatus: http.StatusTeapot, - expectBody: "", + name: "ok, invalid token, callNext and set public_token from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + // this is probably not realistic usecase. on parse error you probably want to return error + if err.Error() != "parser_error" { + panic("must get parser_error") + } + c.Set("user", "public_token") + return nil + }, + whenAuthHeaders: []string{"Bearer invalid_header"}, + whenParseError: errors.New("parser_error"), + expectCode: http.StatusTeapot, + expect: "public_token", }, { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenToken: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", + name: "nok, invalid token, return error from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + if err.Error() != "parser_error" { + panic("must get parser_error") + } + return err + }, + whenAuthHeaders: []string{"Bearer invalid_header"}, + whenParseError: errors.New("parser_error"), + expectCode: http.StatusInternalServerError, + expect: "{\"message\":\"Internal Server Error\"}\n", }, { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenToken: "", - expectStatus: http.StatusTeapot, - expectBody: "public-token", + name: "nok, ContinueOnIgnoredError but return error from error handler", + givenContinueOnIgnoredError: true, + givenErrorHandler: func(c echo.Context, err error) error { + return echo.ErrUnauthorized.WithInternal(err) + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusUnauthorized, + expect: "{\"message\":\"Unauthorized\"}\n", }, { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenToken: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", + name: "nok, ContinueOnIgnoredError=false", + givenContinueOnIgnoredError: false, + givenErrorHandler: func(c echo.Context, err error) error { + return echo.ErrUnauthorized.WithInternal(err) + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusUnauthorized, + expect: "{\"message\":\"Unauthorized\"}\n", }, } @@ -748,32 +609,56 @@ func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) + token := c.Get("user").(string) + return c.String(http.StatusTeapot, token) }) - e.Use(JWTWithConfig(JWTConfig{ - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, c echo.Context) error { - if err == ErrJWTMissing { - c.Set("test", "public-token") - return nil - } - return echo.ErrUnauthorized + mw, err := JWTConfig{ + ContinueOnIgnoredError: tc.givenContinueOnIgnoredError, + TokenLookup: tc.givenTokenLookup, + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + return tc.whenParseReturn, tc.whenParseError }, - })) + ErrorHandler: tc.givenErrorHandler, + }.ToMiddleware() + assert.NoError(t, err) + e.Use(mw) req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenToken != "" { - req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) + for _, a := range tc.whenAuthHeaders { + req.Header.Add(echo.HeaderAuthorization, a) } res := httptest.NewRecorder() - e.ServeHTTP(res, req) - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) + assert.Equal(t, tc.expect, res.Body.String()) + assert.Equal(t, tc.expectCode, res.Code) }) } } + +func TestJWTConfig_TokenLookupFuncs(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) + }) + + e.Use(JWTWithConfig(JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + TokenLookupFuncs: []ValuesExtractor{ + func(c echo.Context) ([]string, error) { + return []string{c.Request().Header.Get("X-API-Key")}, nil + }, + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) +} diff --git a/middleware/key_auth.go b/middleware/key_auth.go index e8a6b0853..77a001ea8 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -2,81 +2,69 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" + "fmt" + "github.com/labstack/echo/v5" "net/http" ) -type ( - // KeyAuthConfig defines the config for KeyAuth middleware. - KeyAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // KeyLookup is a string in the form of ":" or ":,:" that is used - // to extract key from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. - // - "query:" - // - "form:" - // - "cookie:" - // Multiple sources example: - // - "header:Authorization,header:X-Api-Key" - KeyLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // Validator is a function to validate key. - // Required. - Validator KeyAuthValidator - - // ErrorHandler defines a function which is executed for an invalid key. - // It may be used to define a custom error. - ErrorHandler KeyAuthErrorHandler - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandler to set a default public key auth value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. - ContinueOnIgnoredError bool - } +// KeyAuthConfig defines the config for KeyAuth middleware. +type KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. + // - "query:" + // - "form:" + // - "cookie:" + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. + ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool +} - // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(auth string, c echo.Context) (bool, error) +// KeyAuthValidator defines a function to validate KeyAuth credentials. +type KeyAuthValidator func(c echo.Context, key string) (bool, error) - // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(err error, c echo.Context) error -) +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler func(c echo.Context, err error) error -var ( - // DefaultKeyAuthConfig is the default KeyAuth middleware config. - DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - } -) +// ErrKeyMissing denotes an error raised when key value could not be extracted from request +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") -// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups -type ErrKeyAuthMissing struct { - Err error -} +// ErrInvalidKey denotes an error raised when key value is invalid by validator +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") -// Error returns errors text -func (e *ErrKeyAuthMissing) Error() string { - return e.Err.Error() -} - -// Unwrap unwraps error -func (e *ErrKeyAuthMissing) Unwrap() error { - return e.Err +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", } // KeyAuth returns an KeyAuth middleware. @@ -90,27 +78,33 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { return KeyAuthWithConfig(c) } -// KeyAuthWithConfig returns an KeyAuth middleware with config. -// See `KeyAuth()`. +// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. +// +// For first valid key it calls the next handler. +// For invalid key, it sends "401 - Unauthorized" response. +// For missing key, it sends "400 - Bad Request" response. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration +func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } - // Defaults - if config.AuthScheme == "" { - config.AuthScheme = DefaultKeyAuthConfig.AuthScheme - } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { - panic("echo: key-auth middleware requires a validator function") + return nil, errors.New("echo key-auth middleware requires a validator function") } - extractors, err := createExtractors(config.KeyLookup, config.AuthScheme) + extractors, err := createExtractors(config.KeyLookup) if err != nil { - panic(err) + return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err) + } + if len(extractors) == 0 { + return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -122,59 +116,41 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { var lastExtractorErr error var lastValidatorErr error for _, extractor := range extractors { - keys, err := extractor(c) - if err != nil { - lastExtractorErr = err + keys, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr continue } for _, key := range keys { - valid, err := config.Validator(key, c) + valid, err := config.Validator(c, key) if err != nil { lastValidatorErr = err continue } - if valid { - return next(c) + if !valid { + lastValidatorErr = ErrInvalidKey + continue } - lastValidatorErr = errors.New("invalid key") + return next(c) } } - // we are here only when we did not successfully extract and validate any of keys + // prioritize validator errors over extracting errors err := lastValidatorErr - if err == nil { // prioritize validator errors over extracting errors - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - err = errors.New("missing key in the query string") - } else if lastExtractorErr == errCookieExtractorValueMissing { - err = errors.New("missing key in cookies") - } else if lastExtractorErr == errFormExtractorValueMissing { - err = errors.New("missing key in the form") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - err = errors.New("missing key in request header") - } else if lastExtractorErr == errHeaderExtractorValueInvalid { - err = errors.New("invalid key in the request header") - } else { - err = lastExtractorErr - } - err = &ErrKeyAuthMissing{Err: err} + if err == nil { + err = lastExtractorErr } - if config.ErrorHandler != nil { - tmpErr := config.ErrorHandler(err, c) + tmpErr := config.ErrorHandler(c, err) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } - if lastValidatorErr != nil { // prioritize validator errors over extracting errors - return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "Unauthorized", - Internal: lastValidatorErr, - } + if lastValidatorErr == nil { + return ErrKeyMissing.WithInternal(err) } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + return echo.ErrUnauthorized.WithInternal(err) } - } + }, nil } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index ff8968c38..1b64865fb 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -7,11 +7,11 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func testKeyValidator(key string, c echo.Context) (bool, error) { +func testKeyValidator(c echo.Context, key string) (bool, error) { switch key { case "valid-key": return true, nil @@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized, internal=invalid key", + expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -84,24 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, - expectError: "code=400, message=invalid key in the request header", + expectError: "code=401, message=missing key, internal=invalid value in request header", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", - }, - { - name: "ok, custom key lookup from multiple places, query and header", - givenRequest: func(req *http.Request) { - req.URL.RawQuery = "key=invalid-key" - req.Header.Set("API-Key", "valid-key") - }, - whenConfig: func(conf *KeyAuthConfig) { - conf.KeyLookup = "query:key,header:API-Key" - }, - expectHandlerCalled: true, + expectError: "code=401, message=missing key, internal=missing value in request header", }, { name: "ok, custom key lookup, header", @@ -121,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", + expectError: "code=401, message=missing key, internal=missing value in request header", }, { name: "ok, custom key lookup, query", @@ -141,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the query string", + expectError: "code=401, message=missing key, internal=missing value in the query string", }, { name: "ok, custom key lookup, form", @@ -166,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the form", + expectError: "code=401, message=missing key, internal=missing value in the form", }, { name: "ok, custom key lookup, cookie", @@ -190,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in cookies", + expectError: "code=401, message=missing key, internal=missing value in cookies", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, internal=missing key in request header", + expectError: "code=418, message=custom, internal=missing value in request header", }, { name: "nok, custom errorHandler, error from validator", @@ -211,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError @@ -269,108 +258,96 @@ func TestKeyAuthWithConfig(t *testing.T) { } } -func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { - assert.PanicsWithError( - t, - "extractor source for lookup could not be split into needed parts: a", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - KeyLookup: "a", - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { - assert.PanicsWithValue( - t, - "echo: key-auth middleware requires a validator function", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: nil, - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { +func TestKeyAuthWithConfig_errors(t *testing.T) { var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenKey string - expectStatus int - expectBody string + name string + whenConfig KeyAuthConfig + expectError string }{ { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenKey: "valid-key", - expectStatus: http.StatusTeapot, - expectBody: "", + name: "ok, no error", + whenConfig: KeyAuthConfig{ + Validator: func(c echo.Context, key string) (bool, error) { + return false, nil + }, + }, }, { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenKey: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", + name: "ok, missing validator func", + whenConfig: KeyAuthConfig{ + Validator: nil, + }, + expectError: "echo key-auth middleware requires a validator function", }, { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenKey: "", - expectStatus: http.StatusTeapot, - expectBody: "public-auth", + name: "ok, extractor source can not be split", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope", + Validator: func(c echo.Context, key string) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", }, { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenKey: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", + name: "ok, no extractors", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope:nope", + Validator: func(c echo.Context, key string) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create extractors from KeyLookup string", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := echo.New() + mw, err := tc.whenConfig.ToMiddleware() + if tc.expectError != "" { + assert.Nil(t, mw) + assert.EqualError(t, err, tc.expectError) + } else { + assert.NotNil(t, mw) + assert.NoError(t, err) + } + }) + } +} - e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) - }) +func TestMustKeyAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + KeyAuthWithConfig(KeyAuthConfig{}) + }) +} - e.Use(KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - ErrorHandler: func(err error, c echo.Context) error { - if _, ok := err.(*ErrKeyAuthMissing); ok { - c.Set("test", "public-auth") - return nil - } - return echo.ErrUnauthorized - }, - KeyLookup: "header:X-API-Key", - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - })) +func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { + handlerCalled := false + var authValue string + handler := func(c echo.Context) error { + handlerCalled = true + authValue = c.Get("auth").(string) + return c.String(http.StatusOK, "test") + } + middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(c echo.Context, err error) error { + // could check error to decide if we can swallow the error + c.Set("auth", "public") + return nil + }, + ContinueOnIgnoredError: true, + })(handler) - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenKey != "" { - req.Header.Set("X-API-Key", tc.givenKey) - } - res := httptest.NewRecorder() + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // no auth header this time + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - e.ServeHTTP(res, req) + err := middlewareChain(c) - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) - }) - } + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, "public", authValue) } diff --git a/middleware/logger.go b/middleware/logger.go index 9baac4769..bd2d3d932 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,88 +3,86 @@ package middleware import ( "bytes" "encoding/json" + "fmt" "io" "strconv" "strings" "sync" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/color" + "github.com/labstack/echo/v5" "github.com/valyala/fasttemplate" ) -type ( - // LoggerConfig defines the config for Logger middleware. - LoggerConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Tags to construct the logger format. - // - // - time_unix - // - time_unix_nano - // - time_rfc3339 - // - time_rfc3339_nano - // - time_custom - // - id (Request ID) - // - remote_ip - // - uri - // - host - // - method - // - path - // - protocol - // - referer - // - user_agent - // - status - // - error - // - latency (In nanoseconds) - // - latency_human (Human readable) - // - bytes_in (Bytes received) - // - bytes_out (Bytes sent) - // - header: - // - query: - // - form: - // - // Example "${remote_ip} ${status}" - // - // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` - - // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. - CustomTimeFormat string `yaml:"custom_time_format"` - - // Output is a writer where logs in JSON format are written. - // Optional. Default value os.Stdout. - Output io.Writer - - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - } -) +// LoggerConfig defines the config for Logger middleware. +type LoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -var ( - // DefaultLoggerConfig is the default Logger middleware config. - DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - } -) + // Tags to construct the logger format. + // + // - time_unix + // - time_unix_nano + // - time_rfc3339 + // - time_rfc3339_nano + // - time_custom + // - id (Request ID) + // - remote_ip + // - uri + // - host + // - method + // - path + // - protocol + // - referer + // - user_agent + // - status + // - error + // - latency (In nanoseconds) + // - latency_human (Human readable) + // - bytes_in (Bytes received) + // - bytes_out (Bytes sent) + // - header: + // - query: + // - form: + // + // Example "${remote_ip} ${status}" + // + // Optional. Default value DefaultLoggerConfig.Format. + Format string + + // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. + CustomTimeFormat string + + // Output is a writer where logs in JSON format are written. + // Optional. Default destination `echo.Logger.Infof()` + Output io.Writer + + template *fasttemplate.Template + pool *sync.Pool +} + +// DefaultLoggerConfig is the default Logger middleware config. +var DefaultLoggerConfig = LoggerConfig{ + Skipper: DefaultSkipper, + Format: `{"time":"${time_rfc3339_nano}","level":"INFO","id":"${id}","remote_ip":"${remote_ip}",` + + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", + CustomTimeFormat: "2006-01-02 15:04:05.00000", +} // Logger returns a middleware that logs HTTP requests. func Logger() echo.MiddlewareFunc { return LoggerWithConfig(DefaultLoggerConfig) } -// LoggerWithConfig returns a Logger middleware with config. -// See: `Logger()`. +// LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration. func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration +func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultLoggerConfig.Skipper @@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { if config.Format == "" { config.Format = DefaultLoggerConfig.Format } - if config.Output == nil { - config.Output = DefaultLoggerConfig.Output - } config.template = fasttemplate.New(config.Format, "${", "}") - config.colorer = color.New() - config.colorer.SetOutput(config.Output) config.pool = &sync.Pool{ New: func() interface{} { return bytes.NewBuffer(make([]byte, 256)) @@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() + start := time.Now() - if err = next(c); err != nil { - c.Error(err) - } + err := next(c) stop := time.Now() + buf := config.pool.Get().(*bytes.Buffer) buf.Reset() defer config.pool.Put(buf) - if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { + _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { switch tag { case "time_unix": return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) @@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case "user_agent": return buf.WriteString(req.UserAgent()) case "status": - n := res.Status - s := config.colorer.Green(n) - switch { - case n >= 500: - s = config.colorer.Red(n) - case n >= 400: - s = config.colorer.Yellow(n) - case n >= 300: - s = config.colorer.Cyan(n) + status := res.Status + if err != nil { + if httpErr, ok := err.(*echo.HTTPError); ok { + status = httpErr.Code + } } - return buf.WriteString(s) + return buf.WriteString(strconv.Itoa(status)) case "error": if err != nil { // Error may contain invalid JSON e.g. `"` @@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case strings.HasPrefix(tag, "form:"): return buf.Write([]byte(c.FormValue(tag[5:]))) case strings.HasPrefix(tag, "cookie:"): - cookie, err := c.Cookie(tag[7:]) - if err == nil { + cookie, cookieErr := c.Cookie(tag[7:]) + if cookieErr == nil { return buf.Write([]byte(cookie.Value)) } } } return 0, nil - }); err != nil { - return + }) + if tmplErr != nil { + if err != nil { + return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err) + } + return fmt.Errorf("failed to create log from template: %w", tmplErr) } - if config.Output == nil { - _, err = c.Logger().Output().Write(buf.Bytes()) - return + if config.Output != nil { + if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil { + return lErr + } + } else { + if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil { + return lErr + } } - _, err = config.Output.Write(buf.Bytes()) - return + return err } - } + }, nil } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 394f62712..2f1230dda 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -12,7 +12,7 @@ import ( "time" "unsafe" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} ip := "127.0.0.1" h := Logger()(func(c echo.Context) error { return c.String(http.StatusOK, "test") diff --git a/middleware/method_override.go b/middleware/method_override.go index 92b14d2ed..202862f3b 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -3,31 +3,27 @@ package middleware import ( "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // MethodOverrideConfig defines the config for MethodOverride middleware. - MethodOverrideConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// MethodOverrideConfig defines the config for MethodOverride middleware. +type MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Getter is a function that gets overridden method from the request. - // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). - Getter MethodOverrideGetter - } + // Getter is a function that gets overridden method from the request. + // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). + Getter MethodOverrideGetter +} - // MethodOverrideGetter is a function that gets overridden method from the request - MethodOverrideGetter func(echo.Context) string -) +// MethodOverrideGetter is a function that gets overridden method from the request +type MethodOverrideGetter func(echo.Context) string -var ( - // DefaultMethodOverrideConfig is the default MethodOverride middleware config. - DefaultMethodOverrideConfig = MethodOverrideConfig{ - Skipper: DefaultSkipper, - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), - } -) +// DefaultMethodOverrideConfig is the default MethodOverride middleware config. +var DefaultMethodOverrideConfig = MethodOverrideConfig{ + Skipper: DefaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), +} // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and @@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a MethodOverride middleware with config. -// See: `MethodOverride()`. +// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration +func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper @@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 5760b1581..266a575ba 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -6,7 +6,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) { rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) - m(h)(c) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_formParam(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + // Override with form parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) - rec = httptest.NewRecorder() + m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) + rec := httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - m(h)(c) + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_queryParam(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Override with query parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) - req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - m(h)(c) + m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_ignoreGet(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Ignore `GET` - req = httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodGet, req.Method) } diff --git a/middleware/middleware.go b/middleware/middleware.go index f250ca49a..2f8c8b5c8 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -6,17 +6,14 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // Skipper defines a function to skip middleware. Returning true skips processing - // the middleware. - Skipper func(c echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing the middleware. +type Skipper func(c echo.Context) bool - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(c echo.Context) -) +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc func(c echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error func DefaultSkipper(echo.Context) bool { return false } + +func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} diff --git a/middleware/proxy.go b/middleware/proxy.go index 6cfd6731e..1efbc2432 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "io" "math/rand" @@ -15,90 +16,86 @@ import ( "sync/atomic" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // TODO: Handle TLS proxy -type ( - // ProxyConfig defines the config for Proxy middleware. - ProxyConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Balancer defines a load balancing technique. - // Required. - Balancer ProxyBalancer - - // Rewrite defines URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Examples: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - Rewrite map[string]string - - // RegexRewrite defines rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRewrite map[*regexp.Regexp]string - - // Context key to store selected ProxyTarget into context. - // Optional. Default value "target". - ContextKey string - - // To customize the transport to remote. - // Examples: If custom TLS certificates are required. - Transport http.RoundTripper - - // ModifyResponse defines function to modify response from ProxyTarget. - ModifyResponse func(*http.Response) error - } +// ProxyConfig defines the config for Proxy middleware. +type ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Balancer defines a load balancing technique. + // Required. + Balancer ProxyBalancer + + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string + + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + + // Context key to store selected ProxyTarget into context. + // Optional. Default value "target". + ContextKey string + + // To customize the transport to remote. + // Examples: If custom TLS certificates are required. + Transport http.RoundTripper + + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error +} - // ProxyTarget defines the upstream target. - ProxyTarget struct { - Name string - URL *url.URL - Meta echo.Map - } +// ProxyTarget defines the upstream target. +type ProxyTarget struct { + Name string + URL *url.URL + Meta echo.Map +} - // ProxyBalancer defines an interface to implement a load balancing technique. - ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget - } +// ProxyBalancer defines an interface to implement a load balancing technique. +type ProxyBalancer interface { + AddTarget(*ProxyTarget) bool + RemoveTarget(string) bool + Next(echo.Context) *ProxyTarget +} - commonBalancer struct { - targets []*ProxyTarget - mutex sync.RWMutex - } +type commonBalancer struct { + targets []*ProxyTarget + mutex sync.RWMutex +} - // RandomBalancer implements a random load balancing technique. - randomBalancer struct { - *commonBalancer - random *rand.Rand - } +// RandomBalancer implements a random load balancing technique. +type randomBalancer struct { + *commonBalancer + random *rand.Rand +} - // RoundRobinBalancer implements a round-robin load balancing technique. - roundRobinBalancer struct { - *commonBalancer - i uint32 - } -) +// RoundRobinBalancer implements a round-robin load balancing technique. +type roundRobinBalancer struct { + *commonBalancer + i uint32 +} -var ( - // DefaultProxyConfig is the default Proxy middleware config. - DefaultProxyConfig = ProxyConfig{ - Skipper: DefaultSkipper, - ContextKey: "target", - } -) +// DefaultProxyConfig is the default Proxy middleware config. +var DefaultProxyConfig = ProxyConfig{ + Skipper: DefaultSkipper, + ContextKey: "target", +} -func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { +func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { @@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { return ProxyWithConfig(c) } -// ProxyWithConfig returns a Proxy middleware with config. -// See: `Proxy()` +// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. +// +// Proxy middleware forwards the request to upstream server using a configured load balancing technique. func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration +func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } + if config.ContextKey == "" { + config.ContextKey = DefaultProxyConfig.ContextKey + } if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") + return nil, errors.New("echo proxy middleware requires balancer") } if config.Rewrite != nil { @@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Proxy switch { case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) + proxyRaw(c, tgt).ServeHTTP(res, req) case req.Header.Get(echo.HeaderAccept) == "text/event-stream": default: - proxyHTTP(tgt, c, config).ServeHTTP(res, req) + proxyHTTP(c, tgt, config).ServeHTTP(res, req) } if e, ok := c.Get("_error").(error); ok { err = e @@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { return } - } + }, nil } // StatusCodeContextCanceled is a custom HTTP status code for situations @@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // 499 too instead of the more problematic 5xx, which does not allow to detect this situation const StatusCodeContextCanceled = 499 -func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { +func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { desc := tgt.URL.String() diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 7939fc5c2..1d0dee91e 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -14,7 +14,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -55,7 +55,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -77,7 +77,7 @@ func TestProxy(t *testing.T) { // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -113,15 +113,20 @@ func TestProxy(t *testing.T) { return nil } } - rrb1 := NewRoundRobinBalancer(targets) e = echo.New() e.Use(contextObserver) - e.Use(Proxy(rrb1)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } +func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { + assert.Panics(t, func() { + ProxyWithConfig(ProxyConfig{Balancer: nil}) + }) +} + func TestProxyRealIPHeader(t *testing.T) { // Setup upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) { url, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { rb := NewRandomBalancer(nil) assert.True(t, rb.AddTarget(target)) e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(req.Context()) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index be2b348db..09237f05b 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -1,47 +1,42 @@ package middleware import ( + "errors" "net/http" "sync" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "golang.org/x/time/rate" ) -type ( - // RateLimiterStore is the interface to be implemented by custom stores. - RateLimiterStore interface { - // Stores for the rate limiter have to implement the Allow method - Allow(identifier string) (bool, error) - } -) +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + Allow(identifier string) (bool, error) +} -type ( - // RateLimiterConfig defines the configuration for the rate limiter - RateLimiterConfig struct { - Skipper Skipper - BeforeFunc BeforeFunc - // IdentifierExtractor uses echo.Context to extract the identifier for a visitor - IdentifierExtractor Extractor - // Store defines a store for the rate limiter - Store RateLimiterStore - // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(context echo.Context, err error) error - // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(context echo.Context, identifier string, err error) error - } - // Extractor is used to extract data from echo.Context - Extractor func(context echo.Context) (string, error) -) +// RateLimiterConfig defines the configuration for the rate limiter +type RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error +} -// errors -var ( - // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded - ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") - // ErrExtractorError denotes an error raised when extractor function is unsuccessful - ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") -) +// Extractor is used to extract data from echo.Context +type Extractor func(context echo.Context) (string, error) + +// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded +var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + +// ErrExtractorError denotes an error raised when extractor function is unsuccessful +var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ @@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware }, middleware.RateLimiterWithConfig(config)) */ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration +func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultRateLimiterConfig.Skipper } @@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { config.DenyHandler = DefaultRateLimiterConfig.DenyHandler } if config.Store == nil { - panic("Store configuration must be provided") + return nil, errors.New("echo rate limiter store configuration must be provided") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -137,36 +137,32 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { identifier, err := config.IdentifierExtractor(c) if err != nil { - c.Error(config.ErrorHandler(c, err)) - return nil + return config.ErrorHandler(c, err) } - if allow, err := config.Store.Allow(identifier); !allow { - c.Error(config.DenyHandler(c, identifier, err)) - return nil + if allow, allowErr := config.Store.Allow(identifier); !allow { + return config.DenyHandler(c, identifier, allowErr) } return next(c) } - } + }, nil } -type ( - // RateLimiterMemoryStore is the built-in store implementation for RateLimiter - RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit //for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. +// RateLimiterMemoryStore is the built-in store implementation for RateLimiter +type RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit + burst int + expiresIn time.Duration + lastCleanup time.Time +} - burst int - expiresIn time.Duration - lastCleanup time.Time - } - // Visitor signifies a unique user's limiter details - Visitor struct { - *rate.Limiter - lastSeen time.Time - } -) +// Visitor signifies a unique user's limiter details +type Visitor struct { + *rate.Limiter + lastSeen time.Time +} /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 89d9a6edc..de546a19c 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -10,8 +10,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - mw := RateLimiter(inMemoryStore) + mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - _ = mw(handler)(c) - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } -func TestRateLimiter_panicBehaviour(t *testing.T) { +func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) assert.Panics(t, func() { - RateLimiter(nil) + RateLimiterWithConfig(RateLimiterConfig{}) }) assert.NotPanics(t, func() { - RateLimiter(inMemoryStore) + RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) }) } @@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) { return ctx.JSON(http.StatusBadRequest, nil) }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { id string @@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) + err := mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, tc.code, rec.Code) } } @@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return id, nil }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"", http.StatusForbidden}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } @@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } } @@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return true }, @@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, false, beforeFuncRan) } @@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return false }, @@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) _ = mw(handler)(c) @@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ BeforeFunc: func(c echo.Context) { beforeRan = true }, @@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, true, beforeRan) } @@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { - addrs[i] = random.String(15) + addrs[i] = randomString(15) } return addrs } diff --git a/middleware/recover.go b/middleware/recover.go index 7b6128533..7e46ccd7b 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,53 +5,35 @@ import ( "net/http" "runtime" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" ) -type ( - // LogErrorFunc defines a function for custom logging in the middleware. - LogErrorFunc func(c echo.Context, err error, stack []byte) error +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // RecoverConfig defines the config for Recover middleware. - RecoverConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int - // Size of the stack to be printed. - // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool - // DisableStackAll disables formatting stack traces of all other goroutines - // into buffer after the trace for the current goroutine. - // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` - - // DisablePrintStack disables printing stack trace. - // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` - - // LogLevel is log level to printing stack trace. - // Optional. Default value 0 (Print). - LogLevel log.Lvl - - // LogErrorFunc defines a function for custom logging in the middleware. - // If it's set you don't need to provide LogLevel for config. - LogErrorFunc LogErrorFunc - } -) + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool +} -var ( - // DefaultRecoverConfig is the default Recover middleware config. - DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - LogErrorFunc: nil, - } -) +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, +} // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. @@ -59,9 +41,13 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a Recover middleware with config. -// See: `Recover()`. +// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration +func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper @@ -71,7 +57,7 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c echo.Context) (err error) { if config.Skipper(c) { return next(c) } @@ -81,42 +67,19 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if r == http.ErrAbortHandler { panic(r) } - err, ok := r.(error) + tmpErr, ok := r.(error) if !ok { - err = fmt.Errorf("%v", r) + tmpErr = fmt.Errorf("%v", r) } - var stack []byte - var length int - if !config.DisablePrintStack { - stack = make([]byte, config.StackSize) - length = runtime.Stack(stack, !config.DisableStackAll) - stack = stack[:length] + stack := make([]byte, config.StackSize) + length := runtime.Stack(stack, !config.DisableStackAll) + tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length]) } - - if config.LogErrorFunc != nil { - err = config.LogErrorFunc(c, err, stack) - } else if !config.DisablePrintStack { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) - switch config.LogLevel { - case log.DEBUG: - c.Logger().Debug(msg) - case log.INFO: - c.Logger().Info(msg) - case log.WARN: - c.Logger().Warn(msg) - case log.ERROR: - c.Logger().Error(msg) - case log.OFF: - // None. - default: - c.Logger().Print(msg) - } - } - c.Error(err) + err = tmpErr } }() return next(c) } - } + }, nil } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index b27f3b41c..091fef899 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,36 +2,57 @@ package middleware import ( "bytes" - "errors" - "fmt" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c echo.Context) error { panic("test") - })) - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, buf.String(), "PANIC RECOVER") + }) + err := h(c) + assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + assert.Contains(t, buf.String(), "") // nothing is logged +} + +func TestRecover_skipper(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := RecoverConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + } + h := RecoverWithConfig(config)(func(c echo.Context) error { + panic("testPANIC") + }) + + var err error + assert.Panics(t, func() { + err = h(c) + }) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain } func TestRecoverErrAbortHandler(t *testing.T) { e := echo.New() - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -51,115 +72,66 @@ func TestRecoverErrAbortHandler(t *testing.T) { } }() - h(c) + hErr := h(c) assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.NotContains(t, buf.String(), "PANIC RECOVER") + assert.NotContains(t, hErr.Error(), "PANIC RECOVER") } -func TestRecoverWithConfig_LogLevel(t *testing.T) { - tests := []struct { - logLevel log.Lvl - levelName string - }{{ - logLevel: log.DEBUG, - levelName: "DEBUG", - }, { - logLevel: log.INFO, - levelName: "INFO", - }, { - logLevel: log.WARN, - levelName: "WARN", - }, { - logLevel: log.ERROR, - levelName: "ERROR", - }, { - logLevel: log.OFF, - levelName: "OFF", - }} - - for _, tt := range tests { - tt := tt - t.Run(tt.levelName, func(t *testing.T) { - e := echo.New() - e.Logger.SetLevel(log.DEBUG) +func TestRecoverWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenNoPanic bool + whenConfig RecoverConfig + expectErrContain string + expectErr string + }{ + { + name: "ok, default config", + whenConfig: DefaultRecoverConfig, + expectErrContain: "[PANIC RECOVER] testPANIC goroutine", + }, + { + name: "ok, no panic", + givenNoPanic: true, + whenConfig: DefaultRecoverConfig, + expectErrContain: "", + }, + { + name: "ok, DisablePrintStack", + whenConfig: RecoverConfig{ + DisablePrintStack: true, + }, + expectErr: "testPANIC", + }, + } - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - config := DefaultRecoverConfig - config.LogLevel = tt.logLevel - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("test") - })) - - h(c) + config := tc.whenConfig + h := RecoverWithConfig(config)(func(c echo.Context) error { + if tc.givenNoPanic { + return nil + } + panic("testPANIC") + }) - assert.Equal(t, http.StatusInternalServerError, rec.Code) + err := h(c) - output := buf.String() - if tt.logLevel == log.OFF { - assert.Empty(t, output) + if tc.expectErrContain != "" { + assert.Contains(t, err.Error(), tc.expectErrContain) + } else if tc.expectErr != "" { + assert.Contains(t, err.Error(), tc.expectErr) } else { - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) + assert.NoError(t, err) } + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain }) } } - -func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { - e := echo.New() - e.Logger.SetLevel(log.DEBUG) - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - testError := errors.New("test") - config := DefaultRecoverConfig - config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) - if errors.Is(err, testError) { - c.Logger().Debug(msg) - } else { - c.Logger().Error(msg) - } - return err - } - - t.Run("first branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic(testError) - })) - - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"DEBUG"`) - }) - - t.Run("else branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("other") - })) - - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"ERROR"`) - }) -} diff --git a/middleware/redirect.go b/middleware/redirect.go index 13877db38..bda5ac204 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -1,10 +1,11 @@ package middleware import ( + "errors" "net/http" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // RedirectConfig defines the config for Redirect middleware. @@ -14,7 +15,9 @@ type RedirectConfig struct { // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. - Code int `yaml:"code"` + Code int + + redirect redirectLogic } // redirectLogic represents a function that given a scheme, host and uri @@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." -// DefaultRedirectConfig is the default Redirect middleware config. -var DefaultRedirectConfig = RedirectConfig{ - Skipper: DefaultSkipper, - Code: http.StatusMovedPermanently, -} +// RedirectHTTPSConfig is the HTTPS Redirect middleware config. +var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} + +// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. +var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} + +// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. +var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} + +// RedirectWWWConfig is the WWW Redirect middleware config. +var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} + +// RedirectNonWWWConfig is the non WWW Redirect middleware config. +var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { - return HTTPSRedirectWithConfig(DefaultRedirectConfig) + return HTTPSRedirectWithConfig(RedirectHTTPSConfig) } -// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSRedirect()`. +// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" { - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPS + return toMiddlewareOrPanic(config) } // HTTPSWWWRedirect redirects http requests to https www. @@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { - return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) } -// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSWWWRedirect()`. +// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" && !strings.HasPrefix(host, www) { - return true, "https://www." + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPSWWW + return toMiddlewareOrPanic(config) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { - return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) } -// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSNonWWWRedirect()`. +// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if scheme != "https" { - host = strings.TrimPrefix(host, www) - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectNonHTTPSWWW + return toMiddlewareOrPanic(config) } // WWWRedirect redirects non www requests to www. @@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { - return WWWRedirectWithConfig(DefaultRedirectConfig) + return WWWRedirectWithConfig(RedirectWWWConfig) } -// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `WWWRedirect()`. +// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if !strings.HasPrefix(host, www) { - return true, scheme + "://www." + host + uri - } - return false, "" - }) + config.redirect = redirectWWW + return toMiddlewareOrPanic(config) } // NonWWWRedirect redirects www requests to non www. @@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { - return NonWWWRedirectWithConfig(DefaultRedirectConfig) + return NonWWWRedirectWithConfig(RedirectNonWWWConfig) } -// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `NonWWWRedirect()`. +// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if strings.HasPrefix(host, www) { - return true, scheme + "://" + host[4:] + uri - } - return false, "" - }) + config.redirect = redirectNonWWW + return toMiddlewareOrPanic(config) } -func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { +// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration +func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRedirectConfig.Skipper + config.Skipper = DefaultSkipper } if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code + config.Code = http.StatusMovedPermanently + } + if config.redirect == nil { + return nil, errors.New("redirectConfig is missing redirect function") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { req, scheme := c.Request(), c.Scheme() host := req.Host - if ok, url := cb(scheme, host, req.RequestURI); ok { + if ok, url := config.redirect(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } + }, nil +} + +var redirectHTTPS = func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri + } + return false, "" +} + +var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { + if scheme != "https" && !strings.HasPrefix(host, www) { + return true, "https://www." + host + uri + } + return false, "" +} + +var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { + if scheme != "https" { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri + } + return false, "" +} + +var redirectWWW = func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri + } + return false, "" +} + +var redirectNonWWW = func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri } + return false, "" } diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 9d1b56205..9484bdf20 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) diff --git a/middleware/request_id.go b/middleware/request_id.go index 8c5ff6605..b553321ec 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -1,50 +1,42 @@ package middleware import ( - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // RequestIDConfig defines the config for RequestID middleware. - RequestIDConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RequestIDConfig defines the config for RequestID middleware. +type RequestIDConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). - Generator func() string + // Generator defines a function to generate an ID. + // Optional. Default value random.String(32). + Generator func() string - // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(echo.Context, string) + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(c echo.Context, requestID string) - // TargetHeader defines what header to look for to populate the id - TargetHeader string - } -) - -var ( - // DefaultRequestIDConfig is the default RequestID middleware config. - DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - TargetHeader: echo.HeaderXRequestID, - } -) + // TargetHeader defines what header to look for to populate the id + TargetHeader string +} // RequestID returns a X-Request-ID middleware. func RequestID() echo.MiddlewareFunc { - return RequestIDWithConfig(DefaultRequestIDConfig) + return RequestIDWithConfig(RequestIDConfig{}) } -// RequestIDWithConfig returns a X-Request-ID middleware with config. +// RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration. func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration +func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRequestIDConfig.Skipper + config.Skipper = DefaultSkipper } if config.Generator == nil { - config.Generator = generator + config.Generator = createRandomStringGenerator(32) } if config.TargetHeader == "" { config.TargetHeader = echo.HeaderXRequestID @@ -69,9 +61,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { return next(c) } - } -} - -func generator() string { - return random.String(32) + }, nil } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 21b777826..fd0ef5d56 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) { return c.String(http.StatusOK, "test") } - rid := RequestIDWithConfig(RequestIDConfig{}) + rid := RequestID() + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) +} + +func TestMustRequestIDWithConfig_skipper(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + generatorCalled := false + e.Use(RequestIDWithConfig(RequestIDConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Generator: func() string { + generatorCalled = true + return "customGenerator" + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "test", res.Body.String()) + + assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") + assert.False(t, generatorCalled) +} + +func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") +} + +func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + called := false + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + RequestIDHandler: func(c echo.Context, s string) { + called = true + }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, called) +} + +func TestRequestIDWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid, err := RequestIDConfig{}.ToMiddleware() + assert.NoError(t, err) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) - // Custom generator and handler - customID := "customGenerator" - calledHandler := false + // Custom generator rid = RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return customID }, - RequestIDHandler: func(_ echo.Context, id string) { - calledHandler = true - assert.Equal(t, customID, id) - }, + Generator: func() string { return "customGenerator" }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") - assert.True(t, calledHandler) } func TestRequestID_IDNotAltered(t *testing.T) { diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 1b3e3eaad..63b6402fb 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -2,7 +2,7 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/http" "time" ) @@ -24,6 +24,7 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // logger.Info(). +// Date("request_start", v.StartTime). // Str("URI", v.URI). // Int("status", v.Status). // Msg("request") @@ -39,6 +40,7 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // logger.Info("request", +// zap.Time("request_start", v.StartTime), // zap.String("URI", v.URI), // zap.Int("status", v.Status), // ) @@ -54,8 +56,9 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { // log.WithFields(logrus.Fields{ -// "URI": values.URI, -// "status": values.Status, +// "request_start": values.StartTime, +// "URI": values.URI, +// "status": values.Status, // }).Info("request") // // return nil @@ -158,15 +161,15 @@ type RequestLoggerValues struct { // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. ResponseSize int64 // Headers are list of headers from request. Note: request can contain more than one header with same value so slice - // of values is been logger for each given header. + // of values is what will be returned/logged for each given header. // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". Headers map[string][]string // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter - // with same name so slice of values is been logger for each given query param name. + // with same name so slice of values is what will be returned/logged for each given query param name. QueryParams map[string][]string // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with - // same name so slice of values is been logger for each given form value name. + // same name so slice of values is what will be returned/logged for each given form value name. FormValues map[string][]string } diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 5118b1216..c5ddced75 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -1,7 +1,7 @@ package middleware import ( - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) { req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.ServableContext) c.SetPath("/test*") diff --git a/middleware/rewrite.go b/middleware/rewrite.go index e5b0a6b56..16677263f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,62 +1,58 @@ package middleware import ( + "errors" "regexp" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // RewriteConfig defines the config for Rewrite middleware. - RewriteConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RewriteConfig defines the config for Rewrite middleware. +type RewriteConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Rules defines the URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Example: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - // Required. - Rules map[string]string `yaml:"rules"` + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + // Required. + Rules map[string]string - // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` - } -) - -var ( - // DefaultRewriteConfig is the default Rewrite middleware config. - DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, - } -) + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string +} // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { - c := DefaultRewriteConfig + c := RewriteConfig{} c.Rules = rules return RewriteWithConfig(c) } -// RewriteWithConfig returns a Rewrite middleware with config. -// See: `Rewrite()`. +// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. +// +// Rewrite middleware rewrites the URL path based on the provided rules. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { - // Defaults - if config.Rules == nil && config.RegexRules == nil { - panic("echo: rewrite middleware requires url path rewrite rules or regex rules") - } + return toMiddlewareOrPanic(config) +} +// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration +func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Rules == nil && config.RegexRules == nil { + return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") } if config.RegexRules == nil { @@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 0ac04bb2f..1f3419f04 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -8,7 +8,7 @@ import ( "regexp" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) { }, })) e.GET("/public/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) e.GET("/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) var testCases = []struct { @@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) { } } +func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { + assert.Panics(t, func() { + RewriteWithConfig(RewriteConfig{}) + }) +} + +func TestMustRewriteWithConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + givenSkipper func(c echo.Context) bool + whenURL string + expectURL string + expectStatus int + }{ + { + name: "not skipped", + whenURL: "/old", + expectURL: "/new", + expectStatus: http.StatusOK, + }, + { + name: "skipped", + givenSkipper: func(c echo.Context) bool { + return true + }, + whenURL: "/old", + expectURL: "/old", + expectStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig( + RewriteConfig{ + Skipper: tc.givenSkipper, + Rules: map[string]string{"/old": "/new"}}, + )) + + e.GET("/new", func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) + assert.Equal(t, tc.expectStatus, rec.Code) + }) + } +} + // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() - r := e.Router() // Rewrite old url to new one // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches - e.Pre(Rewrite(map[string]string{ - "/old": "/new", - }, - )) + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{"/old": "/new"}}), + ) // Route - r.Add(http.MethodGet, "/new", func(c echo.Context) error { + e.Add(http.MethodGet, "/new", func(c echo.Context) error { return c.NoContent(http.StatusOK) }) @@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() - r := e.Router() // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ @@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { return c.String(http.StatusOK, "hosts") }) - r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { return c.String(http.StatusOK, "eng") }) diff --git a/middleware/secure.go b/middleware/secure.go index 6c4051723..571b35877 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -3,87 +3,83 @@ package middleware import ( "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // SecureConfig defines the config for Secure middleware. - SecureConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // XSSProtection provides protection against cross-site scripting attack (XSS) - // by setting the `X-XSS-Protection` header. - // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` - - // ContentTypeNosniff provides protection against overriding Content-Type - // header by setting the `X-Content-Type-Options` header. - // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` - - // XFrameOptions can be used to indicate whether or not a browser should - // be allowed to render a page in a ,