diff --git a/pkg/guacanalytics/toposort.go b/pkg/guacanalytics/toposort.go index 245a9f5625d..36a783d0dc4 100644 --- a/pkg/guacanalytics/toposort.go +++ b/pkg/guacanalytics/toposort.go @@ -22,58 +22,65 @@ import ( "github.com/Khan/genqlient/graphql" ) -// TODO: add tests -func ToposortFromBfsNodeMap(ctx context.Context, gqlClient graphql.Client, nodeMap map[string]BfsNode) (map[int][]string, []string, error) { - frontiers := make(map[int][]string) - parentsMap, infoNodes := copyParents(nodeMap) - frontierLevel := 0 +// TopoSortFromBfsNodeMap sorts the nodes such that it returns a map of level -> list of nodeIDs at that level +func TopoSortFromBfsNodeMap(ctx context.Context, gqlClient graphql.Client, nodeMap map[string]BfsNode) (map[int][]string, []string, error) { + sortedNodes := make(map[int][]string) // map of level -> list of nodeIDs at that level + parentsMap, childrensMap, infoNodes := copyParents(nodeMap) + // parentsMap: map of nodeID (child) -> list of parents in the form of a map + // childrensMap: map of nodeID (parent) -> list of children in the form of a list + bfsLevel := 0 numNodes := 0 totalNodes := len(parentsMap) for numNodes < totalNodes { foundIDs := make(map[string]bool) - for id, parentsList := range parentsMap { - if len(parentsList) == 0 || (parentsList[0] == "" && len(parentsList) == 1) { - frontiers[frontierLevel] = append(frontiers[frontierLevel], id) - foundIDs[id] = true + for id, pMap := range parentsMap { + if pMap.parents != nil && len(pMap.parents) == 0 { // if this node has no parents, it is a root node + sortedNodes[bfsLevel] = append(sortedNodes[bfsLevel], id) numNodes++ + foundIDs[id] = true + delete(parentsMap, id) } } - if len(foundIDs) == 0 { - // TODO: print out offending cycle - return frontiers, infoNodes, fmt.Errorf("error: cycle detected") - } - for id := range foundIDs { - delete(parentsMap, id) - } - - for id, parentsList := range parentsMap { - newParentsList := []string{} - for _, parentID := range parentsList { - if !foundIDs[parentID] { - newParentsList = append(newParentsList, parentID) - } + for _, childID := range childrensMap[id] { // loop through all the children of this node + delete(parentsMap[childID].parents, id) // remove this node from the map of parents of the child } + } - parentsMap[id] = newParentsList + if len(foundIDs) == 0 { + return sortedNodes, infoNodes, fmt.Errorf("error: cycle detected") } - frontierLevel++ + + bfsLevel++ } - return frontiers, infoNodes, nil + return sortedNodes, infoNodes, nil } -func copyParents(inputMap map[string]BfsNode) (map[string][]string, []string) { - retMap := map[string][]string{} +func copyParents(inputMap map[string]BfsNode) (map[string]parent, map[string][]string, []string) { + parentsMap := map[string]parent{} // map of nodeID (child) -> map of the childs parents + childrenMap := map[string][]string{} // map of nodeID (parent) -> list of the parents children var infoNodes []string for key, value := range inputMap { if !value.NotInBlastRadius { - retMap[key] = append(retMap[key], value.Parents...) + if _, ok := parentsMap[key]; !ok { + parentsMap[key] = parent{make(map[string]bool)} + } + + for _, parent := range value.Parents { + parentsMap[key].parents[parent] = true + childrenMap[parent] = append(childrenMap[parent], key) + } } else { infoNodes = append(infoNodes, key) } } - return retMap, infoNodes + + return parentsMap, childrenMap, infoNodes +} + +type parent struct { + parents map[string]bool // Consider the map[string]bool as a set, the value doesn't matter just the key } diff --git a/pkg/guacanalytics/toposort_test.go b/pkg/guacanalytics/toposort_test.go new file mode 100644 index 00000000000..2be99ee4ca8 --- /dev/null +++ b/pkg/guacanalytics/toposort_test.go @@ -0,0 +1,113 @@ +// +// Copyright 2023 The GUAC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package guacanalytics + +import ( + "context" + "reflect" + "sort" + "testing" + + "github.com/Khan/genqlient/graphql" +) + +func TestTopoSortFromBfsNodeMap(t *testing.T) { + type args struct { + nodeMap map[string]BfsNode + } + tests := []struct { + name string + args args + want map[int][]string + want1 []string + wantErr bool + }{ + { + name: "default", + args: args{ + /* + 1 + / \ + 2 3 + / \ \ + 4 5 6 + */ + nodeMap: map[string]BfsNode{ + "1": {Parents: []string{}}, + "2": {Parents: []string{"1"}}, + "3": {Parents: []string{"1"}}, + "4": {Parents: []string{"2"}}, + "5": {Parents: []string{"2"}}, + "6": {Parents: []string{"3"}}, + }, + }, + want: map[int][]string{ + 0: {"1"}, + 1: {"2", "3"}, + 2: {"4", "5", "6"}, + }, + want1: *new([]string), + }, + { + name: "cycle error", + args: args{ + /* + 1 + / \ + 2 - 3 + */ + nodeMap: map[string]BfsNode{ + "1": {Parents: []string{"3"}}, + "2": {Parents: []string{"1"}}, + "3": {Parents: []string{"2"}}, + }, + }, + wantErr: true, + }, + { + name: "infoNodes not empty", + args: args{ + nodeMap: map[string]BfsNode{ + "1": {NotInBlastRadius: true}, + }, + }, + want: map[int][]string{}, + want1: []string{"1"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + gqlClient := graphql.NewClient("test", nil) + got, got1, err := TopoSortFromBfsNodeMap(ctx, gqlClient, tt.args.nodeMap) + if (err != nil) != tt.wantErr { + t.Errorf("TopoSortFromBfsNodeMap() error = %v, wantErr %v", err, tt.wantErr) + return + } + for k := range got { + sort.Strings(got[k]) + sort.Strings(tt.want[k]) + if !reflect.DeepEqual(got[k], tt.want[k]) { + t.Errorf("TopoSortFromBfsNodeMap() got = %v, want %v", got, tt.want) + } + } + if !reflect.DeepEqual(got1, tt.want1) { + t.Errorf("TopoSortFromBfsNodeMap() got1 = %v, want %v", got1, tt.want1) + } + + }) + } +}