diff --git a/internal/protocols/httpp/location_with_trailing_slash.go b/internal/protocols/httpp/location_with_trailing_slash.go deleted file mode 100644 index 108531cda3f..00000000000 --- a/internal/protocols/httpp/location_with_trailing_slash.go +++ /dev/null @@ -1,22 +0,0 @@ -package httpp - -import "net/url" - -// LocationWithTrailingSlash returns the URL in a relative format, with a trailing slash. -func LocationWithTrailingSlash(u *url.URL) string { - l := "./" - - for i := 1; i < len(u.Path); i++ { - if u.Path[i] == '/' { - l += "../" - } - } - - l += u.Path[1:] + "/" - - if u.RawQuery != "" { - l += "?" + u.RawQuery - } - - return l -} diff --git a/internal/protocols/httpp/location_with_trailing_slash_test.go b/internal/protocols/httpp/location_with_trailing_slash_test.go deleted file mode 100644 index 9622ed9e030..00000000000 --- a/internal/protocols/httpp/location_with_trailing_slash_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package httpp - -import ( - "net/url" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestLocationWithTrailingSlash(t *testing.T) { - for _, ca := range []struct { - name string - url *url.URL - loc string - }{ - { - "with query", - &url.URL{ - Path: "/test", - RawQuery: "key=value", - }, - "./test/?key=value", - }, - { - "xss", - &url.URL{ - Path: "/www.example.com", - }, - "./www.example.com/", - }, - { - "slashes in path", - &url.URL{ - Path: "/my/path", - }, - "./../my/path/", - }, - } { - t.Run(ca.name, func(t *testing.T) { - require.Equal(t, ca.loc, LocationWithTrailingSlash(ca.url)) - }) - } -} diff --git a/internal/protocols/webrtc/whip_client.go b/internal/protocols/webrtc/whip_client.go index 933612c1abc..0a6299648bc 100644 --- a/internal/protocols/webrtc/whip_client.go +++ b/internal/protocols/webrtc/whip_client.go @@ -31,7 +31,7 @@ func (c *WHIPClient) Publish( videoTrack format.Format, audioTrack format.Format, ) ([]*OutgoingTrack, error) { - iceServers, err := c.optionsICEServers(ctx, c.URL.String()) + iceServers, err := c.optionsICEServers(ctx) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (c *WHIPClient) Publish( return nil, err } - res, err := c.postOffer(ctx, c.URL.String(), offer) + res, err := c.postOffer(ctx, offer) if err != nil { c.pc.Close() return nil, err @@ -81,7 +81,7 @@ func (c *WHIPClient) Publish( err = c.pc.SetAnswer(res.Answer) if err != nil { - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, err } @@ -93,9 +93,9 @@ outer: for { select { case ca := <-c.pc.NewLocalCandidate(): - err := c.patchCandidate(ctx, c.URL.String(), offer, res.ETag, ca) + err := c.patchCandidate(ctx, offer, res.ETag, ca) if err != nil { - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, err } @@ -106,7 +106,7 @@ outer: break outer case <-t.C: - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, fmt.Errorf("deadline exceeded while waiting connection") } @@ -117,7 +117,7 @@ outer: // Read reads tracks. func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) { - iceServers, err := c.optionsICEServers(ctx, c.URL.String()) + iceServers, err := c.optionsICEServers(ctx) if err != nil { return nil, err } @@ -147,7 +147,7 @@ func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) { return nil, err } - res, err := c.postOffer(ctx, c.URL.String(), offer) + res, err := c.postOffer(ctx, offer) if err != nil { c.pc.Close() return nil, err @@ -162,7 +162,7 @@ func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) { var sdp sdp.SessionDescription err = sdp.Unmarshal([]byte(res.Answer.SDP)) if err != nil { - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, err } @@ -170,14 +170,14 @@ func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) { // check that there are at most two tracks _, err = TrackCount(sdp.MediaDescriptions) if err != nil { - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, err } err = c.pc.SetAnswer(res.Answer) if err != nil { - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, err } @@ -189,9 +189,9 @@ outer: for { select { case ca := <-c.pc.NewLocalCandidate(): - err := c.patchCandidate(ctx, c.URL.String(), offer, res.ETag, ca) + err := c.patchCandidate(ctx, offer, res.ETag, ca) if err != nil { - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, err } @@ -202,7 +202,7 @@ outer: break outer case <-t.C: - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, fmt.Errorf("deadline exceeded while waiting connection") } @@ -210,7 +210,7 @@ outer: tracks, err := c.pc.GatherIncomingTracks(ctx, 0) if err != nil { - c.deleteSession(context.Background(), c.URL.String()) //nolint:errcheck + c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() return nil, err } @@ -220,7 +220,7 @@ outer: // Close closes the client. func (c *WHIPClient) Close() error { - err := c.deleteSession(context.Background(), c.URL.String()) + err := c.deleteSession(context.Background()) c.pc.Close() return err } @@ -238,9 +238,8 @@ func (c *WHIPClient) Wait(ctx context.Context) error { func (c *WHIPClient) optionsICEServers( ctx context.Context, - ur string, ) ([]webrtc.ICEServer, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodOptions, ur, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodOptions, c.URL.String(), nil) if err != nil { return nil, err } @@ -266,10 +265,9 @@ type whipPostOfferResponse struct { func (c *WHIPClient) postOffer( ctx context.Context, - ur string, offer *webrtc.SessionDescription, ) (*whipPostOfferResponse, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, ur, bytes.NewReader([]byte(offer.SDP))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.URL.String(), bytes.NewReader([]byte(offer.SDP))) if err != nil { return nil, err } @@ -322,7 +320,6 @@ func (c *WHIPClient) postOffer( func (c *WHIPClient) patchCandidate( ctx context.Context, - ur string, offer *webrtc.SessionDescription, etag string, candidate *webrtc.ICECandidateInit, @@ -332,7 +329,7 @@ func (c *WHIPClient) patchCandidate( return err } - req, err := http.NewRequestWithContext(ctx, http.MethodPatch, ur, bytes.NewReader(frag)) + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, c.URL.String(), bytes.NewReader(frag)) if err != nil { return err } @@ -355,9 +352,8 @@ func (c *WHIPClient) patchCandidate( func (c *WHIPClient) deleteSession( ctx context.Context, - ur string, ) error { - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, ur, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.URL.String(), nil) if err != nil { return err } diff --git a/internal/servers/hls/http_server.go b/internal/servers/hls/http_server.go index a297373b7e5..bf2885c7f90 100644 --- a/internal/servers/hls/http_server.go +++ b/internal/servers/hls/http_server.go @@ -29,6 +29,14 @@ var hlsIndex []byte //go:embed hls.min.js var hlsMinJS []byte +func mergePathAndQuery(path string, rawQuery string) string { + res := path + if rawQuery != "" { + res += "?" + rawQuery + } + return res +} + type httpServer struct { address string encryption bool @@ -134,7 +142,7 @@ func (s *httpServer) onRequest(ctx *gin.Context) { dir, fname = pa, "" if !strings.HasSuffix(dir, "/") { - ctx.Writer.Header().Set("Location", httpp.LocationWithTrailingSlash(ctx.Request.URL)) + ctx.Writer.Header().Set("Location", mergePathAndQuery(ctx.Request.URL.Path+"/", ctx.Request.URL.RawQuery)) ctx.Writer.WriteHeader(http.StatusMovedPermanently) return } diff --git a/internal/servers/webrtc/http_server.go b/internal/servers/webrtc/http_server.go index 79d2aafd6d5..7ed96f1ba91 100644 --- a/internal/servers/webrtc/http_server.go +++ b/internal/servers/webrtc/http_server.go @@ -34,14 +34,22 @@ var ( reWHIPWHEPWithID = regexp.MustCompile("^/(.+?)/(whip|whep)/(.+?)$") ) +func mergePathAndQuery(path string, rawQuery string) string { + res := path + if rawQuery != "" { + res += "?" + rawQuery + } + return res +} + func writeError(ctx *gin.Context, statusCode int, err error) { ctx.JSON(statusCode, &defs.APIError{ Error: err.Error(), }) } -func sessionLocation(publish bool, secret uuid.UUID) string { - ret := "" +func sessionLocation(publish bool, path string, secret uuid.UUID) string { + ret := "/" + path + "/" if publish { ret += "whip" } else { @@ -107,12 +115,12 @@ func (s *httpServer) close() { s.inner.Close() } -func (s *httpServer) checkAuthOutsideSession(ctx *gin.Context, path string, publish bool) bool { +func (s *httpServer) checkAuthOutsideSession(ctx *gin.Context, pathName string, publish bool) bool { user, pass, hasCredentials := ctx.Request.BasicAuth() _, err := s.pathManager.FindPathConf(defs.PathFindPathConfReq{ AccessRequest: defs.PathAccessRequest{ - Name: path, + Name: pathName, Query: ctx.Request.URL.RawQuery, Publish: publish, IP: net.ParseIP(ctx.ClientIP()), @@ -146,8 +154,8 @@ func (s *httpServer) checkAuthOutsideSession(ctx *gin.Context, path string, publ return true } -func (s *httpServer) onWHIPOptions(ctx *gin.Context, path string, publish bool) { - if !s.checkAuthOutsideSession(ctx, path, publish) { +func (s *httpServer) onWHIPOptions(ctx *gin.Context, pathName string, publish bool) { + if !s.checkAuthOutsideSession(ctx, pathName, publish) { return } @@ -164,7 +172,7 @@ func (s *httpServer) onWHIPOptions(ctx *gin.Context, path string, publish bool) ctx.Writer.WriteHeader(http.StatusNoContent) } -func (s *httpServer) onWHIPPost(ctx *gin.Context, path string, publish bool) { +func (s *httpServer) onWHIPPost(ctx *gin.Context, pathName string, publish bool) { if ctx.Request.Header.Get("Content-Type") != "application/sdp" { writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid Content-Type")) return @@ -178,7 +186,7 @@ func (s *httpServer) onWHIPPost(ctx *gin.Context, path string, publish bool) { user, pass, _ := ctx.Request.BasicAuth() res := s.parent.newSession(webRTCNewSessionReq{ - pathName: path, + pathName: pathName, remoteAddr: httpp.RemoteAddr(ctx), query: ctx.Request.URL.RawQuery, user: user, @@ -203,12 +211,12 @@ func (s *httpServer) onWHIPPost(ctx *gin.Context, path string, publish bool) { ctx.Writer.Header().Set("ID", res.sx.uuid.String()) ctx.Writer.Header().Set("Accept-Patch", "application/trickle-ice-sdpfrag") ctx.Writer.Header()["Link"] = webrtc.LinkHeaderMarshal(servers) - ctx.Writer.Header().Set("Location", sessionLocation(publish, res.sx.secret)) + ctx.Writer.Header().Set("Location", sessionLocation(publish, pathName, res.sx.secret)) ctx.Writer.WriteHeader(http.StatusCreated) ctx.Writer.Write(res.answer) } -func (s *httpServer) onWHIPPatch(ctx *gin.Context, rawSecret string) { +func (s *httpServer) onWHIPPatch(ctx *gin.Context, pathName string, rawSecret string) { secret, err := uuid.Parse(rawSecret) if err != nil { writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid secret")) @@ -232,6 +240,7 @@ func (s *httpServer) onWHIPPatch(ctx *gin.Context, rawSecret string) { } res := s.parent.addSessionCandidates(webRTCAddSessionCandidatesReq{ + pathName: pathName, secret: secret, candidates: candidates, }) @@ -247,7 +256,7 @@ func (s *httpServer) onWHIPPatch(ctx *gin.Context, rawSecret string) { ctx.Writer.WriteHeader(http.StatusNoContent) } -func (s *httpServer) onWHIPDelete(ctx *gin.Context, rawSecret string) { +func (s *httpServer) onWHIPDelete(ctx *gin.Context, pathName string, rawSecret string) { secret, err := uuid.Parse(rawSecret) if err != nil { writeError(ctx, http.StatusBadRequest, fmt.Errorf("invalid secret")) @@ -255,7 +264,8 @@ func (s *httpServer) onWHIPDelete(ctx *gin.Context, rawSecret string) { } err = s.parent.deleteSession(webRTCDeleteSessionReq{ - secret: secret, + pathName: pathName, + secret: secret, }) if err != nil { if errors.Is(err, ErrSessionNotFound) { @@ -269,8 +279,8 @@ func (s *httpServer) onWHIPDelete(ctx *gin.Context, rawSecret string) { ctx.Writer.WriteHeader(http.StatusOK) } -func (s *httpServer) onPage(ctx *gin.Context, path string, publish bool) { - if !s.checkAuthOutsideSession(ctx, path, publish) { +func (s *httpServer) onPage(ctx *gin.Context, pathName string, publish bool) { + if !s.checkAuthOutsideSession(ctx, pathName, publish) { return } @@ -320,10 +330,10 @@ func (s *httpServer) onRequest(ctx *gin.Context) { if m := reWHIPWHEPWithID.FindStringSubmatch(ctx.Request.URL.Path); m != nil { switch ctx.Request.Method { case http.MethodPatch: - s.onWHIPPatch(ctx, m[3]) + s.onWHIPPatch(ctx, m[1], m[3]) case http.MethodDelete: - s.onWHIPDelete(ctx, m[3]) + s.onWHIPDelete(ctx, m[1], m[3]) } return } @@ -339,7 +349,7 @@ func (s *httpServer) onRequest(ctx *gin.Context) { s.onPage(ctx, ctx.Request.URL.Path[1:len(ctx.Request.URL.Path)-len("/publish")], true) case ctx.Request.URL.Path[len(ctx.Request.URL.Path)-1] != '/': - ctx.Writer.Header().Set("Location", httpp.LocationWithTrailingSlash(ctx.Request.URL)) + ctx.Writer.Header().Set("Location", mergePathAndQuery(ctx.Request.URL.Path+"/", ctx.Request.URL.RawQuery)) ctx.Writer.WriteHeader(http.StatusMovedPermanently) default: diff --git a/internal/servers/webrtc/server.go b/internal/servers/webrtc/server.go index 43a05763370..5db2f6edb6c 100644 --- a/internal/servers/webrtc/server.go +++ b/internal/servers/webrtc/server.go @@ -150,6 +150,7 @@ type webRTCAddSessionCandidatesRes struct { } type webRTCAddSessionCandidatesReq struct { + pathName string secret uuid.UUID candidates []*pwebrtc.ICECandidateInit res chan webRTCAddSessionCandidatesRes @@ -160,8 +161,9 @@ type webRTCDeleteSessionRes struct { } type webRTCDeleteSessionReq struct { - secret uuid.UUID - res chan webRTCDeleteSessionRes + pathName string + secret uuid.UUID + res chan webRTCDeleteSessionRes } type serverPathManager interface { @@ -343,7 +345,7 @@ outer: case req := <-s.chAddSessionCandidates: sx, ok := s.sessionsBySecret[req.secret] - if !ok { + if !ok || sx.req.pathName != req.pathName { req.res <- webRTCAddSessionCandidatesRes{err: ErrSessionNotFound} continue } @@ -352,7 +354,7 @@ outer: case req := <-s.chDeleteSession: sx, ok := s.sessionsBySecret[req.secret] - if !ok { + if !ok || sx.req.pathName != req.pathName { req.res <- webRTCDeleteSessionRes{err: ErrSessionNotFound} continue }