/
s_cluster.go
96 lines (87 loc) · 2.46 KB
/
s_cluster.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package cluster
import (
"fmt"
"log"
"net"
"sync"
"time"
pb "github.com/go-sif/sif/internal/rpc"
"golang.org/x/net/context"
"google.golang.org/grpc/peer"
)
type clusterServer struct {
pb.UnimplementedClusterServiceServer
workers sync.Map
numWorkersLock sync.Mutex
numWorkers int
opts *NodeOptions
}
// createClusterServer creates a new cluster server
func createClusterServer(opts *NodeOptions) *clusterServer {
return &clusterServer{workers: sync.Map{}, opts: opts}
}
// RegisterWorker registers new workers with the cluster
func (s *clusterServer) RegisterWorker(ctx context.Context, req *pb.MRegisterRequest) (*pb.MRegisterResponse, error) {
s.numWorkersLock.Lock()
defer s.numWorkersLock.Unlock()
if _, exists := s.workers.Load(req.Id); exists {
return nil, fmt.Errorf("worker %s is already registered", req.Id)
}
if s.numWorkers == s.opts.NumWorkers {
return nil, fmt.Errorf("maximum number of workers reacher")
}
peer, ok := peer.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("unable to fetch peer data for connecting worker %s", req.Id)
}
tcpAddr, ok := peer.Addr.(*net.TCPAddr)
if !ok {
return nil, fmt.Errorf("connecting worker %s is not using TCP", req.Id)
}
wDescriptor := pb.MWorkerDescriptor{
Id: req.Id,
Host: tcpAddr.IP.String(),
Port: int32(req.Port),
}
s.workers.Store(req.Id, &wDescriptor)
s.numWorkers++
// test connection
conn, err := dialWorker(&wDescriptor)
if err != nil {
log.Fatalf("Unable to connect to worker %s", wDescriptor.Id)
}
defer conn.Close()
log.Printf("Registered worker %s at %s:%d", wDescriptor.Id, wDescriptor.Host, wDescriptor.Port)
return &pb.MRegisterResponse{Time: time.Now().Unix()}, nil
}
// NumberOfWorkers returns the current worker count
func (s *clusterServer) NumberOfWorkers() int {
s.numWorkersLock.Lock()
defer s.numWorkersLock.Unlock()
return s.numWorkers
}
// workers retrieves a slice of connected workers
func (s *clusterServer) Workers() []*pb.MWorkerDescriptor {
result := make([]*pb.MWorkerDescriptor, 0)
s.workers.Range(func(_, v interface{}) bool {
w := v.(*pb.MWorkerDescriptor)
result = append(result, w)
return true
})
return result
}
func (s *clusterServer) waitForWorkers(ctx context.Context) error {
for {
if s.NumberOfWorkers() == s.opts.NumWorkers {
break
}
select {
case <-ctx.Done():
// Did we time out?
return ctx.Err()
case <-time.After(time.Second):
// Wait 1 second and check again (iterate)
}
}
return nil
}