From 33d5459c202fb4aad221f0c315df412f90596c09 Mon Sep 17 00:00:00 2001 From: Deluan Date: Mon, 13 Jul 2020 18:37:48 -0400 Subject: [PATCH] Escape paths in "ByPath" queries --- persistence/mediafile_repository.go | 15 ++++++++++++--- persistence/mediafile_repository_test.go | 17 ++++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/persistence/mediafile_repository.go b/persistence/mediafile_repository.go index 99973b6b96f..8b46fb29d0a 100644 --- a/persistence/mediafile_repository.go +++ b/persistence/mediafile_repository.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" @@ -84,7 +85,7 @@ func (r mediaFileRepository) FindByAlbum(albumId string) (model.MediaFiles, erro func (r mediaFileRepository) FindByPath(path string) (model.MediaFiles, error) { // Query by path based on https://stackoverflow.com/a/13911906/653632 sel0 := r.selectMediaFile().Columns(fmt.Sprintf("substr(path, %d) AS item", len(path)+2)). - Where(Like{"path": filepath.Join(path, "%")}) + Where(pathStartsWith(path)) sel := r.newSelect().Columns("*", "item NOT GLOB '*"+string(os.PathSeparator)+"*' AS isLast"). Where(Eq{"isLast": 1}).FromSelect(sel0, "sel0") @@ -93,11 +94,19 @@ func (r mediaFileRepository) FindByPath(path string) (model.MediaFiles, error) { return res, err } +func pathStartsWith(path string) Sqlizer { + escapeChar := string(os.PathListSeparator) + escapedPath := strings.ReplaceAll(path, escapeChar, escapeChar+escapeChar) + escapedPath = strings.ReplaceAll(escapedPath, "_", escapeChar+"_") + escapedPath = strings.ReplaceAll(escapedPath, "%", escapeChar+"%") + return ConcatExpr(Like{"path": filepath.Join(escapedPath, "%")}, " escape '"+escapeChar+"'") +} + // FindPathsRecursively returns a list of all subfolders of basePath, recursively func (r mediaFileRepository) FindPathsRecursively(basePath string) ([]string, error) { // Query based on https://stackoverflow.com/a/38330814/653632 sel := r.newSelect().Columns(fmt.Sprintf("distinct rtrim(path, replace(path, '%s', ''))", string(os.PathSeparator))). - Where(Like{"path": filepath.Join(basePath, "%")}) + Where(pathStartsWith(basePath)) var res []string err := r.queryAll(sel, &res) return res, err @@ -127,7 +136,7 @@ func (r mediaFileRepository) Delete(id string) error { func (r mediaFileRepository) DeleteByPath(path string) (int64, error) { path = filepath.Clean(path) del := Delete(r.tableName). - Where(And{Like{"path": filepath.Join(path, "%")}, + Where(And{pathStartsWith(path), Eq{fmt.Sprintf("substr(path, %d) glob '*%s*'", len(path)+2, string(os.PathSeparator)): 0}}) log.Debug(r.ctx, "Deleting mediafiles by path", "path", path) return r.executeSQL(del) diff --git a/persistence/mediafile_repository_test.go b/persistence/mediafile_repository_test.go index fd3af59cec9..4a492dc18b0 100644 --- a/persistence/mediafile_repository_test.go +++ b/persistence/mediafile_repository_test.go @@ -52,9 +52,13 @@ var _ = Describe("MediaRepository", func() { }) It("finds tracks by path", func() { - Expect(mr.FindByPath(P("/beatles/1/sgt"))).To(Equal(model.MediaFiles{ - songDayInALife, - })) + Expect(mr.Put(&model.MediaFile{ID: "7001", Path: P("/Find:By'Path/_/123.mp3")})).To(BeNil()) + Expect(mr.Put(&model.MediaFile{ID: "7002", Path: P("/Find:By'Path/1/123.mp3")})).To(BeNil()) + + found, err := mr.FindByPath(P("/Find:By'Path/_/")) + Expect(err).To(BeNil()) + Expect(found).To(HaveLen(1)) + Expect(found[0].ID).To(Equal("7001")) }) It("returns starred tracks", func() { @@ -80,12 +84,15 @@ var _ = Describe("MediaRepository", func() { id2 := "2222" Expect(mr.Put(&model.MediaFile{ID: id2, Path: P("/abc/123/" + id2 + ".mp3")})).To(BeNil()) id3 := "3333" - Expect(mr.Put(&model.MediaFile{ID: id3, Path: P("/abc/" + id3 + ".mp3")})).To(BeNil()) + Expect(mr.Put(&model.MediaFile{ID: id3, Path: P("/ab_/" + id3 + ".mp3")})).To(BeNil()) + id4 := "4444" + Expect(mr.Put(&model.MediaFile{ID: id4, Path: P("/abc/" + id4 + ".mp3")})).To(BeNil()) - Expect(mr.DeleteByPath(P("/abc"))).To(Equal(int64(1))) + Expect(mr.DeleteByPath(P("/ab_"))).To(Equal(int64(1))) Expect(mr.Get(id1)).ToNot(BeNil()) Expect(mr.Get(id2)).ToNot(BeNil()) + Expect(mr.Get(id4)).ToNot(BeNil()) _, err := mr.Get(id3) Expect(err).To(MatchError(model.ErrNotFound)) })