Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use local Terraform state only when lineage match #1588

Merged
merged 15 commits into from
Jul 18, 2024
115 changes: 75 additions & 40 deletions bundle/deploy/terraform/state_pull.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package terraform

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"io/fs"
Expand All @@ -12,10 +12,14 @@ import (
"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/deploy"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/log"
)

type tfState struct {
Serial int64 `json:"serial"`
Lineage string `json:"lineage"`
}

type statePull struct {
filerFactory deploy.FilerFactory
}
Expand All @@ -24,74 +28,105 @@ func (l *statePull) Name() string {
return "terraform:state-pull"
}

func (l *statePull) remoteState(ctx context.Context, f filer.Filer) (*bytes.Buffer, error) {
// Download state file from filer to local cache directory.
remote, err := f.Read(ctx, TerraformStateFileName)
func (l *statePull) remoteState(ctx context.Context, b *bundle.Bundle) (*tfState, []byte, error) {
f, err := l.filerFactory(b)
if err != nil {
// On first deploy this state file doesn't yet exist.
if errors.Is(err, fs.ErrNotExist) {
return nil, nil
}
return nil, err
return nil, nil, err
}

defer remote.Close()

var buf bytes.Buffer
_, err = io.Copy(&buf, remote)
r, err := f.Read(ctx, TerraformStateFileName)
if err != nil {
return nil, err
return nil, nil, err
}
defer r.Close()

return &buf, nil
}
content, err := io.ReadAll(r)
if err != nil {
return nil, nil, err
}

func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
f, err := l.filerFactory(b)
state := &tfState{}
err = json.Unmarshal(content, state)
if err != nil {
return diag.FromErr(err)
return nil, nil, err
}

return state, content, nil
}

func (l *statePull) localState(ctx context.Context, b *bundle.Bundle) (*tfState, error) {
dir, err := Dir(ctx, b)
if err != nil {
return diag.FromErr(err)
return nil, err
}

// Download state file from filer to local cache directory.
log.Infof(ctx, "Opening remote state file")
remote, err := l.remoteState(ctx, f)
content, err := os.ReadFile(filepath.Join(dir, TerraformStateFileName))
if err != nil {
log.Infof(ctx, "Unable to open remote state file: %s", err)
return diag.FromErr(err)
return nil, err
}
if remote == nil {
log.Infof(ctx, "Remote state file does not exist")
return nil

state := &tfState{}
err = json.Unmarshal(content, state)
if err != nil {
return nil, err
}

// Expect the state file to live under dir.
local, err := os.OpenFile(filepath.Join(dir, TerraformStateFileName), os.O_CREATE|os.O_RDWR, 0600)
return state, nil
}

func (l *statePull) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
dir, err := Dir(ctx, b)
if err != nil {
return diag.FromErr(err)
}
defer local.Close()

if !IsLocalStateStale(local, bytes.NewReader(remote.Bytes())) {
log.Infof(ctx, "Local state is the same or newer, ignoring remote state")
localStatePath := filepath.Join(dir, TerraformStateFileName)

// Case: Remote state file does not exist. In this case we fallback to using the
// local Terraform state. This allows users to change the "root_path" their bundle is
// configured with.
remoteState, remoteContent, err := l.remoteState(ctx, b)
if errors.Is(err, fs.ErrNotExist) {
log.Infof(ctx, "Remote state file does not exist. Using local Terraform state.")
return nil
}
if err != nil {
return diag.Errorf("failed to read remote state file: %v", err)
}

// Truncating the file before writing
local.Truncate(0)
local.Seek(0, 0)
// Expected invariant: remote state file should have a lineage UUID. Error
// if that's not the case.
if remoteState.Lineage == "" {
return diag.Errorf("remote state file does not have a lineage")
}

// Write file to disk.
log.Infof(ctx, "Writing remote state file to local cache directory")
_, err = io.Copy(local, bytes.NewReader(remote.Bytes()))
// Case: Local state file does not exist. In this case we should rely on the remote state file.
localState, err := l.localState(ctx, b)
if errors.Is(err, fs.ErrNotExist) {
log.Infof(ctx, "Local state file does not exist. Using remote Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err)
}
if err != nil {
return diag.Errorf("failed to read local state file: %v", err)
}

// If the lineage does not match, the Terraform state files do not correspond to the same deployment.
if localState.Lineage != remoteState.Lineage {
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
log.Infof(ctx, "Remote and local state lineages do not match. Using remote Terraform state. Invalidating local Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err)
}

// If the remote state is newer than the local state, we should use the remote state.
if remoteState.Serial > localState.Serial {
log.Infof(ctx, "Remote state is newer than local state. Using remote Terraform state.")
err := os.WriteFile(localStatePath, remoteContent, 0600)
return diag.FromErr(err)
}

// default: local state is newer or equal to remote state in terms of serial sequence.
// It is also of the same lineage. Keep using the local state.
return nil
}

Expand Down
177 changes: 107 additions & 70 deletions bundle/deploy/terraform/state_pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/mock"
)

func mockStateFilerForPull(t *testing.T, contents map[string]int, merr error) filer.Filer {
func mockStateFilerForPull(t *testing.T, contents map[string]any, merr error) filer.Filer {
buf, err := json.Marshal(contents)
assert.NoError(t, err)

Expand All @@ -41,86 +41,123 @@ func statePullTestBundle(t *testing.T) *bundle.Bundle {
}
}

func TestStatePullLocalMissingRemoteMissing(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist)),
}
func TestStatePullLocalErrorWhenRemoteHasNoLineage(t *testing.T) {
m := &statePull{}

ctx := context.Background()
b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
t.Run("no local state", func(t *testing.T) {
// setup remote state.
m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil))

// Confirm that no local state file has been written.
_, err := os.Stat(localStateFile(t, ctx, b))
assert.ErrorIs(t, err, fs.ErrNotExist)
}
ctx := context.Background()
b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.EqualError(t, diags.Error(), "remote state file does not have a lineage")
})

