From aade2eedb91fe0b44f916415c53e915c9eda9e1b Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Fri, 1 Dec 2023 20:54:18 +0100 Subject: [PATCH] hls, webrtc: prevent XSS attack when appending slash to paths (#2766) (#2767) (#2772) --- internal/core/hls_http_server.go | 6 +--- internal/core/webrtc_http_server.go | 6 +--- .../httpserv/location_with_trailing_slash.go | 12 +++++++ .../location_with_trailing_slash_test.go | 36 +++++++++++++++++++ 4 files changed, 50 insertions(+), 10 deletions(-) create mode 100644 internal/protocols/httpserv/location_with_trailing_slash.go create mode 100644 internal/protocols/httpserv/location_with_trailing_slash_test.go diff --git a/internal/core/hls_http_server.go b/internal/core/hls_http_server.go index f4d694eb98d..cd8a3b8aea1 100644 --- a/internal/core/hls_http_server.go +++ b/internal/core/hls_http_server.go @@ -146,11 +146,7 @@ func (s *hlsHTTPServer) onRequest(ctx *gin.Context) { dir, fname = pa, "" if !strings.HasSuffix(dir, "/") { - l := ctx.Request.URL.Path[1:] + "/" - if ctx.Request.URL.RawQuery != "" { - l += "?" + ctx.Request.URL.RawQuery - } - ctx.Writer.Header().Set("Location", l) + ctx.Writer.Header().Set("Location", httpserv.LocationWithTrailingSlash(ctx.Request.URL)) ctx.Writer.WriteHeader(http.StatusMovedPermanently) return } diff --git a/internal/core/webrtc_http_server.go b/internal/core/webrtc_http_server.go index 035fc22c913..ea98e9c9e33 100644 --- a/internal/core/webrtc_http_server.go +++ b/internal/core/webrtc_http_server.go @@ -352,11 +352,7 @@ func (s *webRTCHTTPServer) onRequest(ctx *gin.Context) { s.onPage(ctx, ctx.Request.URL.Path[1:len(ctx.Request.URL.Path)-len("/publish")], true) case ctx.Request.URL.Path[len(ctx.Request.URL.Path)-1] != '/': - l := ctx.Request.URL.Path[1:] + "/" - if ctx.Request.URL.RawQuery != "" { - l += "?" + ctx.Request.URL.RawQuery - } - ctx.Writer.Header().Set("Location", l) + ctx.Writer.Header().Set("Location", httpserv.LocationWithTrailingSlash(ctx.Request.URL)) ctx.Writer.WriteHeader(http.StatusMovedPermanently) default: diff --git a/internal/protocols/httpserv/location_with_trailing_slash.go b/internal/protocols/httpserv/location_with_trailing_slash.go new file mode 100644 index 00000000000..6c6db9c4bda --- /dev/null +++ b/internal/protocols/httpserv/location_with_trailing_slash.go @@ -0,0 +1,12 @@ +package httpserv + +import "net/url" + +// LocationWithTrailingSlash returns the URL in a relative format, with a trailing slash. +func LocationWithTrailingSlash(u *url.URL) string { + l := "./" + u.Path[1:] + "/" + if u.RawQuery != "" { + l += "?" + u.RawQuery + } + return l +} diff --git a/internal/protocols/httpserv/location_with_trailing_slash_test.go b/internal/protocols/httpserv/location_with_trailing_slash_test.go new file mode 100644 index 00000000000..f782115a19f --- /dev/null +++ b/internal/protocols/httpserv/location_with_trailing_slash_test.go @@ -0,0 +1,36 @@ +package httpserv + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLocationWithTrailingSlash(t *testing.T) { + for _, ca := range []struct { + name string + url *url.URL + loc string + }{ + { + "with query", + &url.URL{ + Path: "/test", + RawQuery: "key=value", + }, + "./test/?key=value", + }, + { + "xss", + &url.URL{ + Path: "/www.example.com", + }, + "./www.example.com/", + }, + } { + t.Run(ca.name, func(t *testing.T) { + require.Equal(t, ca.loc, LocationWithTrailingSlash(ca.url)) + }) + } +}