Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow use of variables references in primitive non-string fields #1219

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion bundle/config/mutator/resolve_variable_references.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/dyn/convert"
"github.com/databricks/cli/libs/dyn/dynvar"
"github.com/databricks/cli/libs/log"
)

type resolveVariableReferences struct {
Expand Down Expand Up @@ -58,7 +59,7 @@ func (m *resolveVariableReferences) Apply(ctx context.Context, b *bundle.Bundle)
}

// Resolve variable references in all values.
return dynvar.Resolve(root, func(path dyn.Path) (dyn.Value, error) {
root, err := dynvar.Resolve(root, func(path dyn.Path) (dyn.Value, error) {
// Rewrite the shorthand path ${var.foo} into ${variables.foo.value}.
if path.HasPrefix(varPath) && len(path) == 2 {
path = dyn.NewPath(
Expand All @@ -77,5 +78,18 @@ func (m *resolveVariableReferences) Apply(ctx context.Context, b *bundle.Bundle)

return dyn.InvalidValue, dynvar.ErrSkipResolution
})
if err != nil {
return dyn.InvalidValue, err
}

// Normalize the result because variable resolution may have been applied to non-string fields.
// For example, a variable reference may have been resolved to a integer.
root, diags := convert.Normalize(b.Config, root)
for _, diag := range diags {
// This occurs when a variable's resolved value is incompatible with the field's type.
// Log a warning until we have a better way to surface these diagnostics to the user.
log.Warnf(ctx, "normalization diagnostic: %s", diag.Summary)
}
return root, nil
})
}
97 changes: 97 additions & 0 deletions bundle/config/mutator/resolve_variable_references_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/bundle/config/variable"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -95,3 +98,97 @@ func TestResolveVariableReferencesToEmptyFields(t *testing.T) {
// The job settings should have been interpolated to an empty string.
require.Equal(t, "", b.Config.Resources.Jobs["job1"].JobSettings.Tags["git_branch"])
}

func TestResolveVariableReferencesForPrimitiveNonStringFields(t *testing.T) {
var err error

b := &bundle.Bundle{
Config: config.Root{
Variables: map[string]*variable.Variable{
"no_alert_for_canceled_runs": {},
"no_alert_for_skipped_runs": {},
"min_workers": {},
"max_workers": {},
"spot_bid_max_price": {},
},
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"job1": {
JobSettings: &jobs.JobSettings{
NotificationSettings: &jobs.JobNotificationSettings{
NoAlertForCanceledRuns: false,
NoAlertForSkippedRuns: false,
},
Tasks: []jobs.Task{
{
NewCluster: &compute.ClusterSpec{
Autoscale: &compute.AutoScale{
MinWorkers: 0,
MaxWorkers: 0,
},
AzureAttributes: &compute.AzureAttributes{
SpotBidMaxPrice: 0.0,
},
},
},
},
},
},
},
},
},
}

ctx := context.Background()

// Initialize the variables.
err = bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) error {
return b.Config.InitializeVariables([]string{
"no_alert_for_canceled_runs=true",
"no_alert_for_skipped_runs=true",
"min_workers=1",
"max_workers=2",
"spot_bid_max_price=0.5",
})
})
require.NoError(t, err)

// Assign the variables to the dynamic configuration.
err = bundle.ApplyFunc(ctx, b, func(ctx context.Context, b *bundle.Bundle) error {
return b.Config.Mutate(func(v dyn.Value) (dyn.Value, error) {
var p dyn.Path
var err error

// Set the notification settings.
p = dyn.MustPathFromString("resources.jobs.job1.notification_settings")
v, err = dyn.SetByPath(v, p.Append(dyn.Key("no_alert_for_canceled_runs")), dyn.V("${var.no_alert_for_canceled_runs}"))
require.NoError(t, err)
v, err = dyn.SetByPath(v, p.Append(dyn.Key("no_alert_for_skipped_runs")), dyn.V("${var.no_alert_for_skipped_runs}"))
require.NoError(t, err)

// Set the min and max workers.
p = dyn.MustPathFromString("resources.jobs.job1.tasks[0].new_cluster.autoscale")
v, err = dyn.SetByPath(v, p.Append(dyn.Key("min_workers")), dyn.V("${var.min_workers}"))
require.NoError(t, err)
v, err = dyn.SetByPath(v, p.Append(dyn.Key("max_workers")), dyn.V("${var.max_workers}"))
require.NoError(t, err)

// Set the spot bid max price.
p = dyn.MustPathFromString("resources.jobs.job1.tasks[0].new_cluster.azure_attributes")
v, err = dyn.SetByPath(v, p.Append(dyn.Key("spot_bid_max_price")), dyn.V("${var.spot_bid_max_price}"))
require.NoError(t, err)

return v, nil
})
})
require.NoError(t, err)

