Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(databricks): support override Databricks instances #5358

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ func TestEndToEnd(t *testing.T) {
}
databricksConfig, err := utils.MarshalObjToStruct(databricksConfDict)
assert.NoError(t, err)
sparkJob := plugins.SparkJob{DatabricksConf: databricksConfig, DatabricksToken: "token", SparkConf: map[string]string{"spark.driver.bindAddress": "127.0.0.1"}}
sparkJob := plugins.SparkJob{
DatabricksConf: databricksConfig,
DatabricksToken: "token",
DatabricksInstance: "Foo",
SparkConf: map[string]string{"spark.driver.bindAddress": "127.0.0.1"},
}
st, err := utils.MarshalPbToStruct(&sparkJob)
assert.NoError(t, err)
inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
Expand Down
17 changes: 10 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,11 @@
}
}
databricksJob[sparkPythonTask] = map[string]interface{}{pythonFile: p.cfg.EntrypointFile, parameters: modifiedArgs}

data, err := p.sendRequest(create, databricksJob, token, "")
databricksInstance := p.cfg.DatabricksInstance
if sparkJob.DatabricksInstance != "" {
databricksInstance = sparkJob.DatabricksInstance
}
data, err := p.sendRequest(create, databricksJob, token, "", databricksInstance)
if err != nil {
return nil, nil, err
}
Expand All @@ -138,12 +141,12 @@
}
runID := fmt.Sprintf("%.0f", data["run_id"])

return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, nil, nil
return ResourceMetaWrapper{runID, databricksInstance, token}, nil, nil
}

func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
exec := taskCtx.ResourceMeta().(ResourceMetaWrapper)
res, err := p.sendRequest(get, nil, exec.Token, exec.RunID)
res, err := p.sendRequest(get, nil, exec.Token, exec.RunID, exec.DatabricksInstance)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -175,7 +178,7 @@
return nil
}
exec := taskCtx.ResourceMeta().(ResourceMetaWrapper)
_, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID)
_, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID, exec.DatabricksInstance)

Check warning on line 181 in flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go#L181

Added line #L181 was not covered by tests
if err != nil {
return err
}
Expand All @@ -184,11 +187,11 @@
return nil
}

func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token string, runID string) (map[string]interface{}, error) {
func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token, runID, databricksInstance string) (map[string]interface{}, error) {
var databricksURL string
// for mocking/testing purposes
if p.cfg.databricksEndpoint == "" {
databricksURL = fmt.Sprintf("https://%v%v", p.cfg.DatabricksInstance, databricksAPI)
databricksURL = fmt.Sprintf("https://%v%v", databricksInstance, databricksAPI)
} else {
databricksURL = fmt.Sprintf("%v%v", p.cfg.databricksEndpoint, databricksAPI)
}
Expand Down
12 changes: 6 additions & 6 deletions flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestSendRequest(t *testing.T) {
}

t.Run("create a Databricks job", func(t *testing.T) {
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.NotNil(t, data)
assert.Equal(t, "someID", data["id"])
assert.Equal(t, "someData", data["data"])
Expand All @@ -88,7 +88,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`{"message":"failed"}`)),
}, nil
}}
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.Nil(t, data)
assert.Equal(t, err.Error(), "failed to create Databricks job with error [failed]")
})
Expand All @@ -98,7 +98,7 @@ func TestSendRequest(t *testing.T) {
assert.Equal(t, req.Method, http.MethodPost)
return nil, errors.New("failed to send request")
}}
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.Nil(t, data)
assert.Equal(t, err.Error(), "failed to send request to Databricks platform with err: [failed to send request]")
})
Expand All @@ -111,7 +111,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`123`)),
}, nil
}}
data, err := plugin.sendRequest(create, databricksJob, token, "")
data, err := plugin.sendRequest(create, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.Nil(t, data)
assert.Equal(t, err.Error(), "failed to parse response with err: [json: cannot unmarshal number into Go value of type map[string]interface {}]")
})
Expand All @@ -124,7 +124,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)),
}, nil
}}
data, err := plugin.sendRequest(get, databricksJob, token, "")
data, err := plugin.sendRequest(get, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.NotNil(t, data)
assert.Nil(t, err)
})
Expand All @@ -137,7 +137,7 @@ func TestSendRequest(t *testing.T) {
Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)),
}, nil
}}
data, err := plugin.sendRequest(cancel, databricksJob, token, "")
data, err := plugin.sendRequest(cancel, databricksJob, token, "", plugin.cfg.DatabricksInstance)
assert.NotNil(t, data)
assert.Nil(t, err)
})
Expand Down
Loading