func TestStatePullLocalMissingRemotePresent(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)),
}
t.Run("local state with lineage", func(t *testing.T) {
// setup remote state.
m.filerFactory = identityFiler(mockStateFilerForPull(t, map[string]any{"serial": 5}, nil))

ctx := context.Background()
b := statePullTestBundle(t)
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
ctx := context.Background()
b := statePullTestBundle(t)
writeLocalState(t, ctx, b, map[string]any{"serial": 5, "lineage": "aaaa"})

// Confirm that the local state file has been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
diags := bundle.Apply(ctx, b, m)
assert.EqualError(t, diags.Error(), "remote state file does not have a lineage")
})
}

func TestStatePullLocalStale(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5}, nil)),
}

ctx := context.Background()
b := statePullTestBundle(t)
func TestStatePullLocal(t *testing.T) {
tcases := []struct {
name string

// Write a stale local state file.
writeLocalState(t, ctx, b, map[string]int{"serial": 4})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
// remote state before applying the pull mutators
remote map[string]any

// Confirm that the local state file has been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
}
// local state before applying the pull mutators
local map[string]any

func TestStatePullLocalEqual(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)),
// expected local state after applying the pull mutators
expected map[string]any
}{
{
name: "remote missing, local missing",
remote: nil,
local: nil,
expected: nil,
},
{
name: "remote missing, local present",
remote: nil,
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// fallback to local state, since remote state is missing.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local stale",
remote: map[string]any{"serial": 10, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// use remote, since remote is newer.
expected: map[string]any{"serial": float64(10), "lineage": "aaaa", "some_other_key": float64(123)},
},
{
name: "local equal",
remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 5, "lineage": "aaaa"},
// use local state, since they are equal in terms of serial sequence.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local newer",
remote: map[string]any{"serial": 5, "lineage": "aaaa", "some_other_key": 123},
local: map[string]any{"serial": 6, "lineage": "aaaa"},
// use local state, since local is newer.
expected: map[string]any{"serial": float64(6), "lineage": "aaaa"},
},
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
{
name: "remote and local have different lineages",
remote: map[string]any{"serial": 5, "lineage": "aaaa"},
local: map[string]any{"serial": 10, "lineage": "bbbb"},
// use remote, since lineages do not match.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
{
name: "local is missing lineage",
remote: map[string]any{"serial": 5, "lineage": "aaaa"},
local: map[string]any{"serial": 10},
// use remote, since local does not have lineage.
expected: map[string]any{"serial": float64(5), "lineage": "aaaa"},
},
}

ctx := context.Background()
b := statePullTestBundle(t)

// Write a local state file with the same serial as the remote.
writeLocalState(t, ctx, b, map[string]int{"serial": 5})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())

// Confirm that the local state file has not been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 5}, localState)
}

func TestStatePullLocalNewer(t *testing.T) {
m := &statePull{
identityFiler(mockStateFilerForPull(t, map[string]int{"serial": 5, "some_other_key": 123}, nil)),
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
m := &statePull{}
if tc.remote == nil {
// nil represents no remote state file.
m.filerFactory = identityFiler(mockStateFilerForPull(t, nil, os.ErrNotExist))
} else {
m.filerFactory = identityFiler(mockStateFilerForPull(t, tc.remote, nil))
}

ctx := context.Background()
b := statePullTestBundle(t)
if tc.local != nil {
writeLocalState(t, ctx, b, tc.local)
}

diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())

if tc.expected == nil {
// nil represents no local state file is expected.
_, err := os.Stat(localStateFile(t, ctx, b))
assert.ErrorIs(t, err, fs.ErrNotExist)
} else {
localState := readLocalState(t, ctx, b)
assert.Equal(t, tc.expected, localState)

}
})
}

ctx := context.Background()
b := statePullTestBundle(t)

// Write a local state file with a newer serial as the remote.
writeLocalState(t, ctx, b, map[string]int{"serial": 6})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())

// Confirm that the local state file has not been updated.
localState := readLocalState(t, ctx, b)
assert.Equal(t, map[string]int{"serial": 6}, localState)
}
2 changes: 1 addition & 1 deletion bundle/deploy/terraform/state_push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestStatePush(t *testing.T) {
b := statePushTestBundle(t)

// Write a stale local state file.
writeLocalState(t, ctx, b, map[string]int{"serial": 4})
writeLocalState(t, ctx, b, map[string]any{"serial": 4})
diags := bundle.Apply(ctx, b, m)
assert.NoError(t, diags.Error())
}
Loading
Loading