Skip to content

Commit

Permalink
Add unit test for validatePackageArchitecture() method (#1957)
Browse files Browse the repository at this point in the history
## Description

- Replaces use of the concrete `*kubernetes.Clientset` implementation
with the `kubernetes.Interface` for client operations for easier
mocking/testing

- Adds unit test for `validatePackageArchitecture()` method

## Related Issue

Relates to #1750

## Type of change

- [x] Other (security config, docs update, etc)

## Checklist before merging

- [x] Test, docs, adr added or updated as needed
- [x] [Contributor Guide
Steps](https://github.com/defenseunicorns/zarf/blob/main/CONTRIBUTING.md#developer-workflow)
followed
  • Loading branch information
lucasrod16 committed Aug 10, 2023
1 parent 9b3c37f commit 8d5d9d8
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 38 deletions.
10 changes: 1 addition & 9 deletions src/internal/api/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/defenseunicorns/zarf/src/internal/api/common"
"github.com/defenseunicorns/zarf/src/internal/cluster"
"github.com/defenseunicorns/zarf/src/pkg/k8s"
"github.com/defenseunicorns/zarf/src/pkg/message"
"github.com/defenseunicorns/zarf/src/types"
"k8s.io/client-go/tools/clientcmd"
Expand All @@ -34,7 +33,7 @@ func Summary(w http.ResponseWriter, _ *http.Request) {
distro, _ = c.Kube.DetectDistro()
state, _ = c.LoadZarfState()
hasZarf = state.Distro != ""
k8sRevision = getServerVersion(c.Kube)
k8sRevision, _ = c.Kube.GetServerVersion()
}

data := types.ClusterSummary{
Expand All @@ -48,10 +47,3 @@ func Summary(w http.ResponseWriter, _ *http.Request) {

common.WriteJSONResponse(w, data, http.StatusOK)
}

// Retrieve and return the k8s revision.
func getServerVersion(k *k8s.K8s) string {
info, _ := k.Clientset.DiscoveryClient.ServerVersion()

return info.String()
}
21 changes: 16 additions & 5 deletions src/internal/cluster/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var labels = k8s.Labels{
config.ZarfManagedByLabel: "zarf",
}

// NewClusterOrDie creates a new cluster instance and waits up to 30 seconds for the cluster to be ready or throws a fatal error.
// NewClusterOrDie creates a new Cluster instance and waits up to 30 seconds for the cluster to be ready or throws a fatal error.
func NewClusterOrDie() *Cluster {
c, err := NewClusterWithWait(DefaultTimeout, true)
if err != nil {
Expand All @@ -37,7 +37,7 @@ func NewClusterOrDie() *Cluster {
return c
}

// NewClusterWithWait creates a new cluster instance and waits for the given timeout for the cluster to be ready.
// NewClusterWithWait creates a new Cluster instance and waits for the given timeout for the cluster to be ready.
func NewClusterWithWait(timeout time.Duration, withSpinner bool) (*Cluster, error) {
var spinner *message.Spinner
if withSpinner {
Expand Down Expand Up @@ -65,10 +65,21 @@ func NewClusterWithWait(timeout time.Duration, withSpinner bool) (*Cluster, erro
return c, nil
}

// NewCluster creates a new cluster instance without waiting for the cluster to be ready.
// NewCluster creates a new Cluster instance and validates connection to the cluster by fetching the Kubernetes version.
func NewCluster() (*Cluster, error) {
var err error
c := &Cluster{}
var err error

c.Kube, err = k8s.New(message.Debugf, labels)
return c, err
if err != nil {
return nil, err
}

// Dogsled the version output. We just want to ensure no errors were returned to validate cluster connection.
_, err = c.Kube.GetServerVersion()
if err != nil {
return nil, err
}

return c, nil
}
4 changes: 2 additions & 2 deletions src/internal/packager/helm/chart.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,12 @@ func (h *Helm) loadChartData() (*chart.Chart, chartutil.Values, error) {

func (h *Helm) migrateDeprecatedAPIs(latestRelease *release.Release) error {
// Get the Kubernetes version from the current cluster
kubeVersion, err := h.Cluster.Kube.Clientset.ServerVersion()
kubeVersion, err := h.Cluster.Kube.GetServerVersion()
if err != nil {
return err
}

kubeGitVersion, err := semver.NewVersion(kubeVersion.GitVersion)
kubeGitVersion, err := semver.NewVersion(kubeVersion)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions src/pkg/k8s/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ func (k *K8s) WaitForHealthyCluster(timeout time.Duration) error {
expired := time.After(timeout)

for {
// delay check 1 seconds
time.Sleep(1 * time.Second)
select {

// on timeout abort
case <-expired:
return fmt.Errorf("timed out waiting for cluster to report healthy")
Expand Down Expand Up @@ -103,6 +100,9 @@ func (k *K8s) WaitForHealthyCluster(timeout time.Duration) error {

k.Log("No pods reported 'succeeded' or 'running' state yet.")
}

// delay check 1 seconds
time.Sleep(1 * time.Second)
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/pkg/k8s/distro.go → src/pkg/k8s/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package k8s

import (
"errors"
"fmt"
"regexp"
)

Expand Down Expand Up @@ -129,3 +130,13 @@ func (k *K8s) GetArchitecture() (string, error) {

return "", errors.New("could not identify node architecture")
}

// GetServerVersion retrieves and returns the k8s revision.
func (k *K8s) GetServerVersion() (version string, err error) {
versionInfo, err := k.Clientset.Discovery().ServerVersion()
if err != nil {
return "", fmt.Errorf("unable to get Kubernetes version from the cluster : %w", err)
}

return versionInfo.String(), nil
}
1 change: 0 additions & 1 deletion src/pkg/k8s/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ func (k *K8s) GetNodes() (*corev1.NodeList, error) {
func (k *K8s) GetNode(nodeName string) (*corev1.Node, error) {
return k.Clientset.CoreV1().Nodes().Get(context.TODO(), nodeName, metav1.GetOptions{})
}

2 changes: 1 addition & 1 deletion src/pkg/k8s/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Labels map[string]string

// K8s is a client for interacting with a Kubernetes cluster.
type K8s struct {
Clientset *kubernetes.Clientset
Clientset kubernetes.Interface
RestConfig *rest.Config
Log Log
Labels Labels
Expand Down
28 changes: 15 additions & 13 deletions src/pkg/packager/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,20 +437,22 @@ func (p *Packager) handleIfPartialPkg() error {
}

// validatePackageArchitecture validates that the package architecture matches the target cluster architecture.
func (p *Packager) validatePackageArchitecture() error {
// Ignore this check if the architecture is explicitly "multi"
if p.arch != "multi" {
// Attempt to connect to a cluster to get the architecture.
if cluster, err := cluster.NewCluster(); err == nil {
clusterArch, err := cluster.Kube.GetArchitecture()
if err != nil {
return lang.ErrUnableToCheckArch
}
func (p *Packager) validatePackageArchitecture() (err error) {
// Ignore this check if the package architecture is explicitly "multi"
if p.arch == "multi" {
return nil
}

// Check if the package architecture and the cluster architecture are the same.
if p.arch != clusterArch {
return fmt.Errorf(lang.CmdPackageDeployValidateArchitectureErr, p.arch, clusterArch)
}
// Fetch cluster architecture only if we're already connected to a cluster.
if p.cluster != nil {
clusterArch, err := p.cluster.Kube.GetArchitecture()
if err != nil {
return lang.ErrUnableToCheckArch
}

// Check if the package architecture and the cluster architecture are the same.
if p.arch != clusterArch {
return fmt.Errorf(lang.CmdPackageDeployValidateArchitectureErr, p.arch, clusterArch)
}
}

Expand Down
96 changes: 96 additions & 0 deletions src/pkg/packager/common_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,111 @@
package packager

import (
"errors"
"fmt"
"testing"

"github.com/defenseunicorns/zarf/src/config"
"github.com/defenseunicorns/zarf/src/config/lang"
"github.com/defenseunicorns/zarf/src/internal/cluster"
"github.com/defenseunicorns/zarf/src/pkg/k8s"
"github.com/defenseunicorns/zarf/src/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/fake"
k8sTesting "k8s.io/client-go/testing"
)

// TestValidatePackageArchitecture verifies that Zarf validates package architecture against cluster architecture correctly.
func TestValidatePackageArchitecture(t *testing.T) {
t.Parallel()

type testCase struct {
name string
pkgArch string
clusterArch string
expectedError error
getArchError error
}

testCases := []testCase{
{
name: "architecture match",
pkgArch: "amd64",
clusterArch: "amd64",
expectedError: nil,
},
{
name: "architecture mismatch",
pkgArch: "arm64",
clusterArch: "amd64",
expectedError: fmt.Errorf(lang.CmdPackageDeployValidateArchitectureErr, "arm64", "amd64"),
},
{
name: "ignore validation when package arch equals 'multi'",
pkgArch: "multi",
clusterArch: "not evaluated",
expectedError: nil,
},
{
name: "test the error path when fetching cluster architecture fails",
pkgArch: "amd64",
getArchError: errors.New("error fetching cluster architecture"),
expectedError: lang.ErrUnableToCheckArch,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.name, func(t *testing.T) {
t.Parallel()

mockClient := fake.NewSimpleClientset()
logger := func(string, ...interface{}) {}

// Create a Packager instance with package architecture set and a mock Kubernetes client.
p := &Packager{
arch: testCase.pkgArch,
cluster: &cluster.Cluster{
Kube: &k8s.K8s{
Clientset: mockClient,
Log: logger,
},
},
}

// Set up test data for fetching cluster architecture.
mockClient.Fake.PrependReactor("list", "nodes", func(action k8sTesting.Action) (handled bool, ret runtime.Object, err error) {
// Return an error for cases that test this error path.
if testCase.getArchError != nil {
return true, nil, testCase.getArchError
}

// Create a NodeList object to fetch cluster architecture with the mock client.
nodeList := &v1.NodeList{
Items: []v1.Node{
{
Status: v1.NodeStatus{
NodeInfo: v1.NodeSystemInfo{
Architecture: testCase.clusterArch,
},
},
},
},
}
return true, nodeList, nil
})

err := p.validatePackageArchitecture()

require.Equal(t, testCase.expectedError, err)
})
}
}

// TestValidateLastNonBreakingVersion verifies that Zarf validates the lastNonBreakingVersion of packages against the CLI version correctly.
func TestValidateLastNonBreakingVersion(t *testing.T) {
t.Parallel()
Expand Down
9 changes: 8 additions & 1 deletion src/pkg/packager/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,16 @@ var (
)

// Deploy attempts to deploy the given PackageConfig.
func (p *Packager) Deploy() error {
func (p *Packager) Deploy() (err error) {
message.Debug("packager.Deploy()")

// Attempt to connect to a Kubernetes cluster.
// Not all packages require Kubernetes, so we only want to log a debug message rather than return the error when we can't connect to a cluster.
p.cluster, err = cluster.NewCluster()
if err != nil {
message.Debug(err)
}

if helpers.IsOCIURL(p.cfg.DeployOpts.PackagePath) {
err := p.SetOCIRemote(p.cfg.DeployOpts.PackagePath)
if err != nil {
Expand Down
5 changes: 2 additions & 3 deletions src/pkg/packager/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/AlecAivazis/survey/v2"
"github.com/defenseunicorns/zarf/src/config"
"github.com/defenseunicorns/zarf/src/internal/cluster"
"github.com/defenseunicorns/zarf/src/internal/packager/sbom"
"github.com/defenseunicorns/zarf/src/pkg/message"
"github.com/defenseunicorns/zarf/src/pkg/packager/deprecated"
Expand Down Expand Up @@ -57,8 +56,8 @@ func (p *Packager) confirmAction(stage string, sbomViewFiles []string) (confirm
}

// Connect to the cluster (if available) to check the Zarf Agent for breaking changes
if cluster, err := cluster.NewCluster(); err == nil {
if initPackage, err := cluster.GetDeployedPackage("init"); err == nil {
if p.cluster != nil {
if initPackage, err := p.cluster.GetDeployedPackage("init"); err == nil {
// We use the build.version for now because it is the most reliable way to get this version info pre v0.26.0
deprecated.PrintBreakingChanges(initPackage.Data.Build.Version)
}
Expand Down

0 comments on commit 8d5d9d8

Please sign in to comment.