Skip to content

Commit

Permalink
Fixes #169 (#170)
Browse files Browse the repository at this point in the history
* Check status code before caching

* Added Test for status code checks

* Update rfc/standalone.go

Co-authored-by: darkweak <darkweak@protonmail.com>

* Mirrored change into vendor packages

Co-authored-by: choelzl <cedric.hoelzl@helcel.net>
Co-authored-by: darkweak <darkweak@protonmail.com>
  • Loading branch information
3 people authored Jan 17, 2022
1 parent 8bdd8cf commit 5ef09a9
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 9 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions rfc/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (t *VaryTransport) UpdateCacheEventually(req *http.Request) (*http.Response

req.Response = cachedResp

if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(req.Response.Header)) {
if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(req.Response.Header), req.Response.StatusCode) {
_ = validateVary(req, req.Response, cacheKey, t)
} else {
req.Response.Header.Set("Cache-Status", "Souin; fwd=uri-miss")
Expand Down Expand Up @@ -186,7 +186,7 @@ func (t *VaryTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
resp.Header.Set("Cache-Status", "Souin; fwd=uri-miss")
}
resp, _ = transport.RoundTrip(req)
if !(cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) && validateVary(req, resp, cacheKey, t)) {
if !(cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header), req.Response.StatusCode) && validateVary(req, resp, cacheKey, t)) {
go func() {
t.Transport.CoalescingLayerStorage.Set(cacheKey)
}()
Expand Down
13 changes: 12 additions & 1 deletion rfc/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ func getEndToEndHeaders(respHeaders http.Header) []string {
return endToEndHeaders
}

func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {
func canStore(reqCacheControl cacheControl, respCacheControl cacheControl, status int) (canStore bool) {
if !cachableStatusCode(status){
return false
}

for _, t := range []string{"no-cache", "no-store"} {
if _, ok := respCacheControl[t]; ok {
return false
Expand All @@ -229,6 +233,13 @@ func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {
return true
}

func cachableStatusCode(statusCode int) bool {
switch statusCode {
case 200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501: return true
default: return false
}
}

func newGatewayTimeoutResponse(req *http.Request) *http.Response {
var b bytes.Buffer
b.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n")
Expand Down
29 changes: 26 additions & 3 deletions rfc/standalone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,46 @@ func TestCanStore(t *testing.T) {
resCacheControl := make(map[string]string)
reqCacheControl := make(map[string]string)

if !canStore(reqCacheControl, resCacheControl) {
if !canStore(reqCacheControl, resCacheControl, 200) {
errors.GenerateError(t, "Res and Req doesn't contains headers, it should return true")
}

if canStore(reqCacheControl, resCacheControl, 502) {
errors.GenerateError(t, "Status code shouldnt be stored, it should return false")
}

reqCacheControl["no-store"] = "any"

if canStore(reqCacheControl, resCacheControl) {
if canStore(reqCacheControl, resCacheControl, 200) {
errors.GenerateError(t, "Req contains headers, it should return false")
}

resCacheControl["no-store"] = "any"

if canStore(reqCacheControl, resCacheControl) {
if canStore(reqCacheControl, resCacheControl, 200) {
errors.GenerateError(t, "Res contains headers, it should return false")
}
}

func TestcachableStatusCode(t *testing.T) {
cachable := map[int]bool{
200: true,
300: true,
301: true,
404: true,
500: false,
502: false,
}

for key, value := range cachable {
res := cachableStatusCode(key)
if (res != value) {
msg := fmt.Sprintf("Unexpected response for statusCode %d: %t (expected: %t)", key, res, value)
errors.GenerateError(t, msg)
}
}
}

func TestNewGatewayTimeoutResponse(t *testing.T) {
if newGatewayTimeoutResponse(httptest.NewRequest("GET", "http://domain.com/testing", nil)).StatusCode != http.StatusGatewayTimeout {
errors.GenerateError(t, "Status code should be 504 if valid request provided")
Expand Down

0 comments on commit 5ef09a9

Please sign in to comment.