diff --git a/pkg/api/v1/atlascluster_types.go b/pkg/api/v1/atlascluster_types.go index d64ece104b..1dee8fec50 100644 --- a/pkg/api/v1/atlascluster_types.go +++ b/pkg/api/v1/atlascluster_types.go @@ -31,7 +31,6 @@ func init() { } type ProviderName string -type ClusterType string const ( ProviderAWS ProviderName = "AWS" @@ -40,6 +39,8 @@ const ( ProviderTenant ProviderName = "TENANT" ) +type ClusterType string + const ( TypeReplicaSet ClusterType = "REPLICASET" TypeSharded ClusterType = "SHARDED" diff --git a/pkg/controller/atlascluster/cluster.go b/pkg/controller/atlascluster/cluster.go index 12d3afb74c..5d0ec914e5 100644 --- a/pkg/controller/atlascluster/cluster.go +++ b/pkg/controller/atlascluster/cluster.go @@ -40,31 +40,31 @@ func (r *AtlasClusterReconciler) ensureClusterState(ctx *workflow.Context, proje switch c.StateName { case "IDLE": - if done, err := clusterMatchesSpec(ctx.Log, c, cluster.Spec); err != nil { + resultingCluster, err := mergedCluster(*c, cluster.Spec) + if err != nil { return c, workflow.Terminate(workflow.Internal, err.Error()) - } else if done { - return c, workflow.OK() } - spec, err := cluster.Spec.Cluster() - if err != nil { - return c, workflow.Terminate(workflow.Internal, err.Error()) + if done := clustersEqual(ctx.Log, *c, resultingCluster); done { + return c, workflow.OK() } if cluster.Spec.Paused != nil { if c.Paused == nil || *c.Paused != *cluster.Spec.Paused { // paused is different from Atlas // we need to first send a special (un)pause request before reconciling everything else - spec = &mongodbatlas.Cluster{ + resultingCluster = mongodbatlas.Cluster{ Paused: cluster.Spec.Paused, } } else { // otherwise, don't send the paused field - spec.Paused = nil + resultingCluster.Paused = nil } } - c, _, err = ctx.Client.Clusters.Update(context.Background(), project.Status.ID, cluster.Spec.Name, spec) + resultingCluster = cleanupCluster(resultingCluster) + + c, _, err = ctx.Client.Clusters.Update(context.Background(), project.Status.ID, cluster.Spec.Name, &resultingCluster) if err != nil { return c, workflow.Terminate(workflow.ClusterNotUpdatedInAtlas, err.Error()) } @@ -84,22 +84,49 @@ func (r *AtlasClusterReconciler) ensureClusterState(ctx *workflow.Context, proje } } -// clusterMatchesSpec will merge everything from the Spec into existing Cluster and use that to detect change. -// Direct comparison is not feasible because Atlas will set a lot of fields to default values, so we need to apply our changes on top of that. -func clusterMatchesSpec(log *zap.SugaredLogger, cluster *mongodbatlas.Cluster, spec mdbv1.AtlasClusterSpec) (bool, error) { - clusterMerged := mongodbatlas.Cluster{} - if err := compat.JSONCopy(&clusterMerged, cluster); err != nil { - return false, err +// cleanupCluster will unset some fields that cannot be changed via API or are deprecated. +func cleanupCluster(cluster mongodbatlas.Cluster) mongodbatlas.Cluster { + cluster.ID = "" + cluster.MongoDBVersion = "" + cluster.MongoURI = "" + cluster.MongoURIUpdated = "" + cluster.MongoURIWithOptions = "" + cluster.SrvAddress = "" + cluster.StateName = "" + cluster.ReplicationFactor = nil + cluster.ReplicationSpec = nil + cluster.ConnectionStrings = nil + return cluster +} + +// mergedCluster will return the result of merging AtlasClusterSpec with Atlas Cluster +func mergedCluster(cluster mongodbatlas.Cluster, spec mdbv1.AtlasClusterSpec) (result mongodbatlas.Cluster, err error) { + if err = compat.JSONCopy(&result, cluster); err != nil { + return + } + + if err = compat.JSONCopy(&result, spec); err != nil { + return } - if err := compat.JSONCopy(&clusterMerged, spec); err != nil { - return false, err + // TODO: might need to do this with other slices + if err = compat.JSONSliceMerge(&result.ReplicationSpecs, cluster.ReplicationSpecs); err != nil { + return } - d := cmp.Diff(*cluster, clusterMerged, cmpopts.EquateEmpty()) + if err = compat.JSONSliceMerge(&result.ReplicationSpecs, spec.ReplicationSpecs); err != nil { + return + } + + return +} + +// clustersEqual compares two Atlas Clusters +func clustersEqual(log *zap.SugaredLogger, clusterA mongodbatlas.Cluster, clusterB mongodbatlas.Cluster) bool { + d := cmp.Diff(clusterA, clusterB, cmpopts.EquateEmpty()) if d != "" { - log.Debugf("Cluster differs from spec: %s", d) + log.Debugf("Clusters are different: %s", d) } - return d == "", nil + return d == "" } diff --git a/pkg/controller/atlascluster/cluster_test.go b/pkg/controller/atlascluster/cluster_test.go index 82a22b6078..c7de050d27 100644 --- a/pkg/controller/atlascluster/cluster_test.go +++ b/pkg/controller/atlascluster/cluster_test.go @@ -12,29 +12,34 @@ import ( func TestClusterMatchesSpec(t *testing.T) { t.Run("Clusters match (enums)", func(t *testing.T) { - atlasClusterEnum := mongodbatlas.Cluster{ + atlasCluster := mongodbatlas.Cluster{ ProviderSettings: &mongodbatlas.ProviderSettings{ ProviderName: "AWS", }, ClusterType: "GEOSHARDED", } - operatorClusterEnum := mdbv1.AtlasClusterSpec{ + operatorCluster := mdbv1.AtlasClusterSpec{ ProviderSettings: &mdbv1.ProviderSettingsSpec{ ProviderName: mdbv1.ProviderAWS, }, ClusterType: mdbv1.TypeGeoSharded, } - match, err := clusterMatchesSpec(zap.S(), &atlasClusterEnum, operatorClusterEnum) + merged, err := mergedCluster(atlasCluster, operatorCluster) assert.NoError(t, err) - assert.True(t, match) + + equal := clustersEqual(zap.S(), atlasCluster, merged) + assert.True(t, equal) }) + t.Run("Clusters don't match (enums)", func(t *testing.T) { atlasClusterEnum := mongodbatlas.Cluster{ClusterType: "GEOSHARDED"} operatorClusterEnum := mdbv1.AtlasClusterSpec{ClusterType: mdbv1.TypeReplicaSet} - match, err := clusterMatchesSpec(zap.S(), &atlasClusterEnum, operatorClusterEnum) + merged, err := mergedCluster(atlasClusterEnum, operatorClusterEnum) assert.NoError(t, err) - assert.False(t, match) + + equal := clustersEqual(zap.S(), atlasClusterEnum, merged) + assert.False(t, equal) }) } diff --git a/pkg/util/compat/json_slice_merge.go b/pkg/util/compat/json_slice_merge.go new file mode 100644 index 0000000000..ed12c5f87e --- /dev/null +++ b/pkg/util/compat/json_slice_merge.go @@ -0,0 +1,59 @@ +package compat + +import ( + "errors" + "fmt" + "reflect" +) + +// JSONSliceMerge will merge two slices using JSONCopy according to these rules: +// +// 1. If `dst` and `src` are the same length, all elements are merged +// +// 2. If `dst` is longer, only the first `len(src)` elements are merged +// +// 3. If `src` is longer, first `len(dst)` elements are merged, then remaining elements are appended to `dst` +func JSONSliceMerge(dst, src interface{}) error { + dstVal := reflect.ValueOf(dst) + srcVal := reflect.ValueOf(src) + + if dstVal.Kind() != reflect.Ptr { + return errors.New("dst must be a pointer to slice") + } + + dstVal = reflect.Indirect(dstVal) + srcVal = reflect.Indirect(srcVal) + + if dstVal.Kind() != reflect.Slice { + return errors.New("dst must be pointing to a slice") + } + + if srcVal.Kind() != reflect.Slice { + return errors.New("src must be a slice or a pointer to slice") + } + + minLen := dstVal.Len() + if srcVal.Len() < minLen { + minLen = srcVal.Len() + } + + // merge common elements + for i := 0; i < minLen; i++ { + dstX := dstVal.Index(i).Addr().Interface() + if err := JSONCopy(dstX, srcVal.Index(i).Interface()); err != nil { + return fmt.Errorf("cannot copy value at index %d: %w", i, err) + } + } + + // append extra elements (if any) + dstType := reflect.TypeOf(dst).Elem().Elem() + for i := minLen; i < srcVal.Len(); i++ { + newVal := reflect.New(dstType).Interface() + if err := JSONCopy(&newVal, srcVal.Index(i).Interface()); err != nil { + return fmt.Errorf("cannot copy value at index %d: %w", i, err) + } + dstVal.Set(reflect.Append(dstVal, reflect.ValueOf(newVal).Elem())) + } + + return nil +} diff --git a/pkg/util/compat/json_slice_merge_test.go b/pkg/util/compat/json_slice_merge_test.go new file mode 100644 index 0000000000..3892de7b37 --- /dev/null +++ b/pkg/util/compat/json_slice_merge_test.go @@ -0,0 +1,98 @@ +package compat_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + . "github.com/mongodb/mongodb-atlas-kubernetes/pkg/util/compat" +) + +func TestJSONSliceMerge(t *testing.T) { + type Item struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + } + + type OtherItem struct { + OtherID string `json:"id,omitempty"` + OtherName string `json:"name,omitempty"` + } + + tests := []struct { + name string + dst, src, expected interface{} + expectedError error + }{ + { + name: "src is longer", + dst: &[]*Item{ + {"00001", "dst1"}, + {"00002", "dst2"}, + {"00003", "dst3"}, + }, + src: []OtherItem{ // copying from different element type + {"99999", "src1"}, // different key, different value + {"", "src2"}, // no key, different value + {"", ""}, // no key, no value + {"12345", "extra"}, // extra value + }, + expected: &[]*Item{ // kept dst element type + {"99999", "src1"}, // key & value replaced by src + {"00002", "src2"}, // only value replaced by src + {"00003", "dst3"}, // untouched + {"12345", "extra"}, // appended from src + }, + }, + { + name: "dst is longer", + dst: &[]*Item{ + {"00001", "dst1"}, + {"00002", "dst2"}, + {"00003", "dst3"}, + }, + src: []OtherItem{ + {"99999", "src1"}, + }, + expected: &[]*Item{ + {"99999", "src1"}, // key & value replaced by src + {"00002", "dst2"}, // untouched + {"00003", "dst3"}, // untouched + }, + }, + { + name: "src is nil", + dst: &[]*Item{ + {"00001", "dst1"}, + {"00002", "dst2"}, + {"00003", "dst3"}, + }, + src: nil, + expectedError: errors.New("src must be a slice or a pointer to slice"), + expected: &[]*Item{ + {"00001", "dst1"}, // untouched + {"00002", "dst2"}, // untouched + {"00003", "dst3"}, // untouched + }, + }, + { + name: "dst is nil", + dst: nil, + expectedError: errors.New("dst must be a pointer to slice"), + src: []OtherItem{ + {"99999", "src1"}, + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + err := JSONSliceMerge(tt.dst, tt.src) + require.Equal(tt.expectedError, err) + require.Equal(tt.expected, tt.dst) + }) + } +} diff --git a/test/int/cluster_test.go b/test/int/cluster_test.go index 4429576823..c9b720a593 100644 --- a/test/int/cluster_test.go +++ b/test/int/cluster_test.go @@ -112,15 +112,111 @@ var _ = Describe("AtlasCluster", func() { }) } - performUpdate := func() { + performUpdate := func(timeout interface{}) { Expect(k8sClient.Update(context.Background(), createdCluster)).To(Succeed()) Eventually(testutil.WaitFor(k8sClient, createdCluster, status.TrueCondition(status.ReadyType), validateClusterUpdatingFunc()), - 1200, interval).Should(BeTrue()) + timeout, interval).Should(BeTrue()) lastGeneration++ } + Describe("Create cluster & change ReplicationSpecs", func() { + It("Should Succeed", func() { + expectedCluster := mdbv1.DefaultGCPCluster(namespace.Name, createdProject.Name) + + By(fmt.Sprintf("Creating the Cluster %s", kube.ObjectKeyFromObject(expectedCluster)), func() { + createdCluster.ObjectMeta = expectedCluster.ObjectMeta + Expect(k8sClient.Create(context.Background(), expectedCluster)).ToNot(HaveOccurred()) + + Eventually(testutil.WaitFor(k8sClient, createdCluster, status.TrueCondition(status.ReadyType), validateClusterCreatingFunc()), + 30*time.Minute, interval).Should(BeTrue()) + + doCommonChecks() + checkAtlasState() + }) + + By("Updating ReplicationSpecs", func() { + createdCluster.Spec.ReplicationSpecs = append(createdCluster.Spec.ReplicationSpecs, mdbv1.ReplicationSpec{ + NumShards: int64ptr(2), + }) + createdCluster.Spec.ClusterType = "SHARDED" + performUpdate(40 * time.Minute) + doCommonChecks() + checkAtlasState() + }) + }) + }) + + Describe("Create cluster & increase DiskSizeGB", func() { + It("Should Succeed", func() { + expectedCluster := mdbv1.DefaultGCPCluster(namespace.Name, createdProject.Name) + + By(fmt.Sprintf("Creating the Cluster %s", kube.ObjectKeyFromObject(expectedCluster)), func() { + createdCluster.ObjectMeta = expectedCluster.ObjectMeta + Expect(k8sClient.Create(context.Background(), expectedCluster)).ToNot(HaveOccurred()) + + Eventually(testutil.WaitFor(k8sClient, createdCluster, status.TrueCondition(status.ReadyType), validateClusterCreatingFunc()), + 1800, interval).Should(BeTrue()) + + doCommonChecks() + checkAtlasState() + }) + + By("Increasing InstanceSize", func() { + createdCluster.Spec.ProviderSettings.InstanceSizeName = "M30" + performUpdate(40 * time.Minute) + doCommonChecks() + checkAtlasState() + }) + }) + }) + + Describe("Create cluster & change it to GEOSHARDED", func() { + It("Should Succeed", func() { + expectedCluster := mdbv1.DefaultGCPCluster(namespace.Name, createdProject.Name) + + By(fmt.Sprintf("Creating the Cluster %s", kube.ObjectKeyFromObject(expectedCluster)), func() { + createdCluster.ObjectMeta = expectedCluster.ObjectMeta + Expect(k8sClient.Create(context.Background(), expectedCluster)).ToNot(HaveOccurred()) + + Eventually(testutil.WaitFor(k8sClient, createdCluster, status.TrueCondition(status.ReadyType), validateClusterCreatingFunc()), + 1800, interval).Should(BeTrue()) + + doCommonChecks() + checkAtlasState() + }) + + By("Change cluster to GEOSHARDED", func() { + createdCluster.Spec.ClusterType = "GEOSHARDED" + createdCluster.Spec.ProviderSettings.RegionName = "" + createdCluster.Spec.ReplicationSpecs = []mdbv1.ReplicationSpec{ + { + NumShards: int64ptr(1), + ZoneName: "Zone 1", + RegionsConfig: map[string]mdbv1.RegionsConfig{ + "EASTERN_US": { + AnalyticsNodes: int64ptr(1), + ElectableNodes: int64ptr(2), + Priority: int64ptr(7), + ReadOnlyNodes: int64ptr(0), + }, + "WESTERN_US": { + AnalyticsNodes: int64ptr(0), + ElectableNodes: int64ptr(1), + Priority: int64ptr(6), + ReadOnlyNodes: int64ptr(0), + }, + }, + }, + } + performUpdate(80 * time.Minute) + doCommonChecks() + checkAtlasState() + }) + }) + }) + Describe("Create/Update the cluster", func() { It("Should fail, then be fixed", func() { createdCluster = mdbv1.DefaultGCPCluster(namespace.Name, createdProject.Name).WithAtlasName("") @@ -150,7 +246,7 @@ var _ = Describe("AtlasCluster", func() { Expect(k8sClient.Update(context.Background(), createdCluster)).To(Succeed()) Eventually(testutil.WaitFor(k8sClient, createdCluster, status.TrueCondition(status.ReadyType)), - 1200, interval).Should(BeTrue()) + 20*time.Minute, interval).Should(BeTrue()) doCommonChecks() checkAtlasState() @@ -172,14 +268,14 @@ var _ = Describe("AtlasCluster", func() { By("Updating the Cluster labels", func() { createdCluster.Spec.Labels = []mdbv1.LabelSpec{{Key: "int-test", Value: "true"}} - performUpdate() + performUpdate(20 * time.Minute) doCommonChecks() checkAtlasState() }) By("Updating the Cluster backups settings", func() { createdCluster.Spec.ProviderBackupEnabled = boolptr(true) - performUpdate() + performUpdate(20 * time.Minute) doCommonChecks() checkAtlasState(func(c *mongodbatlas.Cluster) { Expect(c.ProviderBackupEnabled).To(Equal(createdCluster.Spec.ProviderBackupEnabled)) @@ -188,7 +284,7 @@ var _ = Describe("AtlasCluster", func() { By("Decreasing the Cluster disk size", func() { createdCluster.Spec.DiskSizeGB = intptr(10) - performUpdate() + performUpdate(20 * time.Minute) doCommonChecks() checkAtlasState(func(c *mongodbatlas.Cluster) { Expect(*c.DiskSizeGB).To(BeEquivalentTo(*createdCluster.Spec.DiskSizeGB)) @@ -200,7 +296,7 @@ var _ = Describe("AtlasCluster", func() { By("Pausing the cluster", func() { createdCluster.Spec.Paused = boolptr(true) - performUpdate() + performUpdate(20 * time.Minute) doCommonChecks() checkAtlasState(func(c *mongodbatlas.Cluster) { Expect(c.Paused).To(Equal(createdCluster.Spec.Paused)) @@ -229,7 +325,7 @@ var _ = Describe("AtlasCluster", func() { By("Unpausing the cluster", func() { createdCluster.Spec.Paused = boolptr(false) - performUpdate() + performUpdate(20 * time.Minute) doCommonChecks() checkAtlasState(func(c *mongodbatlas.Cluster) { Expect(c.Paused).To(Equal(createdCluster.Spec.Paused)) @@ -268,7 +364,7 @@ var _ = Describe("AtlasCluster", func() { By("Fixing the Cluster", func() { createdCluster.Spec.ProviderSettings.AutoScaling = nil - performUpdate() + performUpdate(20 * time.Minute) doCommonChecks() checkAtlasState() }) @@ -296,7 +392,7 @@ var _ = Describe("AtlasCluster", func() { By("Fixing the Cluster", func() { createdCluster.Spec.ProviderSettings.InstanceSizeName = oldSizeName - performUpdate() + performUpdate(20 * time.Minute) doCommonChecks() checkAtlasState() }) @@ -364,6 +460,10 @@ func checkAtlasClusterRemoved(projectID string, clusterName string) func() bool } } +func int64ptr(i int64) *int64 { + return &i +} + func intptr(i int) *int { return &i }