diff --git a/context.go b/context.go index 27da28a9c..6a1811685 100644 --- a/context.go +++ b/context.go @@ -584,8 +584,10 @@ func (c *context) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + func (c *context) contentDisposition(file, name, dispositionType string) error { - c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) + c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name))) return c.File(file) } diff --git a/context_test.go b/context_test.go index 85b221446..01a8784b8 100644 --- a/context_test.go +++ b/context_test.go @@ -414,30 +414,72 @@ func TestContextStream(t *testing.T) { } func TestContextAttachment(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - err := c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(t, 219885, rec.Body.Len()) + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `attachment; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Attachment("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) } } func TestContextInline(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - err := c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(t, 219885, rec.Body.Len()) + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `inline; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Inline("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) } }