Skip to content

Commit

Permalink
feat: add announceToTrainer in scheduler (#2371)
Browse files Browse the repository at this point in the history
Signed-off-by: Gaius <gaius.qi@gmail.com>
  • Loading branch information
gaius-qi committed May 22, 2023
1 parent 0cddb97 commit 5f457ca
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 35 deletions.
175 changes: 168 additions & 7 deletions scheduler/announcer/announcer.go
Expand Up @@ -20,12 +20,25 @@ package announcer

import (
"context"
"fmt"
"io"
"time"

"golang.org/x/sync/errgroup"

managerv2 "d7y.io/api/pkg/apis/manager/v2"
trainerv1 "d7y.io/api/pkg/apis/trainer/v1"

logger "d7y.io/dragonfly/v2/internal/dflog"
managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client"
trainerclient "d7y.io/dragonfly/v2/pkg/rpc/trainer/client"
"d7y.io/dragonfly/v2/scheduler/config"
"d7y.io/dragonfly/v2/scheduler/storage"
)

const (
// UploadBufferSize is the buffer size for upload.
UploadBufferSize = 1024 * 1024
)

// Announcer is the interface used for announce service.
Expand All @@ -41,20 +54,34 @@ type Announcer interface {
type announcer struct {
config *config.Config
managerClient managerclient.V2
trainerClient trainerclient.V1
storage storage.Storage
done chan struct{}
}

// WithTrainerClient sets the grpc client of trainer.
func WithTrainerClient(client trainerclient.V1) Option {
return func(a *announcer) {
a.trainerClient = client
}
}

// Option is a functional option for configuring the announcer.
type Option func(s *announcer)

// New returns a new Announcer interface.
func New(cfg *config.Config, managerClient managerclient.V2) (Announcer, error) {
func New(cfg *config.Config, managerClient managerclient.V2, storage storage.Storage, options ...Option) (Announcer, error) {
a := &announcer{
config: cfg,
managerClient: managerClient,
storage: storage,
done: make(chan struct{}),
}

for _, opt := range options {
opt(a)
}

// Register to manager.
if _, err := a.managerClient.UpdateScheduler(context.Background(), &managerv2.UpdateSchedulerRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Expand All @@ -78,6 +105,13 @@ func (a *announcer) Serve() error {
return err
}

if a.trainerClient != nil {
logger.Info("announce scheduler to trainer")
if err := a.announceToTrainer(); err != nil {
return err
}
}

return nil
}

Expand All @@ -90,12 +124,139 @@ func (a *announcer) Stop() error {
// announceSeedPeer announces peer information to manager.
func (a *announcer) announceToManager() error {
// Start keepalive to manager.
a.managerClient.KeepAlive(a.config.Manager.KeepAlive.Interval, &managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
}, a.done)
go func() {
a.managerClient.KeepAlive(a.config.Manager.KeepAlive.Interval, &managerv2.KeepAliveRequest{
SourceType: managerv2.SourceType_SCHEDULER_SOURCE,
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
}, a.done)
}()

return nil
}

// announceSeedPeer announces dataset to trainer.
func (a *announcer) announceToTrainer() error {
tick := time.NewTicker(a.config.Trainer.Interval)
for {
select {
case <-tick.C:
if err := a.train(); err != nil {
logger.Error(err)
}
case <-a.done:
return nil
}
}
}

// train uploads dataset to trainer and trigger training.
func (a *announcer) train() error {
ctx, cancel := context.WithTimeout(context.Background(), a.config.Trainer.UploadTimeout)
defer cancel()

stream, err := a.trainerClient.Train(ctx)
if err != nil {
return err
}

eg := errgroup.Group{}
eg.Go(func() error {
if err := a.uploadDownloadToTrainer(stream); err != nil {
return fmt.Errorf("upload download: %w", err)
}

return nil
})

eg.Go(func() error {
if err := a.uploadNetworkTopologyToTrainer(stream); err != nil {
return fmt.Errorf("upload network topology: %w", err)
}

return nil
})

if err := eg.Wait(); err != nil {
return err
}

if _, err := stream.CloseAndRecv(); err != nil {
return err
}

return nil
}

// uploadDownloadToTrainer uploads download information to trainer.
func (a *announcer) uploadDownloadToTrainer(stream trainerv1.Trainer_TrainClient) error {
readCloser, err := a.storage.OpenDownload()
if err != nil {
return err
}
defer readCloser.Close()

buf := make([]byte, UploadBufferSize)
for {
n, err := readCloser.Read(buf)
if err != nil && err != io.EOF {
return err
}

if err := stream.Send(&trainerv1.TrainRequest{
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
Request: &trainerv1.TrainRequest_TrainMlpRequest{
TrainMlpRequest: &trainerv1.TrainMLPRequest{
Dataset: buf[:n],
},
},
}); err != nil {
return err
}

if err == io.EOF {
break
}
}

return nil
}

// uploadNetworkTopologyToTrainer uploads network topology to trainer.
func (a *announcer) uploadNetworkTopologyToTrainer(stream trainerv1.Trainer_TrainClient) error {
readCloser, err := a.storage.OpenNetworkTopology()
if err != nil {
return err
}
defer readCloser.Close()

buf := make([]byte, UploadBufferSize)
for {
n, err := readCloser.Read(buf)
if err != nil && err != io.EOF {
return err
}

if err := stream.Send(&trainerv1.TrainRequest{
Hostname: a.config.Server.Host,
Ip: a.config.Server.AdvertiseIP.String(),
ClusterId: uint64(a.config.Manager.SchedulerClusterID),
Request: &trainerv1.TrainRequest_TrainGnnRequest{
TrainGnnRequest: &trainerv1.TrainGNNRequest{
Dataset: buf[:n],
},
},
}); err != nil {
return err
}

if err == io.EOF {
break
}
}

return nil
}
14 changes: 8 additions & 6 deletions scheduler/announcer/announcer_test.go
Expand Up @@ -24,15 +24,16 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks"
clientmocks "d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks"
"d7y.io/dragonfly/v2/scheduler/config"
storagemocks "d7y.io/dragonfly/v2/scheduler/storage/mocks"
)

func TestAnnouncer_New(t *testing.T) {
tests := []struct {
name string
config *config.Config
mock func(m *mocks.MockV2MockRecorder)
mock func(m *clientmocks.MockV2MockRecorder)
expect func(t *testing.T, announcer Announcer, err error)
}{
{
Expand All @@ -52,7 +53,7 @@ func TestAnnouncer_New(t *testing.T) {
SchedulerClusterID: 1,
},
},
mock: func(m *mocks.MockV2MockRecorder) {
mock: func(m *clientmocks.MockV2MockRecorder) {
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
},
expect: func(t *testing.T, a Announcer, err error) {
Expand Down Expand Up @@ -80,7 +81,7 @@ func TestAnnouncer_New(t *testing.T) {
SchedulerClusterID: 1,
},
},
mock: func(m *mocks.MockV2MockRecorder) {
mock: func(m *clientmocks.MockV2MockRecorder) {
m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")).Times(1)
},
expect: func(t *testing.T, a Announcer, err error) {
Expand All @@ -94,10 +95,11 @@ func TestAnnouncer_New(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
ctl := gomock.NewController(t)
defer ctl.Finish()
mockManagerClient := mocks.NewMockV2(ctl)
mockManagerClient := clientmocks.NewMockV2(ctl)
mockStorage := storagemocks.NewMockStorage(ctl)
tc.mock(mockManagerClient.EXPECT())

a, err := New(tc.config, mockManagerClient)
a, err := New(tc.config, mockManagerClient, mockStorage)
tc.expect(t, a, err)
})
}
Expand Down
16 changes: 12 additions & 4 deletions scheduler/config/config.go
Expand Up @@ -135,7 +135,7 @@ type SchedulerConfig struct {
}

type GCConfig struct {
// PieceDownloadTimeout is timout of downloading piece.
// PieceDownloadTimeout is timeout of downloading piece.
PieceDownloadTimeout time.Duration `yaml:"pieceDownloadTimeout" mapstructure:"pieceDownloadTimeout"`

// PeerGCInterval is interval of peer gc.
Expand Down Expand Up @@ -328,6 +328,9 @@ type TrainerConfig struct {

// Interval is the interval of training.
Interval time.Duration `yaml:"interval" mapstructure:"interval"`

// UploadTimeout is the timeout of uploading dataset to trainer.
UploadTimeout time.Duration `yaml:"uploadTimeout" mapstructure:"uploadTimeout"`
}

// New default configuration.
Expand Down Expand Up @@ -412,9 +415,10 @@ func New() *Config {
},
},
Trainer: TrainerConfig{
Enable: false,
Addr: DefaultTrainerAddr,
Interval: DefaultTrainerInterval,
Enable: false,
Addr: DefaultTrainerAddr,
Interval: DefaultTrainerInterval,
UploadTimeout: DefaultTrainerUploadTimeout,
},
}
}
Expand Down Expand Up @@ -595,6 +599,10 @@ func (cfg *Config) Validate() error {
if cfg.Trainer.Interval <= 0 {
return errors.New("trainer requires parameter interval")
}

if cfg.Trainer.UploadTimeout <= 0 {
return errors.New("trainer requires parameter uploadTimeout")
}
}

return nil
Expand Down
22 changes: 19 additions & 3 deletions scheduler/config/config_test.go
Expand Up @@ -170,9 +170,10 @@ func TestConfig_Load(t *testing.T) {
},
},
Trainer: TrainerConfig{
Enable: false,
Addr: "127.0.0.1:9000",
Interval: 10 * time.Minute,
Enable: false,
Addr: "127.0.0.1:9000",
Interval: 10 * time.Minute,
UploadTimeout: 2 * time.Hour,
},
}

Expand Down Expand Up @@ -780,6 +781,21 @@ func TestConfig_Validate(t *testing.T) {
assert.EqualError(err, "trainer requires parameter interval")
},
},
{
name: "trainer requires parameter interval",
config: New(),
mock: func(cfg *Config) {
cfg.Manager = mockManagerConfig
cfg.Database.Redis = mockRedisConfig
cfg.Job = mockJobConfig
cfg.Trainer.Enable = true
cfg.Trainer.UploadTimeout = 0
},
expect: func(t *testing.T, err error) {
assert := assert.New(t)
assert.EqualError(err, "trainer requires parameter uploadTimeout")
},
},
}

for _, tc := range tests {
Expand Down
3 changes: 3 additions & 0 deletions scheduler/config/constants.go
Expand Up @@ -185,4 +185,7 @@ const (

// DefaultTrainerInterval is the default interval of training.
DefaultTrainerInterval = 7 * 24 * time.Hour

// DefaultTrainerUploadTimeout is the default timeout of uploading dataset to trainer.
DefaultTrainerUploadTimeout = 1 * time.Hour
)
1 change: 1 addition & 0 deletions scheduler/config/testdata/scheduler.yaml
Expand Up @@ -94,3 +94,4 @@ trainer:
enable: false
addr: "127.0.0.1:9000"
interval: 10m
uploadTimeout: 2h

0 comments on commit 5f457ca

Please sign in to comment.