diff --git a/scheduler/announcer/announcer.go b/scheduler/announcer/announcer.go index 862a9945c46..6c98d75df40 100644 --- a/scheduler/announcer/announcer.go +++ b/scheduler/announcer/announcer.go @@ -20,12 +20,25 @@ package announcer import ( "context" + "fmt" + "io" + "time" + + "golang.org/x/sync/errgroup" managerv2 "d7y.io/api/pkg/apis/manager/v2" + trainerv1 "d7y.io/api/pkg/apis/trainer/v1" logger "d7y.io/dragonfly/v2/internal/dflog" managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client" + trainerclient "d7y.io/dragonfly/v2/pkg/rpc/trainer/client" "d7y.io/dragonfly/v2/scheduler/config" + "d7y.io/dragonfly/v2/scheduler/storage" +) + +const ( + // UploadBufferSize is the buffer size for upload. + UploadBufferSize = 1024 * 1024 ) // Announcer is the interface used for announce service. @@ -41,20 +54,34 @@ type Announcer interface { type announcer struct { config *config.Config managerClient managerclient.V2 + trainerClient trainerclient.V1 + storage storage.Storage done chan struct{} } +// WithTrainerClient sets the grpc client of trainer. +func WithTrainerClient(client trainerclient.V1) Option { + return func(a *announcer) { + a.trainerClient = client + } +} + // Option is a functional option for configuring the announcer. type Option func(s *announcer) // New returns a new Announcer interface. -func New(cfg *config.Config, managerClient managerclient.V2) (Announcer, error) { +func New(cfg *config.Config, managerClient managerclient.V2, storage storage.Storage, options ...Option) (Announcer, error) { a := &announcer{ config: cfg, managerClient: managerClient, + storage: storage, done: make(chan struct{}), } + for _, opt := range options { + opt(a) + } + // Register to manager. if _, err := a.managerClient.UpdateScheduler(context.Background(), &managerv2.UpdateSchedulerRequest{ SourceType: managerv2.SourceType_SCHEDULER_SOURCE, @@ -78,6 +105,13 @@ func (a *announcer) Serve() error { return err } + if a.trainerClient != nil { + logger.Info("announce scheduler to trainer") + if err := a.announceToTrainer(); err != nil { + return err + } + } + return nil } @@ -90,12 +124,139 @@ func (a *announcer) Stop() error { // announceSeedPeer announces peer information to manager. func (a *announcer) announceToManager() error { // Start keepalive to manager. - a.managerClient.KeepAlive(a.config.Manager.KeepAlive.Interval, &managerv2.KeepAliveRequest{ - SourceType: managerv2.SourceType_SCHEDULER_SOURCE, - Hostname: a.config.Server.Host, - Ip: a.config.Server.AdvertiseIP.String(), - ClusterId: uint64(a.config.Manager.SchedulerClusterID), - }, a.done) + go func() { + a.managerClient.KeepAlive(a.config.Manager.KeepAlive.Interval, &managerv2.KeepAliveRequest{ + SourceType: managerv2.SourceType_SCHEDULER_SOURCE, + Hostname: a.config.Server.Host, + Ip: a.config.Server.AdvertiseIP.String(), + ClusterId: uint64(a.config.Manager.SchedulerClusterID), + }, a.done) + }() + + return nil +} + +// announceSeedPeer announces dataset to trainer. +func (a *announcer) announceToTrainer() error { + tick := time.NewTicker(a.config.Trainer.Interval) + for { + select { + case <-tick.C: + if err := a.train(); err != nil { + logger.Error(err) + } + case <-a.done: + return nil + } + } +} + +// train uploads dataset to trainer and trigger training. +func (a *announcer) train() error { + ctx, cancel := context.WithTimeout(context.Background(), a.config.Trainer.UploadTimeout) + defer cancel() + + stream, err := a.trainerClient.Train(ctx) + if err != nil { + return err + } + + eg := errgroup.Group{} + eg.Go(func() error { + if err := a.uploadDownloadToTrainer(stream); err != nil { + return fmt.Errorf("upload download: %w", err) + } + + return nil + }) + + eg.Go(func() error { + if err := a.uploadNetworkTopologyToTrainer(stream); err != nil { + return fmt.Errorf("upload network topology: %w", err) + } + + return nil + }) + + if err := eg.Wait(); err != nil { + return err + } + + if _, err := stream.CloseAndRecv(); err != nil { + return err + } + + return nil +} + +// uploadDownloadToTrainer uploads download information to trainer. +func (a *announcer) uploadDownloadToTrainer(stream trainerv1.Trainer_TrainClient) error { + readCloser, err := a.storage.OpenDownload() + if err != nil { + return err + } + defer readCloser.Close() + + buf := make([]byte, UploadBufferSize) + for { + n, err := readCloser.Read(buf) + if err != nil && err != io.EOF { + return err + } + + if err := stream.Send(&trainerv1.TrainRequest{ + Hostname: a.config.Server.Host, + Ip: a.config.Server.AdvertiseIP.String(), + ClusterId: uint64(a.config.Manager.SchedulerClusterID), + Request: &trainerv1.TrainRequest_TrainMlpRequest{ + TrainMlpRequest: &trainerv1.TrainMLPRequest{ + Dataset: buf[:n], + }, + }, + }); err != nil { + return err + } + + if err == io.EOF { + break + } + } + + return nil +} + +// uploadNetworkTopologyToTrainer uploads network topology to trainer. +func (a *announcer) uploadNetworkTopologyToTrainer(stream trainerv1.Trainer_TrainClient) error { + readCloser, err := a.storage.OpenNetworkTopology() + if err != nil { + return err + } + defer readCloser.Close() + + buf := make([]byte, UploadBufferSize) + for { + n, err := readCloser.Read(buf) + if err != nil && err != io.EOF { + return err + } + + if err := stream.Send(&trainerv1.TrainRequest{ + Hostname: a.config.Server.Host, + Ip: a.config.Server.AdvertiseIP.String(), + ClusterId: uint64(a.config.Manager.SchedulerClusterID), + Request: &trainerv1.TrainRequest_TrainGnnRequest{ + TrainGnnRequest: &trainerv1.TrainGNNRequest{ + Dataset: buf[:n], + }, + }, + }); err != nil { + return err + } + + if err == io.EOF { + break + } + } return nil } diff --git a/scheduler/announcer/announcer_test.go b/scheduler/announcer/announcer_test.go index 6894d7ebaf5..3792052de98 100644 --- a/scheduler/announcer/announcer_test.go +++ b/scheduler/announcer/announcer_test.go @@ -24,15 +24,16 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks" + clientmocks "d7y.io/dragonfly/v2/pkg/rpc/manager/client/mocks" "d7y.io/dragonfly/v2/scheduler/config" + storagemocks "d7y.io/dragonfly/v2/scheduler/storage/mocks" ) func TestAnnouncer_New(t *testing.T) { tests := []struct { name string config *config.Config - mock func(m *mocks.MockV2MockRecorder) + mock func(m *clientmocks.MockV2MockRecorder) expect func(t *testing.T, announcer Announcer, err error) }{ { @@ -52,7 +53,7 @@ func TestAnnouncer_New(t *testing.T) { SchedulerClusterID: 1, }, }, - mock: func(m *mocks.MockV2MockRecorder) { + mock: func(m *clientmocks.MockV2MockRecorder) { m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) }, expect: func(t *testing.T, a Announcer, err error) { @@ -80,7 +81,7 @@ func TestAnnouncer_New(t *testing.T) { SchedulerClusterID: 1, }, }, - mock: func(m *mocks.MockV2MockRecorder) { + mock: func(m *clientmocks.MockV2MockRecorder) { m.UpdateScheduler(gomock.Any(), gomock.Any()).Return(nil, errors.New("foo")).Times(1) }, expect: func(t *testing.T, a Announcer, err error) { @@ -94,10 +95,11 @@ func TestAnnouncer_New(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctl := gomock.NewController(t) defer ctl.Finish() - mockManagerClient := mocks.NewMockV2(ctl) + mockManagerClient := clientmocks.NewMockV2(ctl) + mockStorage := storagemocks.NewMockStorage(ctl) tc.mock(mockManagerClient.EXPECT()) - a, err := New(tc.config, mockManagerClient) + a, err := New(tc.config, mockManagerClient, mockStorage) tc.expect(t, a, err) }) } diff --git a/scheduler/config/config.go b/scheduler/config/config.go index 7c0f48e75a5..fe15710b32a 100644 --- a/scheduler/config/config.go +++ b/scheduler/config/config.go @@ -135,7 +135,7 @@ type SchedulerConfig struct { } type GCConfig struct { - // PieceDownloadTimeout is timout of downloading piece. + // PieceDownloadTimeout is timeout of downloading piece. PieceDownloadTimeout time.Duration `yaml:"pieceDownloadTimeout" mapstructure:"pieceDownloadTimeout"` // PeerGCInterval is interval of peer gc. @@ -328,6 +328,9 @@ type TrainerConfig struct { // Interval is the interval of training. Interval time.Duration `yaml:"interval" mapstructure:"interval"` + + // UploadTimeout is the timeout of uploading dataset to trainer. + UploadTimeout time.Duration `yaml:"uploadTimeout" mapstructure:"uploadTimeout"` } // New default configuration. @@ -412,9 +415,10 @@ func New() *Config { }, }, Trainer: TrainerConfig{ - Enable: false, - Addr: DefaultTrainerAddr, - Interval: DefaultTrainerInterval, + Enable: false, + Addr: DefaultTrainerAddr, + Interval: DefaultTrainerInterval, + UploadTimeout: DefaultTrainerUploadTimeout, }, } } @@ -595,6 +599,10 @@ func (cfg *Config) Validate() error { if cfg.Trainer.Interval <= 0 { return errors.New("trainer requires parameter interval") } + + if cfg.Trainer.UploadTimeout <= 0 { + return errors.New("trainer requires parameter uploadTimeout") + } } return nil diff --git a/scheduler/config/config_test.go b/scheduler/config/config_test.go index b9e1653046b..ecfe93fb74d 100644 --- a/scheduler/config/config_test.go +++ b/scheduler/config/config_test.go @@ -170,9 +170,10 @@ func TestConfig_Load(t *testing.T) { }, }, Trainer: TrainerConfig{ - Enable: false, - Addr: "127.0.0.1:9000", - Interval: 10 * time.Minute, + Enable: false, + Addr: "127.0.0.1:9000", + Interval: 10 * time.Minute, + UploadTimeout: 2 * time.Hour, }, } @@ -780,6 +781,21 @@ func TestConfig_Validate(t *testing.T) { assert.EqualError(err, "trainer requires parameter interval") }, }, + { + name: "trainer requires parameter interval", + config: New(), + mock: func(cfg *Config) { + cfg.Manager = mockManagerConfig + cfg.Database.Redis = mockRedisConfig + cfg.Job = mockJobConfig + cfg.Trainer.Enable = true + cfg.Trainer.UploadTimeout = 0 + }, + expect: func(t *testing.T, err error) { + assert := assert.New(t) + assert.EqualError(err, "trainer requires parameter uploadTimeout") + }, + }, } for _, tc := range tests { diff --git a/scheduler/config/constants.go b/scheduler/config/constants.go index 3e62bdfb119..9ee8813d679 100644 --- a/scheduler/config/constants.go +++ b/scheduler/config/constants.go @@ -185,4 +185,7 @@ const ( // DefaultTrainerInterval is the default interval of training. DefaultTrainerInterval = 7 * 24 * time.Hour + + // DefaultTrainerUploadTimeout is the default timeout of uploading dataset to trainer. + DefaultTrainerUploadTimeout = 1 * time.Hour ) diff --git a/scheduler/config/testdata/scheduler.yaml b/scheduler/config/testdata/scheduler.yaml index eb2d660a245..0fa376d17fd 100644 --- a/scheduler/config/testdata/scheduler.yaml +++ b/scheduler/config/testdata/scheduler.yaml @@ -94,3 +94,4 @@ trainer: enable: false addr: "127.0.0.1:9000" interval: 10m + uploadTimeout: 2h diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 69f85b0cf7d..520741c4e3c 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -44,6 +44,7 @@ import ( "d7y.io/dragonfly/v2/pkg/rpc" managerclient "d7y.io/dragonfly/v2/pkg/rpc/manager/client" securityclient "d7y.io/dragonfly/v2/pkg/rpc/security/client" + trainerclient "d7y.io/dragonfly/v2/pkg/rpc/trainer/client" "d7y.io/dragonfly/v2/pkg/types" "d7y.io/dragonfly/v2/scheduler/announcer" "d7y.io/dragonfly/v2/scheduler/config" @@ -78,6 +79,9 @@ type Server struct { // Security client. securityClient securityclient.V1 + // Trainer client. + trainerClient trainerclient.V1 + // Resource interface. resource resource.Resource @@ -115,7 +119,19 @@ func New(ctx context.Context, cfg *config.Config, d dfpath.Dfpath) (*Server, err return nil, err } - // Initialize manager client and dial options of manager grpc client. + // Initialize Storage. + storage, err := storage.New( + d.DataDir(), + cfg.Storage.MaxSize, + cfg.Storage.MaxBackups, + cfg.Storage.BufferSize, + ) + if err != nil { + return nil, err + } + s.storage = storage + + // Initialize dial options of manager grpc client. managerDialOptions := []grpc.DialOption{} if cfg.Security.AutoIssueCert { clientTransportCredentials, err := rpc.NewClientCredentials(cfg.Security.TLSPolicy, nil, []byte(cfg.Security.CACert)) @@ -135,8 +151,36 @@ func New(ctx context.Context, cfg *config.Config, d dfpath.Dfpath) (*Server, err } s.managerClient = managerClient + // Initialize dial options of trainer grpc client. + if cfg.Trainer.Enable { + trainerDialOptions := []grpc.DialOption{} + if cfg.Security.AutoIssueCert { + clientTransportCredentials, err := rpc.NewClientCredentials(cfg.Security.TLSPolicy, nil, []byte(cfg.Security.CACert)) + if err != nil { + return nil, err + } + + trainerDialOptions = append(trainerDialOptions, grpc.WithTransportCredentials(clientTransportCredentials)) + } else { + trainerDialOptions = append(trainerDialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + // Initialize trainer client. + trainerClient, err := trainerclient.GetV1ByAddr(ctx, cfg.Trainer.Addr, trainerDialOptions...) + if err != nil { + return nil, err + } + s.trainerClient = trainerClient + } + + // Initialize dial options of announcer. + announcerOptions := []announcer.Option{} + if s.trainerClient != nil { + announcerOptions = append(announcerOptions, announcer.WithTrainerClient(s.trainerClient)) + } + // Initialize announcer. - announcer, err := announcer.New(cfg, s.managerClient) + announcer, err := announcer.New(cfg, s.managerClient, storage, announcerOptions...) if err != nil { return nil, err } @@ -203,19 +247,7 @@ func New(ctx context.Context, cfg *config.Config, d dfpath.Dfpath) (*Server, err // Initialize scheduling. scheduling := scheduling.New(&cfg.Scheduler, dynconfig, d.PluginDir()) - // Initialize Storage. - storage, err := storage.New( - d.DataDir(), - cfg.Storage.MaxSize, - cfg.Storage.MaxBackups, - cfg.Storage.BufferSize, - ) - if err != nil { - return nil, err - } - s.storage = storage - - // Initialize grpc service and server options of scheduler grpc server. + // Initialize server options of scheduler grpc server. schedulerServerOptions := []grpc.ServerOption{} if certifyClient != nil { serverTransportCredentials, err := rpc.NewServerCredentialsByCertify(cfg.Security.TLSPolicy, cfg.Security.TLSVerify, []byte(cfg.Security.CACert), certifyClient) @@ -375,6 +407,15 @@ func (s *Server) Stop() { } } + // Stop trainer client. + if s.trainerClient != nil { + if err := s.trainerClient.Close(); err != nil { + logger.Errorf("trainer client failed to stop: %s", err.Error()) + } else { + logger.Info("trainer client closed") + } + } + // Stop security client. if s.securityClient != nil { if err := s.securityClient.Close(); err != nil {