Skip to content

Commit

Permalink
Fix scene query filtered by fingerprint (stashapp#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
InfiniteStash authored and feederbox826 committed Nov 15, 2023
1 parent 682ef7d commit f9b6028
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pkg/sqlx/querybuilder_edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ func (qb *editQueryBuilder) buildQuery(filter models.EditQueryInput, userID uuid
switch *filter.Voted {
case models.UserVotedFilterEnumNotVoted:
where := fmt.Sprintf("%s.user_id = ?", editVoteTable.name)
query.AddJoinTableFilter(editVoteTable, where, nil, true, userID)
query.AddJoinTableFilter(editVoteTable, where, false, nil, true, userID)
default:
where := fmt.Sprintf("%[1]s.user_id = ? AND %[1]s.vote = ?", editVoteTable.Name())
query.AddJoinTableFilter(editVoteTable, where, nil, false, userID, q.String())
query.AddJoinTableFilter(editVoteTable, where, false, nil, false, userID, q.String())
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlx/querybuilder_performer.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func (qb *performerQueryBuilder) buildQuery(filter models.PerformerQueryInput, u

if q := filter.URL; q != nil && *q != "" {
where := fmt.Sprintf("%s.url = ?", performerURLTable.Name())
query.AddJoinTableFilter(performerURLTable, where, nil, false, *q)
query.AddJoinTableFilter(performerURLTable, where, false, nil, false, *q)
}

if q := filter.Name; q != nil && *q != "" {
Expand Down
18 changes: 10 additions & 8 deletions pkg/sqlx/querybuilder_scene.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (qb *sceneQueryBuilder) buildQuery(filter models.SceneQueryInput, userID uu

if q := filter.URL; q != nil && *q != "" {
where := fmt.Sprintf("%s.url = ?", sceneURLTable.Name())
query.AddJoinTableFilter(sceneURLTable, where, nil, false, *q)
query.AddJoinTableFilter(sceneURLTable, where, false, nil, false, *q)
}

if filter.ParentStudio != nil {
Expand All @@ -357,19 +357,19 @@ func (qb *sceneQueryBuilder) buildQuery(filter models.SceneQueryInput, userID uu
}

if q := filter.Performers; q != nil && len(q.Value) > 0 {
if err := setMultiCriterionClause(query, scenePerformerTable, performerJoinKey, q); err != nil {
if err := setMultiCriterionClause(query, scenePerformerTable, performerJoinKey, q, false); err != nil {
return nil, err
}
}

if q := filter.Tags; q != nil && len(q.Value) > 0 {
if err := setMultiCriterionClause(query, sceneTagTable, tagJoinKey, q); err != nil {
if err := setMultiCriterionClause(query, sceneTagTable, tagJoinKey, q, false); err != nil {
return nil, err
}
}

if q := filter.Fingerprints; q != nil && len(q.Value) > 0 {
if err := setMultiCriterionClause(query, sceneFingerprintTable, "hash", q); err != nil {
if err := setMultiCriterionClause(query, sceneFingerprintTable, "hash", q, true); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -504,23 +504,25 @@ func (qb *sceneQueryBuilder) QueryCount(filter models.SceneQueryInput, userID uu
return qb.dbi.CountOnly(*query)
}

func setMultiCriterionClause(query *queryBuilder, joinTable tableJoin, joinTableField string, criterion models.MultiCriterionInput) error {
func setMultiCriterionClause(query *queryBuilder, joinTable tableJoin, joinTableField string, criterion models.MultiCriterionInput, group bool) error {
args := criterion.GetValues()
inClause := fmt.Sprintf("%s.%s IN %s", joinTable.Name(), joinTableField, getInBinding(criterion.Count()))

groupBy := group || len(args) > 1

switch criterion.GetModifier() {
case models.CriterionModifierIncludes:
// includes any of the provided ids
query.AddJoinTableFilter(joinTable, inClause, nil, false, args...)
query.AddJoinTableFilter(joinTable, inClause, groupBy, nil, false, args...)

case models.CriterionModifierIncludesAll:
// includes all of the provided ids
having := fmt.Sprintf("COUNT(*) = %d", criterion.Count())
query.AddJoinTableFilter(joinTable, inClause, &having, false, args...)
query.AddJoinTableFilter(joinTable, inClause, true, &having, false, args...)

case models.CriterionModifierExcludes:
// excludes all of the provided ids
query.AddJoinTableFilter(joinTable, inClause, nil, true, args...)
query.AddJoinTableFilter(joinTable, inClause, groupBy, nil, true, args...)

default:
return fmt.Errorf("unsupported modifier %s for %s.%s", criterion.GetModifier(), joinTable.Name(), joinTableField)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlx/querybuilder_studio.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (qb *studioQueryBuilder) Query(filter models.StudioQueryInput, userID uuid.

if q := filter.URL; q != nil && *q != "" {
where := fmt.Sprintf("%s.url = ?", studioURLTable.Name())
query.AddJoinTableFilter(studioURLTable, where, nil, false, *q)
query.AddJoinTableFilter(studioURLTable, where, false, nil, false, *q)
}

if q := filter.Name; q != nil && *q != "" {
Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlx/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ func (qb *queryBuilder) AddJoin(jt table, on string) {
qb.Body += " JOIN " + jt.Name() + " ON " + on
}

func (qb *queryBuilder) AddJoinTableFilter(tj tableJoin, query string, having *string, not bool, args ...interface{}) {
func (qb *queryBuilder) AddJoinTableFilter(tj tableJoin, query string, group bool, having *string, not bool, args ...interface{}) {
clause := fmt.Sprintf(" JOIN (SELECT %[1]s.%[2]s FROM %[1]s WHERE %[3]s", tj.Name(), tj.joinColumn, query)
if len(args) > 1 {
if group {
clause += fmt.Sprintf(" GROUP BY %s.%s", tj.Name(), tj.joinColumn)

if having != nil {
Expand Down

0 comments on commit f9b6028

Please sign in to comment.