Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add unit test for local crds client #46

Merged
merged 4 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions pkg/openapiclient/local_crds.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,46 @@ func NewLocalCRDFiles(fs fs.FS, dirPath string) openapi.Client {
}

func (k *localCRDsClient) Paths() (map[string]openapi.GroupVersion, error) {
if len(k.dir) == 0 && k.fs == nil {
if len(k.dir) == 0 {
return nil, nil
}
files, err := utils.ReadDir(k.fs, k.dir)
if err != nil {
return nil, fmt.Errorf("error listing %s: %w", k.dir, err)
}
codecs := serializer.NewCodecFactory(apiserver.Scheme).UniversalDecoder()
crds := map[schema.GroupVersion]*spec3.OpenAPI{}
var documents []utils.Document
for _, f := range files {
path := filepath.Join(k.dir, f.Name())
if f.IsDir() {
continue
}

if !utils.IsYamlOrJson(f.Name()) {
continue
}

yamlFile, err := utils.ReadFile(k.fs, path)
if err != nil {
return nil, fmt.Errorf("failed to read %s: %w", path, err)
if utils.IsJson(f.Name()) {
fileBytes, err := utils.ReadFile(k.fs, path)
if err != nil {
return nil, fmt.Errorf("failed to read %s: %w", path, err)
}
documents = append(documents, fileBytes)
} else if utils.IsYaml(f.Name()) {
fileBytes, err := utils.ReadFile(k.fs, path)
if err != nil {
return nil, fmt.Errorf("failed to read %s: %w", path, err)
}
yamlDocs, err := utils.SplitYamlDocuments(fileBytes)
if err != nil {
return nil, fmt.Errorf("failed to read %s: %w", path, err)
}
for _, document := range yamlDocs {
if !utils.IsEmptyYamlDocument(document) {
documents = append(documents, document)
}
}
}

}
codecs := serializer.NewCodecFactory(apiserver.Scheme).UniversalDecoder()
crds := map[schema.GroupVersion]*spec3.OpenAPI{}
for _, document := range documents {
crdObj, _, err := codecs.Decode(
yamlFile,
document,
&schema.GroupVersionKind{
Group: "apiextensions.k8s.io",
Version: runtime.APIVersionInternal,
Expand All @@ -80,24 +94,20 @@ func (k *localCRDsClient) Paths() (map[string]openapi.GroupVersion, error) {
if err != nil {
return nil, err
}

crd, ok := crdObj.(*apiextensions.CustomResourceDefinition)
if !ok {
return nil, fmt.Errorf("crd deserialized into incorrect type: %T", crdObj)
}

for _, v := range crd.Spec.Versions {
// Convert schema to spec.Schema
jsProps, err := apiextensions.GetSchemaForVersion(crd, v.Name)
if err != nil {
return nil, err
}

ss, err := structuralschema.NewStructural(jsProps.OpenAPIV3Schema)
if err != nil {
return nil, err
}

sch := ss.ToKubeOpenAPI()
gvk := schema.GroupVersionKind{
Group: crd.Spec.Group,
Expand All @@ -108,15 +118,14 @@ func (k *localCRDsClient) Paths() (map[string]openapi.GroupVersion, error) {
if err != nil {
return nil, err
}

gvr := gvk.GroupVersion().WithResource(crd.Spec.Names.Plural)
sch.AddExtension("x-kubernetes-group-version-kind", []interface{}{gvkObj})

// Add schema extension to propagate the scope
sch.AddExtension("x-kubectl-validate-scope", string(crd.Spec.Scope))
key := fmt.Sprintf("%s/%s.%s", gvk.Group, gvk.Version, gvk.Kind)
if existing, exists := crds[gvr.GroupVersion()]; exists {
if existing, exists := crds[gvk.GroupVersion()]; exists {
existing.Components.Schemas[key] = sch
} else {
crds[gvr.GroupVersion()] = &spec3.OpenAPI{
crds[gvk.GroupVersion()] = &spec3.OpenAPI{
Components: &spec3.Components{
Schemas: map[string]*spec.Schema{
key: sch,
Expand All @@ -126,7 +135,6 @@ func (k *localCRDsClient) Paths() (map[string]openapi.GroupVersion, error) {
}
}
}

res := map[string]openapi.GroupVersion{}
for k, v := range crds {
res[fmt.Sprintf("apis/%s/%s", k.Group, k.Version)] = inmemoryGroupVersion{v}
Expand Down
140 changes: 140 additions & 0 deletions pkg/openapiclient/local_crds_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package openapiclient

import (
"io/fs"
"os"
"reflect"
"testing"

"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/client-go/openapi"
)

func TestNewLocalCRDFiles(t *testing.T) {
tests := []struct {
name string
fs fs.FS
dirPath string
want openapi.Client
}{{
name: "fs nil and dir empty",
want: &localCRDsClient{},
}, {
name: "only dir",
dirPath: "test",
want: &localCRDsClient{
dir: "test",
},
}, {
name: "only fs",
fs: os.DirFS("."),
want: &localCRDsClient{
fs: os.DirFS("."),
},
}, {
name: "both fs and dir",
fs: os.DirFS("."),
dirPath: "test",
want: &localCRDsClient{
fs: os.DirFS("."),
dir: "test",
},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewLocalCRDFiles(tt.fs, tt.dirPath); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewLocalCRDFiles() = %v, want %v", got, tt.want)
}
})
}
}

func Test_localCRDsClient_Paths(t *testing.T) {
tests := []struct {
name string
fs fs.FS
dir string
want map[string]sets.Set[string]
wantErr bool
}{{
name: "fs nil and dir empty",
}, {
name: "only dir",
dir: "../../testcases/crds",
want: map[string]sets.Set[string]{
"apis/batch.x-k8s.io/v1alpha1": sets.New(
"batch.x-k8s.io/v1alpha1.JobSet",
),
"apis/stable.example.com/v1": sets.New(
"stable.example.com/v1.CELBasic",
),
"apis/acme.cert-manager.io/v1": sets.New(
"acme.cert-manager.io/v1.Challenge",
"acme.cert-manager.io/v1.Order",
),
"apis/cert-manager.io/v1": sets.New(
"cert-manager.io/v1.Certificate",
"cert-manager.io/v1.CertificateRequest",
"cert-manager.io/v1.ClusterIssuer",
"cert-manager.io/v1.Issuer",
),
},
}, {
name: "only fs",
fs: os.DirFS("../../testcases/crds"),
}, {
name: "both fs and dir",
fs: os.DirFS("../../testcases"),
dir: "crds",
want: map[string]sets.Set[string]{
"apis/batch.x-k8s.io/v1alpha1": sets.New(
"batch.x-k8s.io/v1alpha1.JobSet",
),
"apis/stable.example.com/v1": sets.New(
"stable.example.com/v1.CELBasic",
),
"apis/acme.cert-manager.io/v1": sets.New(
"acme.cert-manager.io/v1.Challenge",
"acme.cert-manager.io/v1.Order",
),
"apis/cert-manager.io/v1": sets.New(
"cert-manager.io/v1.Certificate",
"cert-manager.io/v1.CertificateRequest",
"cert-manager.io/v1.ClusterIssuer",
"cert-manager.io/v1.Issuer",
),
},
}, {
name: "invalid dir",
dir: "invalid",
wantErr: true,
}, {
name: "invalid fs",
fs: os.DirFS("../../invalid"),
dir: ".",
wantErr: true,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := NewLocalCRDFiles(tt.fs, tt.dir)
paths, err := k.Paths()
if (err != nil) != tt.wantErr {
t.Errorf("localCRDsClient.Paths() error = %v, wantErr %v", err, tt.wantErr)
return
}
var got map[string]sets.Set[string]
if paths != nil {
got = map[string]sets.Set[string]{}
for key, value := range paths {
got[key] = sets.New[string]()
for component := range value.(inmemoryGroupVersion).Components.Schemas {
got[key] = got[key].Insert(component)
}
}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("localCRDsClient.Paths() = %v, want %v", got, tt.want)
}
})
}
}
9 changes: 8 additions & 1 deletion pkg/validatorfactory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,14 @@ func (s *ValidatorFactory) ValidatorsForGVK(gvk schema.GroupVersionKind) (*Valid
continue
}

val := newValidatorEntry(nam, namespaced.Has(gvk), def, ssf)
// Try to infer the scope from paths
nsScoped := namespaced.Has(gvk)
// Check schema extensions to see if the scope was manually added
if scope, ok := def.Extensions.GetString("x-kubectl-validate-scope"); ok {
nsScoped = strings.EqualFold(scope, string(apiextensions.NamespaceScoped))
}

val := newValidatorEntry(nam, nsScoped, def, ssf)

for _, specGVK := range gvks {
s.validatorCache[specGVK] = val
Expand Down
Loading