Skip to content

Commit

Permalink
Merge pull request #29 from c-bata/handle-signals
Browse files Browse the repository at this point in the history
Provide GetContext from Trial object for handling signals.
  • Loading branch information
c-bata committed Aug 14, 2019
2 parents 45a64b6 + 42cab38 commit d071280
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 2 deletions.
19 changes: 19 additions & 0 deletions _examples/signalhandling/check.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/sh

export GO111MODULE=on
DIR=$(cd $(dirname $0); pwd)
REPOSITORY_ROOT=$(cd $(dirname $(dirname $(dirname $0))); pwd)

rm db.sqlite3

gtimeout 6 go run ${DIR}/main.go sqlite3 db.sqlite3 # brew install coreutils

echo ""
echo "*** check trials ***"
echo ""

sqlite3 db.sqlite3 <<END_SQL
.header on
.mode column
select * from trials;
END_SQL
93 changes: 93 additions & 0 deletions _examples/signalhandling/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package main

import (
"context"
"fmt"
"math"
"os"
"os/exec"
"os/signal"
"sync"
"syscall"

"github.com/c-bata/goptuna"
"github.com/c-bata/goptuna/rdb"
"github.com/jinzhu/gorm"
"go.uber.org/zap"

_ "github.com/jinzhu/gorm/dialects/sqlite"
)

func objective(trial goptuna.Trial) (float64, error) {
ctx := trial.GetContext()

x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)

cmd := exec.CommandContext(ctx, "sleep", "1")
err := cmd.Run()
if err != nil {
return -1, err
}
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

func main() {
logger, err := zap.NewDevelopment()
if err != nil {
os.Exit(1)
}
defer logger.Sync()

db, err := gorm.Open("sqlite3", "db.sqlite3")
if err != nil {
logger.Fatal("failed to open db", zap.Error(err))
}
defer db.Close()
rdb.RunAutoMigrate(db)

// create a study
study, err := goptuna.CreateStudy(
"goptuna-example",
goptuna.StudyOptionStorage(rdb.NewStorage(db)),
goptuna.StudyOptionSetDirection(goptuna.StudyDirectionMinimize),
goptuna.StudyOptionSetLogger(logger),
)
if err != nil {
logger.Fatal("failed to create study", zap.Error(err))
}

// create a context with cancel function
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
study.WithContext(ctx)

// set signal handler
signalch := make(chan os.Signal, 1)
defer close(signalch)
signal.Notify(signalch, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)

// run optimize with context
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
sig, ok := <-signalch
if !ok {
return
}
cancel()
logger.Error("Catch a kill signal", zap.String("signal", sig.String()))
}()
go func() {
defer wg.Done()
err = study.Optimize(objective, 10)
}()
wg.Wait()
if err != nil {
logger.Fatal("got error while run optimize", zap.Error(err))
}

v, _ := study.GetBestValue()
fmt.Println("Best evaluation value:", v)
}
6 changes: 5 additions & 1 deletion study.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ func (s *Study) Optimize(objective FuncObjective, evaluateMax int) error {
if s.ctx != nil {
select {
case <-s.ctx.Done():
return s.ctx.Err()
err := s.ctx.Err()
if err != nil && s.logger != nil {
s.logger.Debug("context is canceled", zap.Error(err))
}
return err
default:
// do nothing
}
Expand Down
10 changes: 9 additions & 1 deletion trial.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package goptuna

import "errors"
import (
"context"
"errors"
)

//go:generate stringer -trimprefix TrialState -output stringer_trial_state.go -type=TrialState

Expand Down Expand Up @@ -162,3 +165,8 @@ func (t *Trial) GetUserAttrs() (map[string]string, error) {
func (t *Trial) GetSystemAttrs() (map[string]string, error) {
return t.Study.Storage.GetTrialSystemAttrs(t.ID)
}

// GetContext returns a context which is registered at 'study.WithContext()'.
func (t *Trial) GetContext() context.Context {
return t.Study.ctx
}

0 comments on commit d071280

Please sign in to comment.