// Apply for the variable prefix. This should resolve the variables to their values.
err = bundle.Apply(context.Background(), b, ResolveVariableReferences("variables"))
require.NoError(t, err)
assert.Equal(t, true, b.Config.Resources.Jobs["job1"].JobSettings.NotificationSettings.NoAlertForCanceledRuns)
assert.Equal(t, true, b.Config.Resources.Jobs["job1"].JobSettings.NotificationSettings.NoAlertForSkippedRuns)
assert.Equal(t, 1, b.Config.Resources.Jobs["job1"].JobSettings.Tasks[0].NewCluster.Autoscale.MinWorkers)
assert.Equal(t, 2, b.Config.Resources.Jobs["job1"].JobSettings.Tasks[0].NewCluster.Autoscale.MaxWorkers)
assert.Equal(t, 0.5, b.Config.Resources.Jobs["job1"].JobSettings.Tasks[0].NewCluster.AzureAttributes.SpotBidMaxPrice)
}
16 changes: 16 additions & 0 deletions libs/dyn/convert/from_typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"slices"

"github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/dyn/dynvar"
)

type fromTypedOptions int
Expand Down Expand Up @@ -185,6 +186,11 @@ func fromTypedBool(src reflect.Value, ref dyn.Value, options ...fromTypedOptions
return dyn.NilValue, nil
}
return dyn.V(src.Bool()), nil
case dyn.KindString:
// Ignore pure variable references (e.g. ${var.foo}).
if dynvar.IsPureVariableReference(ref.MustString()) {
return ref, nil
}
}

return dyn.InvalidValue, fmt.Errorf("unhandled type: %s", ref.Kind())
Expand All @@ -205,6 +211,11 @@ func fromTypedInt(src reflect.Value, ref dyn.Value, options ...fromTypedOptions)
return dyn.NilValue, nil
}
return dyn.V(src.Int()), nil
case dyn.KindString:
// Ignore pure variable references (e.g. ${var.foo}).
if dynvar.IsPureVariableReference(ref.MustString()) {
return ref, nil
}
}

return dyn.InvalidValue, fmt.Errorf("unhandled type: %s", ref.Kind())
Expand All @@ -225,6 +236,11 @@ func fromTypedFloat(src reflect.Value, ref dyn.Value, options ...fromTypedOption
return dyn.NilValue, nil
}
return dyn.V(src.Float()), nil
case dyn.KindString:
// Ignore pure variable references (e.g. ${var.foo}).
if dynvar.IsPureVariableReference(ref.MustString()) {
return ref, nil
}
}

return dyn.InvalidValue, fmt.Errorf("unhandled type: %s", ref.Kind())
Expand Down
24 changes: 24 additions & 0 deletions libs/dyn/convert/from_typed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,14 @@ func TestFromTypedBoolRetainsLocationsIfUnchanged(t *testing.T) {
assert.Equal(t, dyn.NewValue(true, dyn.Location{File: "foo"}), nv)
}

func TestFromTypedBoolVariableReference(t *testing.T) {
var src bool = true
var ref = dyn.V("${var.foo}")
nv, err := FromTyped(src, ref)
require.NoError(t, err)
assert.Equal(t, dyn.V("${var.foo}"), nv)
}

