Skip to content

Commit

Permalink
feat: add announceToTrainer in scheduler
Browse files Browse the repository at this point in the history
Signed-off-by: Gaius <gaius.qi@gmail.com>
  • Loading branch information
gaius-qi committed May 22, 2023
1 parent 0cddb97 commit ffbde0d
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 74 deletions.
197 changes: 179 additions & 18 deletions scheduler/announcer/announcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,25 @@ package announcer

import (
"context"
"fmt"
"io"
"time"

"golang.org/x/sync/errgroup"

managerv2 "d7y.io/api/pkg/apis/manager/v2"
trainerv1 "d7y.io/api/pkg/apis/trainer/v1"

logger "d7y.io/dragonfly/v2/internal/dflog"
managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client"
trainerclient "d7y.io/dragonfly/v2/pkg/rpc/trainer/client"
"d7y.io/dragonfly/v2/scheduler/config"
"d7y.io/dragonfly/v2/scheduler/storage"
)

const (
// UploadBufferSize is the buffer size for upload.
UploadBufferSize = 1024 * 1024
)

// Announcer is the interface used for announce service.
Expand All @@ -41,31 +54,32 @@ type Announcer interface {
type announcer struct {
config *config.Config
managerClient managerclient.V2
trainerClient trainerclient.V1
storage storage.Storage
done chan struct{}
}

// WithTrainerClient sets the grpc client of trainer.
func WithTrainerClient(client trainerclient.V1) Option {
return func(a *announcer) {
a.trainerClient = client
}
}

// Option is a functional option for configuring the announcer.
type Option func(s *announcer)

// New returns a new Announcer interface.
func New(cfg *config.Config, managerClient managerclient.V2) (Announcer, error) {
func New(cfg *config.Config, managerClient managerclient.V2, storage storage.Storage, options ...Option) (Announcer, error) {
a := &announcer{
config: cfg,
managerClient: managerClient,
storage: storage,
done: make(chan struct{}),
}

// Register to manager.
if _, err := a.managerClient.UpdateScheduler(context.Background(), &managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
Port: int32(a.config.Server.AdvertisePort),
Idc: a.config.Host.IDC,
Location: a.config.Host.Location,
SchedulerClusterId: uint64(a.config.Manager.SchedulerClusterID),
}); err != nil {
return nil, err
for _, opt := range options {
opt(a)
}

return a, nil
Expand All @@ -78,6 +92,13 @@ func (a *announcer) Serve() error {
return err
}

if a.trainerClient != nil {
logger.Info("announce scheduler to trainer")
if err := a.announceToTrainer(); err != nil {
return err
}
}

return nil
}

Expand All @@ -89,13 +110,153 @@ func (a *announcer) Stop() error {

// announceSeedPeer announces peer information to manager.
func (a *announcer) announceToManager() error {
// Register to manager.
if _, err := a.managerClient.UpdateScheduler(context.Background(), &managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
Port: int32(a.config.Server.AdvertisePort),
Idc: a.config.Host.IDC,
Location: a.config.Host.Location,
SchedulerClusterId: uint64(a.config.Manager.SchedulerClusterID),
}); err != nil {
return err
}

// Start keepalive to manager.
a.managerClient.KeepAlive(a.config.Manager.KeepAlive.Interval, &managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
}, a.done)
go func() {
a.managerClient.KeepAlive(a.config.Manager.KeepAlive.Interval, &managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
}, a.done)
}()

return nil
}

// announceSeedPeer announces dataset to trainer.
func (a *announcer) announceToTrainer() error {
tick := time.NewTicker(a.config.Trainer.Interval)
for {
select {
case <-tick.C:
if err := a.train(); err != nil {
logger.Error(err)
}
case <-a.done:
return nil
}
}
}

// train uploads dataset to trainer and trigger training.
func (a *announcer) train() error {
ctx, cancel := context.WithTimeout(context.Background(), a.config.Trainer.UploadTimeout)
defer cancel()

stream, err := a.trainerClient.Train(ctx)
if err != nil {
return err
}

eg := errgroup.Group{}
eg.Go(func() error {
if err := a.uploadDownloadToTrainer(stream); err != nil {
return fmt.Errorf("upload download: %w", err)
}

return nil
})

eg.Go(func() error {
if err := a.uploadNetworkTopologyToTrainer(stream); err != nil {
return fmt.Errorf("upload network topology: %w", err)
}

return nil
})

if err := eg.Wait(); err != nil {
return err
}

if _, err := stream.CloseAndRecv(); err != nil {
return err
}

return nil
}

