Skip to content

Commit

Permalink
Merge pull request #86 from c-bata/enqueue-trial
Browse files Browse the repository at this point in the history
Experimental support of EnqueueTrial.
  • Loading branch information
c-bata committed Mar 11, 2020
2 parents 66ce706 + 0e11a80 commit d9758c3
Show file tree
Hide file tree
Showing 12 changed files with 1,007 additions and 49 deletions.
6 changes: 6 additions & 0 deletions export_test.go
@@ -0,0 +1,6 @@
package goptuna

var (
ExportFrozenTrialValidate = (*FrozenTrial).validate
ExportStudyAppendTrial = (*Study).appendTrial
)
4 changes: 4 additions & 0 deletions rdb/converter.go
Expand Up @@ -134,6 +134,8 @@ func toStateExternalRepresentation(state string) (goptuna.TrialState, error) {
return goptuna.TrialStatePruned, nil
case trialStateFail:
return goptuna.TrialStateFail, nil
case trialStateWaiting:
return goptuna.TrialStateWaiting, nil
default:
return goptuna.TrialStateRunning, errors.New("invalid trial state")
}
Expand All @@ -149,6 +151,8 @@ func toStateInternalRepresentation(state goptuna.TrialState) (string, error) {
return trialStatePruned, nil
case goptuna.TrialStateFail:
return trialStateFail, nil
case goptuna.TrialStateWaiting:
return trialStateWaiting, nil
default:
return "", errors.New("invalid trial state")
}
Expand Down
1 change: 1 addition & 0 deletions rdb/model.go
Expand Up @@ -17,6 +17,7 @@ const (
trialStateComplete = "COMPLETE"
trialStatePruned = "PRUNED"
trialStateFail = "FAIL"
trialStateWaiting = "WAITING"
)

// https://gorm.io/docs/models.html
Expand Down
168 changes: 166 additions & 2 deletions rdb/storage.go
@@ -1,13 +1,13 @@
package rdb

import (
"fmt"
"strconv"
"time"

"github.com/c-bata/goptuna"
"github.com/google/uuid"
"github.com/jinzhu/gorm"

"github.com/c-bata/goptuna"
)

var _ goptuna.Storage = &Storage{}
Expand Down Expand Up @@ -248,6 +248,136 @@ func (s *Storage) CreateNewTrial(studyID int) (int, error) {
return trial.ID, err
}

// CloneTrial creates new Trial from the given base Trial.
func (s *Storage) CloneTrial(studyID int, baseTrial goptuna.FrozenTrial) (int, error) {
tx := s.db.Begin()
if tx.Error != nil {
return -1, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()

// Because only `RUNNING` trials can be updated,
// we temporarily set the state of the new trial to `RUNNING`.
// After all fields of the trial have been updated,
// the state is set to `template_trial.state`.
tempState := trialStateWaiting

trial := &trialModel{
TrialReferStudy: studyID,
State: tempState,
Value: baseTrial.Value,
DatetimeStart: &baseTrial.DatetimeStart,
DatetimeComplete: &baseTrial.DatetimeComplete,
}
if err := tx.Create(trial).Error; err != nil {
tx.Rollback()
return -1, err
}

// params
for name := range baseTrial.InternalParams {
d, ok := baseTrial.Distributions[name]
if !ok {
tx.Rollback()
return -1, fmt.Errorf("'%s' distribution is not found", name)
}
jsonBytes, err := goptuna.DistributionToJSON(d)
if err != nil {
tx.Rollback()
return -1, err
}
err = tx.Create(&trialParamModel{
TrialParamReferTrial: trial.ID,
Name: name,
Value: baseTrial.InternalParams[name],
DistributionJSON: string(jsonBytes),
}).Error
}

// user attrs
for key := range baseTrial.UserAttrs {
err := tx.Create(&trialUserAttributeModel{
UserAttributeReferTrial: trial.ID,
Key: key,
ValueJSON: encodeAttrValue(baseTrial.UserAttrs[key]),
}).Error
if err != nil {
tx.Rollback()
return -1, err
}
}

// system attrs
for key := range baseTrial.SystemAttrs {
if key == "_number" {
continue
}
err := tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: key,
ValueJSON: encodeAttrValue(baseTrial.SystemAttrs[key]),
}).Error
if err != nil {
tx.Rollback()
return -1, err
}
}

// intermediate values
for step := range baseTrial.IntermediateValues {
err := tx.Create(&trialValueModel{
TrialValueReferTrial: trial.ID,
Step: step,
Value: baseTrial.IntermediateValues[step],
}).Error
if err != nil {
tx.Rollback()
return -1, err
}
}

// state
state, err := toStateInternalRepresentation(baseTrial.State)
if err != nil {
tx.Rollback()
return -1, err
}
err = tx.Model(&trialModel{}).
Where("trial_id = ?", trial.ID).
Update("state", state).Error
if err != nil {
tx.Rollback()
return -1, err
}

// trial number
var number int
err = tx.Model(&trialModel{}).
Where("study_id = ?", studyID).
Where("trial_id < ?", trial.ID).
Count(&number).Error
if err != nil {
tx.Rollback()
return -1, err
}
err = tx.Create(&trialSystemAttributeModel{
SystemAttributeReferTrial: trial.ID,
Key: keyNumber,
ValueJSON: strconv.Itoa(number),
}).Error
if err != nil {
tx.Rollback()
return -1, err
}

