Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add topology sort and unit tests #2

Merged
merged 1 commit into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 47 additions & 48 deletions graph/directed_acyclic_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package graph
import "errors"

var (
ErrDAGVertexNotFound = errors.New("vertex not found")
ErrDAGHasCycle = errors.New("edge would create cycle")
ErrDAGCycle = errors.New("edge would create cycle")
ErrDAGHasCycle = errors.New("the graph contains a cycle")
)

type DAGVertex struct {
Expand All @@ -25,9 +25,11 @@ func NewDAG() *DAG {
}

// AddVertexWithID adds a new vertex with the given id to the graph.
func (d *DAG) AddVertexWithID(id int64) {
func (d *DAG) AddVertexWithID(id int64) *DAGVertex {
v := &DAGVertex{ID: id}
d.Vertices = append(d.Vertices, v)

return v
}

func (d *DAG) AddVertex(v *DAGVertex) {
Expand All @@ -38,77 +40,74 @@ func (d *DAG) AddVertex(v *DAGVertex) {
// the vertex with the 'to' id, after checking if the edge would create
// a cycle.
//
// AddEdge guarantees that the graph remain DAG after adding new edge.
//
// It returns error if it finds a cycle between 'from' and 'to'.
func (d *DAG) AddEdge(from, to *DAGVertex) error {
// Add the new edge
from.Neighbors = append(from.Neighbors, to)
to.inDegree++

// Perform a topological sort to check for cycles
var sortedVertices []*DAGVertex
queue := make([]*DAGVertex, 0)
// If topological sort returns an error, new edge created a cycle
_, err := d.TopologySort()
if err != nil {
// Remove the new edge
from.Neighbors = from.Neighbors[:len(from.Neighbors)-1]
to.inDegree--

return ErrDAGCycle
}

// Add all vertices with inDegree 0 to the queue
return nil
}

// TopologySort performs a topological sort of the graph using
// Kahn's algorithm. If the sorted list of vertices does not contain
// all vertices in the graph, it means there is a cycle in the graph.
//
// It returns error if it finds a cycle in the graph.
func (d *DAG) TopologySort() ([]*DAGVertex, error) {
// Initialize a map to store the inDegree of each vertex
inDegrees := make(map[*DAGVertex]int)
for _, v := range d.Vertices {
if v.inDegree == 0 {
inDegrees[v] = v.inDegree
}

// Initialize a queue with vertices of inDegrees zero
queue := make([]*DAGVertex, 0)
for v, inDegree := range inDegrees {
if inDegree == 0 {
queue = append(queue, v)
}
}

// Traverse the graph using Kahn's algorithm
// Initialize the sorted list of vertices
sortedVertices := make([]*DAGVertex, 0)

// Loop through the vertices with inDegree zero
for len(queue) > 0 {
// Dequeue a vertex
v := queue[0]
// Get the next vertex with inDegree zero
curr := queue[0]
queue = queue[1:]

// Add the vertex to the sorted list
sortedVertices = append(sortedVertices, v)
sortedVertices = append(sortedVertices, curr)

// Decrement the inDegree of all neighbors of the dequeued vertex
for _, neighbor := range v.Neighbors {
neighbor.inDegree--
if neighbor.inDegree == 0 {
// Decrement the inDegree of each of the vertex's neighbors
for _, neighbor := range curr.Neighbors {
inDegrees[neighbor]--
if inDegrees[neighbor] == 0 {
queue = append(queue, neighbor)
}
}
}

// If the sorted list does not contain all vertices, there is a cycle
if len(sortedVertices) != len(d.Vertices) {
// Remove the new edge
from.Neighbors = from.Neighbors[:len(from.Neighbors)-1]
to.inDegree--

return errors.New("adding this edge would create a cycle in the graph")
}

return nil
}

func (d *DAG) hasCycle(current, parent *DAGVertex, visited map[*DAGVertex]bool) bool {
// Mark the current vertex as visited
visited[current] = true

// Check all neighbors of the current vertex
for _, neighbor := range current.Neighbors {
// If the neighbor is the parent vertex, continue to the next neighbor
if neighbor == parent {
continue
}

// If the neighbor has already been visited, there is a cycle
if visited[neighbor] {
return true
}

// Recursively check for cycles in the neighbor's subtree
if d.hasCycle(neighbor, current, visited) {
return true
}
return nil, ErrDAGHasCycle
}

// No cycle was found
return false
return sortedVertices, nil
}

// findVertex searches for the given id in the vertices. It returns
Expand Down
60 changes: 58 additions & 2 deletions graph/directed_acyclic_graph_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graph

import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -32,11 +33,11 @@ func TestDAG_AddEdge(t *testing.T) {
// Add edges from 1 to 2 and from 2 to 3
err := dag.AddEdge(v1, v2)
if err != nil {
t.Errorf("Unexpected error: %v", err)
t.Errorf("unexpected error: %v", err)
}
err = dag.AddEdge(v2, v3)
if err != nil {
t.Errorf("Unexpected error: %v", err)
t.Errorf("unexpected error: %v", err)
}

// Try to add an edge from 3 to 1, which should result in an error
Expand All @@ -45,3 +46,58 @@ func TestDAG_AddEdge(t *testing.T) {
t.Error("Expected error, but got none")
}
}

func TestDAG_TopologySort(t *testing.T) {
// Create a DAG with 6 vertices and 6 edges
dag := NewDAG()
v0 := dag.AddVertexWithID(0)
v1 := dag.AddVertexWithID(1)
v2 := dag.AddVertexWithID(2)
v3 := dag.AddVertexWithID(3)
v4 := dag.AddVertexWithID(4)
v5 := dag.AddVertexWithID(5)

err := dag.AddEdge(v5, v2)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v5, v0)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v4, v0)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v4, v1)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v2, v3)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

err = dag.AddEdge(v3, v1)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Perform a topological sort
sortedVertices, err := dag.TopologySort()

// Check that there was no error
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// Check that the sorted order is correct
expectedOrder := []*DAGVertex{v4, v5, v2, v0, v3, v1}
if !reflect.DeepEqual(sortedVertices, expectedOrder) {
t.Errorf("unexpected sort order. Got %v, expected %v", sortedVertices, expectedOrder)
}
}