Skip to content

Commit

Permalink
Refactored and Included Tests for TopoSortFromBfsNodeMap
Browse files Browse the repository at this point in the history
* Refactored the code for TopoSortFromBfsNodeMap for readability and time complexity purposes.
* Included tests for TopoSortFromBfsNodeMap

Signed-off-by: nathannaveen <42319948+nathannaveen@users.noreply.github.com>
  • Loading branch information
nathannaveen committed Aug 9, 2023
1 parent 7d1960b commit 9407373
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 30 deletions.
67 changes: 37 additions & 30 deletions pkg/guacanalytics/toposort.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,58 +22,65 @@ import (
"github.com/Khan/genqlient/graphql"
)

// TODO: add tests
func ToposortFromBfsNodeMap(ctx context.Context, gqlClient graphql.Client, nodeMap map[string]BfsNode) (map[int][]string, []string, error) {
frontiers := make(map[int][]string)
parentsMap, infoNodes := copyParents(nodeMap)
frontierLevel := 0
// TopoSortFromBfsNodeMap sorts the nodes such that it returns a map of level -> list of nodeIDs at that level
func TopoSortFromBfsNodeMap(ctx context.Context, gqlClient graphql.Client, nodeMap map[string]BfsNode) (map[int][]string, []string, error) {
sortedNodes := make(map[int][]string) // map of level -> list of nodeIDs at that level
parentsMap, childrensMap, infoNodes := copyParents(nodeMap)
// parentsMap: map of nodeID (child) -> list of parents in the form of a map
// childrensMap: map of nodeID (parent) -> list of children in the form of a list
bfsLevel := 0
numNodes := 0
totalNodes := len(parentsMap)

for numNodes < totalNodes {
foundIDs := make(map[string]bool)
for id, parentsList := range parentsMap {
if len(parentsList) == 0 || (parentsList[0] == "" && len(parentsList) == 1) {
frontiers[frontierLevel] = append(frontiers[frontierLevel], id)
foundIDs[id] = true
for id, pMap := range parentsMap {
if pMap.parents != nil && len(pMap.parents) == 0 { // if this node has no parents, it is a root node
sortedNodes[bfsLevel] = append(sortedNodes[bfsLevel], id)
numNodes++
foundIDs[id] = true
delete(parentsMap, id)
}
}

if len(foundIDs) == 0 {
// TODO: print out offending cycle
return frontiers, infoNodes, fmt.Errorf("error: cycle detected")
}

for id := range foundIDs {
delete(parentsMap, id)
}

for id, parentsList := range parentsMap {
newParentsList := []string{}
for _, parentID := range parentsList {
if !foundIDs[parentID] {
newParentsList = append(newParentsList, parentID)
}
for _, childID := range childrensMap[id] { // loop through all the children of this node
delete(parentsMap[childID].parents, id) // remove this node from the map of parents of the child
}
}

parentsMap[id] = newParentsList
if len(foundIDs) == 0 {
return sortedNodes, infoNodes, fmt.Errorf("error: cycle detected")
}
frontierLevel++

bfsLevel++
}

return frontiers, infoNodes, nil
return sortedNodes, infoNodes, nil
}

func copyParents(inputMap map[string]BfsNode) (map[string][]string, []string) {
retMap := map[string][]string{}
func copyParents(inputMap map[string]BfsNode) (map[string]parent, map[string][]string, []string) {
parentsMap := map[string]parent{} // map of nodeID (child) -> map of the childs parents
childrenMap := map[string][]string{} // map of nodeID (parent) -> list of the parents children
var infoNodes []string
for key, value := range inputMap {
if !value.NotInBlastRadius {
retMap[key] = append(retMap[key], value.Parents...)
if _, ok := parentsMap[key]; !ok {
parentsMap[key] = parent{make(map[string]bool)}
}

for _, parent := range value.Parents {
parentsMap[key].parents[parent] = true
childrenMap[parent] = append(childrenMap[parent], key)
}
} else {
infoNodes = append(infoNodes, key)
}
}
return retMap, infoNodes

return parentsMap, childrenMap, infoNodes
}

type parent struct {
parents map[string]bool // Consider the map[string]bool as a set, the value doesn't matter just the key
}
113 changes: 113 additions & 0 deletions pkg/guacanalytics/toposort_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//
// Copyright 2023 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package guacanalytics

import (
"context"
"reflect"
"sort"
"testing"

"github.com/Khan/genqlient/graphql"
)

func TestTopoSortFromBfsNodeMap(t *testing.T) {
type args struct {
nodeMap map[string]BfsNode
}
tests := []struct {
name string
args args
want map[int][]string
want1 []string
wantErr bool
}{
{
name: "default",
args: args{
/*
1
/ \
2 3
/ \ \
4 5 6
*/
nodeMap: map[string]BfsNode{
"1": {Parents: []string{}},
"2": {Parents: []string{"1"}},
"3": {Parents: []string{"1"}},
"4": {Parents: []string{"2"}},
"5": {Parents: []string{"2"}},
"6": {Parents: []string{"3"}},
},
},
want: map[int][]string{
0: {"1"},
1: {"2", "3"},
2: {"4", "5", "6"},
},
want1: *new([]string),
},
{
name: "cycle error",
args: args{
/*
1
/ \
2 - 3
*/
nodeMap: map[string]BfsNode{
"1": {Parents: []string{"3"}},
"2": {Parents: []string{"1"}},
"3": {Parents: []string{"2"}},
},
},
wantErr: true,
},
{
name: "infoNodes not empty",
args: args{
nodeMap: map[string]BfsNode{
"1": {NotInBlastRadius: true},
},
},
want: map[int][]string{},
want1: []string{"1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
gqlClient := graphql.NewClient("test", nil)
got, got1, err := TopoSortFromBfsNodeMap(ctx, gqlClient, tt.args.nodeMap)
if (err != nil) != tt.wantErr {
t.Errorf("TopoSortFromBfsNodeMap() error = %v, wantErr %v", err, tt.wantErr)
return
}
for k := range got {
sort.Strings(got[k])
sort.Strings(tt.want[k])
if !reflect.DeepEqual(got[k], tt.want[k]) {
t.Errorf("TopoSortFromBfsNodeMap() got = %v, want %v", got, tt.want)
}
}
if !reflect.DeepEqual(got1, tt.want1) {
t.Errorf("TopoSortFromBfsNodeMap() got1 = %v, want %v", got1, tt.want1)
}

})
}
}

0 comments on commit 9407373

Please sign in to comment.