err = tx.Commit().Error
return trial.ID, err
}

// SetTrialValue sets the value of trial.
func (s *Storage) SetTrialValue(trialID int, value float64) error {
tx := s.db.Begin()
Expand Down Expand Up @@ -375,6 +505,40 @@ func (s *Storage) SetTrialState(trialID int, state goptuna.TrialState) error {
return tx.Error
}

// TODO(c-bata): Add `FOR UPDATE` clause.
//
// result := tx.Set("gorm:query_option", "FOR UPDATE").
// First(&trial, "trial_id = ?", trialID)
//
// But SQLite3 doesn't interpret `FOR UPDATE` clause.
// SQLAlchemy can automatically remove it, but Gorm can't.
//
// Another solution is to add `EnableForUpdateClause=false` option.
//
// See following pages for the information.
// * https://github.com/optuna/optuna/pull/1014
// * http://gorm.io/docs/query.html#Extra-Querying-option
var trial trialModel
result := tx.First(&trial, "trial_id = ?", trialID)
if result.RecordNotFound() {
tx.Rollback()
return goptuna.ErrInvalidTrialID
}
if result.Error != nil {
tx.Rollback()
return result.Error
}

previousState, err := toStateExternalRepresentation(trial.State)
if err != nil {
tx.Rollback()
return err
}
if previousState.IsFinished() {
tx.Rollback()
return goptuna.ErrTrialCannotBeUpdated
}

err = tx.Model(&trialModel{}).
Where("trial_id = ?", trialID).
Update("state", xr).Error
Expand Down
107 changes: 107 additions & 0 deletions rdb/storage_test.go
Expand Up @@ -4,6 +4,7 @@ import (
"os"
"reflect"
"testing"
"time"

"github.com/c-bata/goptuna"

Expand Down Expand Up @@ -638,3 +639,109 @@ func TestStorage_SetTrialIntermediateValue(t *testing.T) {
t.Errorf("want two intermediate vales, but got %#v", trial.IntermediateValues)
}
}

func TestStorage_CloneTrial(t *testing.T) {
db, teardown, err := SetupSQLite3Test(t, "goptuna-test.db")
defer teardown()
if err != nil {
t.Errorf("failed to setup tests with %s", err)
return
}

storage := rdb.NewStorage(db)
now := time.Now()

baseTrial := goptuna.FrozenTrial{
ID: -1, // dummy value (unused)
Number: -1, // dummy value (unused)
State: goptuna.TrialStateComplete,
Value: 10000,
IntermediateValues: map[int]float64{
1: 10,
2: 100,
3: 1000,
},
DatetimeStart: now,
DatetimeComplete: now,
InternalParams: map[string]float64{
"x": 0.5,
},
Params: map[string]interface{}{
"x": 0.5,
},
Distributions: map[string]interface{}{
"x": goptuna.UniformDistribution{High: 1, Low: 0},
},
UserAttrs: map[string]string{
"foo": "bar",
},
SystemAttrs: map[string]string{
"baz": "123",
},
}

studyID, err := storage.CreateNewStudy("")
if err != nil {
t.Errorf("should be nil, but got %s", err)
return
}

err = storage.SetStudyDirection(studyID, goptuna.StudyDirectionMinimize)
if err != nil {
t.Errorf("should be nil, but got %s", err)
return
}

trialID, err := storage.CloneTrial(studyID, baseTrial)
if err != nil {
t.Errorf("should be nil, but got %s", err)
return
}

trials, err := storage.GetAllTrials(studyID)
if err != nil {
t.Errorf("should be nil, but got %s", err)
return
}

if len(trials) != 1 {
t.Errorf("should get one trial, but got %d", len(trials))
return
}

if trials[0].ID != trialID {
t.Errorf("trialID should be %d, but got %d", trialID, trials[0].ID)
}

if trials[0].Number != 0 {
t.Errorf("number should be 0, but got %d", trials[0].Number)
}

if trials[0].State != goptuna.TrialStateComplete {
t.Errorf("state should be complete, but got %s", trials[0].State)
}

if !reflect.DeepEqual(trials[0].Distributions, baseTrial.Distributions) {
t.Errorf("Distributions should be %v, but got %v", trials[0].Distributions, baseTrial.Distributions)
}

if !reflect.DeepEqual(trials[0].Params, baseTrial.Params) {
t.Errorf("Params should be %v, but got %v", trials[0].Params, baseTrial.Params)
}

if !reflect.DeepEqual(trials[0].InternalParams, baseTrial.InternalParams) {
t.Errorf("InternalParams should be %v, but got %v", trials[0].InternalParams, baseTrial.InternalParams)
}

if !reflect.DeepEqual(trials[0].IntermediateValues, baseTrial.IntermediateValues) {
t.Errorf("InternalValues should be %v, but got %v", trials[0].IntermediateValues, baseTrial.IntermediateValues)
}

if trials[0].DatetimeStart.Second() != baseTrial.DatetimeStart.Second() {
t.Errorf("DatetimeStart should be %s, but got %s", trials[0].DatetimeStart, baseTrial.DatetimeStart)
}

if trials[0].DatetimeComplete.Second() != baseTrial.DatetimeComplete.Second() {
t.Errorf("DatetimeComplete should be %s, but got %s", trials[0].DatetimeComplete, baseTrial.DatetimeComplete)
}
}

0 comments on commit d9758c3

Please sign in to comment.