func TestFromTypedBoolTypeError(t *testing.T) {
var src bool = true
var ref = dyn.V("string")
Expand Down Expand Up @@ -542,6 +550,14 @@ func TestFromTypedIntRetainsLocationsIfUnchanged(t *testing.T) {
assert.Equal(t, dyn.NewValue(1234, dyn.Location{File: "foo"}), nv)
}

func TestFromTypedIntVariableReference(t *testing.T) {
var src int = 1234
var ref = dyn.V("${var.foo}")
nv, err := FromTyped(src, ref)
require.NoError(t, err)
assert.Equal(t, dyn.V("${var.foo}"), nv)
}

func TestFromTypedIntTypeError(t *testing.T) {
var src int = 1234
var ref = dyn.V("string")
Expand Down Expand Up @@ -589,6 +605,14 @@ func TestFromTypedFloatRetainsLocationsIfUnchanged(t *testing.T) {
assert.Equal(t, dyn.NewValue(1.23, dyn.Location{File: "foo"}), nv)
}

func TestFromTypedFloatVariableReference(t *testing.T) {
var src float64 = 1.23
var ref = dyn.V("${var.foo}")
nv, err := FromTyped(src, ref)
require.NoError(t, err)
assert.Equal(t, dyn.V("${var.foo}"), nv)
}

func TestFromTypedFloatTypeError(t *testing.T) {
var src float64 = 1.23
var ref = dyn.V("string")
Expand Down
16 changes: 16 additions & 0 deletions libs/dyn/convert/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/dyn/dynvar"
)

// NormalizeOption is the type for options that can be passed to Normalize.
Expand Down Expand Up @@ -245,6 +246,11 @@ func (n normalizeOptions) normalizeBool(typ reflect.Type, src dyn.Value) (dyn.Va
case "false", "n", "N", "no", "No", "NO", "off", "Off", "OFF":
out = false
default:
// Return verbatim if it's a pure variable reference.
if dynvar.IsPureVariableReference(src.MustString()) {
return src, nil
}

// Cannot interpret as a boolean.
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindBool, src))
}
Expand All @@ -266,6 +272,11 @@ func (n normalizeOptions) normalizeInt(typ reflect.Type, src dyn.Value) (dyn.Val
var err error
out, err = strconv.ParseInt(src.MustString(), 10, 64)
if err != nil {
// Return verbatim if it's a pure variable reference.
if dynvar.IsPureVariableReference(src.MustString()) {
return src, nil
}

return dyn.InvalidValue, diags.Append(diag.Diagnostic{
Severity: diag.Error,
Summary: fmt.Sprintf("cannot parse %q as an integer", src.MustString()),
Expand All @@ -290,6 +301,11 @@ func (n normalizeOptions) normalizeFloat(typ reflect.Type, src dyn.Value) (dyn.V
var err error
out, err = strconv.ParseFloat(src.MustString(), 64)
if err != nil {
// Return verbatim if it's a pure variable reference.
if dynvar.IsPureVariableReference(src.MustString()) {
return src, nil
}

return dyn.InvalidValue, diags.Append(diag.Diagnostic{
Severity: diag.Error,
Summary: fmt.Sprintf("cannot parse %q as a floating point number", src.MustString()),
Expand Down
24 changes: 24 additions & 0 deletions libs/dyn/convert/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ func TestNormalizeBoolFromString(t *testing.T) {
}
}

func TestNormalizeBoolFromStringVariableReference(t *testing.T) {
var typ bool
vin := dyn.V("${var.foo}")
vout, err := Normalize(&typ, vin)
assert.Empty(t, err)
assert.Equal(t, vin, vout)
}

func TestNormalizeBoolFromStringError(t *testing.T) {
var typ bool
vin := dyn.V("abc")
Expand Down Expand Up @@ -542,6 +550,14 @@ func TestNormalizeIntFromString(t *testing.T) {
assert.Equal(t, dyn.V(int64(123)), vout)
}

func TestNormalizeIntFromStringVariableReference(t *testing.T) {
var typ int
vin := dyn.V("${var.foo}")
vout, err := Normalize(&typ, vin)
assert.Empty(t, err)
assert.Equal(t, vin, vout)
}

func TestNormalizeIntFromStringError(t *testing.T) {
var typ int
vin := dyn.V("abc")
Expand Down Expand Up @@ -594,6 +610,14 @@ func TestNormalizeFloatFromString(t *testing.T) {
assert.Equal(t, dyn.V(1.2), vout)
}

func TestNormalizeFloatFromStringVariableReference(t *testing.T) {
var typ float64
vin := dyn.V("${var.foo}")
vout, err := Normalize(&typ, vin)
assert.Empty(t, err)
assert.Equal(t, vin, vout)
}

func TestNormalizeFloatFromStringError(t *testing.T) {
var typ float64
vin := dyn.V("abc")
Expand Down
16 changes: 16 additions & 0 deletions libs/dyn/convert/to_typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strconv"

"github.com/databricks/cli/libs/dyn"
"github.com/databricks/cli/libs/dyn/dynvar"
)

func ToTyped(dst any, src dyn.Value) error {
Expand Down Expand Up @@ -195,6 +196,11 @@ func toTypedBool(dst reflect.Value, src dyn.Value) error {
dst.SetBool(false)
return nil
}
// Ignore pure variable references (e.g. ${var.foo}).
if dynvar.IsPureVariableReference(src.MustString()) {
dst.SetZero()
return nil
}
}

return TypeError{
Expand All @@ -213,6 +219,11 @@ func toTypedInt(dst reflect.Value, src dyn.Value) error {
dst.SetInt(i64)
return nil
}
// Ignore pure variable references (e.g. ${var.foo}).
if dynvar.IsPureVariableReference(src.MustString()) {
dst.SetZero()
return nil
}
}

return TypeError{
Expand All @@ -231,6 +242,11 @@ func toTypedFloat(dst reflect.Value, src dyn.Value) error {
dst.SetFloat(f64)
return nil
}
// Ignore pure variable references (e.g. ${var.foo}).
if dynvar.IsPureVariableReference(src.MustString()) {
dst.SetZero()
return nil
}
}

return TypeError{
Expand Down
Loading
Loading