diff --git a/cloud/config/config.go b/cloud/config/config.go index c836f20a..28a45b6e 100644 --- a/cloud/config/config.go +++ b/cloud/config/config.go @@ -27,6 +27,9 @@ import ( "gopkg.in/yaml.v2" ) +// set this way to allow for mocks to be used for testing +var instancePrincipalProviderFunc = auth.InstancePrincipalConfigurationProvider + const ( UseInstancePrincipal = "useInstancePrincipal" Tenancy = "tenancy" @@ -150,7 +153,7 @@ func NewConfigurationProvider(cfg *AuthConfig) (common.ConfigurationProvider, er return nil, errors.New("auth config must not be nil") } if cfg.UseInstancePrincipals { - return auth.InstancePrincipalConfigurationProvider() + return instancePrincipalProviderFunc() } else { return NewConfigurationProviderWithUserPrincipal(cfg) } diff --git a/cloud/config/config_test.go b/cloud/config/config_test.go index ad494b20..9a941b9c 100644 --- a/cloud/config/config_test.go +++ b/cloud/config/config_test.go @@ -17,14 +17,51 @@ limitations under the License. package config import ( + "crypto/rsa" "os" "path/filepath" "reflect" "testing" "github.com/oracle/oci-go-sdk/v65/common" + "github.com/oracle/oci-go-sdk/v65/common/auth" ) +// mockConfigurationProvider implements common.ConfigurationProvider for testing +type mockInstancePrincipalConfigurationProvider struct{} + +func (m *mockInstancePrincipalConfigurationProvider) TenancyOCID() (string, error) { + return "mock-tenancy", nil +} + +func (m *mockInstancePrincipalConfigurationProvider) UserOCID() (string, error) { + return "mock-user", nil +} + +func (m *mockInstancePrincipalConfigurationProvider) KeyID() (string, error) { + return "mock-key-id", nil +} + +func (m *mockInstancePrincipalConfigurationProvider) KeyFingerprint() (string, error) { + return "mock-fingerprint", nil +} + +func (m *mockInstancePrincipalConfigurationProvider) Key() (string, error) { + return "mock-key", nil +} + +func (m *mockInstancePrincipalConfigurationProvider) PrivateRSAKey() (*rsa.PrivateKey, error) { + return nil, nil // Mock implementation +} + +func (m *mockInstancePrincipalConfigurationProvider) Region() (string, error) { + return "us-phoenix-1", nil +} + +func (m *mockInstancePrincipalConfigurationProvider) AuthType() (common.AuthConfig, error) { + return common.AuthConfig{AuthType: "instance_principal"}, nil +} + func TestFromDir(t *testing.T) { type args struct { path string @@ -263,6 +300,8 @@ func TestNewConfigurationProvider(t *testing.T) { name string args args wantErr bool + setup func() + cleanup func() }{ { name: "nil config", @@ -275,10 +314,24 @@ func TestNewConfigurationProvider(t *testing.T) { UseInstancePrincipals: true, }}, wantErr: false, + setup: func() { + instancePrincipalProviderFunc = func() (common.ConfigurationProvider, error) { + return &mockInstancePrincipalConfigurationProvider{}, nil + } + }, + cleanup: func() { + instancePrincipalProviderFunc = auth.InstancePrincipalConfigurationProvider + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + if tt.cleanup != nil { + defer tt.cleanup() + } got, err := NewConfigurationProvider(tt.args.cfg) if (err != nil) != tt.wantErr { t.Errorf("NewConfigurationProvider() error = %v, wantErr %v", err, tt.wantErr)