diff --git a/acyclic_test.go b/acyclic_test.go index d5bb98c..28b621a 100644 --- a/acyclic_test.go +++ b/acyclic_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestDAG_TopologySort(t *testing.T) { +func TestTopologySort(t *testing.T) { // Create a dag with 6 vertices and 6 edges g := New[int](Acyclic()) @@ -17,39 +17,44 @@ func TestDAG_TopologySort(t *testing.T) { t.Error(testErrMsgNotTrue) } - v0 := g.AddVertexByLabel(0) v1 := g.AddVertexByLabel(1) v2 := g.AddVertexByLabel(2) v3 := g.AddVertexByLabel(3) v4 := g.AddVertexByLabel(4) v5 := g.AddVertexByLabel(5) + v6 := g.AddVertexByLabel(6) - _, err := g.AddEdge(v5, v2) + _, err := g.AddEdge(v1, v2) if err != nil { t.Errorf("unexpected error: %v", err) } - _, err = g.AddEdge(v5, v0) + _, err = g.AddEdge(v2, v3) if err != nil { t.Errorf("unexpected error: %v", err) } - _, err = g.AddEdge(v4, v0) + _, err = g.AddEdge(v2, v4) if err != nil { t.Errorf("unexpected error: %v", err) } - _, err = g.AddEdge(v4, v1) + _, err = g.AddEdge(v2, v5) if err != nil { t.Errorf("unexpected error: %v", err) } - _, err = g.AddEdge(v2, v3) + _, err = g.AddEdge(v3, v5) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + _, err = g.AddEdge(v4, v6) if err != nil { t.Errorf("unexpected error: %v", err) } - _, err = g.AddEdge(v3, v1) + _, err = g.AddEdge(v5, v6) if err != nil { t.Errorf("unexpected error: %v", err) } @@ -63,7 +68,7 @@ func TestDAG_TopologySort(t *testing.T) { } // Check that the sorted order is correct - expectedOrder := []*Vertex[int]{v4, v5, v2, v0, v3, v1} + expectedOrder := []*Vertex[int]{v1, v2, v3, v4, v5, v6} if !reflect.DeepEqual(sortedVertices, expectedOrder) { t.Errorf("unexpected sort order. Got %v, expected %v", sortedVertices, expectedOrder) } diff --git a/traverse/random_walk_iterator.go b/traverse/random_walk_iterator.go new file mode 100644 index 0000000..86f7513 --- /dev/null +++ b/traverse/random_walk_iterator.go @@ -0,0 +1,71 @@ +package traverse + +import ( + "crypto/rand" + "math/big" + + "github.com/hmdsefi/gograph" +) + +// randomWalkIterator implements the Iterator interface to travers +// a graph in a random walk fashion. +// +// Random walk is a stochastic process used to explore a graph, where +// a walker moves through the graph by following random edges. At each +// step, the walker chooses a random neighbor of the current node and +// moves to it, and the process is repeated until a stopping condition +// is met. +type randomWalkIterator[T comparable] struct { + graph gograph.Graph[T] // the graph that being traversed. + start *gograph.Vertex[T] // the starting point of the traversal. + current *gograph.Vertex[T] // the latest node that has been returned by the iterator. + steps int // the maximum number of steps to be taken during the traversal. + currentStep int // the step counter. +} + +func NewRandomWalkIterator[T comparable](graph gograph.Graph[T], start *gograph.Vertex[T], steps int) Iterator[T] { + return &randomWalkIterator[T]{ + graph: graph, + start: start, + current: start, + steps: steps, + } +} + +func (r *randomWalkIterator[T]) HasNext() bool { + return r.current.OutDegree() > 0 && r.currentStep < r.steps +} + +func (r *randomWalkIterator[T]) Next() *gograph.Vertex[T] { + if !r.HasNext() { + return nil + } + + neighbors := r.current.Neighbors() + if len(neighbors) == 0 { + // there is no vertex to continue + r.currentStep = r.steps + return r.current + } + + i, _ := rand.Int(rand.Reader, big.NewInt(int64(len(neighbors)))) + r.current = neighbors[i.Int64()] + r.currentStep++ + + return r.current +} + +func (r *randomWalkIterator[T]) Iterate(f func(v *gograph.Vertex[T]) error) error { + for r.HasNext() { + if err := f(r.Next()); err != nil { + return err + } + } + + return nil +} + +func (r *randomWalkIterator[T]) Reset() { + r.current = r.start + r.currentStep = 0 +}