diff --git a/specs/destination.go b/specs/destination.go index 8d5abc7d80..8682e3685b 100644 --- a/specs/destination.go +++ b/specs/destination.go @@ -30,9 +30,6 @@ func (d *Destination) SetDefaults() { if d.Path == "" { d.Path = d.Name } - if d.Version == "" { - d.Version = "latest" - } if d.Registry == RegistryGithub && !strings.Contains(d.Path, "/") { d.Path = "cloudquery/" + d.Path } @@ -43,10 +40,23 @@ func (d *Destination) UnmarshalSpec(out interface{}) error { if err != nil { return err } - dec := json.NewDecoder(nil) + dec := json.NewDecoder(bytes.NewReader(b)) dec.UseNumber() dec.DisallowUnknownFields() - return json.Unmarshal(b, out) + return dec.Decode(out) +} + +func (d *Destination) Validate() error { + if d.Name == "" { + return fmt.Errorf("name is required") + } + if d.Version == "" { + return fmt.Errorf("version is required") + } + if !strings.HasPrefix(d.Version, "v") { + return fmt.Errorf("version must start with v") + } + return nil } func (m WriteMode) String() string { diff --git a/specs/destination_test.go b/specs/destination_test.go index 1729dd317f..ce11f45b74 100644 --- a/specs/destination_test.go +++ b/specs/destination_test.go @@ -2,6 +2,8 @@ package specs import ( "testing" + + "github.com/google/go-cmp/cmp" ) type testDestinationSpec struct { @@ -24,23 +26,7 @@ func TestWriteModeFromString(t *testing.T) { } } -func TestDestinationSetDefaults(t *testing.T) { - destination := Destination{ - Name: "testDestination", - } - destination.SetDefaults() - if destination.Registry != RegistryGithub { - t.Fatalf("expected RegistryGithub, got %v", destination.Registry) - } - if destination.Path != "cloudquery/testDestination" { - t.Fatalf("expected destination.Path (%s), got %s", destination.Name, destination.Path) - } - if destination.Version != "latest" { - t.Fatalf("expected latest, got %s", destination.Version) - } -} - -func TestDestinationUnmarshalSpec(t *testing.T) { +func TestDestinationSpecUnmarshalSpec(t *testing.T) { destination := Destination{ Spec: map[string]interface{}{ "connection_string": "postgres://user:pass@host:port/db", @@ -54,3 +40,137 @@ func TestDestinationUnmarshalSpec(t *testing.T) { t.Fatalf("expected postgres://user:pass@host:port/db, got %s", spec.ConnectionString) } } + +var destinationUnmarshalSpecTestCases = []struct { + name string + spec string + err string + source *Source +}{ + { + "invalid_kind", + `kind: nice`, + "failed to decode spec: unknown kind nice", + nil, + }, + { + "invalid_type", + `kind: source +spec: + name: 3 +`, + "failed to decode spec: json: cannot unmarshal number into Go struct field Source.name of type string", + &Source{ + Name: "test", + Tables: []string{"*"}, + }, + }, + { + "unknown_field", + `kind: source +spec: + namea: 3 +`, + `failed to decode spec: json: unknown field "namea"`, + &Source{ + Name: "test", + Tables: []string{"*"}, + }, + }, +} + +func TestDestinationUnmarshalSpec(t *testing.T) { + for _, tc := range destinationUnmarshalSpecTestCases { + t.Run(tc.name, func(t *testing.T) { + var err error + var spec Spec + err = SpecUnmarshalYamlStrict([]byte(tc.spec), &spec) + if err != nil { + if err.Error() != tc.err { + t.Fatalf("expected:%s got:%s", tc.err, err.Error()) + } + return + } + + source := spec.Spec.(*Source) + if cmp.Diff(source, tc.source) != "" { + t.Fatalf("expected:%v got:%v", tc.source, source) + } + }) + } +} + +var destinationUnmarshalSpecValidateTestCases = []struct { + name string + spec string + err string + destination *Destination +}{ + { + "required_name", + `kind: destination +spec:`, + "name is required", + nil, + }, + { + "required_version", + `kind: destination +spec: + name: test +`, + "version is required", + nil, + }, + { + "required_version_format", + `kind: destination +spec: + name: test + version: 1.1.0 +`, + "version must start with v", + nil, + }, + { + "success", + `kind: destination +spec: + name: test + version: v1.1.0 +`, + "", + &Destination{ + Name: "test", + Registry: RegistryGithub, + Path: "cloudquery/test", + Version: "v1.1.0", + }, + }, +} + +func TestDestinationUnmarshalSpecValidate(t *testing.T) { + for _, tc := range destinationUnmarshalSpecValidateTestCases { + t.Run(tc.name, func(t *testing.T) { + var err error + var spec Spec + err = SpecUnmarshalYamlStrict([]byte(tc.spec), &spec) + if err != nil { + t.Fatal(err) + } + destination := spec.Spec.(*Destination) + destination.SetDefaults() + err = destination.Validate() + if err != nil { + if err.Error() != tc.err { + t.Fatalf("expected:%s got:%s", tc.err, err.Error()) + } + return + } + + if cmp.Diff(destination, tc.destination) != "" { + t.Fatalf("expected:%v got:%v", tc.destination, destination) + } + }) + } +} diff --git a/specs/source.go b/specs/source.go index b89c456a6a..c8373aa308 100644 --- a/specs/source.go +++ b/specs/source.go @@ -1,10 +1,10 @@ package specs import ( + "bytes" "encoding/json" + "fmt" "strings" - - "github.com/xeipuuv/gojsonschema" ) // Source is the spec for a source plugin @@ -40,12 +40,12 @@ func (s *Source) SetDefaults() { if s.Path == "" { s.Path = s.Name } - if s.Version == "" { - s.Version = "latest" - } if s.Registry == RegistryGithub && !strings.Contains(s.Path, "/") { s.Path = "cloudquery/" + s.Path } + if s.Tables == nil { + s.Tables = []string{"*"} + } } // UnmarshalSpec unmarshals the internal spec into the given interface @@ -54,12 +54,24 @@ func (s *Source) UnmarshalSpec(out interface{}) error { if err != nil { return err } - dec := json.NewDecoder(nil) + dec := json.NewDecoder(bytes.NewReader(b)) dec.UseNumber() dec.DisallowUnknownFields() - return json.Unmarshal(b, out) + return dec.Decode(out) } -func (*Source) Validate() (*gojsonschema.Result, error) { - return nil, nil +func (s *Source) Validate() error { + if s.Name == "" { + return fmt.Errorf("name is required") + } + if s.Version == "" { + return fmt.Errorf("version is required") + } + if !strings.HasPrefix(s.Version, "v") { + return fmt.Errorf("version must start with v") + } + if len(s.Destinations) == 0 { + return fmt.Errorf("at least one destination is required") + } + return nil } diff --git a/specs/source_test.go b/specs/source_test.go index bd178b2c01..d065068d76 100644 --- a/specs/source_test.go +++ b/specs/source_test.go @@ -1,44 +1,154 @@ package specs -import "testing" +import ( + "testing" -type testSourceSpec struct { - Accounts []string `json:"accounts"` + "github.com/google/go-cmp/cmp" +) + +var sourceUnmarshalSpecTestCases = []struct { + name string + spec string + err string + source *Source +}{ + { + "invalid_kind", + `kind: nice`, + "failed to decode spec: unknown kind nice", + nil, + }, + { + "invalid_type", + `kind: source +spec: + name: 3 +`, + "failed to decode spec: json: cannot unmarshal number into Go struct field Source.name of type string", + &Source{ + Name: "test", + Tables: []string{"*"}, + }, + }, + { + "unknown_field", + `kind: source +spec: + namea: 3 +`, + `failed to decode spec: json: unknown field "namea"`, + &Source{ + Name: "test", + Tables: []string{"*"}, + }, + }, } -func TestSourceSetDefaults(t *testing.T) { - source := Source{ - Name: "testSource", - } - source.SetDefaults() - if source.Registry != RegistryGithub { - t.Fatalf("expected RegistryGithub, got %v", source.Registry) - } - if source.Path != "cloudquery/testSource" { - t.Fatalf("expected source.Path (%s), got %s", source.Name, source.Path) - } - if source.Version != "latest" { - t.Fatalf("expected latest, got %s", source.Version) +func TestSourceUnmarshalSpec(t *testing.T) { + for _, tc := range sourceUnmarshalSpecTestCases { + t.Run(tc.name, func(t *testing.T) { + var err error + var spec Spec + err = SpecUnmarshalYamlStrict([]byte(tc.spec), &spec) + if err != nil { + if err.Error() != tc.err { + t.Fatalf("expected:%s got:%s", tc.err, err.Error()) + } + return + } + + source := spec.Spec.(*Source) + if cmp.Diff(source, tc.source) != "" { + t.Fatalf("expected:%v got:%v", tc.source, source) + } + }) } } -func TestSourceUnmarshalSpec(t *testing.T) { - source := Source{ - Spec: map[string]interface{}{ - "accounts": []string{"test_account1", "test_account2"}, +var sourceUnmarshalSpecValidateTestCases = []struct { + name string + spec string + err string + source *Source +}{ + { + "required_name", + `kind: source +spec:`, + "name is required", + nil, + }, + { + "required_version", + `kind: source +spec: + name: test +`, + "version is required", + nil, + }, + { + "required_version_format", + `kind: source +spec: + name: test + version: 1.1.0 +`, + "version must start with v", + nil, + }, + { + "destination_required", + `kind: source +spec: + name: test + version: v1.1.0 +`, + "at least one destination is required", + nil, + }, + { + "success", + `kind: source +spec: + name: test + version: v1.1.0 + destinations: ["test"] +`, + "", + &Source{ + Name: "test", + Registry: RegistryGithub, + Path: "cloudquery/test", + Version: "v1.1.0", + Tables: []string{"*"}, + Destinations: []string{"test"}, }, - } - var spec testSourceSpec - if err := source.UnmarshalSpec(&spec); err != nil { - t.Fatal(err) - } - if len(spec.Accounts) != 2 { - t.Fatalf("expected 2 accounts, got %d", len(spec.Accounts)) - } - if spec.Accounts[0] != "test_account1" { - t.Fatalf("expected test_account1, got %s", spec.Accounts[0]) - } - if spec.Accounts[1] != "test_account2" { - t.Fatalf("expected test_account2, got %s", spec.Accounts[1]) + }, +} + +func TestSourceUnmarshalSpecValidate(t *testing.T) { + for _, tc := range sourceUnmarshalSpecValidateTestCases { + t.Run(tc.name, func(t *testing.T) { + var err error + var spec Spec + err = SpecUnmarshalYamlStrict([]byte(tc.spec), &spec) + if err != nil { + t.Fatal(err) + } + source := spec.Spec.(*Source) + source.SetDefaults() + err = source.Validate() + if err != nil { + if err.Error() != tc.err { + t.Fatalf("expected:%s got:%s", tc.err, err.Error()) + } + return + } + + if cmp.Diff(source, tc.source) != "" { + t.Fatalf("expected:%v got:%v", tc.source, source) + } + }) } } diff --git a/specs/spec.go b/specs/spec.go index 24928ef8ea..ea358de13d 100644 --- a/specs/spec.go +++ b/specs/spec.go @@ -58,8 +58,10 @@ func (s *Spec) UnmarshalJSON(data []byte) error { Kind Kind `json:"kind"` Spec interface{} `json:"spec"` } - - if err := json.Unmarshal(data, &t); err != nil { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.DisallowUnknownFields() + dec.UseNumber() + if err := dec.Decode(&t); err != nil { return err } s.Kind = t.Kind @@ -75,7 +77,10 @@ func (s *Spec) UnmarshalJSON(data []byte) error { if err != nil { return err } - return json.Unmarshal(b, s.Spec) + dec = json.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + dec.DisallowUnknownFields() + return dec.Decode(s.Spec) } func UnmarshalJSONStrict(b []byte, out interface{}) error { @@ -94,15 +99,7 @@ func SpecUnmarshalYamlStrict(b []byte, spec *Spec) error { dec.DisallowUnknownFields() dec.UseNumber() if err := dec.Decode(spec); err != nil { - return fmt.Errorf("failed to decode json: %w", err) - } - switch spec.Kind { - case KindSource: - spec.Spec.(*Source).SetDefaults() - case KindDestination: - spec.Spec.(*Destination).SetDefaults() - default: - return fmt.Errorf("unknown kind %s", spec.Kind) + return fmt.Errorf("failed to decode spec: %w", err) } return nil } diff --git a/specs/spec_reader_test.go b/specs/spec_reader_test.go index 9fbe1e413f..f6689b4c19 100644 --- a/specs/spec_reader_test.go +++ b/specs/spec_reader_test.go @@ -2,46 +2,17 @@ package specs import ( "testing" - - "github.com/stretchr/testify/require" ) -var sources = map[string]Source{ - "aws.yml": { - Name: "aws", - Path: "aws", - Version: "v1.0.0", - Concurrency: 10, - Registry: RegistryLocal, - }, -} - -var destinations = map[string]Destination{ - "postgresql.yml": { - Name: "postgresql", - Path: "postgresql", - Version: "v1.0.0", - Registry: RegistryGrpc, - WriteMode: WriteModeOverwrite, - }, -} - func TestLoadSpecs(t *testing.T) { - specReader, err := NewSpecReader("testdata/valid") + specReader, err := NewSpecReader("testdata") if err != nil { t.Fatal(err) } - - require.Equal(t, sources, specReader.sources) - require.Equal(t, destinations, specReader.destinations) -} - -func TestWrongKind(t *testing.T) { - _, err := NewSpecReader("testdata/wrong_source") - require.Equal(t, err.Error(), "failed to unmarshal file invalid.yml: failed to decode json: unknown kind test") -} - -func TestNoSpecs(t *testing.T) { - _, err := NewSpecReader("testdata") - require.Equal(t, err.Error(), "no valid config files found in directory testdata") + if len(specReader.sources) != 1 { + t.Fatalf("got: %d expected: 1", len(specReader.sources)) + } + if len(specReader.destinations) != 1 { + t.Fatalf("got: %d expected: 1", len(specReader.destinations)) + } } diff --git a/specs/spec_test.go b/specs/spec_test.go deleted file mode 100644 index 83611434ce..0000000000 --- a/specs/spec_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package specs - -import ( - "os" - "reflect" - "testing" -) - -var testSpecs = map[string]Spec{ - "testdata/valid/postgresql.yml": { - Kind: KindDestination, - Spec: &Destination{ - Name: "postgresql", - Path: "postgresql", - Version: "v1.0.0", - Registry: RegistryGrpc, - WriteMode: WriteModeOverwrite, - }, - }, - "testdata/valid/aws.yml": { - Kind: KindSource, - Spec: &Source{ - Name: "aws", - Path: "aws", - Version: "v1.0.0", - Concurrency: 10, - Registry: RegistryLocal, - }, - }, -} - -func TestSpecYamlMarshal(t *testing.T) { - for fileName, expectedSpec := range testSpecs { - t.Run(fileName, func(t *testing.T) { - b, err := os.ReadFile(fileName) - if err != nil { - t.Fatal(err) - } - - var spec Spec - if err := SpecUnmarshalYamlStrict(b, &spec); err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(spec, expectedSpec) { - t.Errorf("expected spec %s to be:\n%+v\nbut got:\n%+v", fileName, expectedSpec.Spec, spec.Spec) - } - }) - } -} diff --git a/specs/testdata/valid/aws.yml b/specs/testdata/aws.yml similarity index 100% rename from specs/testdata/valid/aws.yml rename to specs/testdata/aws.yml diff --git a/specs/testdata/valid/postgresql.yml b/specs/testdata/postgresql.yml similarity index 100% rename from specs/testdata/valid/postgresql.yml rename to specs/testdata/postgresql.yml diff --git a/specs/testdata/wrong_source/invalid.yml b/specs/testdata/wrong_source/invalid.yml deleted file mode 100644 index 787d446c31..0000000000 --- a/specs/testdata/wrong_source/invalid.yml +++ /dev/null @@ -1,6 +0,0 @@ -kind: test -spec: - name: test - version: v1.0.0 - registry: grpc - write_mode: overwrite