Skip to content

Commit

Permalink
add topology sort and unit tests (#2)
Browse files Browse the repository at this point in the history
Co-authored-by: Hamed Yousefi <hamed@betnomi.com>
  • Loading branch information
hmdsefi and hamednomi committed Feb 19, 2023
1 parent ba1f7ba commit 843e919
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 50 deletions.
95 changes: 47 additions & 48 deletions graph/directed_acyclic_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package graph
import "errors"

var (
ErrDAGVertexNotFound = errors.New("vertex not found")
ErrDAGHasCycle = errors.New("edge would create cycle")
ErrDAGCycle = errors.New("edge would create cycle")
ErrDAGHasCycle = errors.New("the graph contains a cycle")
)

type DAGVertex struct {
Expand All @@ -25,9 +25,11 @@ func NewDAG() *DAG {
}

// AddVertexWithID adds a new vertex with the given id to the graph.
func (d *DAG) AddVertexWithID(id int64) {
func (d *DAG) AddVertexWithID(id int64) *DAGVertex {
v := &DAGVertex{ID: id}
d.Vertices = append(d.Vertices, v)

return v
}

func (d *DAG) AddVertex(v *DAGVertex) {
Expand All @@ -38,77 +40,74 @@ func (d *DAG) AddVertex(v *DAGVertex) {
// the vertex with the 'to' id, after checking if the edge would create
// a cycle.
//
// AddEdge guarantees that the graph remain DAG after adding new edge.
//
// It returns error if it finds a cycle between 'from' and 'to'.
func (d *DAG) AddEdge(from, to *DAGVertex) error {
// Add the new edge
from.Neighbors = append(from.Neighbors, to)
to.inDegree++

// Perform a topological sort to check for cycles
var sortedVertices []*DAGVertex
queue := make([]*DAGVertex, 0)
// If topological sort returns an error, new edge created a cycle
_, err := d.TopologySort()
if err != nil {
// Remove the new edge
from.Neighbors = from.Neighbors[:len(from.Neighbors)-1]
to.inDegree--

return ErrDAGCycle
}

// Add all vertices with inDegree 0 to the queue
return nil
}

// TopologySort performs a topological sort of the graph using
// Kahn's algorithm. If the sorted list of vertices does not contain
// all vertices in the graph, it means there is a cycle in the graph.
//
// It returns error if it finds a cycle in the graph.
func (d *DAG) TopologySort() ([]*DAGVertex, error) {
// Initialize a map to store the inDegree of each vertex
inDegrees := make(map[*DAGVertex]int)
for _, v := range d.Vertices {
if v.inDegree == 0 {
inDegrees[v] = v.inDegree
}

// Initialize a queue with vertices of inDegrees zero
queue := make([]*DAGVertex, 0)
for v, inDegree := range inDegrees {
if inDegree == 0 {
queue = append(queue, v)
}
}

// Traverse the graph using Kahn's algorithm
// Initialize the sorted list of vertices
sortedVertices := make([]*DAGVertex, 0)

// Loop through the vertices with inDegree zero
for len(queue) > 0 {
// Dequeue a vertex
v := queue[0]
// Get the next vertex with inDegree zero
curr := queue[0]
queue = queue[1:]

// Add the vertex to the sorted list
sortedVertices = append(sortedVertices, v)
sortedVertices = append(sortedVertices, curr)

// Decrement the inDegree of all neighbors of the dequeued vertex
for _, neighbor := range v.Neighbors {
neighbor.inDegree--
if neighbor.inDegree == 0 {
// Decrement the inDegree of each of the vertex's neighbors
for _, neighbor := range curr.Neighbors {
inDegrees[neighbor]--
if inDegrees[neighbor] == 0 {
queue = append(queue, neighbor)
}
}
}

// If the sorted list does not contain all vertices, there is a cycle
if len(sortedVertices) != len(d.Vertices) {
// Remove the new edge
from.Neighbors = from.Neighbors[:len(from.Neighbors)-1]
to.inDegree--

return errors.New("adding this edge would create a cycle in the graph")
}

return nil
}

func (d *DAG) hasCycle(current, parent *DAGVertex, visited map[*DAGVertex]bool) bool {
// Mark the current vertex as visited
visited[current] = true

// Check all neighbors of the current vertex
for _, neighbor := range current.Neighbors {
// If the neighbor is the parent vertex, continue to the next neighbor
if neighbor == parent {
continue
}

// If the neighbor has already been visited, there is a cycle
if visited[neighbor] {
return true
}

// Recursively check for cycles in the neighbor's subtree
if d.hasCycle(neighbor, current, visited) {
return true
}
return nil, ErrDAGHasCycle
}

// No cycle was found
return false
return sortedVertices, nil
}

// findVertex searches for the given id in the vertices. It returns
Expand Down
60 changes: 58 additions & 2 deletions graph/directed_acyclic_graph_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graph

import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -32,11 +33,11 @@ func TestDAG_AddEdge(t *testing.T) {
// Add edges from 1 to 2 and from 2 to 3
err := dag.AddEdge(v1, v2)
if err != nil {
t.Errorf("Unexpected error: %v", err)
t.Errorf("unexpected error: %v", err)
}
err = dag.AddEdge(v2, v3)
if err != nil {
t.Errorf("Unexpected error: %v", err)
t.Errorf("unexpected error: %v", err)
}

// Try to add an edge from 3 to 1, which should result in an error
Expand All @@ -45,3 +46,58 @@ func TestDAG_AddEdge(t *testing.T) {
t.Error("Expected error, but got none")
}
}

func TestDAG_TopologySort(t *testing.T) {
// Create a DAG with 6 vertices and 6 edges
dag := NewDAG()
v0 := dag.AddVertexWithID(0)
v1 := dag.AddVertexWithID(1)
v2 := dag.AddVertexWithID(2)
v3 := dag.AddVertexWithID(3)
v4 := dag.AddVertexWithID(4)
v5 := dag.AddVertexWithID(5)

err := dag.AddEdge(v5, v2)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v5, v0)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v4, v0)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v4, v1)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v2, v3)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v3, v1)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Perform a topological sort
sortedVertices, err := dag.TopologySort()

// Check that there was no error
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Check that the sorted order is correct
expectedOrder := []*DAGVertex{v4, v5, v2, v0, v3, v1}
if !reflect.DeepEqual(sortedVertices, expectedOrder) {
t.Errorf("unexpected sort order. Got %v, expected %v", sortedVertices, expectedOrder)
}
}

0 comments on commit 843e919

Please sign in to comment.