// uploadDownloadToTrainer uploads download information to trainer.
func (a *announcer) uploadDownloadToTrainer(stream trainerv1.Trainer_TrainClient) error {
readCloser, err := a.storage.OpenDownload()
if err != nil {
return err
}
defer readCloser.Close()

buf := make([]byte, UploadBufferSize)
for {
n, err := readCloser.Read(buf)
if err != nil && err != io.EOF {
return err
}

if err := stream.Send(&trainerv1.TrainRequest{
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
Request: &trainerv1.TrainRequest_TrainMlpRequest{
TrainMlpRequest: &trainerv1.TrainMLPRequest{
Dataset: buf[:n],
},
},
}); err != nil {
return err
}

if err == io.EOF {
break
}
}

return nil
}

// uploadNetworkTopologyToTrainer uploads network topology to trainer.
func (a *announcer) uploadNetworkTopologyToTrainer(stream trainerv1.Trainer_TrainClient) error {
readCloser, err := a.storage.OpenNetworkTopology()
if err != nil {
return err
}
defer readCloser.Close()

buf := make([]byte, UploadBufferSize)
for {
n, err := readCloser.Read(buf)
if err != nil && err != io.EOF {
return err
}

if err := stream.Send(&trainerv1.TrainRequest{
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
Request: &trainerv1.TrainRequest_TrainGnnRequest{
TrainGnnRequest: &trainerv1.TrainGNNRequest{
Dataset: buf[:n],
},
},
}); err != nil {
return err
}

if err == io.EOF {
break
}
}

return nil
}
39 changes: 5 additions & 34 deletions scheduler/announcer/announcer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,21 @@
package announcer

import (
"errors"
"net"
"testing"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks"
clientmocks "d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks"
"d7y.io/dragonfly/v2/scheduler/config"
storagemocks "d7y.io/dragonfly/v2/scheduler/storage/mocks"
)

func TestAnnouncer_New(t *testing.T) {
tests := []struct {
name string
config *config.Config
mock func(m *mocks.MockV2MockRecorder)
expect func(t *testing.T, announcer Announcer, err error)
}{
{
Expand All @@ -52,9 +51,6 @@ func TestAnnouncer_New(t *testing.T) {
SchedulerClusterID: 1,
},
},
mock: func(m *mocks.MockV2MockRecorder) {
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
},
expect: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
instance := a.(*announcer)
Expand All @@ -63,41 +59,16 @@ func TestAnnouncer_New(t *testing.T) {
assert.NotNil(instance.managerClient)
},
},
{
name: "update scheduler failed",
config: &config.Config{
Server: config.ServerConfig{
Host: "localhost",
AdvertiseIP: net.ParseIP("127.0.0.1"),
AdvertisePort: 8004,
Port: 8080,
},
Host: config.HostConfig{
IDC: "foo",
Location: "bar",
},
Manager: config.ManagerConfig{
SchedulerClusterID: 1,
},
},
mock: func(m *mocks.MockV2MockRecorder) {
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")).Times(1)
},
expect: func(t *testing.T, a Announcer, err error) {
assert := assert.New(t)
assert.Error(err)
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockManagerClient := mocks.NewMockV2(ctl)
tc.mock(mockManagerClient.EXPECT())
mockManagerClient := clientmocks.NewMockV2(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)

a, err := New(tc.config, mockManagerClient)
a, err := New(tc.config, mockManagerClient, mockStorage)
tc.expect(t, a, err)
})
}
Expand Down
16 changes: 12 additions & 4 deletions scheduler/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ type SchedulerConfig struct {
}

type GCConfig struct {
// PieceDownloadTimeout is timout of downloading piece.
// PieceDownloadTimeout is timeout of downloading piece.
PieceDownloadTimeout time.Duration `yaml:"pieceDownloadTimeout" mapstructure:"pieceDownloadTimeout"`

// PeerGCInterval is interval of peer gc.
Expand Down Expand Up @@ -328,6 +328,9 @@ type TrainerConfig struct {

// Interval is the interval of training.
Interval time.Duration `yaml:"interval" mapstructure:"interval"`

// UploadTimeout is the timeout of uploading dataset to trainer.
UploadTimeout time.Duration `yaml:"uploadTimeout" mapstructure:"uploadTimeout"`
}

// New default configuration.
Expand Down Expand Up @@ -412,9 +415,10 @@ func New() *Config {
},
},
Trainer: TrainerConfig{
Enable: false,
Addr: DefaultTrainerAddr,
Interval: DefaultTrainerInterval,
Enable: false,
Addr: DefaultTrainerAddr,
Interval: DefaultTrainerInterval,
UploadTimeout: DefaultTrainerUploadTimeout,
},
}
}
Expand Down Expand Up @@ -595,6 +599,10 @@ func (cfg *Config) Validate() error {
if cfg.Trainer.Interval <= 0 {
return errors.New("trainer requires parameter interval")
}

if cfg.Trainer.UploadTimeout <= 0 {
return errors.New("trainer requires parameter uploadTimeout")
}
}

return nil
Expand Down

0 comments on commit ffbde0d

Please sign in to comment.