Skip to content

Commit

Permalink
Training persistence templated algorithm of reconcilation DB and kube…
Browse files Browse the repository at this point in the history
…rnetes (#268)
  • Loading branch information
vlad-tokarev committed Jul 29, 2020
1 parent f4cf3ec commit 91872f3
Showing 1 changed file with 51 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"errors"
"fmt"
odahuflowv1alpha1 "github.com/odahu/odahu-flow/packages/operator/api/v1alpha1"
"github.com/odahu/odahu-flow/packages/operator/pkg/apis/training"
odahu_errs "github.com/odahu/odahu-flow/packages/operator/pkg/errors"
mt_repository "github.com/odahu/odahu-flow/packages/operator/pkg/repository/training"
ctrl "sigs.k8s.io/controller-runtime"
Expand All @@ -19,17 +17,35 @@ const (
)

var (
//log = ctrl.Log.WithName("model-training-runner")
log = logf.Log.WithName("runner-manager")
log = logf.Log.WithName("model-training-runner")
)

// Entity that can be created on ODAHU Kubernetes
type OdahuEntity interface {
GetID() string

CreateInKube() error
UpdateInKube() error
DeleteInKube() error
}

type KubernetesOdahuService interface {
// Attach entity
AttachKubeMgr(ctrl.Manager) error
// ServicedEntities that are not finished on Kubernetes operator
ListNotFinished() ([]OdahuEntity, error)
// Get OdahuEntity from storage
Get(ID string) (OdahuEntity, error)
}

type Runner struct {
storage mt_repository.Storage
service mt_repository.Service
mgr ctrl.Manager
name string
mu sync.Mutex
launchPeriod time.Duration
odahuService KubernetesOdahuService
}

func NewRunner(
Expand All @@ -50,36 +66,34 @@ func NewRunner(

func (r *Runner) Reconcile(request ctrl.Request) (ctrl.Result, error) {

trainID := request.Name
trainLog := log.WithValues("trainID", trainID)
trainLog.Info("New update is received")
ID := request.Name
itemLog := log.WithValues("ID", ID)
itemLog.Info("New update is received")

train, err := r.service.GetModelTraining(trainID)
item, err := r.odahuService.Get(ID)
if err != nil {
trainLog.Error(err, "Unable to fetch state from training service")
itemLog.Error(err, "Unable to fetch state from kubernetes")
return ctrl.Result{}, err
}

trainLog.Info("Trying to persist new state")
err = r.storage.UpdateModelTraining(train)
itemLog.Info("Trying to persist new state")
err = item.UpdateInKube()
if odahu_errs.IsNotFoundError(err) {
trainLog.Info(
"Training service sent update but there is no train in DB with this ID. " +
"Trying to stop training",
)
delErr := r.service.DeleteModelTraining(request.Name)
itemLog.Info("Kubernetes sent update but there is no item in DB with this ID. Trying to stop")
delErr := item.DeleteInKube()
if delErr != nil {
trainLog.Error(err, "Unable to stop training on service")
itemLog.Error(err, "Unable to stop on kubernetes")
return ctrl.Result{}, err
}
} else if err != nil {
trainLog.Error(err, "Unable to persist state in DB")
itemLog.Error(err, "Unable to persist state in DB")
return ctrl.Result{}, err
}

return ctrl.Result{}, nil
}

// Return name of runner
func (r *Runner) String() string {
return r.name
}
Expand All @@ -99,44 +113,45 @@ func (r *Runner) Launch(_ context.Context, excluded []string) ([]string, error)
// we will penalize them to increase delay for such cases

// Fetch all not finished trainings
notFinished, err := r.storage.GetModelTrainingList()

// Remove trainings that are probably already running (penalized trainings)
notFinished, err := r.odahuService.ListNotFinished()
if err != nil {
log.Error(err, "Error while fetch training list from DB")
return excluded, err
}

// Remove trainings that are probably already running (penalized trainings)
trains := make([]training.ModelTraining, 0)
items := make([]OdahuEntity, 0)
for _, nft := range notFinished {
elemExcluded := false
for _, et := range excluded {
if nft.ID == et {
if nft.GetID() == et {
elemExcluded = true
}
}
if !elemExcluded {
trains = append(trains, nft)
items = append(items, nft)
}
}

if (len(trains)) > 0 {
log.Info(fmt.Sprintf("%v not launched trains", len(trains)))
if (len(items)) > 0 {
log.Info(fmt.Sprintf("%v not launched items", len(items)))
}

// Launch trainings in parallel.
newPenalties := make(chan string, len(trains))
// We need drop entity status during PUT (Update)
newPenalties := make(chan string, len(items))
errCh := make(chan error)
for _, mt := range trains {
for _, mt := range items {
mt := mt
go func() {
err := r.service.CreateModelTraining(&mt)
err := mt.CreateInKube()
if err != nil {
if odahu_errs.IsAlreadyExistError(err) {
log.Info(fmt.Sprintf("%s training is already launched", mt.ID))
newPenalties <- mt.ID
log.Info(fmt.Sprintf("%s item is already launched", mt.GetID()))
newPenalties <- mt.GetID()
errCh <- nil
} else {
log.Error(err, fmt.Sprintf("Error while launch training %s", mt.ID))
log.Error(err, fmt.Sprintf("Error while launch item %s", mt.GetID()))
errCh <- err
}
} else {
Expand All @@ -147,7 +162,7 @@ func (r *Runner) Launch(_ context.Context, excluded []string) ([]string, error)

// Gather errors or nils
var resErr error
for i := 0; i < len(trains); i ++ {
for i := 0; i < len(items); i ++ {
cerr := <-errCh
if cerr != nil {
resErr = errors.New("one or more errors occurred while launch new trainings")
Expand All @@ -162,7 +177,7 @@ func (r *Runner) Launch(_ context.Context, excluded []string) ([]string, error)
}

if resErr != nil {
return excluded, errors.New("one or more errors occurred while launch new trainings")
return excluded, errors.New("one or more errors occurred while launch new items")
}
return excluded, nil
}
Expand All @@ -173,14 +188,12 @@ func (r *Runner) Run(ctx context.Context) (err error) {

// Attach controller to manager
// This controller get updates from K8S resource and update state in DB
err = ctrl.NewControllerManagedBy(r.mgr).
For(&odahuflowv1alpha1.ModelTraining{}).
Complete(r)
err = r.odahuService.AttachKubeMgr(r.mgr)
if err != nil{
log.Error(err, "unable to initialize controller")
return err
}
log.Info("Persistence controller for model training attached to kube manager")
log.Info("Persistence controller for attached to kube manager")

// We are launching new or hung trainings every `launchPeriod` seconds
t := time.NewTicker(r.launchPeriod)
Expand Down

0 comments on commit 91872f3

Please sign in to comment.