Skip to content

Commit 0b01441

Browse files
authored
feat(Query): Add random keyword in DQL (#7693)
This change implements the `random` function argument. If an argument `random: k` is provided to a dql function then random k results will be returned. For example: `func(has(name), random:5)` will choose 5 random nodes having the name predicate.
1 parent 8b1cfdc commit 0b01441

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

gql/parser.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,7 +2595,7 @@ func parseLanguageList(it *lex.ItemIterator) ([]string, error) {
25952595

25962596
func validKeyAtRoot(k string) bool {
25972597
switch k {
2598-
case "func", "orderasc", "orderdesc", "first", "offset", "after":
2598+
case "func", "orderasc", "orderdesc", "first", "offset", "after", "random":
25992599
return true
26002600
case "from", "to", "numpaths", "minweight", "maxweight":
26012601
// Specific to shortest path
@@ -2609,7 +2609,7 @@ func validKeyAtRoot(k string) bool {
26092609
// Check for validity of key at non-root nodes.
26102610
func validKey(k string) bool {
26112611
switch k {
2612-
case "orderasc", "orderdesc", "first", "offset", "after":
2612+
case "orderasc", "orderdesc", "first", "offset", "after", "random":
26132613
return true
26142614
}
26152615
return false

query/query.go

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"fmt"
2222
"math"
23+
"math/rand"
2324
"sort"
2425
"strconv"
2526
"strings"
@@ -114,6 +115,8 @@ type params struct {
114115
Count int
115116
// Offset is the value of the "offset" parameter.
116117
Offset int
118+
// Random is the value of the "random" parameter
119+
Random int
117120
// AfterUID is the value of the "after" parameter.
118121
AfterUID uint64
119122
// DoCount is true if the count of the predicate is requested instead of its value.
@@ -745,6 +748,15 @@ func (args *params) fill(gq *gql.GraphQuery) error {
745748
}
746749
args.Count = int(first)
747750
}
751+
752+
if v, ok := gq.Args["random"]; ok {
753+
random, err := strconv.ParseInt(v, 0, 32)
754+
if err != nil {
755+
return err
756+
}
757+
args.Random = int(random)
758+
}
759+
748760
return nil
749761
}
750762

@@ -2307,6 +2319,13 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) {
23072319
}
23082320
}
23092321

2322+
if sg.Params.Random > 0 {
2323+
if err = sg.applyRandom(ctx); err != nil {
2324+
rch <- err
2325+
return
2326+
}
2327+
}
2328+
23102329
// Here we consider handling count with filtering. We do this after
23112330
// pagination because otherwise, we need to do the count with pagination
23122331
// taken into account. For example, a PL might have only 50 entries but the
@@ -2404,6 +2423,43 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) {
24042423
rch <- childErr
24052424
}
24062425

2426+
// stores index of a uid as the index in the uidMatrix (x)
2427+
// and index in the corresponding list of the uidMatrix (y)
2428+
type UidKey struct {
2429+
x int
2430+
y int
2431+
}
2432+
2433+
// applies "random" to lists inside uidMatrix
2434+
// sg.Params.Random number of nodes are selected in each uid list
2435+
// duplicates are avoided (random selection without replacement)
2436+
// if sg.Params.Random is more than the number of available nodes
2437+
// all nodes are returned
2438+
func (sg *SubGraph) applyRandom(ctx context.Context) error {
2439+
sg.updateUidMatrix()
2440+
2441+
for i := 0; i < len(sg.uidMatrix); i++ {
2442+
// shuffle the uid list and select the
2443+
// first sg.Params.Random uids
2444+
2445+
uidList := sg.uidMatrix[i].Uids
2446+
2447+
rand.Shuffle(len(uidList), func(i, j int) {
2448+
uidList[i], uidList[j] = uidList[j], uidList[i]
2449+
})
2450+
2451+
numRandom := sg.Params.Random
2452+
if sg.Params.Random > len(uidList) {
2453+
numRandom = len(uidList)
2454+
}
2455+
2456+
sg.uidMatrix[i].Uids = uidList[:numRandom]
2457+
}
2458+
2459+
sg.DestMap = codec.Merge(sg.uidMatrix)
2460+
return nil
2461+
}
2462+
24072463
// applyPagination applies count and offset to lists inside uidMatrix.
24082464
func (sg *SubGraph) applyPagination(ctx context.Context) error {
24092465
if sg.Params.Count == 0 && sg.Params.Offset == 0 { // No pagination.
@@ -2647,7 +2703,7 @@ func (sg *SubGraph) sortAndPaginateUsingVar(ctx context.Context) error {
26472703
func isValidArg(a string) bool {
26482704
switch a {
26492705
case "numpaths", "from", "to", "orderasc", "orderdesc", "first", "offset", "after", "depth",
2650-
"minweight", "maxweight":
2706+
"minweight", "maxweight", "random":
26512707
return true
26522708
}
26532709
return false

0 commit comments

Comments
 (0)