Skip to content

Commit

Permalink
add more flexibility to training
Browse files Browse the repository at this point in the history
  • Loading branch information
bourdakos1 committed Jan 5, 2020
1 parent 2fe1a13 commit 2770f08
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
6 changes: 5 additions & 1 deletion cacli/cmd/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func init() {
trainCmd.Flags().String("name", "", "Optional project name")
trainCmd.Flags().String("output", "", "Optional output bucket")
trainCmd.Flags().Int("steps", 1000, "Number of training steps")
trainCmd.Flags().String("gpu", "k80", "k80 | k80x2 | k80x4 | v100 | v100x2")
trainCmd.Flags().String("gpu", "k80", "k80x2, k80x4, v100, v100x2")
trainCmd.Flags().String("script", "", "Custom training script.zip")

trainCmd.Flags().String("framework", "tensorflow", "keras, pytorch, caffe")
trainCmd.Flags().String("frameworkv", "1.14", "Framework version")
trainCmd.Flags().String("pythonv", "3.6", "Python version")
}
6 changes: 5 additions & 1 deletion cacli/cmd/train/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func Run(cmd *cobra.Command, args []string) {
gpu, err := cmd.Flags().GetString("gpu")
script, err := cmd.Flags().GetString("script")

framework, err := cmd.Flags().GetString("framework")
frameworkVersion, err := cmd.Flags().GetString("frameworkv")
pythonVersion, err := cmd.Flags().GetString("pythonv")

if err != nil {
e.Exit(err)
}
Expand Down Expand Up @@ -109,7 +113,7 @@ func Run(cmd *cobra.Command, args []string) {
// if non default steps, include it in project name.
projectName = projectName + " (" + strconv.Itoa(steps) + ")"
}
model, err := session.StartTraining(script, projectName, trainingBucket, outputBucket, steps, gpu)
model, err := session.StartTraining(script, projectName, trainingBucket, outputBucket, steps, gpu, framework, frameworkVersion, pythonVersion)
if err != nil {
e.Exit(err)
}
Expand Down
8 changes: 4 additions & 4 deletions cacli/ibmcloud/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,17 +332,17 @@ func (s *AccountSession) CreateCredential(params CreateCredentialParams) (*Crede
return credential, nil
}

func (s *CredentialSession) StartTraining(trainingZip string, projectName string, bucket *s3.BucketExtended, output *s3.BucketExtended, steps int, gpu string) (*Model, error) {
func (s *CredentialSession) StartTraining(trainingZip string, projectName string, bucket *s3.BucketExtended, output *s3.BucketExtended, steps int, gpu string, framework string, frameworkVersion string, pythonVersion string) (*Model, error) {
// TODO: We shouldn't hard code all of this.
trainingDefinition := &TrainingDefinition{
Name: projectName,
Framework: Framework{
Name: "tensorflow",
Version: "1.12",
Name: framework,
Version: frameworkVersion,
Runtimes: []Runtimes{
Runtimes{
Name: "python",
Version: "3.6",
Version: pythonVersion,
},
},
},
Expand Down

0 comments on commit 2770f08

Please sign in to comment.