diff --git a/.github/workflows/lint_markdown.yml b/.github/workflows/lint_markdown.yml index cecd2c8bdc..aabcd9530d 100644 --- a/.github/workflows/lint_markdown.yml +++ b/.github/workflows/lint_markdown.yml @@ -16,7 +16,7 @@ jobs: - name: Vale uses: errata-ai/vale-action@v2 with: - vale_flags: "--glob=!{plugins/source/testdata/*,CHANGELOG.md,.github/styles/proselint/README.md}" + vale_flags: "--glob=!{docs/testdata/*,CHANGELOG.md,.github/styles/proselint/README.md}" filter_mode: nofilter env: GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} @@ -31,4 +31,4 @@ jobs: with: files: . config_file: .markdownlint.yaml - ignore_files: "{plugins/source/testdata/*,CHANGELOG.md}" + ignore_files: "{docs/testdata/*,CHANGELOG.md}" diff --git a/.gitignore b/.gitignore index d15ff8fe72..2d30804e99 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ config.hcl vendor cover.out .delta.* -bench.json \ No newline at end of file +bench.json +serve/^TestPluginDocs$/ +cover diff --git a/backend/backend.go b/backend/backend.go deleted file mode 100644 index fc4e639233..0000000000 --- a/backend/backend.go +++ /dev/null @@ -1,12 +0,0 @@ -package backend - -import "context" - -type Backend interface { - // Set sets the value for the given table and client id. - Set(ctx context.Context, table, clientID, value string) error - // Get returns the value for the given table and client id. - Get(ctx context.Context, table, clientID string) (string, error) - // Close closes the backend. - Close(ctx context.Context) error -} diff --git a/buf.yaml b/buf.yaml deleted file mode 100644 index b348cd312c..0000000000 --- a/buf.yaml +++ /dev/null @@ -1,12 +0,0 @@ -version: v1 -breaking: - use: - - FILE -lint: - use: - - BASIC - ignore: - # We are ignoring those as this is an old version and we are not doing any changes here anymore - - cloudquery/destination/v0/destination.proto - - cloudquery/source/v0/source.proto - - cloudquery/base/v0/base.proto diff --git a/docs/generator.go b/docs/generator.go new file mode 100644 index 0000000000..62dba4f67b --- /dev/null +++ b/docs/generator.go @@ -0,0 +1,137 @@ +package docs + +import ( + "embed" + "fmt" + "os" + "regexp" + "sort" + + "github.com/cloudquery/plugin-sdk/v4/caser" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +//go:embed templates/*.go.tpl +var templatesFS embed.FS + +var reMatchNewlines = regexp.MustCompile(`\n{3,}`) +var reMatchHeaders = regexp.MustCompile(`(#{1,6}.+)\n+`) + +var DefaultTitleExceptions = map[string]string{ + // common abbreviations + "acl": "ACL", + "acls": "ACLs", + "api": "API", + "apis": "APIs", + "ca": "CA", + "cidr": "CIDR", + "cidrs": "CIDRs", + "db": "DB", + "dbs": "DBs", + "dhcp": "DHCP", + "iam": "IAM", + "iot": "IOT", + "ip": "IP", + "ips": "IPs", + "ipv4": "IPv4", + "ipv6": "IPv6", + "mfa": "MFA", + "ml": "ML", + "oauth": "OAuth", + "vpc": "VPC", + "vpcs": "VPCs", + "vpn": "VPN", + "vpns": "VPNs", + "waf": "WAF", + "wafs": "WAFs", + + // cloud providers + "aws": "AWS", + "gcp": "GCP", +} + +type Format int + +const ( + FormatMarkdown Format = iota + FormatJSON +) + +func (r Format) String() string { + return [...]string{"markdown", "json"}[r] +} + +func FormatFromString(s string) (Format, error) { + switch s { + case "markdown": + return FormatMarkdown, nil + case "json": + return FormatJSON, nil + default: + return FormatMarkdown, fmt.Errorf("unknown format %s", s) + } +} + +type Generator struct { + tables schema.Tables + titleTransformer func(*schema.Table) string + pluginName string +} + +func DefaultTitleTransformer(table *schema.Table) string { + if table.Title != "" { + return table.Title + } + csr := caser.New(caser.WithCustomExceptions(DefaultTitleExceptions)) + return csr.ToTitle(table.Name) +} + +func sortTables(tables schema.Tables) { + sort.SliceStable(tables, func(i, j int) bool { + return tables[i].Name < tables[j].Name + }) + + for _, table := range tables { + sortTables(table.Relations) + } +} + +// NewGenerator creates a new generator for the given tables. +// The tables are sorted by name. pluginName is optional and is used in markdown only +func NewGenerator(pluginName string, tables schema.Tables) *Generator { + sortedTables := make(schema.Tables, 0, len(tables)) + for _, t := range tables { + sortedTables = append(sortedTables, t.Copy(nil)) + } + sortTables(sortedTables) + + return &Generator{ + tables: sortedTables, + titleTransformer: DefaultTitleTransformer, + pluginName: pluginName, + } +} + +func (g *Generator) Generate(dir string, format Format) error { + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return err + } + + switch format { + case FormatMarkdown: + return g.renderTablesAsMarkdown(dir) + case FormatJSON: + return g.renderTablesAsJSON(dir) + default: + return fmt.Errorf("unsupported format: %v", format) + } +} + +// setDestinationManagedCqColumns overwrites or adds the CQ columns that are managed by the destination plugins (_cq_sync_time, _cq_source_name). +// func setDestinationManagedCqColumns(tables []*schema.Table) { +// for _, table := range tables { +// table.OverwriteOrAddColumn(&schema.CqSyncTimeColumn) +// table.OverwriteOrAddColumn(&schema.CqSourceNameColumn) +// setDestinationManagedCqColumns(table.Relations) +// } +// } diff --git a/plugins/source/docs_test.go b/docs/generator_test.go similarity index 92% rename from plugins/source/docs_test.go rename to docs/generator_test.go index 30d34814d3..22d4001719 100644 --- a/plugins/source/docs_test.go +++ b/docs/generator_test.go @@ -1,6 +1,6 @@ //go:build !windows -package source +package docs import ( "os" @@ -9,8 +9,8 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/bradleyjkemp/cupaloy/v2" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/cloudquery/plugin-sdk/v4/types" "github.com/stretchr/testify/require" ) @@ -120,14 +120,13 @@ var testTables = []*schema.Table{ } func TestGeneratePluginDocs(t *testing.T) { - p := NewPlugin("test", "v1.0.0", testTables, newTestExecutionClient) - + g := NewGenerator("test", testTables) cup := cupaloy.New(cupaloy.SnapshotSubdirectory("testdata")) t.Run("Markdown", func(t *testing.T) { tmpdir := t.TempDir() - err := p.GeneratePluginDocs(tmpdir, "markdown") + err := g.Generate(tmpdir, FormatMarkdown) if err != nil { t.Fatalf("unexpected error calling GeneratePluginDocs: %v", err) } @@ -146,7 +145,7 @@ func TestGeneratePluginDocs(t *testing.T) { t.Run("JSON", func(t *testing.T) { tmpdir := t.TempDir() - err := p.GeneratePluginDocs(tmpdir, "json") + err := g.Generate(tmpdir, FormatJSON) if err != nil { t.Fatalf("unexpected error calling GeneratePluginDocs: %v", err) } diff --git a/docs/json.go b/docs/json.go new file mode 100644 index 0000000000..8972a86b8c --- /dev/null +++ b/docs/json.go @@ -0,0 +1,62 @@ +package docs + +import ( + "bytes" + "encoding/json" + "os" + "path/filepath" + + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +type jsonTable struct { + Name string `json:"name"` + Title string `json:"title"` + Description string `json:"description"` + Columns []jsonColumn `json:"columns"` + Relations []jsonTable `json:"relations"` +} + +type jsonColumn struct { + Name string `json:"name"` + Type string `json:"type"` + IsPrimaryKey bool `json:"is_primary_key,omitempty"` + IsIncrementalKey bool `json:"is_incremental_key,omitempty"` +} + +func (g *Generator) renderTablesAsJSON(dir string) error { + jsonTables := g.jsonifyTables(g.tables) + buffer := &bytes.Buffer{} + m := json.NewEncoder(buffer) + m.SetIndent("", " ") + m.SetEscapeHTML(false) + err := m.Encode(jsonTables) + if err != nil { + return err + } + outputPath := filepath.Join(dir, "__tables.json") + return os.WriteFile(outputPath, buffer.Bytes(), 0644) +} + +func (g *Generator) jsonifyTables(tables schema.Tables) []jsonTable { + jsonTables := make([]jsonTable, len(tables)) + for i, table := range tables { + jsonColumns := make([]jsonColumn, len(table.Columns)) + for c, col := range table.Columns { + jsonColumns[c] = jsonColumn{ + Name: col.Name, + Type: col.Type.String(), + IsPrimaryKey: col.PrimaryKey, + IsIncrementalKey: col.IncrementalKey, + } + } + jsonTables[i] = jsonTable{ + Name: table.Name, + Title: g.titleTransformer(table), + Description: table.Description, + Columns: jsonColumns, + Relations: g.jsonifyTables(table.Relations), + } + } + return jsonTables +} diff --git a/docs/markdown.go b/docs/markdown.go new file mode 100644 index 0000000000..6f8fe9dcaa --- /dev/null +++ b/docs/markdown.go @@ -0,0 +1,94 @@ +package docs + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "text/template" + + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +type templateData struct { + PluginName string + Tables schema.Tables +} + +func (g *Generator) renderTablesAsMarkdown(dir string) error { + for _, table := range g.tables { + if err := g.renderAllTables(dir, table); err != nil { + return err + } + } + t, err := template.New("all_tables.md.go.tpl").Funcs(template.FuncMap{ + "indentToDepth": indentToDepth, + }).ParseFS(templatesFS, "templates/all_tables*.md.go.tpl") + if err != nil { + return fmt.Errorf("failed to parse template for README.md: %v", err) + } + + var b bytes.Buffer + if err := t.Execute(&b, templateData{PluginName: g.pluginName, Tables: g.tables}); err != nil { + return fmt.Errorf("failed to execute template: %v", err) + } + content := formatMarkdown(b.String()) + outputPath := filepath.Join(dir, "README.md") + f, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create file %v: %v", outputPath, err) + } + f.WriteString(content) + return nil +} + +func (g *Generator) renderAllTables(dir string, t *schema.Table) error { + if err := g.renderTable(dir, t); err != nil { + return err + } + for _, r := range t.Relations { + if err := g.renderAllTables(dir, r); err != nil { + return err + } + } + return nil +} + +func (g *Generator) renderTable(dir string, table *schema.Table) error { + t := template.New("").Funcs(map[string]any{ + "title": g.titleTransformer, + }) + t, err := t.New("table.md.go.tpl").ParseFS(templatesFS, "templates/table.md.go.tpl") + if err != nil { + return fmt.Errorf("failed to parse template: %v", err) + } + + outputPath := filepath.Join(dir, fmt.Sprintf("%s.md", table.Name)) + + var b bytes.Buffer + if err := t.Execute(&b, table); err != nil { + return fmt.Errorf("failed to execute template: %v", err) + } + content := formatMarkdown(b.String()) + f, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create file %v: %v", outputPath, err) + } + f.WriteString(content) + return f.Close() +} + +func formatMarkdown(s string) string { + s = reMatchNewlines.ReplaceAllString(s, "\n\n") + return reMatchHeaders.ReplaceAllString(s, `$1`+"\n\n") +} + +func indentToDepth(table *schema.Table) string { + s := "" + t := table + for t.Parent != nil { + s += " " + t = t.Parent + } + return s +} diff --git a/plugins/source/templates/all_tables.md.go.tpl b/docs/templates/all_tables.md.go.tpl similarity index 100% rename from plugins/source/templates/all_tables.md.go.tpl rename to docs/templates/all_tables.md.go.tpl diff --git a/plugins/source/templates/all_tables_entry.md.go.tpl b/docs/templates/all_tables_entry.md.go.tpl similarity index 100% rename from plugins/source/templates/all_tables_entry.md.go.tpl rename to docs/templates/all_tables_entry.md.go.tpl diff --git a/plugins/source/templates/table.md.go.tpl b/docs/templates/table.md.go.tpl similarity index 95% rename from plugins/source/templates/table.md.go.tpl rename to docs/templates/table.md.go.tpl index 202d343e39..21a8ed135e 100644 --- a/plugins/source/templates/table.md.go.tpl +++ b/docs/templates/table.md.go.tpl @@ -40,5 +40,5 @@ The following tables depend on {{.Name}}: | Name | Type | | ------------- | ------------- | {{- range $column := $.Columns }} -|{{$column.Name}}{{if $column.PrimaryKey}} (PK){{end}}{{if $column.IncrementalKey}} (Incremental Key){{end}}|`{{$column.Type}}`| +|{{$column.Name}}{{if $column.PrimaryKey}} (PK){{end}}{{if $column.IncrementalKey}} (Incremental Key){{end}}|{{$column.Type}}| {{- end }} \ No newline at end of file diff --git a/plugins/source/testdata/TestGeneratePluginDocs-JSON-__tables.json b/docs/testdata/TestGeneratePluginDocs-JSON-__tables.json similarity index 52% rename from plugins/source/testdata/TestGeneratePluginDocs-JSON-__tables.json rename to docs/testdata/TestGeneratePluginDocs-JSON-__tables.json index 7a8280833e..2623746cb5 100644 --- a/plugins/source/testdata/TestGeneratePluginDocs-JSON-__tables.json +++ b/docs/testdata/TestGeneratePluginDocs-JSON-__tables.json @@ -4,22 +4,6 @@ "title": "Incremental Table", "description": "Description for incremental table", "columns": [ - { - "name": "_cq_source_name", - "type": "utf8" - }, - { - "name": "_cq_sync_time", - "type": "timestamp[us, tz=UTC]" - }, - { - "name": "_cq_id", - "type": "uuid" - }, - { - "name": "_cq_parent_id", - "type": "uuid" - }, { "name": "int_col", "type": "int64" @@ -43,22 +27,6 @@ "title": "Test Table", "description": "Description for test table", "columns": [ - { - "name": "_cq_source_name", - "type": "utf8" - }, - { - "name": "_cq_sync_time", - "type": "timestamp[us, tz=UTC]" - }, - { - "name": "_cq_id", - "type": "uuid" - }, - { - "name": "_cq_parent_id", - "type": "uuid" - }, { "name": "int_col", "type": "int64" @@ -96,23 +64,6 @@ "title": "Relation Table", "description": "Description for relational table", "columns": [ - { - "name": "_cq_source_name", - "type": "utf8" - }, - { - "name": "_cq_sync_time", - "type": "timestamp[us, tz=UTC]" - }, - { - "name": "_cq_id", - "type": "uuid", - "is_primary_key": true - }, - { - "name": "_cq_parent_id", - "type": "uuid" - }, { "name": "string_col", "type": "utf8" @@ -124,23 +75,6 @@ "title": "Relation Relation Table A", "description": "Description for relational table's relation", "columns": [ - { - "name": "_cq_source_name", - "type": "utf8" - }, - { - "name": "_cq_sync_time", - "type": "timestamp[us, tz=UTC]" - }, - { - "name": "_cq_id", - "type": "uuid", - "is_primary_key": true - }, - { - "name": "_cq_parent_id", - "type": "uuid" - }, { "name": "string_col", "type": "utf8" @@ -153,23 +87,6 @@ "title": "Relation Relation Table B", "description": "Description for relational table's relation", "columns": [ - { - "name": "_cq_source_name", - "type": "utf8" - }, - { - "name": "_cq_sync_time", - "type": "timestamp[us, tz=UTC]" - }, - { - "name": "_cq_id", - "type": "uuid", - "is_primary_key": true - }, - { - "name": "_cq_parent_id", - "type": "uuid" - }, { "name": "string_col", "type": "utf8" @@ -184,23 +101,6 @@ "title": "Relation Table2", "description": "Description for second relational table", "columns": [ - { - "name": "_cq_source_name", - "type": "utf8" - }, - { - "name": "_cq_sync_time", - "type": "timestamp[us, tz=UTC]" - }, - { - "name": "_cq_id", - "type": "uuid", - "is_primary_key": true - }, - { - "name": "_cq_parent_id", - "type": "uuid" - }, { "name": "string_col", "type": "utf8" diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-README.md b/docs/testdata/TestGeneratePluginDocs-Markdown-README.md similarity index 100% rename from plugins/source/testdata/TestGeneratePluginDocs-Markdown-README.md rename to docs/testdata/TestGeneratePluginDocs-Markdown-README.md diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md b/docs/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md similarity index 61% rename from plugins/source/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md rename to docs/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md index d0b1530577..4148e838eb 100644 --- a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md +++ b/docs/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md @@ -11,10 +11,6 @@ It supports incremental syncs based on the (**id_col**, **id_col2**) columns. | Name | Type | | ------------- | ------------- | -|_cq_source_name|`utf8`| -|_cq_sync_time|`timestamp[us, tz=UTC]`| -|_cq_id|`uuid`| -|_cq_parent_id|`uuid`| -|int_col|`int64`| -|id_col (PK) (Incremental Key)|`int64`| -|id_col2 (Incremental Key)|`int64`| +|int_col|int64| +|id_col (PK) (Incremental Key)|int64| +|id_col2 (Incremental Key)|int64| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md b/docs/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md similarity index 62% rename from plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md rename to docs/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md index 9ee22d1ba1..1c0b8b63c8 100644 --- a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md +++ b/docs/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md @@ -4,7 +4,7 @@ This table shows data for Relation Relation Table A. Description for relational table's relation -The primary key for this table is **_cq_id**. +The composite primary key for this table is (). ## Relations @@ -14,8 +14,4 @@ This table depends on [relation_table](relation_table.md). | Name | Type | | ------------- | ------------- | -|_cq_source_name|`utf8`| -|_cq_sync_time|`timestamp[us, tz=UTC]`| -|_cq_id (PK)|`uuid`| -|_cq_parent_id|`uuid`| -|string_col|`utf8`| +|string_col|utf8| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md b/docs/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md similarity index 62% rename from plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md rename to docs/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md index f6d68a71e1..77dce363dc 100644 --- a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md +++ b/docs/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md @@ -4,7 +4,7 @@ This table shows data for Relation Relation Table B. Description for relational table's relation -The primary key for this table is **_cq_id**. +The composite primary key for this table is (). ## Relations @@ -14,8 +14,4 @@ This table depends on [relation_table](relation_table.md). | Name | Type | | ------------- | ------------- | -|_cq_source_name|`utf8`| -|_cq_sync_time|`timestamp[us, tz=UTC]`| -|_cq_id (PK)|`uuid`| -|_cq_parent_id|`uuid`| -|string_col|`utf8`| +|string_col|utf8| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_table.md b/docs/testdata/TestGeneratePluginDocs-Markdown-relation_table.md similarity index 70% rename from plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_table.md rename to docs/testdata/TestGeneratePluginDocs-Markdown-relation_table.md index 95c4125aa7..96b152a8fe 100644 --- a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_table.md +++ b/docs/testdata/TestGeneratePluginDocs-Markdown-relation_table.md @@ -4,7 +4,7 @@ This table shows data for Relation Table. Description for relational table -The primary key for this table is **_cq_id**. +The composite primary key for this table is (). ## Relations @@ -18,8 +18,4 @@ The following tables depend on relation_table: | Name | Type | | ------------- | ------------- | -|_cq_source_name|`utf8`| -|_cq_sync_time|`timestamp[us, tz=UTC]`| -|_cq_id (PK)|`uuid`| -|_cq_parent_id|`uuid`| -|string_col|`utf8`| +|string_col|utf8| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-test_table.md b/docs/testdata/TestGeneratePluginDocs-Markdown-test_table.md similarity index 53% rename from plugins/source/testdata/TestGeneratePluginDocs-Markdown-test_table.md rename to docs/testdata/TestGeneratePluginDocs-Markdown-test_table.md index cdd1df3317..089a0b4b3e 100644 --- a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-test_table.md +++ b/docs/testdata/TestGeneratePluginDocs-Markdown-test_table.md @@ -16,14 +16,10 @@ The following tables depend on test_table: | Name | Type | | ------------- | ------------- | -|_cq_source_name|`utf8`| -|_cq_sync_time|`timestamp[us, tz=UTC]`| -|_cq_id|`uuid`| -|_cq_parent_id|`uuid`| -|int_col|`int64`| -|id_col (PK)|`int64`| -|id_col2 (PK)|`int64`| -|json_col|`json`| -|list_col|`list`| -|map_col|`map`| -|struct_col|`struct`| +|int_col|int64| +|id_col (PK)|int64| +|id_col2 (PK)|int64| +|json_col|json| +|list_col|list| +|map_col|map| +|struct_col|struct| diff --git a/internal/glob/LICENSE b/glob/LICENSE similarity index 100% rename from internal/glob/LICENSE rename to glob/LICENSE diff --git a/internal/glob/README.md b/glob/README.md similarity index 100% rename from internal/glob/README.md rename to glob/README.md diff --git a/internal/glob/glob.go b/glob/glob.go similarity index 85% rename from internal/glob/glob.go rename to glob/glob.go index e67db3be18..b4fd6535db 100644 --- a/internal/glob/glob.go +++ b/glob/glob.go @@ -5,6 +5,20 @@ import "strings" // The character which is treated like a glob const GLOB = "*" +func IncludeTable(name string, tables []string, skipTables []string) bool { + for _, t := range skipTables { + if Glob(t, name) { + return false + } + } + for _, t := range tables { + if Glob(t, name) { + return true + } + } + return false +} + // Glob will test a string pattern, potentially containing globs, against a // subject string. The result is a simple true/false, determining whether or // not the glob pattern matched the subject text. diff --git a/internal/glob/glob_test.go b/glob/glob_test.go similarity index 100% rename from internal/glob/glob_test.go rename to glob/glob_test.go diff --git a/go.mod b/go.mod index 29f76cd3b1..f940e8ae0a 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/cloudquery/plugin-sdk/v3 +module github.com/cloudquery/plugin-sdk/v4 go 1.19 @@ -14,12 +14,10 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2 v2.0.0-rc.3 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0-rc.3 github.com/rs/zerolog v1.29.1 - github.com/spf13/cast v1.5.0 github.com/spf13/cobra v1.6.1 github.com/stretchr/testify v1.8.4 github.com/thoas/go-funk v0.9.3 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 - golang.org/x/net v0.9.0 golang.org/x/sync v0.1.0 golang.org/x/text v0.9.0 google.golang.org/grpc v1.55.0 @@ -43,6 +41,7 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/mod v0.8.0 // indirect + golang.org/x/net v0.9.0 // indirect golang.org/x/sys v0.7.0 // indirect golang.org/x/tools v0.6.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/go.sum b/go.sum index 0067abecab..76f7b2023a 100644 --- a/go.sum +++ b/go.sum @@ -59,7 +59,6 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= github.com/getsentry/sentry-go v0.20.0 h1:bwXW98iMRIWxn+4FgPW7vMrjmbym6HblXALmhjHmQaQ= github.com/getsentry/sentry-go v0.20.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= @@ -148,8 +147,8 @@ github.com/klauspost/compress v1.16.0 h1:iULayQNOReoYUe+1qtKOqw9CwJv3aNQu8ivo7lw github.com/klauspost/compress v1.16.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -171,15 +170,12 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.19.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= -github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU= github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/helpers/integers.go b/helpers/integers.go deleted file mode 100644 index a539552377..0000000000 --- a/helpers/integers.go +++ /dev/null @@ -1,19 +0,0 @@ -package helpers - -import "math" - -// Uint64ToInt64 if value is greater than math.MaxInt64 return math.MaxInt64 -// otherwise returns original value cast to int64 -func Uint64ToInt64(i uint64) int64 { - if i > math.MaxInt64 { - return math.MaxInt64 - } - return int64(i) -} - -func Uint64ToInt(i uint64) int { - if i > math.MaxInt { - return math.MaxInt - } - return int(i) -} diff --git a/helpers/pointers.go b/helpers/pointers.go deleted file mode 100644 index 2f5a008535..0000000000 --- a/helpers/pointers.go +++ /dev/null @@ -1,20 +0,0 @@ -package helpers - -import "reflect" - -// ToPointer takes an any object and will return a pointer to this object -// if the object is not already a pointer. Otherwise, it will return the original value. -// It is safe to typecast the return-value of GetPointer into a pointer of the right type, -// except in very special cases (such as passing in nil without an explicit type) -func ToPointer(v any) any { - val := reflect.ValueOf(v) - if val.Kind() == reflect.Ptr { - return v - } - if !val.IsValid() { - return v - } - p := reflect.New(val.Type()) - p.Elem().Set(val) - return p.Interface() -} diff --git a/helpers/pointers_test.go b/helpers/pointers_test.go deleted file mode 100644 index 2ae81ed7a3..0000000000 --- a/helpers/pointers_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package helpers - -import ( - "testing" -) - -type testStruct struct { - test string -} - -func TestToPointer(t *testing.T) { - // passing string should return pointer to string - give := "test" - got := ToPointer(give) - if _, ok := got.(*string); !ok { - t.Errorf("ToPointer(%q) returned %q, expected type *string", give, got) - } - - // passing struct by value should return pointer to (copy of the) struct - giveObj := testStruct{ - test: "value", - } - gotStruct := ToPointer(giveObj) - if _, ok := gotStruct.(*testStruct); !ok { - t.Errorf("ToPointer(%q) returned %q, expected type *testStruct", giveObj, gotStruct) - } - - // passing pointer should return the original pointer - ptr := &giveObj - gotPointer := ToPointer(ptr) - if gotPointer != ptr { - t.Errorf("ToPointer(%q) returned %q, expected %q", ptr, gotPointer, ptr) - } - - // passing nil should return nil back without panicking - gotNil := ToPointer(nil) - if gotNil != nil { - t.Errorf("ToPointer(%v) returned %q, expected nil", nil, gotNil) - } - - // passing number should return pointer to number - giveNumber := int64(0) - gotNumber := ToPointer(giveNumber) - if v, ok := gotNumber.(*int64); !ok { - t.Errorf("ToPointer(%q) returned %q, expected type *int64", giveNumber, gotNumber) - if *v != 0 { - t.Errorf("ToPointer(%q) returned %q, expected 0", giveNumber, gotNumber) - } - } -} diff --git a/helpers/strings.go b/helpers/strings.go deleted file mode 100644 index e522a3c5ea..0000000000 --- a/helpers/strings.go +++ /dev/null @@ -1,39 +0,0 @@ -package helpers - -import ( - "fmt" - "sort" - "strings" - - "github.com/spf13/cast" -) - -func FormatSlice(a []string) string { - // sort slice for consistency - sort.Strings(a) - q := make([]string, len(a)) - for i, s := range a { - q[i] = fmt.Sprintf("%q", s) - } - return fmt.Sprintf("[\n\t%s\n]", strings.Join(q, ",\n\t")) -} - -func HasDuplicates(resources []string) bool { - dups := make(map[string]bool, len(resources)) - for _, r := range resources { - if _, ok := dups[r]; ok { - return true - } - dups[r] = true - } - return false -} - -func ToStringSliceE(i any) ([]string, error) { - switch v := i.(type) { - case *[]string: - return cast.ToStringSliceE(*v) - default: - return cast.ToStringSliceE(i) - } -} diff --git a/helpers/strings_test.go b/helpers/strings_test.go deleted file mode 100644 index 991492df8e..0000000000 --- a/helpers/strings_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package helpers - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestHasDuplicates(t *testing.T) { - assert.False(t, HasDuplicates([]string{"A", "b", "c"})) - assert.False(t, HasDuplicates([]string{"A", "a", "c"})) - assert.True(t, HasDuplicates([]string{"a", "a", "c"})) - assert.True(t, HasDuplicates([]string{"a", "a", "c", "c", "f"})) -} - -func TestToStingSliceE(t *testing.T) { - arr := &[]string{"a", "b", "c"} - newArr, _ := ToStringSliceE(arr) - assert.Equal(t, newArr, []string{"a", "b", "c"}) -} diff --git a/internal/backends/local/local.go b/internal/backends/local/local.go deleted file mode 100644 index 0593d8b0b0..0000000000 --- a/internal/backends/local/local.go +++ /dev/null @@ -1,157 +0,0 @@ -package local - -import ( - "context" - "encoding/json" - "fmt" - "io" - "os" - "path" - "strings" - "sync" - - "github.com/cloudquery/plugin-pb-go/specs" -) - -type Local struct { - sourceName string - spec Spec - tables map[string]entries // table -> key -> value - tablesLock sync.RWMutex -} - -type entries map[string]string - -func New(sourceSpec specs.Source) (*Local, error) { - spec := Spec{} - err := sourceSpec.UnmarshalBackendSpec(&spec) - if err != nil { - return nil, err - } - spec.SetDefaults() - - l := &Local{ - sourceName: sourceSpec.Name, - spec: spec, - } - tables, err := l.loadPreviousState() - if err != nil { - return nil, err - } - if tables == nil { - tables = map[string]entries{} - } - l.tables = tables - return l, nil -} - -func (l *Local) loadPreviousState() (map[string]entries, error) { - files, err := os.ReadDir(l.spec.Path) - if os.IsNotExist(err) { - return nil, nil - } - var tables = map[string]entries{} - for _, f := range files { - if f.IsDir() || !f.Type().IsRegular() { - continue - } - name := f.Name() - if !strings.HasSuffix(name, ".json") || !strings.HasPrefix(name, l.sourceName+"-") { - continue - } - table, kv, err := l.readFile(name) - if err != nil { - return nil, err - } - tables[table] = kv - } - return tables, nil -} - -func (l *Local) readFile(name string) (table string, kv entries, err error) { - p := path.Join(l.spec.Path, name) - f, err := os.Open(p) - if err != nil { - return "", nil, fmt.Errorf("failed to open state file: %w", err) - } - b, err := io.ReadAll(f) - if err != nil { - return "", nil, fmt.Errorf("failed to read state file: %w", err) - } - err = f.Close() - if err != nil { - return "", nil, fmt.Errorf("failed to close state file: %w", err) - } - err = json.Unmarshal(b, &kv) - if err != nil { - return "", nil, fmt.Errorf("failed to unmarshal state file: %w", err) - } - table = strings.TrimPrefix(strings.TrimSuffix(name, ".json"), l.sourceName+"-") - return table, kv, nil -} - -func (l *Local) Get(_ context.Context, table, clientID string) (string, error) { - l.tablesLock.RLock() - defer l.tablesLock.RUnlock() - - if _, ok := l.tables[table]; !ok { - return "", nil - } - return l.tables[table][clientID], nil -} - -func (l *Local) Set(_ context.Context, table, clientID, value string) error { - l.tablesLock.Lock() - defer l.tablesLock.Unlock() - - if _, ok := l.tables[table]; !ok { - l.tables[table] = map[string]string{} - } - prev := l.tables[table][clientID] - l.tables[table][clientID] = value - if prev != value { - // only flush if the value changed - return l.flushTable(table, l.tables[table]) - } - return nil -} - -func (l *Local) Close(_ context.Context) error { - l.tablesLock.RLock() - defer l.tablesLock.RUnlock() - - return l.flush() -} - -func (l *Local) flush() error { - for table, kv := range l.tables { - err := l.flushTable(table, kv) - if err != nil { - return err - } - } - return nil -} - -func (l *Local) flushTable(table string, entries entries) error { - if len(entries) == 0 { - return nil - } - - err := os.MkdirAll(l.spec.Path, 0755) - if err != nil { - return fmt.Errorf("failed to create state directory %v: %w", l.spec.Path, err) - } - - b, err := json.MarshalIndent(entries, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal state for table %v: %w", table, err) - } - f := path.Join(l.spec.Path, l.sourceName+"-"+table+".json") - err = os.WriteFile(f, b, 0644) - if err != nil { - return fmt.Errorf("failed to write state for table %v: %w", table, err) - } - - return nil -} diff --git a/internal/backends/local/local_test.go b/internal/backends/local/local_test.go deleted file mode 100644 index 4e3423f9d8..0000000000 --- a/internal/backends/local/local_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package local - -import ( - "context" - "testing" - - "github.com/cloudquery/plugin-pb-go/specs" -) - -func TestLocal(t *testing.T) { - tmpDir := t.TempDir() - ctx := context.Background() - ss := specs.Source{ - Name: "test", - Version: "vtest", - Path: "test", - Backend: specs.BackendLocal, - BackendSpec: Spec{ - Path: tmpDir, - }, - } - local, err := New(ss) - if err != nil { - t.Fatalf("failed to create local backend: %v", err) - } - if local.spec.Path != tmpDir { - t.Fatalf("expected path to be %s, but got %s", tmpDir, local.spec.Path) - } - - tableName := "test_table" - clientID := "test_client" - got, err := local.Get(ctx, tableName, clientID) - if err != nil { - t.Fatalf("failed to get value: %v", err) - } - if got != "" { - t.Fatalf("expected empty value, but got %s", got) - } - - err = local.Set(ctx, tableName, clientID, "test_value") - if err != nil { - t.Fatalf("failed to set value: %v", err) - } - - got, err = local.Get(ctx, tableName, clientID) - if err != nil { - t.Fatalf("failed to get value after setting it: %v", err) - } - if got != "test_value" { - t.Fatalf("expected value to be test_value, but got %s", got) - } - - err = local.Close(ctx) - if err != nil { - t.Fatalf("failed to close local backend: %v", err) - } - - local, err = New(ss) - if err != nil { - t.Fatalf("failed to open local backend after closing it: %v", err) - } - - got, err = local.Get(ctx, tableName, clientID) - if err != nil { - t.Fatalf("failed to get value after closing and reopening local backend: %v", err) - } - if got != "test_value" { - t.Fatalf("expected value to be test_value, but got %s", got) - } - - got, err = local.Get(ctx, "some_other_table", clientID) - if err != nil { - t.Fatalf("failed to get value after closing and reopening local backend: %v", err) - } - if got != "" { - t.Fatalf("expected empty value for some_other_table -> test_key, but got %s", got) - } - err = local.Close(ctx) - if err != nil { - t.Fatalf("failed to close local backend the second time: %v", err) - } - - // check that state is namespaced by source name - ss.Name = "test2" - local2, err := New(ss) - if err != nil { - t.Fatalf("failed to create local backend for test2: %v", err) - } - - got, err = local2.Get(ctx, "test_table", clientID) - if err != nil { - t.Fatalf("failed to get value for local backend test2: %v", err) - } - if got != "" { - t.Fatalf("expected empty value for test2 -> test_table -> test_key, but got %s", got) - } - err = local2.Close(ctx) - if err != nil { - t.Fatalf("failed to close second local backend: %v", err) - } -} diff --git a/internal/backends/local/spec.go b/internal/backends/local/spec.go deleted file mode 100644 index f2b7040c1d..0000000000 --- a/internal/backends/local/spec.go +++ /dev/null @@ -1,12 +0,0 @@ -package local - -type Spec struct { - // Path is the path to the local directory. - Path string `json:"path"` -} - -func (s *Spec) SetDefaults() { - if s.Path == "" { - s.Path = ".cq/state" - } -} diff --git a/internal/backends/nop/nop.go b/internal/backends/nop/nop.go deleted file mode 100644 index 45e713608a..0000000000 --- a/internal/backends/nop/nop.go +++ /dev/null @@ -1,23 +0,0 @@ -package nop - -import "context" - -func New() *Backend { - return &Backend{} -} - -// Backend can be used in cases where no backend is specified to avoid the need to check for nil -// pointers in all resolvers. -type Backend struct{} - -func (*Backend) Set(_ context.Context, _, _, _ string) error { - return nil -} - -func (*Backend) Get(_ context.Context, _, _ string) (string, error) { - return "", nil -} - -func (*Backend) Close(_ context.Context) error { - return nil -} diff --git a/internal/clients/state/v3/state.go b/internal/clients/state/v3/state.go new file mode 100644 index 0000000000..fc713ed9a2 --- /dev/null +++ b/internal/clients/state/v3/state.go @@ -0,0 +1,183 @@ +package state + +import ( + "bytes" + "context" + "fmt" + "io" + "sync" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/ipc" + "github.com/apache/arrow/go/v13/arrow/memory" + pb "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +const keyColumn = "key" +const valueColumn = "value" + +type Client struct { + client pb.PluginClient + tableName string + mem map[string]string + mutex *sync.RWMutex + keys []string + values []string + schema *arrow.Schema +} + +func NewClient(ctx context.Context, pbClient pb.PluginClient, tableName string) (*Client, error) { + c := &Client{ + client: pbClient, + tableName: tableName, + mem: make(map[string]string), + mutex: &sync.RWMutex{}, + keys: make([]string, 0), + values: make([]string, 0), + } + table := &schema.Table{ + Name: tableName, + Columns: []schema.Column{ + { + Name: keyColumn, + Type: arrow.BinaryTypes.String, + PrimaryKey: true, + }, + { + Name: valueColumn, + Type: arrow.BinaryTypes.String, + }, + }, + } + sc := table.ToArrowSchema() + c.schema = sc + tableBytes, err := pb.SchemaToBytes(sc) + if err != nil { + return nil, err + } + + writeClient, err := c.client.Write(ctx) + if err != nil { + return nil, err + } + if err := writeClient.Send(&pb.Write_Request{ + Message: &pb.Write_Request_Options{ + Options: &pb.WriteOptions{MigrateForce: false}, + }, + }); err != nil { + return nil, err + } + if err := writeClient.Send(&pb.Write_Request{ + Message: &pb.Write_Request_MigrateTable{ + MigrateTable: &pb.MessageMigrateTable{ + Table: tableBytes, + }, + }, + }); err != nil { + return nil, err + } + + syncClient, err := c.client.Sync(ctx, &pb.Sync_Request{ + Tables: []string{tableName}, + }) + if err != nil { + return nil, err + } + c.mutex.Lock() + defer c.mutex.Unlock() + for { + res, err := syncClient.Recv() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + var insertMessage *pb.Sync_Response_Insert + switch m := res.Message.(type) { + case *pb.Sync_Response_Delete: + continue + case *pb.Sync_Response_MigrateTable: + continue + case *pb.Sync_Response_Insert: + insertMessage = m + } + rdr, err := ipc.NewReader(bytes.NewReader(insertMessage.Insert.Record)) + if err != nil { + return nil, err + } + for { + record, err := rdr.Read() + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + if record.NumRows() == 0 { + continue + } + keys := record.Columns()[0].(*array.String) + values := record.Columns()[1].(*array.String) + for i := 0; i < keys.Len(); i++ { + c.mem[keys.Value(i)] = values.Value(i) + } + } + } + return c, nil +} + +func (c *Client) SetKey(_ context.Context, key string, value string) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.mem[key] = value + return nil +} + +func (c *Client) Flush(ctx context.Context) error { + c.mutex.RLock() + defer c.mutex.RUnlock() + bldr := array.NewRecordBuilder(memory.DefaultAllocator, c.schema) + for k, v := range c.mem { + bldr.Field(0).(*array.StringBuilder).Append(k) + bldr.Field(1).(*array.StringBuilder).Append(v) + } + rec := bldr.NewRecord() + recordBytes, err := pb.RecordToBytes(rec) + if err != nil { + return err + } + writeClient, err := c.client.Write(ctx) + if err != nil { + return err + } + if err := writeClient.Send(&pb.Write_Request{ + Message: &pb.Write_Request_Options{}, + }); err != nil { + return err + } + if err := writeClient.Send(&pb.Write_Request{ + Message: &pb.Write_Request_Insert{ + Insert: &pb.MessageInsert{ + Record: recordBytes, + }, + }, + }); err != nil { + return err + } + if _, err := writeClient.CloseAndRecv(); err != nil { + return err + } + return nil +} + +func (c *Client) GetKey(_ context.Context, key string) (string, error) { + c.mutex.RLock() + defer c.mutex.RUnlock() + if val, ok := c.mem[key]; ok { + return val, nil + } + return "", fmt.Errorf("key not found") +} diff --git a/internal/clients/state/v3/state_test.go b/internal/clients/state/v3/state_test.go new file mode 100644 index 0000000000..ab446d55fc --- /dev/null +++ b/internal/clients/state/v3/state_test.go @@ -0,0 +1,3 @@ +package state + +// Note: State is tested under serve/state_test.go with a real plugin server. diff --git a/internal/memdb/memdb.go b/internal/memdb/memdb.go index 9c6bbb74d1..2f7ce4fb5e 100644 --- a/internal/memdb/memdb.go +++ b/internal/memdb/memdb.go @@ -3,22 +3,18 @@ package memdb import ( "context" "fmt" - "os" "sync" - "testing" - "time" "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/plugins/destination" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/rs/zerolog" ) // client is mostly used for testing the destination plugin. type client struct { - spec specs.Destination memoryDB map[string][]arrow.Record tables map[string]*schema.Table memoryDBLock sync.RWMutex @@ -28,6 +24,9 @@ type client struct { type Option func(*client) +type Spec struct { +} + func WithErrOnWrite() Option { return func(c *client) { c.errOnWrite = true @@ -40,42 +39,36 @@ func WithBlockingWrite() Option { } } -func GetNewClient(options ...Option) destination.NewClientFunc { +func GetNewClient(options ...Option) plugin.NewClientFunc { c := &client{ memoryDB: make(map[string][]arrow.Record), memoryDBLock: sync.RWMutex{}, + tables: make(map[string]*schema.Table), } for _, opt := range options { opt(c) } - return func(context.Context, zerolog.Logger, specs.Destination) (destination.Client, error) { + return func(context.Context, zerolog.Logger, []byte) (plugin.Client, error) { return c, nil } } -func getTestLogger(t *testing.T) zerolog.Logger { - t.Helper() - zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs - return zerolog.New(zerolog.NewTestWriter(t)).Output( - zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.StampMicro}, - ).Level(zerolog.DebugLevel).With().Timestamp().Logger() +func NewMemDBClient(ctx context.Context, l zerolog.Logger, spec []byte) (plugin.Client, error) { + return GetNewClient()(ctx, l, spec) } -func NewClient(_ context.Context, _ zerolog.Logger, spec specs.Destination) (destination.Client, error) { - return &client{ - memoryDB: make(map[string][]arrow.Record), - tables: make(map[string]*schema.Table), - spec: spec, - }, nil -} - -func NewClientErrOnNew(context.Context, zerolog.Logger, specs.Destination) (destination.Client, error) { +func NewMemDBClientErrOnNew(context.Context, zerolog.Logger, []byte) (plugin.Client, error) { return nil, fmt.Errorf("newTestDestinationMemDBClientErrOnNew") } func (c *client) overwrite(table *schema.Table, data arrow.Record) { - pksIndex := table.PrimaryKeysIndexes() tableName := table.Name + pksIndex := table.PrimaryKeysIndexes() + if len(pksIndex) == 0 { + c.memoryDB[tableName] = append(c.memoryDB[tableName], data) + return + } + for i, row := range c.memoryDB[tableName] { found := true for _, pkIndex := range pksIndex { @@ -94,83 +87,69 @@ func (c *client) overwrite(table *schema.Table, data arrow.Record) { c.memoryDB[tableName] = append(c.memoryDB[tableName], data) } -func (c *client) Migrate(_ context.Context, tables schema.Tables) error { - for _, table := range tables { - tableName := table.Name - memTable := c.memoryDB[tableName] - if memTable == nil { - c.memoryDB[tableName] = make([]arrow.Record, 0) - c.tables[tableName] = table - continue - } +func (*client) ID() string { + return "testDestinationMemDB" +} - changes := table.GetChanges(c.tables[tableName]) - // memdb doesn't support any auto-migrate - if changes == nil { - continue - } - c.memoryDB[tableName] = make([]arrow.Record, 0) - c.tables[tableName] = table - } - return nil +func (*client) GetSpec() any { + return &Spec{} } -func (c *client) Read(_ context.Context, table *schema.Table, source string, res chan<- arrow.Record) error { +func (c *client) Read(_ context.Context, table *schema.Table, res chan<- arrow.Record) error { + c.memoryDBLock.RLock() + defer c.memoryDBLock.RUnlock() + tableName := table.Name - if c.memoryDB[tableName] == nil { - return nil - } - sourceColIndex := table.Columns.Index(schema.CqSourceNameColumn.Name) - if sourceColIndex == -1 { - return fmt.Errorf("table %s doesn't have source column", tableName) + for _, row := range c.memoryDB[tableName] { + res <- row } - var sortedRes []arrow.Record + return nil +} + +func (c *client) Sync(_ context.Context, options plugin.SyncOptions, res chan<- message.Message) error { c.memoryDBLock.RLock() - for _, row := range c.memoryDB[tableName] { - arr := row.Column(sourceColIndex) - if arr.(*array.String).Value(0) == source { - sortedRes = append(sortedRes, row) + + for tableName := range c.memoryDB { + if !plugin.MatchesTable(tableName, options.Tables, options.SkipTables) { + continue + } + for _, row := range c.memoryDB[tableName] { + res <- &message.Insert{ + Record: row, + } } } c.memoryDBLock.RUnlock() - - for _, row := range sortedRes { - res <- row - } return nil } -func (c *client) Write(ctx context.Context, _ schema.Tables, resources <-chan arrow.Record) error { - if c.errOnWrite { - return fmt.Errorf("errOnWrite") +func (c *client) Tables(_ context.Context) (schema.Tables, error) { + tables := make(schema.Tables, 0, len(c.tables)) + for _, table := range c.tables { + tables = append(tables, table) } - if c.blockingWrite { - <-ctx.Done() - if c.errOnWrite { - return fmt.Errorf("errOnWrite") - } - return nil + return tables, nil +} + +func (c *client) migrate(_ context.Context, table *schema.Table) { + tableName := table.Name + memTable := c.memoryDB[tableName] + if memTable == nil { + c.memoryDB[tableName] = make([]arrow.Record, 0) + c.tables[tableName] = table + return } - for resource := range resources { - c.memoryDBLock.Lock() - sc := resource.Schema() - tableName, ok := sc.Metadata().GetValue(schema.MetadataTableName) - if !ok { - return fmt.Errorf("table name not found in schema metadata") - } - table := c.tables[tableName] - if c.spec.WriteMode == specs.WriteModeAppend { - c.memoryDB[tableName] = append(c.memoryDB[tableName], resource) - } else { - c.overwrite(table, resource) - } - c.memoryDBLock.Unlock() + changes := table.GetChanges(c.tables[tableName]) + // memdb doesn't support any auto-migrate + if changes == nil { + return } - return nil + c.memoryDB[tableName] = make([]arrow.Record, 0) + c.tables[tableName] = table } -func (c *client) WriteTableBatch(ctx context.Context, table *schema.Table, resources []arrow.Record) error { +func (c *client) Write(ctx context.Context, _ plugin.WriteOptions, msgs <-chan message.Message) error { if c.errOnWrite { return fmt.Errorf("errOnWrite") } @@ -181,44 +160,54 @@ func (c *client) WriteTableBatch(ctx context.Context, table *schema.Table, resou } return nil } - tableName := table.Name - for _, resource := range resources { + + for msg := range msgs { c.memoryDBLock.Lock() - if c.spec.WriteMode == specs.WriteModeAppend { - c.memoryDB[tableName] = append(c.memoryDB[tableName], resource) - } else { - c.overwrite(table, resource) + + switch msg := msg.(type) { + case *message.MigrateTable: + c.migrate(ctx, msg.Table) + case *message.DeleteStale: + c.deleteStale(ctx, msg) + case *message.Insert: + sc := msg.Record.Schema() + tableName, ok := sc.Metadata().GetValue(schema.MetadataTableName) + if !ok { + return fmt.Errorf("table name not found in schema metadata") + } + table := c.tables[tableName] + c.overwrite(table, msg.Record) } + c.memoryDBLock.Unlock() } return nil } -func (*client) Metrics() destination.Metrics { - return destination.Metrics{} -} - func (c *client) Close(context.Context) error { c.memoryDB = nil return nil } -func (c *client) DeleteStale(ctx context.Context, tables schema.Tables, source string, syncTime time.Time) error { - for _, table := range tables { - c.deleteStaleTable(ctx, table, source, syncTime) - } - return nil -} - -func (c *client) deleteStaleTable(_ context.Context, table *schema.Table, source string, syncTime time.Time) { - sourceColIndex := table.Columns.Index(schema.CqSourceNameColumn.Name) - syncColIndex := table.Columns.Index(schema.CqSyncTimeColumn.Name) - tableName := table.Name +func (c *client) deleteStale(_ context.Context, msg *message.DeleteStale) { var filteredTable []arrow.Record + tableName := msg.Table.Name for i, row := range c.memoryDB[tableName] { - if row.Column(sourceColIndex).(*array.String).Value(0) == source { + sc := row.Schema() + indices := sc.FieldIndices(schema.CqSourceNameColumn.Name) + if len(indices) == 0 { + continue + } + sourceColIndex := indices[0] + indices = sc.FieldIndices(schema.CqSyncTimeColumn.Name) + if len(indices) == 0 { + continue + } + syncColIndex := indices[0] + + if row.Column(sourceColIndex).(*array.String).Value(0) == msg.SourceName { rowSyncTime := row.Column(syncColIndex).(*array.Timestamp).Value(0).ToTime(arrow.Microsecond).UTC() - if !rowSyncTime.Before(syncTime) { + if !rowSyncTime.Before(msg.SyncTime) { filteredTable = append(filteredTable, c.memoryDB[tableName][i]) } } diff --git a/internal/memdb/memdb_test.go b/internal/memdb/memdb_test.go index 7f9e8a5759..cd06eaa230 100644 --- a/internal/memdb/memdb_test.go +++ b/internal/memdb/memdb_test.go @@ -3,102 +3,29 @@ package memdb import ( "context" "testing" - "time" - "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/plugins/destination" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/google/uuid" - "github.com/rs/zerolog" - "github.com/stretchr/testify/require" + "github.com/cloudquery/plugin-sdk/v4/plugin" ) -var migrateStrategyOverwrite = destination.MigrateStrategy{ - AddColumn: specs.MigrateModeForced, - AddColumnNotNull: specs.MigrateModeForced, - RemoveColumn: specs.MigrateModeForced, - RemoveColumnNotNull: specs.MigrateModeForced, - ChangeColumn: specs.MigrateModeForced, -} - -var migrateStrategyAppend = destination.MigrateStrategy{ - AddColumn: specs.MigrateModeForced, - AddColumnNotNull: specs.MigrateModeForced, - RemoveColumn: specs.MigrateModeForced, - RemoveColumnNotNull: specs.MigrateModeForced, - ChangeColumn: specs.MigrateModeForced, -} - -func TestPluginUnmanagedClient(t *testing.T) { - destination.PluginTestSuiteRunner( +func TestPlugin(t *testing.T) { + ctx := context.Background() + p := plugin.NewPlugin("test", "development", NewMemDBClient) + if err := p.Init(ctx, nil); err != nil { + t.Fatal(err) + } + plugin.TestWriterSuiteRunner( t, - func() *destination.Plugin { - return destination.NewPlugin("test", "development", NewClient) - }, - specs.Destination{}, - destination.PluginTestSuiteTests{ - MigrateStrategyOverwrite: migrateStrategyOverwrite, - MigrateStrategyAppend: migrateStrategyAppend, + p, + plugin.WriterTestSuiteTests{ + SafeMigrations: plugin.SafeMigrations{}, }, ) } -func TestPluginManagedClient(t *testing.T) { - destination.PluginTestSuiteRunner(t, - func() *destination.Plugin { - return destination.NewPlugin("test", "development", NewClient, destination.WithManagedWriter()) - }, - specs.Destination{}, - destination.PluginTestSuiteTests{ - MigrateStrategyOverwrite: migrateStrategyOverwrite, - MigrateStrategyAppend: migrateStrategyAppend, - }) -} - -func TestPluginManagedClientWithSmallBatchSize(t *testing.T) { - destination.PluginTestSuiteRunner(t, - func() *destination.Plugin { - return destination.NewPlugin("test", "development", NewClient, destination.WithManagedWriter(), - destination.WithDefaultBatchSize(1), - destination.WithDefaultBatchSizeBytes(1)) - }, specs.Destination{}, - destination.PluginTestSuiteTests{ - MigrateStrategyOverwrite: migrateStrategyOverwrite, - MigrateStrategyAppend: migrateStrategyAppend, - }) -} - -func TestPluginManagedClientWithLargeBatchSize(t *testing.T) { - destination.PluginTestSuiteRunner(t, - func() *destination.Plugin { - return destination.NewPlugin("test", "development", NewClient, destination.WithManagedWriter(), - destination.WithDefaultBatchSize(100000000), - destination.WithDefaultBatchSizeBytes(100000000)) - }, - specs.Destination{}, - destination.PluginTestSuiteTests{ - MigrateStrategyOverwrite: migrateStrategyOverwrite, - MigrateStrategyAppend: migrateStrategyAppend, - }) -} - -func TestPluginManagedClientWithCQPKs(t *testing.T) { - destination.PluginTestSuiteRunner(t, - func() *destination.Plugin { - return destination.NewPlugin("test", "development", NewClient) - }, - specs.Destination{PKMode: specs.PKModeCQID}, - destination.PluginTestSuiteTests{ - MigrateStrategyOverwrite: migrateStrategyOverwrite, - MigrateStrategyAppend: migrateStrategyAppend, - }) -} - func TestPluginOnNewError(t *testing.T) { ctx := context.Background() - p := destination.NewPlugin("test", "development", NewClientErrOnNew) - err := p.Init(ctx, getTestLogger(t), specs.Destination{}) + p := plugin.NewPlugin("test", "development", NewMemDBClientErrOnNew) + err := p.Init(ctx, nil) if err == nil { t.Fatal("expected error") @@ -108,94 +35,46 @@ func TestPluginOnNewError(t *testing.T) { func TestOnWriteError(t *testing.T) { ctx := context.Background() newClientFunc := GetNewClient(WithErrOnWrite()) - p := destination.NewPlugin("test", "development", newClientFunc) - if err := p.Init(ctx, getTestLogger(t), specs.Destination{}); err != nil { - t.Fatal(err) - } - table := schema.TestTable("test", schema.TestSourceOptions{}) - tables := schema.Tables{ - table, - } - sourceName := "TestDestinationOnWriteError" - syncTime := time.Now() - sourceSpec := specs.Source{ - Name: sourceName, - } - ch := make(chan arrow.Record, 1) - opts := schema.GenTestDataOptions{ - SourceName: "test", - SyncTime: time.Now(), - MaxRows: 1, - StableUUID: uuid.Nil, - } - record := schema.GenTestData(table, opts)[0] - ch <- record - close(ch) - err := p.Write(ctx, sourceSpec, tables, syncTime, ch) - if err == nil { - t.Fatal("expected error") - } - if err.Error() != "errOnWrite" { - t.Fatalf("expected errOnWrite, got %s", err.Error()) - } -} - -func TestOnWriteCtxCancelled(t *testing.T) { - ctx := context.Background() - newClientFunc := GetNewClient(WithBlockingWrite()) - p := destination.NewPlugin("test", "development", newClientFunc) - if err := p.Init(ctx, getTestLogger(t), specs.Destination{}); err != nil { + p := plugin.NewPlugin("test", "development", newClientFunc) + if err := p.Init(ctx, nil); err != nil { t.Fatal(err) } - table := schema.TestTable("test", schema.TestSourceOptions{}) - tables := schema.Tables{ - table, - } - sourceName := "TestDestinationOnWriteError" - syncTime := time.Now() - sourceSpec := specs.Source{ - Name: sourceName, - } - ch := make(chan arrow.Record, 1) - ctx, cancel := context.WithTimeout(ctx, 2*time.Second) - opts := schema.GenTestDataOptions{ - SourceName: "test", - SyncTime: time.Now(), - MaxRows: 1, - StableUUID: uuid.Nil, - } - record := schema.GenTestData(table, opts)[0] - ch <- record - defer cancel() - err := p.Write(ctx, sourceSpec, tables, syncTime, ch) - if err != nil { - t.Fatal(err) + if err := p.WriteAll(ctx, plugin.WriteOptions{}, nil); err.Error() != "errOnWrite" { + t.Fatalf("expected errOnWrite, got %s", err) } } -func TestPluginInit(t *testing.T) { - const ( - batchSize = 100 - batchSizeBytes = 1000 - ) - - var ( - batchSizeObserved int - batchSizeBytesObserved int - ) - p := destination.NewPlugin( - "test", - "development", - func(ctx context.Context, logger zerolog.Logger, s specs.Destination) (destination.Client, error) { - batchSizeObserved = s.BatchSize - batchSizeBytesObserved = s.BatchSizeBytes - return NewClient(ctx, logger, s) - }, - destination.WithDefaultBatchSize(batchSize), - destination.WithDefaultBatchSizeBytes(batchSizeBytes), - ) - require.NoError(t, p.Init(context.TODO(), getTestLogger(t), specs.Destination{})) - - require.Equal(t, batchSize, batchSizeObserved) - require.Equal(t, batchSizeBytes, batchSizeBytesObserved) -} +// func TestOnWriteCtxCancelled(t *testing.T) { +// ctx := context.Background() +// newClientFunc := GetNewClient(WithBlockingWrite()) +// p := plugin.NewPlugin("test", "development", newClientFunc) +// if err := p.Init(ctx, pbPlugin.Spec{ +// WriteSpec: &pbPlugin.WriteSpec{}, +// }); err != nil { +// t.Fatal(err) +// } +// table := schema.TestTable("test", schema.TestSourceOptions{}) +// tables := schema.Tables{ +// table, +// } +// sourceName := "TestDestinationOnWriteError" +// syncTime := time.Now() +// sourceSpec := pbPlugin.Spec{ +// Name: sourceName, +// } +// ch := make(chan arrow.Record, 1) +// ctx, cancel := context.WithTimeout(ctx, 2*time.Second) +// opts := schema.GenTestDataOptions{ +// SourceName: "test", +// SyncTime: time.Now(), +// MaxRows: 1, +// StableUUID: uuid.Nil, +// } +// record := schema.GenTestData(table, opts)[0] +// ch <- record +// defer cancel() +// err := p.Write(ctx, sourceSpec, tables, syncTime, ch) +// if err != nil { +// t.Fatal(err) +// } +// } diff --git a/internal/pk/pk.go b/internal/pk/pk.go index 22b2b277db..ca8c5f2806 100644 --- a/internal/pk/pk.go +++ b/internal/pk/pk.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/schema" ) func String(resource arrow.Record) string { diff --git a/internal/servers/destination/v0/destinations.go b/internal/servers/destination/v0/destinations.go index c09b242e4c..6532d89440 100644 --- a/internal/servers/destination/v0/destinations.go +++ b/internal/servers/destination/v0/destinations.go @@ -3,17 +3,18 @@ package destination import ( "context" "encoding/json" - "fmt" "io" + "sync" - "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/memory" pbBase "github.com/cloudquery/plugin-pb-go/pb/base/v0" pb "github.com/cloudquery/plugin-pb-go/pb/destination/v0" "github.com/cloudquery/plugin-pb-go/specs" schemav2 "github.com/cloudquery/plugin-sdk/v2/schema" - "github.com/cloudquery/plugin-sdk/v3/plugins/destination" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" @@ -22,7 +23,7 @@ import ( type Server struct { pb.UnimplementedDestinationServer - Plugin *destination.Plugin + Plugin *plugin.Plugin Logger zerolog.Logger spec specs.Destination } @@ -39,7 +40,11 @@ func (s *Server) Configure(ctx context.Context, req *pbBase.Configure_Request) ( return nil, status.Errorf(codes.InvalidArgument, "failed to unmarshal spec: %v", err) } s.spec = spec - return &pbBase.Configure_Response{}, s.Plugin.Init(ctx, s.Logger, spec) + pluginSpec, err := json.Marshal(s.spec.Spec) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to marshal spec: %v", err) + } + return &pbBase.Configure_Response{}, s.Plugin.Init(ctx, pluginSpec) } func (s *Server) GetName(context.Context, *pbBase.GetName_Request) (*pbBase.GetName_Response, error) { @@ -62,8 +67,23 @@ func (s *Server) Migrate(ctx context.Context, req *pb.Migrate_Request) (*pb.Migr tables := TablesV2ToV3(tablesV2).FlattenTables() SetDestinationManagedCqColumns(tables) s.setPKsForTables(tables) - - return &pb.Migrate_Response{}, s.Plugin.Migrate(ctx, tables) + writeCh := make(chan message.Message) + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return s.Plugin.Write(ctx, plugin.WriteOptions{ + MigrateForce: s.spec.MigrateMode == specs.MigrateModeForced, + }, writeCh) + }) + for _, table := range tables { + writeCh <- &message.MigrateTable{ + Table: table, + } + } + close(writeCh) + if err := eg.Wait(); err != nil { + return nil, status.Errorf(codes.Internal, "failed to write: %v", err) + } + return &pb.Migrate_Response{}, nil } func (*Server) Write(pb.Destination_WriteServer) error { @@ -73,7 +93,7 @@ func (*Server) Write(pb.Destination_WriteServer) error { // Note the order of operations in this method is important! // Trying to insert into the `resources` channel before starting the reader goroutine will cause a deadlock. func (s *Server) Write2(msg pb.Destination_Write2Server) error { - resources := make(chan arrow.Record) + msgs := make(chan message.Message) r, err := msg.Recv() if err != nil { @@ -102,9 +122,19 @@ func (s *Server) Write2(msg pb.Destination_Write2Server) error { SetDestinationManagedCqColumns(tables) s.setPKsForTables(tables) eg, ctx := errgroup.WithContext(msg.Context()) + // sourceName := r.Source eg.Go(func() error { - return s.Plugin.Write(ctx, sourceSpec, tables, syncTime, resources) + return s.Plugin.Write(ctx, plugin.WriteOptions{ + MigrateForce: s.spec.MigrateMode == specs.MigrateModeForced, + }, msgs) }) + + for _, table := range tables { + msgs <- &message.MigrateTable{ + Table: table, + } + } + sourceColumn := &schemav2.Text{} _ = sourceColumn.Set(sourceSpec.Name) syncTimeColumn := &schemav2.Timestamptz{} @@ -113,30 +143,32 @@ func (s *Server) Write2(msg pb.Destination_Write2Server) error { for { r, err := msg.Recv() if err == io.EOF { - close(resources) + close(msgs) if err := eg.Wait(); err != nil { return status.Errorf(codes.Internal, "write failed: %v", err) } return msg.SendAndClose(&pb.Write2_Response{}) } if err != nil { - close(resources) + close(msgs) if wgErr := eg.Wait(); wgErr != nil { return status.Errorf(codes.Internal, "failed to receive msg: %v and write failed: %v", err, wgErr) } return status.Errorf(codes.Internal, "failed to receive msg: %v", err) } + var origResource schemav2.DestinationResource if err := json.Unmarshal(r.Resource, &origResource); err != nil { - close(resources) + close(msgs) if wgErr := eg.Wait(); wgErr != nil { return status.Errorf(codes.InvalidArgument, "failed to unmarshal resource: %v and write failed: %v", err, wgErr) } return status.Errorf(codes.InvalidArgument, "failed to unmarshal resource: %v", err) } + table := tables.Get(origResource.TableName) if table == nil { - close(resources) + close(msgs) if wgErr := eg.Wait(); wgErr != nil { return status.Errorf(codes.InvalidArgument, "failed to get table: %s and write failed: %v", origResource.TableName, wgErr) } @@ -148,11 +180,14 @@ func (s *Server) Write2(msg pb.Destination_Write2Server) error { origResource.Data = append([]schemav2.CQType{sourceColumn, syncTimeColumn}, origResource.Data...) } convertedResource := CQTypesToRecord(memory.DefaultAllocator, []schemav2.CQTypes{origResource.Data}, table.ToArrowSchema()) + msg := &message.Insert{ + Record: convertedResource, + } + select { - case resources <- convertedResource: + case msgs <- msg: case <-ctx.Done(): - convertedResource.Release() - close(resources) + close(msgs) if err := eg.Wait(); err != nil { return status.Errorf(codes.Internal, "Context done: %v and failed to wait for plugin: %v", ctx.Err(), err) } @@ -185,15 +220,8 @@ func SetDestinationManagedCqColumns(tables []*schema.Table) { } } -func (s *Server) GetMetrics(context.Context, *pb.GetDestinationMetrics_Request) (*pb.GetDestinationMetrics_Response, error) { - stats := s.Plugin.Metrics() - b, err := json.Marshal(stats) - if err != nil { - return nil, fmt.Errorf("failed to marshal stats: %w", err) - } - return &pb.GetDestinationMetrics_Response{ - Metrics: b, - }, nil +func (*Server) GetMetrics(context.Context, *pb.GetDestinationMetrics_Request) (*pb.GetDestinationMetrics_Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetMetrics is deprecated. Please update CLI") } func (s *Server) DeleteStale(ctx context.Context, req *pb.DeleteStale_Request) (*pb.DeleteStale_Response, error) { @@ -203,11 +231,28 @@ func (s *Server) DeleteStale(ctx context.Context, req *pb.DeleteStale_Request) ( } tables := TablesV2ToV3(tablesV2).FlattenTables() SetDestinationManagedCqColumns(tables) - if err := s.Plugin.DeleteStale(ctx, tables, req.Source, req.Timestamp.AsTime()); err != nil { - return nil, err - } - return &pb.DeleteStale_Response{}, nil + msgs := make(chan message.Message) + var writeErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + writeErr = s.Plugin.Write(ctx, plugin.WriteOptions{}, msgs) + }() + for _, table := range tables { + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(table.Columns.Index(schema.CqSourceNameColumn.Name)).(*array.StringBuilder).Append(req.Source) + bldr.Field(table.Columns.Index(schema.CqSyncTimeColumn.Name)).(*array.TimestampBuilder).AppendTime(req.Timestamp.AsTime()) + msgs <- &message.DeleteStale{ + Table: table, + SourceName: req.Source, + SyncTime: req.Timestamp.AsTime(), + } + } + close(msgs) + wg.Wait() + return &pb.DeleteStale_Response{}, writeErr } func (s *Server) setPKsForTables(tables schema.Tables) { diff --git a/internal/servers/destination/v0/schemav2tov3.go b/internal/servers/destination/v0/schemav2tov3.go index eabd37fd94..3b63448b15 100644 --- a/internal/servers/destination/v0/schemav2tov3.go +++ b/internal/servers/destination/v0/schemav2tov3.go @@ -8,8 +8,8 @@ import ( "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/memory" schemav2 "github.com/cloudquery/plugin-sdk/v2/schema" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/cloudquery/plugin-sdk/v4/types" ) func TablesV2ToV3(tables schemav2.Tables) schema.Tables { diff --git a/internal/servers/destination/v1/convert.go b/internal/servers/destination/v1/convert.go new file mode 100644 index 0000000000..7fc57f2f01 --- /dev/null +++ b/internal/servers/destination/v1/convert.go @@ -0,0 +1,32 @@ +package destination + +import ( + "bytes" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/ipc" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +// Legacy conversion functions to and from Arrow bytes. From plugin v3 onwards +// this responsibility is handled by plugin-pb-go. + +func NewFromBytes(b []byte) (*arrow.Schema, error) { + rdr, err := ipc.NewReader(bytes.NewReader(b)) + if err != nil { + return nil, err + } + return rdr.Schema(), nil +} + +func NewSchemasFromBytes(b [][]byte) (schema.Schemas, error) { + var err error + ret := make([]*arrow.Schema, len(b)) + for i, buf := range b { + ret[i], err = NewFromBytes(buf) + if err != nil { + return nil, err + } + } + return ret, nil +} diff --git a/internal/servers/destination/v1/destination_test.go b/internal/servers/destination/v1/destination_test.go new file mode 100644 index 0000000000..9d398f0599 --- /dev/null +++ b/internal/servers/destination/v1/destination_test.go @@ -0,0 +1,152 @@ +package destination + +import ( + "context" + "encoding/json" + "io" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + pb "github.com/cloudquery/plugin-pb-go/pb/destination/v1" + pbSource "github.com/cloudquery/plugin-pb-go/pb/source/v2" + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v4/internal/memdb" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +func TestGetName(t *testing.T) { + ctx := context.Background() + s := Server{ + Plugin: plugin.NewPlugin("test", "development", memdb.NewMemDBClient), + } + res, err := s.GetName(ctx, &pb.GetName_Request{}) + if err != nil { + t.Fatal(err) + } + if res.Name != "test" { + t.Fatalf("expected test, got %s", res.GetName()) + } +} + +func TestGetVersion(t *testing.T) { + ctx := context.Background() + s := Server{ + Plugin: plugin.NewPlugin("test", "development", memdb.NewMemDBClient), + } + resVersion, err := s.GetVersion(ctx, &pb.GetVersion_Request{}) + if err != nil { + t.Fatal(err) + } + if resVersion.Version != "development" { + t.Fatalf("expected development, got %s", resVersion.GetVersion()) + } +} + +type mockWriteServer struct { + grpc.ServerStream + messages []*pb.Write_Request +} + +func (*mockWriteServer) SendAndClose(*pb.Write_Response) error { + return nil +} +func (s *mockWriteServer) Recv() (*pb.Write_Request, error) { + if len(s.messages) > 0 { + msg := s.messages[0] + s.messages = s.messages[1:] + return msg, nil + } + return nil, io.EOF +} +func (*mockWriteServer) SetHeader(metadata.MD) error { + return nil +} +func (*mockWriteServer) SendHeader(metadata.MD) error { + return nil +} +func (*mockWriteServer) SetTrailer(metadata.MD) { +} +func (*mockWriteServer) Context() context.Context { + return context.Background() +} +func (*mockWriteServer) SendMsg(any) error { + return nil +} +func (*mockWriteServer) RecvMsg(any) error { + return nil +} + +func TestPluginSync(t *testing.T) { + ctx := context.Background() + s := Server{ + Plugin: plugin.NewPlugin("test", "development", memdb.NewMemDBClient), + } + destinationSpec := specs.Destination{ + Name: "test", + } + destinationSpecBytes, err := json.Marshal(destinationSpec) + if err != nil { + t.Fatal(err) + } + _, err = s.Configure(ctx, &pb.Configure_Request{ + Config: destinationSpecBytes, + }) + if err != nil { + t.Fatal(err) + } + + writeMockServer := &mockWriteServer{} + if err := s.Write(writeMockServer); err != nil { + t.Fatal(err) + } + table := &schema.Table{ + Name: "test", + Columns: []schema.Column{ + { + Name: "test", + Type: arrow.BinaryTypes.String, + }, + }, + } + schemas := schema.Tables{table}.ToArrowSchemas() + schemaBytes, err := pbSource.SchemasToBytes(schemas) + if err != nil { + t.Fatal(err) + } + sc := table.ToArrowSchema() + bldr := array.NewRecordBuilder(memory.DefaultAllocator, sc) + bldr.Field(0).(*array.StringBuilder).Append("test") + record := bldr.NewRecord() + recordBytes, err := pbSource.RecordToBytes(record) + if err != nil { + t.Fatal(err) + } + + sourceSpec := specs.Source{ + Name: "source_test", + } + sourceSpecBytes, err := json.Marshal(sourceSpec) + if err != nil { + t.Fatal(err) + } + + writeMockServer.messages = []*pb.Write_Request{ + { + Tables: schemaBytes, + Resource: recordBytes, + SourceSpec: sourceSpecBytes, + }, + } + if err := s.Write(writeMockServer); err != nil { + t.Fatal(err) + } + + if _, err := s.Close(ctx, &pb.Close_Request{}); err != nil { + t.Fatal(err) + } +} diff --git a/internal/servers/destination/v1/destinations.go b/internal/servers/destination/v1/destinations.go index 447c03b596..fc9b688800 100644 --- a/internal/servers/destination/v1/destinations.go +++ b/internal/servers/destination/v1/destinations.go @@ -4,15 +4,17 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" + "sync" - "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/ipc" + "github.com/apache/arrow/go/v13/arrow/memory" pb "github.com/cloudquery/plugin-pb-go/pb/destination/v1" "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/plugins/destination" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" @@ -21,9 +23,10 @@ import ( type Server struct { pb.UnimplementedDestinationServer - Plugin *destination.Plugin - Logger zerolog.Logger - spec specs.Destination + Plugin *plugin.Plugin + Logger zerolog.Logger + spec specs.Destination + migrateMode plugin.MigrateMode } func (s *Server) Configure(ctx context.Context, req *pb.Configure_Request) (*pb.Configure_Response, error) { @@ -32,7 +35,11 @@ func (s *Server) Configure(ctx context.Context, req *pb.Configure_Request) (*pb. return nil, status.Errorf(codes.InvalidArgument, "failed to unmarshal spec: %v", err) } s.spec = spec - return &pb.Configure_Response{}, s.Plugin.Init(ctx, s.Logger, spec) + pluginSpec, err := json.Marshal(s.spec.Spec) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to marshal spec: %v", err) + } + return &pb.Configure_Response{}, s.Plugin.Init(ctx, pluginSpec) } func (s *Server) GetName(context.Context, *pb.GetName_Request) (*pb.GetName_Response, error) { @@ -48,7 +55,7 @@ func (s *Server) GetVersion(context.Context, *pb.GetVersion_Request) (*pb.GetVer } func (s *Server) Migrate(ctx context.Context, req *pb.Migrate_Request) (*pb.Migrate_Response, error) { - schemas, err := schema.NewSchemasFromBytes(req.Tables) + schemas, err := NewSchemasFromBytes(req.Tables) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to create schemas: %v", err) } @@ -58,13 +65,29 @@ func (s *Server) Migrate(ctx context.Context, req *pb.Migrate_Request) (*pb.Migr } s.setPKsForTables(tables) - return &pb.Migrate_Response{}, s.Plugin.Migrate(ctx, tables) + writeCh := make(chan message.Message) + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return s.Plugin.Write(ctx, plugin.WriteOptions{ + MigrateForce: s.migrateMode == plugin.MigrateModeForce, + }, writeCh) + }) + for _, table := range tables { + writeCh <- &message.MigrateTable{ + Table: table, + } + } + close(writeCh) + if err := eg.Wait(); err != nil { + return nil, status.Errorf(codes.Internal, "failed to write: %v", err) + } + return &pb.Migrate_Response{}, nil } // Note the order of operations in this method is important! // Trying to insert into the `resources` channel before starting the reader goroutine will cause a deadlock. func (s *Server) Write(msg pb.Destination_WriteServer) error { - resources := make(chan arrow.Record) + msgs := make(chan message.Message) r, err := msg.Recv() if err != nil { @@ -74,7 +97,7 @@ func (s *Server) Write(msg pb.Destination_WriteServer) error { return status.Errorf(codes.Internal, "failed to receive msg: %v", err) } - schemas, err := schema.NewSchemasFromBytes(r.Tables) + schemas, err := NewSchemasFromBytes(r.Tables) if err != nil { return status.Errorf(codes.InvalidArgument, "failed to create schemas: %v", err) } @@ -93,24 +116,32 @@ func (s *Server) Write(msg pb.Destination_WriteServer) error { return status.Errorf(codes.InvalidArgument, "failed to unmarshal source spec: %v", err) } } - syncTime := r.Timestamp.AsTime() s.setPKsForTables(tables) eg, ctx := errgroup.WithContext(msg.Context()) + eg.Go(func() error { - return s.Plugin.Write(ctx, sourceSpec, tables, syncTime, resources) + return s.Plugin.Write(ctx, plugin.WriteOptions{ + MigrateForce: s.spec.MigrateMode == specs.MigrateModeForced, + }, msgs) }) + for _, table := range tables { + msgs <- &message.MigrateTable{ + Table: table, + } + } + for { r, err := msg.Recv() if err == io.EOF { - close(resources) + close(msgs) if err := eg.Wait(); err != nil { return status.Errorf(codes.Internal, "write failed: %v", err) } return msg.SendAndClose(&pb.Write_Response{}) } if err != nil { - close(resources) + close(msgs) if wgErr := eg.Wait(); wgErr != nil { return status.Errorf(codes.Internal, "failed to receive msg: %v and write failed: %v", err, wgErr) } @@ -118,7 +149,7 @@ func (s *Server) Write(msg pb.Destination_WriteServer) error { } rdr, err := ipc.NewReader(bytes.NewReader(r.Resource)) if err != nil { - close(resources) + close(msgs) if wgErr := eg.Wait(); wgErr != nil { return status.Errorf(codes.InvalidArgument, "failed to create reader: %v and write failed: %v", err, wgErr) } @@ -127,10 +158,13 @@ func (s *Server) Write(msg pb.Destination_WriteServer) error { for rdr.Next() { rec := rdr.Record() rec.Retain() + msg := &message.Insert{ + Record: rec, + } select { - case resources <- rec: + case msgs <- msg: case <-ctx.Done(): - close(resources) + close(msgs) if err := eg.Wait(); err != nil { return status.Errorf(codes.Internal, "Context done: %v and failed to wait for plugin: %v", ctx.Err(), err) } @@ -152,19 +186,12 @@ func setCQIDAsPrimaryKeysForTables(tables schema.Tables) { } } -func (s *Server) GetMetrics(context.Context, *pb.GetDestinationMetrics_Request) (*pb.GetDestinationMetrics_Response, error) { - stats := s.Plugin.Metrics() - b, err := json.Marshal(stats) - if err != nil { - return nil, fmt.Errorf("failed to marshal stats: %w", err) - } - return &pb.GetDestinationMetrics_Response{ - Metrics: b, - }, nil +func (*Server) GetMetrics(context.Context, *pb.GetDestinationMetrics_Request) (*pb.GetDestinationMetrics_Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetMetrics is deprecated. please upgrade CLI") } func (s *Server) DeleteStale(ctx context.Context, req *pb.DeleteStale_Request) (*pb.DeleteStale_Response, error) { - schemas, err := schema.NewSchemasFromBytes(req.Tables) + schemas, err := NewSchemasFromBytes(req.Tables) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to create schemas: %v", err) } @@ -173,11 +200,27 @@ func (s *Server) DeleteStale(ctx context.Context, req *pb.DeleteStale_Request) ( return nil, status.Errorf(codes.InvalidArgument, "failed to create tables: %v", err) } - if err := s.Plugin.DeleteStale(ctx, tables, req.Source, req.Timestamp.AsTime()); err != nil { - return nil, err + msgs := make(chan message.Message) + var writeErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + writeErr = s.Plugin.Write(ctx, plugin.WriteOptions{}, msgs) + }() + for _, table := range tables { + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(table.Columns.Index(schema.CqSourceNameColumn.Name)).(*array.StringBuilder).Append(req.Source) + bldr.Field(table.Columns.Index(schema.CqSyncTimeColumn.Name)).(*array.TimestampBuilder).AppendTime(req.Timestamp.AsTime()) + msgs <- &message.DeleteStale{ + Table: table, + SourceName: req.Source, + SyncTime: req.Timestamp.AsTime(), + } } - - return &pb.DeleteStale_Response{}, nil + close(msgs) + wg.Wait() + return &pb.DeleteStale_Response{}, writeErr } func (s *Server) setPKsForTables(tables schema.Tables) { diff --git a/internal/servers/discovery/v0/discovery_test.go b/internal/servers/discovery/v0/discovery_test.go new file mode 100644 index 0000000000..0eaab884d7 --- /dev/null +++ b/internal/servers/discovery/v0/discovery_test.go @@ -0,0 +1,28 @@ +package discovery + +import ( + "context" + "testing" + + pb "github.com/cloudquery/plugin-pb-go/pb/discovery/v0" +) + +func TestDiscovery(t *testing.T) { + ctx := context.Background() + s := &Server{ + Versions: []string{"1", "2"}, + } + resp, err := s.GetVersions(ctx, &pb.GetVersions_Request{}) + if err != nil { + t.Fatal(err) + } + if len(resp.Versions) != 2 { + t.Fatal("expected 2 versions") + } + if resp.Versions[0] != "1" { + t.Fatal("expected version 1") + } + if resp.Versions[1] != "2" { + t.Fatal("expected version 2") + } +} diff --git a/internal/servers/discovery/v1/discovery.go b/internal/servers/discovery/v1/discovery.go new file mode 100644 index 0000000000..896e8a9cea --- /dev/null +++ b/internal/servers/discovery/v1/discovery.go @@ -0,0 +1,16 @@ +package discovery + +import ( + "context" + + pb "github.com/cloudquery/plugin-pb-go/pb/discovery/v1" +) + +type Server struct { + pb.UnimplementedDiscoveryServer + Versions []int32 +} + +func (s *Server) GetVersions(context.Context, *pb.GetVersions_Request) (*pb.GetVersions_Response, error) { + return &pb.GetVersions_Response{Versions: s.Versions}, nil +} diff --git a/internal/servers/discovery/v1/discovery_test.go b/internal/servers/discovery/v1/discovery_test.go new file mode 100644 index 0000000000..a54b24c746 --- /dev/null +++ b/internal/servers/discovery/v1/discovery_test.go @@ -0,0 +1,28 @@ +package discovery + +import ( + "context" + "testing" + + pb "github.com/cloudquery/plugin-pb-go/pb/discovery/v1" +) + +func TestDiscovery(t *testing.T) { + ctx := context.Background() + s := &Server{ + Versions: []int32{1, 2}, + } + resp, err := s.GetVersions(ctx, &pb.GetVersions_Request{}) + if err != nil { + t.Fatal(err) + } + if len(resp.Versions) != 2 { + t.Fatal("expected 2 versions") + } + if resp.Versions[0] != 1 { + t.Fatal("expected version 1") + } + if resp.Versions[1] != 2 { + t.Fatal("expected version 2") + } +} diff --git a/internal/servers/plugin/v3/plugin.go b/internal/servers/plugin/v3/plugin.go new file mode 100644 index 0000000000..b424b0bd26 --- /dev/null +++ b/internal/servers/plugin/v3/plugin.go @@ -0,0 +1,239 @@ +package plugin + +import ( + "context" + "fmt" + "io" + + pb "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const MaxMsgSize = 100 * 1024 * 1024 // 100 MiB + +type Server struct { + pb.UnimplementedPluginServer + Plugin *plugin.Plugin + Logger zerolog.Logger + Directory string + NoSentry bool +} + +func (s *Server) GetTables(ctx context.Context, _ *pb.GetTables_Request) (*pb.GetTables_Response, error) { + tables, err := s.Plugin.Tables(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get tables: %v", err) + } + schemas := tables.ToArrowSchemas() + encoded := make([][]byte, len(schemas)) + for i, sc := range schemas { + encoded[i], err = pb.SchemaToBytes(sc) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to encode tables: %v", err) + } + } + return &pb.GetTables_Response{ + Tables: encoded, + }, nil +} + +func (s *Server) GetName(context.Context, *pb.GetName_Request) (*pb.GetName_Response, error) { + return &pb.GetName_Response{ + Name: s.Plugin.Name(), + }, nil +} + +func (s *Server) GetVersion(context.Context, *pb.GetVersion_Request) (*pb.GetVersion_Response, error) { + return &pb.GetVersion_Response{ + Version: s.Plugin.Version(), + }, nil +} + +func (s *Server) Init(ctx context.Context, req *pb.Init_Request) (*pb.Init_Response, error) { + if err := s.Plugin.Init(ctx, req.Spec); err != nil { + return nil, status.Errorf(codes.Internal, "failed to init plugin: %v", err) + } + return &pb.Init_Response{}, nil +} + +func (s *Server) Sync(req *pb.Sync_Request, stream pb.Plugin_SyncServer) error { + msgs := make(chan message.Message) + var syncErr error + ctx := stream.Context() + + syncOptions := plugin.SyncOptions{ + Tables: req.Tables, + SkipTables: req.SkipTables, + } + + go func() { + defer close(msgs) + err := s.Plugin.Sync(ctx, syncOptions, msgs) + if err != nil { + syncErr = fmt.Errorf("failed to sync records: %w", err) + } + }() + + for msg := range msgs { + pbMsg := &pb.Sync_Response{} + switch m := msg.(type) { + case *message.MigrateTable: + tableSchema := m.Table.ToArrowSchema() + schemaBytes, err := pb.SchemaToBytes(tableSchema) + if err != nil { + return status.Errorf(codes.Internal, "failed to encode table schema: %v", err) + } + pbMsg.Message = &pb.Sync_Response_MigrateTable{ + MigrateTable: &pb.MessageMigrateTable{ + Table: schemaBytes, + }, + } + + case *message.Insert: + recordBytes, err := pb.RecordToBytes(m.Record) + if err != nil { + return status.Errorf(codes.Internal, "failed to encode record: %v", err) + } + pbMsg.Message = &pb.Sync_Response_Insert{ + Insert: &pb.MessageInsert{ + Record: recordBytes, + }, + } + case *message.DeleteStale: + sc := m.Table.ToArrowSchema() + tableBytes, err := pb.SchemaToBytes(sc) + if err != nil { + return status.Errorf(codes.Internal, "failed to encode record: %v", err) + } + pbMsg.Message = &pb.Sync_Response_Delete{ + Delete: &pb.MessageDeleteStale{ + Table: tableBytes, + SourceName: m.SourceName, + SyncTime: timestamppb.New(m.SyncTime), + }, + } + default: + return status.Errorf(codes.Internal, "unknown message type: %T", msg) + } + + size := proto.Size(pbMsg) + if size > MaxMsgSize { + s.Logger.Error().Int("bytes", size).Msg("Message exceeds max size") + continue + } + if err := stream.Send(pbMsg); err != nil { + return status.Errorf(codes.Internal, "failed to send message: %v", err) + } + } + + return syncErr +} + +func (s *Server) Write(msg pb.Plugin_WriteServer) error { + msgs := make(chan message.Message) + r, err := msg.Recv() + if err != nil { + return status.Errorf(codes.Internal, "failed to receive msg: %v", err) + } + pbWriteOptions, ok := r.Message.(*pb.Write_Request_Options) + if !ok { + return status.Errorf(codes.Internal, "expected options message, got %T", r.Message) + } + eg, ctx := errgroup.WithContext(msg.Context()) + eg.Go(func() error { + return s.Plugin.Write(ctx, plugin.WriteOptions{ + MigrateForce: pbWriteOptions.Options.MigrateForce, + }, msgs) + }) + + for { + r, err := msg.Recv() + if err == io.EOF { + close(msgs) + if err := eg.Wait(); err != nil { + return status.Errorf(codes.Internal, "write failed: %v", err) + } + return msg.SendAndClose(&pb.Write_Response{}) + } + if err != nil { + close(msgs) + if wgErr := eg.Wait(); wgErr != nil { + return status.Errorf(codes.Internal, "failed to receive msg: %v and write failed: %v", err, wgErr) + } + return status.Errorf(codes.Internal, "failed to receive msg: %v", err) + } + var pluginMessage message.Message + var pbMsgConvertErr error + switch pbMsg := r.Message.(type) { + case *pb.Write_Request_MigrateTable: + sc, err := pb.NewSchemaFromBytes(pbMsg.MigrateTable.Table) + if err != nil { + pbMsgConvertErr = status.Errorf(codes.InvalidArgument, "failed to create schema from bytes: %v", err) + break + } + table, err := schema.NewTableFromArrowSchema(sc) + if err != nil { + pbMsgConvertErr = status.Errorf(codes.InvalidArgument, "failed to create table from schema: %v", err) + break + } + pluginMessage = &message.MigrateTable{ + Table: table, + } + case *pb.Write_Request_Insert: + record, err := pb.NewRecordFromBytes(pbMsg.Insert.Record) + if err != nil { + pbMsgConvertErr = status.Errorf(codes.InvalidArgument, "failed to create record: %v", err) + break + } + pluginMessage = &message.Insert{ + Record: record, + } + case *pb.Write_Request_Delete: + sc, err := pb.NewSchemaFromBytes(pbMsg.Delete.Table) + if err != nil { + pbMsgConvertErr = status.Errorf(codes.InvalidArgument, "failed to create schema from bytes: %v", err) + break + } + table, err := schema.NewTableFromArrowSchema(sc) + if err != nil { + pbMsgConvertErr = status.Errorf(codes.InvalidArgument, "failed to create table from schema: %v", err) + break + } + pluginMessage = &message.DeleteStale{ + Table: table, + SourceName: pbMsg.Delete.SourceName, + SyncTime: pbMsg.Delete.SyncTime.AsTime(), + } + } + + if pbMsgConvertErr != nil { + close(msgs) + if wgErr := eg.Wait(); wgErr != nil { + return status.Errorf(codes.Internal, "failed to convert message: %v and write failed: %v", pbMsgConvertErr, wgErr) + } + return pbMsgConvertErr + } + + select { + case msgs <- pluginMessage: + case <-ctx.Done(): + close(msgs) + if err := eg.Wait(); err != nil { + return status.Errorf(codes.Internal, "Context done: %v and failed to wait for plugin: %v", ctx.Err(), err) + } + return status.Errorf(codes.Internal, "Context done: %v", ctx.Err()) + } + } +} + +func (s *Server) Close(ctx context.Context, _ *pb.Close_Request) (*pb.Close_Response, error) { + return &pb.Close_Response{}, s.Plugin.Close(ctx) +} diff --git a/internal/servers/plugin/v3/plugin_test.go b/internal/servers/plugin/v3/plugin_test.go new file mode 100644 index 0000000000..4b03fc53dd --- /dev/null +++ b/internal/servers/plugin/v3/plugin_test.go @@ -0,0 +1,193 @@ +package plugin + +import ( + "context" + "io" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + pb "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "github.com/cloudquery/plugin-sdk/v4/internal/memdb" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +func TestGetName(t *testing.T) { + ctx := context.Background() + s := Server{ + Plugin: plugin.NewPlugin("test", "development", memdb.NewMemDBClient), + } + res, err := s.GetName(ctx, &pb.GetName_Request{}) + if err != nil { + t.Fatal(err) + } + if res.Name != "test" { + t.Fatalf("expected test, got %s", res.GetName()) + } +} + +func TestGetVersion(t *testing.T) { + ctx := context.Background() + s := Server{ + Plugin: plugin.NewPlugin("test", "development", memdb.NewMemDBClient), + } + resVersion, err := s.GetVersion(ctx, &pb.GetVersion_Request{}) + if err != nil { + t.Fatal(err) + } + if resVersion.Version != "development" { + t.Fatalf("expected development, got %s", resVersion.GetVersion()) + } +} + +type mockSyncServer struct { + grpc.ServerStream + messages []*pb.Sync_Response +} + +func (s *mockSyncServer) Send(*pb.Sync_Response) error { + s.messages = append(s.messages, &pb.Sync_Response{}) + return nil +} + +func (*mockSyncServer) SetHeader(metadata.MD) error { + return nil +} +func (*mockSyncServer) SendHeader(metadata.MD) error { + return nil +} +func (*mockSyncServer) SetTrailer(metadata.MD) { +} +func (*mockSyncServer) Context() context.Context { + return context.Background() +} +func (*mockSyncServer) SendMsg(any) error { + return nil +} +func (*mockSyncServer) RecvMsg(any) error { + return nil +} + +type mockWriteServer struct { + grpc.ServerStream + messages []*pb.Write_Request +} + +func (*mockWriteServer) SendAndClose(*pb.Write_Response) error { + return nil +} +func (s *mockWriteServer) Recv() (*pb.Write_Request, error) { + if len(s.messages) > 0 { + msg := s.messages[0] + s.messages = s.messages[1:] + return msg, nil + } + return nil, io.EOF +} +func (*mockWriteServer) SetHeader(metadata.MD) error { + return nil +} +func (*mockWriteServer) SendHeader(metadata.MD) error { + return nil +} +func (*mockWriteServer) SetTrailer(metadata.MD) { +} +func (*mockWriteServer) Context() context.Context { + return context.Background() +} +func (*mockWriteServer) SendMsg(any) error { + return nil +} +func (*mockWriteServer) RecvMsg(any) error { + return nil +} + +func TestPluginSync(t *testing.T) { + ctx := context.Background() + s := Server{ + Plugin: plugin.NewPlugin("test", "development", memdb.NewMemDBClient), + } + + _, err := s.Init(ctx, &pb.Init_Request{}) + if err != nil { + t.Fatal(err) + } + + streamSyncServer := &mockSyncServer{} + if err := s.Sync(&pb.Sync_Request{}, streamSyncServer); err != nil { + t.Fatal(err) + } + if len(streamSyncServer.messages) != 0 { + t.Fatalf("expected 0 messages, got %d", len(streamSyncServer.messages)) + } + writeMockServer := &mockWriteServer{} + + if err := s.Write(writeMockServer); err == nil { + t.Fatal("expected error, got nil") + } + table := &schema.Table{ + Name: "test", + Columns: []schema.Column{ + { + Name: "test", + Type: arrow.BinaryTypes.String, + }, + }, + } + sc := table.ToArrowSchema() + b, err := pb.SchemaToBytes(sc) + if err != nil { + t.Fatal(err) + } + bldr := array.NewRecordBuilder(memory.DefaultAllocator, sc) + bldr.Field(0).(*array.StringBuilder).Append("test") + record := bldr.NewRecord() + recordBytes, err := pb.RecordToBytes(record) + if err != nil { + t.Fatal(err) + } + + writeMockServer.messages = []*pb.Write_Request{ + { + Message: &pb.Write_Request_Options{ + Options: &pb.WriteOptions{}, + }, + }, + { + Message: &pb.Write_Request_MigrateTable{ + MigrateTable: &pb.MessageMigrateTable{ + Table: b, + }, + }, + }, + { + Message: &pb.Write_Request_Insert{ + Insert: &pb.MessageInsert{ + Record: recordBytes, + }, + }, + }, + } + + if err := s.Write(writeMockServer); err != nil { + t.Fatal(err) + } + + streamSyncServer = &mockSyncServer{} + if err := s.Sync(&pb.Sync_Request{ + Tables: []string{"*"}, + }, streamSyncServer); err != nil { + t.Fatal(err) + } + if len(streamSyncServer.messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(streamSyncServer.messages)) + } + + if _, err := s.Close(ctx, &pb.Close_Request{}); err != nil { + t.Fatal(err) + } +} diff --git a/internal/servers/source/v2/source.go b/internal/servers/source/v2/source.go deleted file mode 100644 index a010fefef3..0000000000 --- a/internal/servers/source/v2/source.go +++ /dev/null @@ -1,173 +0,0 @@ -package source - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - - "github.com/apache/arrow/go/v13/arrow/array" - "github.com/apache/arrow/go/v13/arrow/ipc" - "github.com/apache/arrow/go/v13/arrow/memory" - pb "github.com/cloudquery/plugin-pb-go/pb/source/v2" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/plugins/source" - "github.com/cloudquery/plugin-sdk/v3/scalar" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/getsentry/sentry-go" - "github.com/rs/zerolog" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" -) - -const MaxMsgSize = 100 * 1024 * 1024 // 100 MiB - -type Server struct { - pb.UnimplementedSourceServer - Plugin *source.Plugin - Logger zerolog.Logger -} - -func (s *Server) GetTables(context.Context, *pb.GetTables_Request) (*pb.GetTables_Response, error) { - tables := s.Plugin.Tables().ToArrowSchemas() - encoded, err := tables.Encode() - if err != nil { - return nil, fmt.Errorf("failed to encode tables: %w", err) - } - return &pb.GetTables_Response{ - Tables: encoded, - }, nil -} - -func (s *Server) GetDynamicTables(context.Context, *pb.GetDynamicTables_Request) (*pb.GetDynamicTables_Response, error) { - tables := s.Plugin.GetDynamicTables().ToArrowSchemas() - encoded, err := tables.Encode() - if err != nil { - return nil, fmt.Errorf("failed to encode tables: %w", err) - } - return &pb.GetDynamicTables_Response{ - Tables: encoded, - }, nil -} - -func (s *Server) GetName(context.Context, *pb.GetName_Request) (*pb.GetName_Response, error) { - return &pb.GetName_Response{ - Name: s.Plugin.Name(), - }, nil -} - -func (s *Server) GetVersion(context.Context, *pb.GetVersion_Request) (*pb.GetVersion_Response, error) { - return &pb.GetVersion_Response{ - Version: s.Plugin.Version(), - }, nil -} - -func (s *Server) Init(ctx context.Context, req *pb.Init_Request) (*pb.Init_Response, error) { - var spec specs.Source - dec := json.NewDecoder(bytes.NewReader(req.Spec)) - dec.UseNumber() - // TODO: warn about unknown fields - if err := dec.Decode(&spec); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to decode spec: %v", err) - } - - if err := s.Plugin.Init(ctx, spec); err != nil { - return nil, status.Errorf(codes.Internal, "failed to init plugin: %v", err) - } - return &pb.Init_Response{}, nil -} - -func (s *Server) Sync(req *pb.Sync_Request, stream pb.Source_SyncServer) error { - resources := make(chan *schema.Resource) - var syncErr error - ctx := stream.Context() - - go func() { - defer close(resources) - err := s.Plugin.Sync(ctx, req.SyncTime.AsTime(), resources) - if err != nil { - syncErr = fmt.Errorf("failed to sync resources: %w", err) - } - }() - - for resource := range resources { - vector := resource.GetValues() - bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema()) - scalar.AppendToRecordBuilder(bldr, vector) - rec := bldr.NewRecord() - - var buf bytes.Buffer - w := ipc.NewWriter(&buf, ipc.WithSchema(rec.Schema())) - if err := w.Write(rec); err != nil { - return status.Errorf(codes.Internal, "failed to write record: %v", err) - } - if err := w.Close(); err != nil { - return status.Errorf(codes.Internal, "failed to close writer: %v", err) - } - - msg := &pb.Sync_Response{ - Resource: buf.Bytes(), - } - err := checkMessageSize(msg, resource) - if err != nil { - s.Logger.Warn().Str("table", resource.Table.Name). - Int("bytes", len(msg.String())). - Msg("Row exceeding max bytes ignored") - continue - } - if err := stream.Send(msg); err != nil { - return status.Errorf(codes.Internal, "failed to send resource: %v", err) - } - } - - return syncErr -} - -func (s *Server) GetMetrics(context.Context, *pb.GetMetrics_Request) (*pb.GetMetrics_Response, error) { - // Aggregate metrics before sending to keep response size small. - // Temporary fix for https://github.com/cloudquery/cloudquery/issues/3962 - m := s.Plugin.Metrics() - agg := &source.TableClientMetrics{} - for _, table := range m.TableClient { - for _, tableClient := range table { - agg.Resources += tableClient.Resources - agg.Errors += tableClient.Errors - agg.Panics += tableClient.Panics - } - } - b, err := json.Marshal(&source.Metrics{ - TableClient: map[string]map[string]*source.TableClientMetrics{"": {"": agg}}, - }) - if err != nil { - return nil, fmt.Errorf("failed to marshal source metrics: %w", err) - } - return &pb.GetMetrics_Response{ - Metrics: b, - }, nil -} - -func (s *Server) GenDocs(_ context.Context, req *pb.GenDocs_Request) (*pb.GenDocs_Response, error) { - err := s.Plugin.GeneratePluginDocs(req.Path, req.Format.String()) - if err != nil { - return nil, fmt.Errorf("failed to generate docs: %w", err) - } - return &pb.GenDocs_Response{}, nil -} - -func checkMessageSize(msg proto.Message, resource *schema.Resource) error { - size := proto.Size(msg) - // log error to Sentry if row exceeds half of the max size - if size > MaxMsgSize/2 { - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetTag("table", resource.Table.Name) - scope.SetExtra("bytes", size) - sentry.CurrentHub().CaptureMessage("Large message detected") - }) - } - if size > MaxMsgSize { - return errors.New("message exceeds max size") - } - return nil -} diff --git a/message/message.go b/message/message.go new file mode 100644 index 0000000000..8377cc7777 --- /dev/null +++ b/message/message.go @@ -0,0 +1,110 @@ +package message + +import ( + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +type Message interface { + GetTable() *schema.Table +} + +type MigrateTable struct { + Table *schema.Table +} + +func (m MigrateTable) GetTable() *schema.Table { + return m.Table +} + +type Insert struct { + Record arrow.Record +} + +func (m *Insert) GetTable() *schema.Table { + table, err := schema.NewTableFromArrowSchema(m.Record.Schema()) + if err != nil { + panic(err) + } + return table +} + +// DeleteStale is a pretty specific message which requires the destination to be aware of a CLI use-case +// thus it might be deprecated in the future +// in favour of MessageDelete or MessageRawQuery +// The message indeciates that the destination needs to run something like "DELETE FROM table WHERE _cq_source_name=$1 and sync_time < $2" +type DeleteStale struct { + Table *schema.Table + SourceName string + SyncTime time.Time +} + +func (m DeleteStale) GetTable() *schema.Table { + return m.Table +} + +type Messages []Message + +type MigrateTables []*MigrateTable + +type Inserts []*Insert + +func (messages Messages) InsertItems() int64 { + items := int64(0) + for _, msg := range messages { + if m, ok := msg.(*Insert); ok { + items += m.Record.NumRows() + } + } + return items +} + +func (messages Messages) InsertMessage() Inserts { + inserts := []*Insert{} + for _, msg := range messages { + if m, ok := msg.(*Insert); ok { + inserts = append(inserts, m) + } + } + return inserts +} + +func (m MigrateTables) Exists(tableName string) bool { + for _, table := range m { + if table.Table.Name == tableName { + return true + } + } + return false +} + +func (m Inserts) Exists(tableName string) bool { + for _, insert := range m { + md := insert.Record.Schema().Metadata() + tableNameMeta, ok := md.GetValue(schema.MetadataTableName) + if !ok { + continue + } + if tableNameMeta == tableName { + return true + } + } + return false +} + +func (m Inserts) GetRecordsForTable(table *schema.Table) []arrow.Record { + res := []arrow.Record{} + for _, insert := range m { + md := insert.Record.Schema().Metadata() + tableNameMeta, ok := md.GetValue(schema.MetadataTableName) + if !ok { + continue + } + if tableNameMeta == table.Name { + res = append(res, insert.Record) + } + } + return res +} diff --git a/plugins/destination/diff.go b/plugin/diff.go similarity index 85% rename from plugins/destination/diff.go rename to plugin/diff.go index dc3c555ce0..a5e532a9fe 100644 --- a/plugins/destination/diff.go +++ b/plugin/diff.go @@ -1,4 +1,4 @@ -package destination +package plugin import ( "fmt" @@ -31,7 +31,3 @@ func RecordDiff(l, r arrow.Record) string { } return sb.String() } - -func recordApproxEqual(l, r arrow.Record) bool { - return array.RecordApproxEqual(l, r, array.WithUnorderedMapKeys(true)) -} diff --git a/plugin/options.go b/plugin/options.go new file mode 100644 index 0000000000..966f692e60 --- /dev/null +++ b/plugin/options.go @@ -0,0 +1,18 @@ +package plugin + +type MigrateMode int + +const ( + MigrateModeSafe MigrateMode = iota + MigrateModeForce +) + +var ( + migrateModeStrings = []string{"safe", "force"} +) + +func (m MigrateMode) String() string { + return migrateModeStrings[m] +} + +type Option func(*Plugin) diff --git a/plugin/plugin.go b/plugin/plugin.go new file mode 100644 index 0000000000..8bc04e516e --- /dev/null +++ b/plugin/plugin.go @@ -0,0 +1,130 @@ +package plugin + +import ( + "context" + "fmt" + "sync" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/rs/zerolog" +) + +var ErrNotImplemented = fmt.Errorf("not implemented") + +type NewClientFunc func(context.Context, zerolog.Logger, []byte) (Client, error) + +type Client interface { + SourceClient + DestinationClient +} + +type UnimplementedDestination struct{} + +func (UnimplementedDestination) Write(context.Context, WriteOptions, <-chan message.Message) error { + return ErrNotImplemented +} + +func (UnimplementedDestination) Read(context.Context, *schema.Table, chan<- arrow.Record) error { + return fmt.Errorf("not implemented") +} + +type UnimplementedSource struct{} + +func (UnimplementedSource) Sync(context.Context, SyncOptions, chan<- message.Message) error { + return ErrNotImplemented +} + +func (UnimplementedSource) Tables(context.Context) (schema.Tables, error) { + return nil, ErrNotImplemented +} + +// Plugin is the base structure required to pass to sdk.serve +// We take a declarative approach to API here similar to Cobra +type Plugin struct { + // Name of plugin i.e aws,gcp, azure etc' + name string + // Version of the plugin + version string + // Called upon init call to validate and init configuration + newClient NewClientFunc + // Logger to call, this logger is passed to the serve.Serve Client, if not defined Serve will create one instead. + logger zerolog.Logger + // mu is a mutex that limits the number of concurrent init/syncs (can only be one at a time) + mu sync.Mutex + // client is the initialized session client + client Client + // spec is the spec the client was initialized with + spec any + // NoInternalColumns if set to true will not add internal columns to tables such as _cq_id and _cq_parent_id + // useful for sources such as PostgreSQL and other databases + internalColumns bool +} + +// NewPlugin returns a new CloudQuery Plugin with the given name, version and implementation. +// Depending on the options, it can be a write-only plugin, read-only plugin, or both. +func NewPlugin(name string, version string, newClient NewClientFunc, options ...Option) *Plugin { + p := Plugin{ + name: name, + version: version, + internalColumns: true, + newClient: newClient, + } + for _, opt := range options { + opt(&p) + } + return &p +} + +// Name return the name of this plugin +func (p *Plugin) Name() string { + return p.name +} + +// Version returns the version of this plugin +func (p *Plugin) Version() string { + return p.version +} + +func (p *Plugin) SetLogger(logger zerolog.Logger) { + p.logger = logger.With().Str("module", p.name+"-src").Logger() +} + +func (p *Plugin) Tables(ctx context.Context) (schema.Tables, error) { + if p.client == nil { + return nil, fmt.Errorf("plugin not initialized") + } + tables, err := p.client.Tables(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tables: %w", err) + } + return tables, nil +} + +// Init initializes the plugin with the given spec. +func (p *Plugin) Init(ctx context.Context, spec []byte) error { + if !p.mu.TryLock() { + return fmt.Errorf("plugin already in use") + } + defer p.mu.Unlock() + var err error + p.client, err = p.newClient(ctx, p.logger, spec) + if err != nil { + return fmt.Errorf("failed to initialize client: %w", err) + } + p.spec = spec + + return nil +} + +func (p *Plugin) Close(ctx context.Context) error { + if !p.mu.TryLock() { + return fmt.Errorf("plugin already in use") + } + defer p.mu.Unlock() + if p.client == nil { + return nil + } + return p.client.Close(ctx) +} diff --git a/plugin/plugin_destination.go b/plugin/plugin_destination.go new file mode 100644 index 0000000000..2a1871152d --- /dev/null +++ b/plugin/plugin_destination.go @@ -0,0 +1,43 @@ +package plugin + +import ( + "context" + "fmt" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +type WriteOptions struct { + MigrateForce bool +} + +type DestinationClient interface { + Close(ctx context.Context) error + Read(ctx context.Context, table *schema.Table, res chan<- arrow.Record) error + Write(ctx context.Context, options WriteOptions, res <-chan message.Message) error +} + +// writeOne is currently used mostly for testing, so it's not a public api +func (p *Plugin) writeOne(ctx context.Context, options WriteOptions, resource message.Message) error { + resources := []message.Message{resource} + return p.WriteAll(ctx, options, resources) +} + +// WriteAll is currently used mostly for testing, so it's not a public api +func (p *Plugin) WriteAll(ctx context.Context, options WriteOptions, resources []message.Message) error { + ch := make(chan message.Message, len(resources)) + for _, resource := range resources { + ch <- resource + } + close(ch) + return p.Write(ctx, options, ch) +} + +func (p *Plugin) Write(ctx context.Context, options WriteOptions, res <-chan message.Message) error { + if p.client == nil { + return fmt.Errorf("plugin is not initialized. call Init first") + } + return p.client.Write(ctx, options, res) +} diff --git a/plugin/plugin_source.go b/plugin/plugin_source.go new file mode 100644 index 0000000000..e5cdf1ad9b --- /dev/null +++ b/plugin/plugin_source.go @@ -0,0 +1,108 @@ +package plugin + +import ( + "context" + "fmt" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v4/glob" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/rs/zerolog" +) + +type SyncOptions struct { + Tables []string + SkipTables []string + DeterministicCQID bool +} + +type SourceClient interface { + Close(ctx context.Context) error + Tables(ctx context.Context) (schema.Tables, error) + Sync(ctx context.Context, options SyncOptions, res chan<- message.Message) error +} + +func MatchesTable(name string, includeTablesPattern []string, skipTablesPattern []string) bool { + for _, pattern := range skipTablesPattern { + if glob.Glob(pattern, name) { + return false + } + } + for _, pattern := range includeTablesPattern { + if glob.Glob(pattern, name) { + return true + } + } + return false +} + +type NewSourceClientFunc func(context.Context, zerolog.Logger, any) (SourceClient, error) + +// NewSourcePlugin returns a new CloudQuery Plugin with the given name, version and implementation. +// Source plugins only support read operations. For Read & Write plugin use NewPlugin. +func NewSourcePlugin(name string, version string, newClient NewSourceClientFunc, options ...Option) *Plugin { + newClientWrapper := func(ctx context.Context, logger zerolog.Logger, spec []byte) (Client, error) { + sourceClient, err := newClient(ctx, logger, spec) + if err != nil { + return nil, err + } + wrapperClient := struct { + SourceClient + UnimplementedDestination + }{ + SourceClient: sourceClient, + } + return wrapperClient, nil + } + return NewPlugin(name, version, newClientWrapper, options...) +} + +func (p *Plugin) readAll(ctx context.Context, table *schema.Table) ([]arrow.Record, error) { + var err error + ch := make(chan arrow.Record) + go func() { + defer close(ch) + err = p.client.Read(ctx, table, ch) + }() + // nolint:prealloc + var records []arrow.Record + for record := range ch { + records = append(records, record) + } + return records, err +} + +func (p *Plugin) SyncAll(ctx context.Context, options SyncOptions) (message.Messages, error) { + var err error + ch := make(chan message.Message) + go func() { + defer close(ch) + err = p.Sync(ctx, options, ch) + }() + // nolint:prealloc + var resources []message.Message + for resource := range ch { + resources = append(resources, resource) + } + return resources, err +} + +// Sync is syncing data from the requested tables in spec to the given channel +func (p *Plugin) Sync(ctx context.Context, options SyncOptions, res chan<- message.Message) error { + if !p.mu.TryLock() { + return fmt.Errorf("plugin already in use") + } + defer p.mu.Unlock() + if p.client == nil { + return fmt.Errorf("plugin not initialized. call Init() first") + } + // startTime := time.Now() + + if err := p.client.Sync(ctx, options, res); err != nil { + return fmt.Errorf("failed to sync unmanaged client: %w", err) + } + + // p.logger.Info().Uint64("resources", p.metrics.TotalResources()).Uint64("errors", p.metrics.TotalErrors()).Uint64("panics", p.metrics.TotalPanics()).TimeDiff("duration", time.Now(), startTime).Msg("sync finished") + return nil +} diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go new file mode 100644 index 0000000000..57c3b8ebf9 --- /dev/null +++ b/plugin/plugin_test.go @@ -0,0 +1,85 @@ +package plugin + +import ( + "context" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/rs/zerolog" +) + +type testPluginClient struct { + messages []message.Message +} + +func newTestPluginClient(context.Context, zerolog.Logger, []byte) (Client, error) { + return &testPluginClient{}, nil +} + +func (*testPluginClient) GetSpec() any { + return &struct{}{} +} + +func (*testPluginClient) Tables(context.Context) (schema.Tables, error) { + return schema.Tables{}, nil +} + +func (*testPluginClient) Read(context.Context, *schema.Table, chan<- arrow.Record) error { + return nil +} + +func (c *testPluginClient) Sync(_ context.Context, _ SyncOptions, res chan<- message.Message) error { + for _, msg := range c.messages { + res <- msg + } + return nil +} +func (c *testPluginClient) Write(_ context.Context, _ WriteOptions, res <-chan message.Message) error { + for msg := range res { + c.messages = append(c.messages, msg) + } + return nil +} +func (*testPluginClient) Close(context.Context) error { + return nil +} + +func TestPluginSuccess(t *testing.T) { + ctx := context.Background() + p := NewPlugin("test", "v1.0.0", newTestPluginClient) + if err := p.Init(ctx, []byte("")); err != nil { + t.Fatal(err) + } + tables, err := p.Tables(ctx) + if err != nil { + t.Fatal(err) + } + if len(tables) != 0 { + t.Fatal("expected 0 tables") + } + if err := p.WriteAll(ctx, WriteOptions{}, nil); err != nil { + t.Fatal(err) + } + if err := p.WriteAll(ctx, WriteOptions{}, []message.Message{ + message.MigrateTable{}, + }); err != nil { + t.Fatal(err) + } + if len(p.client.(*testPluginClient).messages) != 1 { + t.Fatal("expected 1 message") + } + + messages, err := p.SyncAll(ctx, SyncOptions{}) + if err != nil { + t.Fatal(err) + } + if len(messages) != 1 { + t.Fatal("expected 1 message") + } + + if err := p.Close(ctx); err != nil { + t.Fatal(err) + } +} diff --git a/plugin/testing_upsert.go b/plugin/testing_upsert.go new file mode 100644 index 0000000000..9dfaa83d4b --- /dev/null +++ b/plugin/testing_upsert.go @@ -0,0 +1,65 @@ +package plugin + +import ( + "context" + "fmt" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +func (s *WriterTestSuite) testUpsert(ctx context.Context) error { + tableName := fmt.Sprintf("cq_test_upsert_%d", time.Now().Unix()) + table := &schema.Table{ + Name: tableName, + Columns: []schema.Column{ + {Name: "name", Type: arrow.BinaryTypes.String, PrimaryKey: true}, + }, + } + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.MigrateTable{ + Table: table, + }); err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.StringBuilder).Append("foo") + record := bldr.NewRecord() + + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.Insert{ + Record: record, + }); err != nil { + return fmt.Errorf("failed to insert record: %w", err) + } + + records, err := s.plugin.readAll(ctx, table) + if err != nil { + return fmt.Errorf("failed to readAll: %w", err) + } + totalItems := TotalRows(records) + if totalItems != 1 { + return fmt.Errorf("expected 1 item, got %d", totalItems) + } + + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.Insert{ + Record: record, + }); err != nil { + return fmt.Errorf("failed to insert record: %w", err) + } + + records, err = s.plugin.readAll(ctx, table) + if err != nil { + return fmt.Errorf("failed to sync: %w", err) + } + + totalItems = TotalRows(records) + if totalItems != 1 { + return fmt.Errorf("expected 1 item, got %d", totalItems) + } + + return nil +} diff --git a/plugin/testing_write.go b/plugin/testing_write.go new file mode 100644 index 0000000000..4f198e80a8 --- /dev/null +++ b/plugin/testing_write.go @@ -0,0 +1,123 @@ +package plugin + +import ( + "context" + "testing" + + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +type WriterTestSuite struct { + tests WriterTestSuiteTests + + plugin *Plugin + + // AllowNull is a custom func to determine whether a data type may be correctly represented as null. + // Destinations that have problems representing some data types should provide a custom implementation here. + // If this param is empty, the default is to allow all data types to be nullable. + // When the value returned by this func is `true` the comparison is made with the empty value instead of null. + // allowNull AllowNullFunc + + // IgnoreNullsInLists allows stripping null values from lists before comparison. + // Destination setups that don't support nulls in lists should set this to true. + ignoreNullsInLists bool + + // genDataOptions define how to generate test data and which data types to skip + genDatOptions schema.TestSourceOptions +} + +// SafeMigrations defines which migrations are supported by the plugin in safe migrate mode +type SafeMigrations struct { + AddColumn bool + AddColumnNotNull bool + RemoveColumn bool + RemoveColumnNotNull bool + ChangeColumn bool +} + +type WriterTestSuiteTests struct { + // SkipUpsert skips testing with message.Insert and Upsert=true. + // Usually when a destination is not supporting primary keys + SkipUpsert bool + + // SkipDeleteStale skips testing message.Delete events. + SkipDeleteStale bool + + // SkipAppend skips testing message.Insert and Upsert=false. + SkipInsert bool + + // SkipMigrate skips testing migration + SkipMigrate bool + + // SafeMigrations defines which tests should work with force migration + // and which should pass with safe migration + SafeMigrations SafeMigrations +} + +type NewPluginFunc func() *Plugin + +// func WithTestSourceAllowNull(allowNull func(arrow.DataType) bool) func(o *WriterTestSuite) { +// return func(o *WriterTestSuite) { +// o.allowNull = allowNull +// } +// } + +func WithTestIgnoreNullsInLists() func(o *WriterTestSuite) { + return func(o *WriterTestSuite) { + o.ignoreNullsInLists = true + } +} + +func WithTestDataOptions(opts schema.TestSourceOptions) func(o *WriterTestSuite) { + return func(o *WriterTestSuite) { + o.genDatOptions = opts + } +} + +func TestWriterSuiteRunner(t *testing.T, p *Plugin, tests WriterTestSuiteTests, opts ...func(o *WriterTestSuite)) { + suite := &WriterTestSuite{ + tests: tests, + plugin: p, + } + + for _, opt := range opts { + opt(suite) + } + + ctx := context.Background() + + t.Run("TestUpsert", func(t *testing.T) { + if suite.tests.SkipUpsert { + t.Skip("skipping " + t.Name()) + } + if err := suite.testUpsert(ctx); err != nil { + t.Fatal(err) + } + }) + + t.Run("TestInsert", func(t *testing.T) { + if suite.tests.SkipInsert { + t.Skip("skipping " + t.Name()) + } + if err := suite.testInsert(ctx); err != nil { + t.Fatal(err) + } + }) + + t.Run("TestDeleteStale", func(t *testing.T) { + if suite.tests.SkipDeleteStale { + t.Skip("skipping " + t.Name()) + } + if err := suite.testDeleteStale(ctx); err != nil { + t.Fatal(err) + } + }) + + t.Run("TestMigrate", func(t *testing.T) { + if suite.tests.SkipMigrate { + t.Skip("skipping " + t.Name()) + } + suite.testMigrate(ctx, t, false) + suite.testMigrate(ctx, t, true) + }) +} diff --git a/plugin/testing_write_delete.go b/plugin/testing_write_delete.go new file mode 100644 index 0000000000..5ec89b8d93 --- /dev/null +++ b/plugin/testing_write_delete.go @@ -0,0 +1,74 @@ +package plugin + +import ( + "context" + "fmt" + "time" + + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +func (s *WriterTestSuite) testDeleteStale(ctx context.Context) error { + tableName := fmt.Sprintf("cq_delete_%d", time.Now().Unix()) + syncTime := time.Now().UTC().Round(1 * time.Second) + table := &schema.Table{ + Name: tableName, + Columns: []schema.Column{ + schema.CqSourceNameColumn, + schema.CqSyncTimeColumn, + }, + } + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.MigrateTable{ + Table: table, + }); err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.StringBuilder).Append("test") + bldr.Field(1).(*array.TimestampBuilder).AppendTime(syncTime) + record := bldr.NewRecord() + + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.Insert{ + Record: record, + }); err != nil { + return fmt.Errorf("failed to insert record: %w", err) + } + + records, err := s.plugin.readAll(ctx, table) + if err != nil { + return fmt.Errorf("failed to sync: %w", err) + } + totalItems := TotalRows(records) + + if totalItems != 1 { + return fmt.Errorf("expected 1 items, got %d", totalItems) + } + + bldr = array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.StringBuilder).Append("test") + bldr.Field(1).(*array.TimestampBuilder).AppendTime(syncTime.Add(time.Second)) + + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.DeleteStale{ + Table: table, + SourceName: "test", + SyncTime: syncTime, + }); err != nil { + return fmt.Errorf("failed to delete stale records: %w", err) + } + + records, err = s.plugin.readAll(ctx, table) + if err != nil { + return fmt.Errorf("failed to sync: %w", err) + } + totalItems = TotalRows(records) + + if totalItems != 1 { + return fmt.Errorf("expected 1 item, got %d", totalItems) + } + + return nil +} diff --git a/plugin/testing_write_insert.go b/plugin/testing_write_insert.go new file mode 100644 index 0000000000..892f7b659a --- /dev/null +++ b/plugin/testing_write_insert.go @@ -0,0 +1,73 @@ +package plugin + +import ( + "context" + "fmt" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +func TotalRows(records []arrow.Record) int64 { + totalRows := int64(0) + for _, record := range records { + totalRows += record.NumRows() + } + return totalRows +} + +func (s *WriterTestSuite) testInsert(ctx context.Context) error { + tableName := fmt.Sprintf("cq_test_insert_%d", time.Now().Unix()) + table := &schema.Table{ + Name: tableName, + Columns: []schema.Column{ + {Name: "name", Type: arrow.BinaryTypes.String}, + }, + } + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.MigrateTable{ + Table: table, + }); err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.StringBuilder).Append("foo") + record := bldr.NewRecord() + + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.Insert{ + Record: record, + }); err != nil { + return fmt.Errorf("failed to insert record: %w", err) + } + readRecords, err := s.plugin.readAll(ctx, table) + if err != nil { + return fmt.Errorf("failed to sync: %w", err) + } + + totalItems := TotalRows(readRecords) + if totalItems != 1 { + return fmt.Errorf("expected 1 item, got %d", totalItems) + } + + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.Insert{ + Record: record, + }); err != nil { + return fmt.Errorf("failed to insert record: %w", err) + } + + readRecords, err = s.plugin.readAll(ctx, table) + if err != nil { + return fmt.Errorf("failed to sync: %w", err) + } + + totalItems = TotalRows(readRecords) + if totalItems != 2 { + return fmt.Errorf("expected 2 item, got %d", totalItems) + } + + return nil +} diff --git a/plugin/testing_write_migrate.go b/plugin/testing_write_migrate.go new file mode 100644 index 0000000000..af6224cbe2 --- /dev/null +++ b/plugin/testing_write_migrate.go @@ -0,0 +1,218 @@ +package plugin + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/cloudquery/plugin-sdk/v4/types" + "github.com/google/uuid" +) + +func tableUUIDSuffix() string { + return strings.ReplaceAll(uuid.NewString(), "-", "_")[:8] // use only first 8 chars +} + +// nolint:revive +func (s *WriterTestSuite) migrate(ctx context.Context, target *schema.Table, source *schema.Table, supportsSafeMigrate bool, writeOptionMigrateForce bool) error { + if err := s.plugin.writeOne(ctx, WriteOptions{ + MigrateForce: writeOptionMigrateForce, + }, &message.MigrateTable{ + Table: source, + }); err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + sourceName := target.Name + syncTime := time.Now().UTC().Round(1 * time.Second) + opts := schema.GenTestDataOptions{ + SourceName: sourceName, + SyncTime: syncTime, + MaxRows: 1, + TimePrecision: s.genDatOptions.TimePrecision, + } + + resource1 := schema.GenTestData(source, opts)[0] + + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.Insert{ + Record: resource1, + }); err != nil { + return fmt.Errorf("failed to insert first record: %w", err) + } + + records, err := s.plugin.readAll(ctx, source) + if err != nil { + return fmt.Errorf("failed to sync: %w", err) + } + totalItems := TotalRows(records) + if totalItems != 1 { + return fmt.Errorf("expected 1 item, got %d", totalItems) + } + + if err := s.plugin.writeOne(ctx, WriteOptions{MigrateForce: writeOptionMigrateForce}, &message.MigrateTable{ + Table: target, + }); err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + resource2 := schema.GenTestData(target, opts)[0] + if err := s.plugin.writeOne(ctx, WriteOptions{}, &message.Insert{ + Record: resource2, + }); err != nil { + return fmt.Errorf("failed to insert second record: %w", err) + } + + records, err = s.plugin.readAll(ctx, target) + if err != nil { + return fmt.Errorf("failed to readAll: %w", err) + } + // if force migration is not required, we don't expect any items to be dropped (so there should be 2 items) + if !writeOptionMigrateForce || supportsSafeMigrate { + totalItems = TotalRows(records) + if totalItems != 2 { + return fmt.Errorf("expected 2 item, got %d", totalItems) + } + } else { + totalItems = TotalRows(records) + if totalItems != 1 { + return fmt.Errorf("expected 1 item, got %d", totalItems) + } + } + + return nil +} + +// nolint:revive +func (s *WriterTestSuite) testMigrate( + ctx context.Context, + t *testing.T, + forceMigrate bool, +) { + suffix := "_safe" + if forceMigrate { + suffix = "_force" + } + t.Run("add_column"+suffix, func(t *testing.T) { + if !forceMigrate && !s.tests.SafeMigrations.AddColumn { + t.Skip("skipping test: add_column") + } + tableName := "add_column" + suffix + "_" + tableUUIDSuffix() + source := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + }, + } + + target := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, + }, + } + if err := s.migrate(ctx, target, source, s.tests.SafeMigrations.AddColumn, forceMigrate); err != nil { + t.Fatalf("failed to migrate %s: %v", tableName, err) + } + }) + + t.Run("add_column_not_null"+suffix, func(t *testing.T) { + if !forceMigrate && !s.tests.SafeMigrations.AddColumnNotNull { + t.Skip("skipping test: add_column_not_null") + } + tableName := "add_column_not_null" + suffix + "_" + tableUUIDSuffix() + source := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + }, + } + + target := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, NotNull: true}, + }} + if err := s.migrate(ctx, target, source, s.tests.SafeMigrations.AddColumnNotNull, forceMigrate); err != nil { + t.Fatalf("failed to migrate add_column_not_null: %v", err) + } + }) + + t.Run("remove_column"+suffix, func(t *testing.T) { + if !forceMigrate && !s.tests.SafeMigrations.RemoveColumn { + t.Skip("skipping test: remove_column") + } + tableName := "remove_column" + suffix + "_" + tableUUIDSuffix() + source := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, + }} + target := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + }} + if err := s.migrate(ctx, target, source, s.tests.SafeMigrations.RemoveColumn, forceMigrate); err != nil { + t.Fatalf("failed to migrate remove_column: %v", err) + } + }) + + t.Run("remove_column_not_null"+suffix, func(t *testing.T) { + if !forceMigrate && !s.tests.SafeMigrations.RemoveColumnNotNull { + t.Skip("skipping test: remove_column_not_null") + } + tableName := "remove_column_not_null" + suffix + "_" + tableUUIDSuffix() + source := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, NotNull: true}, + }, + } + target := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + }} + if err := s.migrate(ctx, target, source, s.tests.SafeMigrations.RemoveColumnNotNull, forceMigrate); err != nil { + t.Fatalf("failed to migrate remove_column_not_null: %v", err) + } + }) + + t.Run("change_column"+suffix, func(t *testing.T) { + if !forceMigrate && !s.tests.SafeMigrations.ChangeColumn { + t.Skip("skipping test: change_column") + } + tableName := "change_column" + suffix + "_" + tableUUIDSuffix() + source := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, NotNull: true}, + }} + target := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + {Name: "id", Type: types.ExtensionTypes.UUID}, + {Name: "bool", Type: arrow.BinaryTypes.String, NotNull: true}, + }} + if err := s.migrate(ctx, target, source, s.tests.SafeMigrations.ChangeColumn, forceMigrate); err != nil { + t.Fatalf("failed to migrate change_column: %v", err) + } + }) + + t.Run("double_migration", func(t *testing.T) { + // tableName := "double_migration_" + tableUUIDSuffix() + // table := schema.TestTable(tableName, testOpts.TestSourceOptions) + // require.NoError(t, p.Migrate(ctx, schema.Tables{table}, MigrateOptions{MigrateMode: MigrateModeForce})) + // require.NoError(t, p.Migrate(ctx, schema.Tables{table}, MigrateOptions{MigrateMode: MigrateModeForce})) + }) +} diff --git a/plugins/destination/managed_writer.go b/plugins/destination/managed_writer.go deleted file mode 100644 index 0d00f14bc3..0000000000 --- a/plugins/destination/managed_writer.go +++ /dev/null @@ -1,169 +0,0 @@ -package destination - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/apache/arrow/go/v13/arrow/util" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/internal/pk" - "github.com/cloudquery/plugin-sdk/v3/schema" -) - -type worker struct { - count int - wg *sync.WaitGroup - ch chan arrow.Record - flush chan chan bool -} - -func (p *Plugin) worker(ctx context.Context, metrics *Metrics, table *schema.Table, ch <-chan arrow.Record, flush <-chan chan bool) { - sizeBytes := int64(0) - resources := make([]arrow.Record, 0) - for { - select { - case r, ok := <-ch: - if !ok { - if len(resources) > 0 { - p.flush(ctx, metrics, table, resources) - } - return - } - if len(resources) == p.spec.BatchSize || sizeBytes+util.TotalRecordSize(r) > int64(p.spec.BatchSizeBytes) { - p.flush(ctx, metrics, table, resources) - resources = resources[:0] // allows for mem reuse - sizeBytes = 0 - } - resources = append(resources, r) - sizeBytes += util.TotalRecordSize(r) - case <-time.After(p.batchTimeout): - if len(resources) > 0 { - p.flush(ctx, metrics, table, resources) - resources = resources[:0] // allows for mem reuse - sizeBytes = 0 - } - case done := <-flush: - if len(resources) > 0 { - p.flush(ctx, metrics, table, resources) - resources = resources[:0] // allows for mem reuse - sizeBytes = 0 - } - done <- true - case <-ctx.Done(): - // this means the request was cancelled - return // after this NO other call will succeed - } - } -} - -func (p *Plugin) flush(ctx context.Context, metrics *Metrics, table *schema.Table, resources []arrow.Record) { - resources = p.removeDuplicatesByPK(table, resources) - start := time.Now() - batchSize := len(resources) - if err := p.client.WriteTableBatch(ctx, table, resources); err != nil { - p.logger.Err(err).Str("table", table.Name).Int("len", batchSize).Dur("duration", time.Since(start)).Msg("failed to write batch") - // we don't return an error as we need to continue until channel is closed otherwise there will be a deadlock - atomic.AddUint64(&metrics.Errors, uint64(batchSize)) - } else { - p.logger.Info().Str("table", table.Name).Int("len", batchSize).Dur("duration", time.Since(start)).Msg("batch written successfully") - atomic.AddUint64(&metrics.Writes, uint64(batchSize)) - } -} - -func (*Plugin) removeDuplicatesByPK(table *schema.Table, resources []arrow.Record) []arrow.Record { - pkIndices := table.PrimaryKeysIndexes() - // special case where there's no PK at all - if len(pkIndices) == 0 { - return resources - } - - pks := make(map[string]struct{}, len(resources)) - res := make([]arrow.Record, 0, len(resources)) - for _, r := range resources { - if r.NumRows() > 1 { - panic(fmt.Sprintf("record with more than 1 row: %d", r.NumRows())) - } - key := pk.String(r) - _, ok := pks[key] - if !ok { - pks[key] = struct{}{} - res = append(res, r) - continue - } - // duplicate, release - r.Release() - } - - return res -} - -func (p *Plugin) writeManagedTableBatch(ctx context.Context, _ specs.Source, tables schema.Tables, _ time.Time, res <-chan arrow.Record) error { - workers := make(map[string]*worker, len(tables)) - metrics := &Metrics{} - - p.workersLock.Lock() - for _, table := range tables { - table := table - if p.workers[table.Name] == nil { - ch := make(chan arrow.Record) - flush := make(chan chan bool) - wg := &sync.WaitGroup{} - p.workers[table.Name] = &worker{ - count: 1, - ch: ch, - flush: flush, - wg: wg, - } - wg.Add(1) - go func() { - defer wg.Done() - p.worker(ctx, metrics, table, ch, flush) - }() - } else { - p.workers[table.Name].count++ - } - // we save this locally because we don't want to access the map after that so we can - // keep the workersLock for as short as possible - workers[table.Name] = p.workers[table.Name] - } - p.workersLock.Unlock() - - for r := range res { - tableName, ok := r.Schema().Metadata().GetValue(schema.MetadataTableName) - if !ok { - return fmt.Errorf("missing table name in record metadata") - } - if _, ok := workers[tableName]; !ok { - return fmt.Errorf("table %s not found in destination", tableName) - } - workers[tableName].ch <- r - } - - // flush and wait for all workers to finish flush before finish and calling delete stale - // This is because destinations can be longed lived and called from multiple sources - flushChannels := make(map[string]chan bool, len(workers)) - for tableName, w := range workers { - flushCh := make(chan bool) - flushChannels[tableName] = flushCh - w.flush <- flushCh - } - for tableName := range flushChannels { - <-flushChannels[tableName] - } - - p.workersLock.Lock() - for tableName := range workers { - p.workers[tableName].count-- - if p.workers[tableName].count == 0 { - close(p.workers[tableName].ch) - p.workers[tableName].wg.Wait() - delete(p.workers, tableName) - } - } - p.workersLock.Unlock() - return nil -} diff --git a/plugins/destination/metrics.go b/plugins/destination/metrics.go deleted file mode 100644 index d00613ecf8..0000000000 --- a/plugins/destination/metrics.go +++ /dev/null @@ -1,8 +0,0 @@ -package destination - -type Metrics struct { - // Errors number of errors / failed writes - Errors uint64 - // Writes number of successful writes - Writes uint64 -} diff --git a/plugins/destination/nulls.go b/plugins/destination/nulls.go deleted file mode 100644 index 6f965106e4..0000000000 --- a/plugins/destination/nulls.go +++ /dev/null @@ -1,72 +0,0 @@ -package destination - -import ( - "github.com/apache/arrow/go/v13/arrow" - "github.com/apache/arrow/go/v13/arrow/array" - "github.com/apache/arrow/go/v13/arrow/memory" -) - -func stripNullsFromLists(records []arrow.Record) { - for i := range records { - cols := records[i].Columns() - for c, col := range cols { - if col.DataType().ID() != arrow.LIST { - continue - } - - list := col.(*array.List) - bldr := array.NewListBuilder(memory.DefaultAllocator, list.DataType().(*arrow.ListType).Elem()) - for j := 0; j < list.Len(); j++ { - if list.IsNull(j) { - bldr.AppendNull() - continue - } - bldr.Append(true) - vBldr := bldr.ValueBuilder() - from, to := list.ValueOffsets(j) - slc := array.NewSlice(list.ListValues(), from, to) - for k := 0; k < int(to-from); k++ { - if slc.IsNull(k) { - continue - } - err := vBldr.AppendValueFromString(slc.ValueStr(k)) - if err != nil { - panic(err) - } - } - } - cols[c] = bldr.NewArray() - } - records[i] = array.NewRecord(records[i].Schema(), cols, records[i].NumRows()) - } -} - -type AllowNullFunc func(arrow.DataType) bool - -func (f AllowNullFunc) replaceNullsByEmpty(records []arrow.Record) { - if f == nil { - return - } - for i := range records { - cols := records[i].Columns() - for c, col := range records[i].Columns() { - if col.NullN() == 0 || f(col.DataType()) { - continue - } - - builder := array.NewBuilder(memory.DefaultAllocator, records[i].Column(c).DataType()) - for j := 0; j < col.Len(); j++ { - if col.IsNull(j) { - builder.AppendEmptyValue() - continue - } - - if err := builder.AppendValueFromString(col.ValueStr(j)); err != nil { - panic(err) - } - } - cols[c] = builder.NewArray() - } - records[i] = array.NewRecord(records[i].Schema(), cols, records[i].NumRows()) - } -} diff --git a/plugins/destination/plugin.go b/plugins/destination/plugin.go deleted file mode 100644 index 1d40f6af80..0000000000 --- a/plugins/destination/plugin.go +++ /dev/null @@ -1,314 +0,0 @@ -package destination - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/rs/zerolog" -) - -type writerType int - -const ( - unmanaged writerType = iota - managed -) - -const ( - defaultBatchTimeoutSeconds = 20 - defaultBatchSize = 10000 - defaultBatchSizeBytes = 5 * 1024 * 1024 // 5 MiB -) - -type NewClientFunc func(context.Context, zerolog.Logger, specs.Destination) (Client, error) - -type ManagedWriter interface { - WriteTableBatch(ctx context.Context, table *schema.Table, data []arrow.Record) error -} - -type UnimplementedManagedWriter struct{} - -var _ ManagedWriter = UnimplementedManagedWriter{} - -func (UnimplementedManagedWriter) WriteTableBatch(context.Context, *schema.Table, []arrow.Record) error { - panic("WriteTableBatch not implemented") -} - -type UnmanagedWriter interface { - Write(ctx context.Context, tables schema.Tables, res <-chan arrow.Record) error - Metrics() Metrics -} - -var _ UnmanagedWriter = UnimplementedUnmanagedWriter{} - -type UnimplementedUnmanagedWriter struct{} - -func (UnimplementedUnmanagedWriter) Write(context.Context, schema.Tables, <-chan arrow.Record) error { - panic("Write not implemented") -} - -func (UnimplementedUnmanagedWriter) Metrics() Metrics { - panic("Metrics not implemented") -} - -type Client interface { - Migrate(ctx context.Context, tables schema.Tables) error - Read(ctx context.Context, table *schema.Table, sourceName string, res chan<- arrow.Record) error - ManagedWriter - UnmanagedWriter - DeleteStale(ctx context.Context, tables schema.Tables, sourceName string, syncTime time.Time) error - Close(ctx context.Context) error -} - -type ClientResource struct { - TableName string - Data []any -} - -type Option func(*Plugin) - -type Plugin struct { - // Name of destination plugin i.e postgresql,snowflake - name string - // Version of the destination plugin - version string - // Called upon configure call to validate and init configuration - newClient NewClientFunc - writerType writerType - // initialized destination client - client Client - // spec the client was initialized with - spec specs.Destination - // Logger to call, this logger is passed to the serve.Serve Client, if not define Serve will create one instead. - logger zerolog.Logger - - // This is in use if the user passed a managed client - metrics map[string]*Metrics - metricsLock *sync.RWMutex - - workers map[string]*worker - workersLock *sync.Mutex - - batchTimeout time.Duration - defaultBatchSize int - defaultBatchSizeBytes int -} - -func WithManagedWriter() Option { - return func(p *Plugin) { - p.writerType = managed - } -} - -func WithBatchTimeout(seconds int) Option { - return func(p *Plugin) { - p.batchTimeout = time.Duration(seconds) * time.Second - } -} - -func WithDefaultBatchSize(defaultBatchSize int) Option { - return func(p *Plugin) { - p.defaultBatchSize = defaultBatchSize - } -} - -func WithDefaultBatchSizeBytes(defaultBatchSizeBytes int) Option { - return func(p *Plugin) { - p.defaultBatchSizeBytes = defaultBatchSizeBytes - } -} - -// NewPlugin creates a new destination plugin -func NewPlugin(name string, version string, newClientFunc NewClientFunc, opts ...Option) *Plugin { - p := &Plugin{ - name: name, - version: version, - newClient: newClientFunc, - metrics: make(map[string]*Metrics), - metricsLock: &sync.RWMutex{}, - workers: make(map[string]*worker), - workersLock: &sync.Mutex{}, - batchTimeout: time.Duration(defaultBatchTimeoutSeconds) * time.Second, - defaultBatchSize: defaultBatchSize, - defaultBatchSizeBytes: defaultBatchSizeBytes, - } - if newClientFunc == nil { - // we do this check because we only call this during runtime later on so it can fail - // before the server starts - panic("newClientFunc can't be nil") - } - for _, opt := range opts { - opt(p) - } - return p -} - -func (p *Plugin) Name() string { - return p.name -} - -func (p *Plugin) Version() string { - return p.version -} - -func (p *Plugin) Metrics() Metrics { - switch p.writerType { - case unmanaged: - return p.client.Metrics() - case managed: - metrics := Metrics{} - p.metricsLock.RLock() - for _, m := range p.metrics { - metrics.Errors += m.Errors - metrics.Writes += m.Writes - } - p.metricsLock.RUnlock() - return metrics - default: - panic("unknown client type") - } -} - -// we need lazy loading because we want to be able to initialize after -func (p *Plugin) Init(ctx context.Context, logger zerolog.Logger, spec specs.Destination) error { - var err error - p.logger = logger - p.spec = spec - p.spec.SetDefaults(p.defaultBatchSize, p.defaultBatchSizeBytes) - p.client, err = p.newClient(ctx, logger, p.spec) - if err != nil { - return err - } - return nil -} - -// we implement all DestinationClient functions so we can hook into pre-post behavior -func (p *Plugin) Migrate(ctx context.Context, tables schema.Tables) error { - if err := checkDestinationColumns(tables); err != nil { - return err - } - return p.client.Migrate(ctx, tables) -} - -func (p *Plugin) readAll(ctx context.Context, table *schema.Table, sourceName string) ([]arrow.Record, error) { - var readErr error - ch := make(chan arrow.Record) - go func() { - defer close(ch) - readErr = p.Read(ctx, table, sourceName, ch) - }() - // nolint:prealloc - var resources []arrow.Record - for resource := range ch { - resources = append(resources, resource) - } - return resources, readErr -} - -func (p *Plugin) Read(ctx context.Context, table *schema.Table, sourceName string, res chan<- arrow.Record) error { - return p.client.Read(ctx, table, sourceName, res) -} - -// this function is currently used mostly for testing so it's not a public api -func (p *Plugin) writeOne(ctx context.Context, sourceSpec specs.Source, syncTime time.Time, resource arrow.Record) error { - resources := []arrow.Record{resource} - return p.writeAll(ctx, sourceSpec, syncTime, resources) -} - -// this function is currently used mostly for testing so it's not a public api -func (p *Plugin) writeAll(ctx context.Context, sourceSpec specs.Source, syncTime time.Time, resources []arrow.Record) error { - ch := make(chan arrow.Record, len(resources)) - for _, resource := range resources { - ch <- resource - } - close(ch) - tables := make(schema.Tables, 0) - tableNames := make(map[string]struct{}) - for _, resource := range resources { - sc := resource.Schema() - tableMD := sc.Metadata() - name, found := tableMD.GetValue(schema.MetadataTableName) - if !found { - return fmt.Errorf("missing table name") - } - if _, ok := tableNames[name]; ok { - continue - } - table, err := schema.NewTableFromArrowSchema(resource.Schema()) - if err != nil { - return err - } - tables = append(tables, table) - tableNames[table.Name] = struct{}{} - } - return p.Write(ctx, sourceSpec, tables, syncTime, ch) -} - -func (p *Plugin) Write(ctx context.Context, sourceSpec specs.Source, tables schema.Tables, syncTime time.Time, res <-chan arrow.Record) error { - syncTime = syncTime.UTC() - if err := checkDestinationColumns(tables); err != nil { - return err - } - switch p.writerType { - case unmanaged: - if err := p.writeUnmanaged(ctx, sourceSpec, tables, syncTime, res); err != nil { - return err - } - case managed: - if err := p.writeManagedTableBatch(ctx, sourceSpec, tables, syncTime, res); err != nil { - return err - } - default: - panic("unknown client type") - } - if p.spec.WriteMode == specs.WriteModeOverwriteDeleteStale { - tablesToDelete := tables - if sourceSpec.Backend != specs.BackendNone { - tablesToDelete = make(schema.Tables, 0, len(tables)) - for _, t := range tables { - if !t.IsIncremental { - tablesToDelete = append(tablesToDelete, t) - } - } - } - if err := p.DeleteStale(ctx, tablesToDelete, sourceSpec.Name, syncTime); err != nil { - return err - } - } - return nil -} - -func (p *Plugin) DeleteStale(ctx context.Context, tables schema.Tables, sourceName string, syncTime time.Time) error { - syncTime = syncTime.UTC() - return p.client.DeleteStale(ctx, tables, sourceName, syncTime) -} - -func (p *Plugin) Close(ctx context.Context) error { - return p.client.Close(ctx) -} - -func checkDestinationColumns(tables schema.Tables) error { - for _, table := range tables { - if table.Columns.Index(schema.CqSourceNameColumn.Name) == -1 { - return fmt.Errorf("table %s is missing column %s. please consider upgrading source plugin", table.Name, schema.CqSourceNameColumn.Name) - } - if table.Columns.Index(schema.CqSyncTimeColumn.Name) == -1 { - return fmt.Errorf("table %s is missing column %s. please consider upgrading source plugin", table.Name, schema.CqSourceNameColumn.Name) - } - column := table.Columns.Get(schema.CqIDColumn.Name) - if column != nil { - if !column.NotNull { - return fmt.Errorf("column %s.%s cannot be nullable. please consider upgrading source plugin", table.Name, schema.CqIDColumn.Name) - } - if !column.Unique { - return fmt.Errorf("column %s.%s must be unique. please consider upgrading source plugin", table.Name, schema.CqIDColumn.Name) - } - } - } - return nil -} diff --git a/plugins/destination/plugin_testing.go b/plugins/destination/plugin_testing.go deleted file mode 100644 index c3ee806aed..0000000000 --- a/plugins/destination/plugin_testing.go +++ /dev/null @@ -1,294 +0,0 @@ -package destination - -import ( - "context" - "os" - "sort" - "strings" - "testing" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/apache/arrow/go/v13/arrow/array" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" - "github.com/rs/zerolog" -) - -type PluginTestSuite struct { - tests PluginTestSuiteTests -} - -// MigrateStrategy defines which tests we should include -type MigrateStrategy struct { - AddColumn specs.MigrateMode - AddColumnNotNull specs.MigrateMode - RemoveColumn specs.MigrateMode - RemoveColumnNotNull specs.MigrateMode - ChangeColumn specs.MigrateMode -} - -type PluginTestSuiteTests struct { - // SkipOverwrite skips testing for "overwrite" mode. Use if the destination - // plugin doesn't support this feature. - SkipOverwrite bool - - // SkipDeleteStale skips testing "delete-stale" mode. Use if the destination - // plugin doesn't support this feature. - SkipDeleteStale bool - - // SkipAppend skips testing for "append" mode. Use if the destination - // plugin doesn't support this feature. - SkipAppend bool - - // SkipSecondAppend skips the second append step in the test. - // This is useful in cases like cloud storage where you can't append to an - // existing object after the file has been closed. - SkipSecondAppend bool - - // SkipMigrateAppend skips a test for the migrate function where a column is added, - // data is appended, then the column is removed and more data appended, checking that the migrations handle - // this correctly. - SkipMigrateAppend bool - // SkipMigrateAppendForce skips a test for the migrate function where a column is changed in force mode - SkipMigrateAppendForce bool - - // SkipMigrateOverwrite skips a test for the migrate function where a column is added, - // data is appended, then the column is removed and more data overwritten, checking that the migrations handle - // this correctly. - SkipMigrateOverwrite bool - // SkipMigrateOverwriteForce skips a test for the migrate function where a column is changed in force mode - SkipMigrateOverwriteForce bool - - MigrateStrategyOverwrite MigrateStrategy - MigrateStrategyAppend MigrateStrategy -} - -func getTestLogger(t *testing.T) zerolog.Logger { - t.Helper() - zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs - return zerolog.New(zerolog.NewTestWriter(t)).Output( - zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.StampMicro}, - ).Level(zerolog.TraceLevel).With().Timestamp().Logger() -} - -type NewPluginFunc func() *Plugin - -type PluginTestSuiteRunnerOptions struct { - // IgnoreNullsInLists allows stripping null values from lists before comparison. - // Destination setups that don't support nulls in lists should set this to true. - IgnoreNullsInLists bool - - // AllowNull is a custom func to determine whether a data type may be correctly represented as null. - // Destinations that have problems representing some data types should provide a custom implementation here. - // If this param is empty, the default is to allow all data types to be nullable. - // When the value returned by this func is `true` the comparison is made with the empty value instead of null. - AllowNull AllowNullFunc - - schema.TestSourceOptions -} - -func WithTestSourceAllowNull(allowNull func(arrow.DataType) bool) func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.AllowNull = allowNull - } -} - -func WithTestIgnoreNullsInLists() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.IgnoreNullsInLists = true - } -} - -func WithTestSourceTimePrecision(precision time.Duration) func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.TimePrecision = precision - } -} - -func WithTestSourceSkipLists() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipLists = true - } -} - -func WithTestSourceSkipTimestamps() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipTimestamps = true - } -} - -func WithTestSourceSkipDates() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipDates = true - } -} - -func WithTestSourceSkipMaps() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipMaps = true - } -} - -func WithTestSourceSkipStructs() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipStructs = true - } -} - -func WithTestSourceSkipIntervals() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipIntervals = true - } -} - -func WithTestSourceSkipDurations() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipDurations = true - } -} - -func WithTestSourceSkipTimes() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipTimes = true - } -} - -func WithTestSourceSkipLargeTypes() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipLargeTypes = true - } -} - -func WithTestSourceSkipDecimals() func(o *PluginTestSuiteRunnerOptions) { - return func(o *PluginTestSuiteRunnerOptions) { - o.SkipDecimals = true - } -} - -func PluginTestSuiteRunner(t *testing.T, newPlugin NewPluginFunc, destSpec specs.Destination, tests PluginTestSuiteTests, testOptions ...func(o *PluginTestSuiteRunnerOptions)) { - t.Helper() - destSpec.Name = "testsuite" - - suite := &PluginTestSuite{ - tests: tests, - } - - opts := PluginTestSuiteRunnerOptions{ - TestSourceOptions: schema.TestSourceOptions{ - TimePrecision: time.Microsecond, - }, - } - for _, o := range testOptions { - o(&opts) - } - - ctx := context.Background() - logger := getTestLogger(t) - - t.Run("TestWriteOverwrite", func(t *testing.T) { - t.Helper() - if suite.tests.SkipOverwrite { - t.Skip("skipping " + t.Name()) - } - destSpec.Name = "test_write_overwrite" - p := newPlugin() - if err := suite.destinationPluginTestWriteOverwrite(ctx, p, logger, destSpec, opts); err != nil { - t.Fatal(err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("TestWriteOverwriteDeleteStale", func(t *testing.T) { - t.Helper() - if suite.tests.SkipOverwrite || suite.tests.SkipDeleteStale { - t.Skip("skipping " + t.Name()) - } - destSpec.Name = "test_write_overwrite_delete_stale" - p := newPlugin() - if err := suite.destinationPluginTestWriteOverwriteDeleteStale(ctx, p, logger, destSpec, opts); err != nil { - t.Fatal(err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("TestMigrateOverwrite", func(t *testing.T) { - t.Helper() - if suite.tests.SkipMigrateOverwrite { - t.Skip("skipping " + t.Name()) - } - destSpec.WriteMode = specs.WriteModeOverwrite - destSpec.MigrateMode = specs.MigrateModeSafe - destSpec.Name = "test_migrate_overwrite" - suite.destinationPluginTestMigrate(ctx, t, newPlugin, logger, destSpec, tests.MigrateStrategyOverwrite, opts) - }) - - t.Run("TestMigrateOverwriteForce", func(t *testing.T) { - t.Helper() - if suite.tests.SkipMigrateOverwriteForce { - t.Skip("skipping " + t.Name()) - } - destSpec.WriteMode = specs.WriteModeOverwrite - destSpec.MigrateMode = specs.MigrateModeForced - destSpec.Name = "test_migrate_overwrite_force" - suite.destinationPluginTestMigrate(ctx, t, newPlugin, logger, destSpec, tests.MigrateStrategyOverwrite, opts) - }) - - t.Run("TestWriteAppend", func(t *testing.T) { - t.Helper() - if suite.tests.SkipAppend { - t.Skip("skipping " + t.Name()) - } - destSpec.Name = "test_write_append" - p := newPlugin() - if err := suite.destinationPluginTestWriteAppend(ctx, p, logger, destSpec, opts); err != nil { - t.Fatal(err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("TestMigrateAppend", func(t *testing.T) { - t.Helper() - if suite.tests.SkipMigrateAppend { - t.Skip("skipping " + t.Name()) - } - destSpec.WriteMode = specs.WriteModeAppend - destSpec.MigrateMode = specs.MigrateModeSafe - destSpec.Name = "test_migrate_append" - suite.destinationPluginTestMigrate(ctx, t, newPlugin, logger, destSpec, tests.MigrateStrategyAppend, opts) - }) - - t.Run("TestMigrateAppendForce", func(t *testing.T) { - t.Helper() - if suite.tests.SkipMigrateAppendForce { - t.Skip("skipping " + t.Name()) - } - destSpec.WriteMode = specs.WriteModeAppend - destSpec.MigrateMode = specs.MigrateModeForced - destSpec.Name = "test_migrate_append_force" - suite.destinationPluginTestMigrate(ctx, t, newPlugin, logger, destSpec, tests.MigrateStrategyAppend, opts) - }) -} - -func sortRecordsBySyncTime(table *schema.Table, records []arrow.Record) { - syncTimeIndex := table.Columns.Index(schema.CqSyncTimeColumn.Name) - cqIDIndex := table.Columns.Index(schema.CqIDColumn.Name) - sort.Slice(records, func(i, j int) bool { - // sort by sync time, then UUID - first := records[i].Column(syncTimeIndex).(*array.Timestamp).Value(0).ToTime(arrow.Millisecond) - second := records[j].Column(syncTimeIndex).(*array.Timestamp).Value(0).ToTime(arrow.Millisecond) - if first.Equal(second) { - firstUUID := records[i].Column(cqIDIndex).(*types.UUIDArray).Value(0).String() - secondUUID := records[j].Column(cqIDIndex).(*types.UUIDArray).Value(0).String() - return strings.Compare(firstUUID, secondUUID) < 0 - } - return first.Before(second) - }) -} diff --git a/plugins/destination/plugin_testing_migrate.go b/plugins/destination/plugin_testing_migrate.go deleted file mode 100644 index b28ef18f50..0000000000 --- a/plugins/destination/plugin_testing_migrate.go +++ /dev/null @@ -1,283 +0,0 @@ -package destination - -import ( - "context" - "fmt" - "strings" - "testing" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" - "github.com/google/uuid" - "github.com/rs/zerolog" - "github.com/stretchr/testify/require" -) - -func tableUUIDSuffix() string { - return strings.ReplaceAll(uuid.NewString(), "-", "_") -} - -func testMigration(ctx context.Context, _ *testing.T, p *Plugin, logger zerolog.Logger, spec specs.Destination, target *schema.Table, source *schema.Table, mode specs.MigrateMode, testOpts PluginTestSuiteRunnerOptions) error { - if err := p.Init(ctx, logger, spec); err != nil { - return fmt.Errorf("failed to init plugin: %w", err) - } - - if err := p.Migrate(ctx, schema.Tables{source}); err != nil { - return fmt.Errorf("failed to migrate tables: %w", err) - } - - sourceName := target.Name - sourceSpec := specs.Source{ - Name: sourceName, - } - syncTime := time.Now().UTC().Round(1 * time.Second) - opts := schema.GenTestDataOptions{ - SourceName: sourceName, - SyncTime: syncTime, - MaxRows: 1, - TimePrecision: testOpts.TimePrecision, - } - resource1 := schema.GenTestData(source, opts)[0] - if err := p.writeOne(ctx, sourceSpec, syncTime, resource1); err != nil { - return fmt.Errorf("failed to write one: %w", err) - } - - if err := p.Migrate(ctx, schema.Tables{target}); err != nil { - return fmt.Errorf("failed to migrate existing table: %w", err) - } - opts.SyncTime = syncTime.Add(time.Second).UTC() - resource2 := schema.GenTestData(target, opts) - if err := p.writeAll(ctx, sourceSpec, syncTime, resource2); err != nil { - return fmt.Errorf("failed to write one after migration: %w", err) - } - - testOpts.AllowNull.replaceNullsByEmpty(resource2) - if testOpts.IgnoreNullsInLists { - stripNullsFromLists(resource2) - } - - resourcesRead, err := p.readAll(ctx, target, sourceName) - if err != nil { - return fmt.Errorf("failed to read all: %w", err) - } - sortRecordsBySyncTime(target, resourcesRead) - if mode == specs.MigrateModeSafe { - if len(resourcesRead) != 2 { - return fmt.Errorf("expected 2 resources after write, got %d", len(resourcesRead)) - } - if !recordApproxEqual(resourcesRead[1], resource2[0]) { - diff := RecordDiff(resourcesRead[1], resource2[0]) - return fmt.Errorf("resource1 and resource2 are not equal. diff: %s", diff) - } - } else { - if len(resourcesRead) != 1 { - return fmt.Errorf("expected 1 resource after write, got %d", len(resourcesRead)) - } - if !recordApproxEqual(resourcesRead[0], resource2[0]) { - diff := RecordDiff(resourcesRead[0], resource2[0]) - return fmt.Errorf("resource1 and resource2 are not equal. diff: %s", diff) - } - } - - return nil -} - -func (*PluginTestSuite) destinationPluginTestMigrate( - ctx context.Context, - t *testing.T, - newPlugin NewPluginFunc, - logger zerolog.Logger, - spec specs.Destination, - strategy MigrateStrategy, - testOpts PluginTestSuiteRunnerOptions, -) { - spec.BatchSize = 1 - - t.Run("add_column", func(t *testing.T) { - if strategy.AddColumn == specs.MigrateModeForced && spec.MigrateMode == specs.MigrateModeSafe { - t.Skip("skipping as migrate mode is safe") - return - } - tableName := "add_column_" + tableUUIDSuffix() - source := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - }, - } - - target := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - {Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, - }, - } - - p := newPlugin() - if err := testMigration(ctx, t, p, logger, spec, target, source, strategy.AddColumn, testOpts); err != nil { - t.Fatalf("failed to migrate %s: %v", tableName, err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("add_column_not_null", func(t *testing.T) { - if strategy.AddColumnNotNull == specs.MigrateModeForced && spec.MigrateMode == specs.MigrateModeSafe { - t.Skip("skipping as migrate mode is safe") - return - } - tableName := "add_column_not_null_" + tableUUIDSuffix() - source := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - }, - } - - target := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, NotNull: true}, - }} - p := newPlugin() - if err := testMigration(ctx, t, p, logger, spec, target, source, strategy.AddColumnNotNull, testOpts); err != nil { - t.Fatalf("failed to migrate add_column_not_null: %v", err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("remove_column", func(t *testing.T) { - if strategy.RemoveColumn == specs.MigrateModeForced && spec.MigrateMode == specs.MigrateModeSafe { - t.Skip("skipping as migrate mode is safe") - return - } - tableName := "remove_column_" + tableUUIDSuffix() - source := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - {Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, - }} - target := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - }} - - p := newPlugin() - if err := testMigration(ctx, t, p, logger, spec, target, source, strategy.RemoveColumn, testOpts); err != nil { - t.Fatalf("failed to migrate remove_column: %v", err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("remove_column_not_null", func(t *testing.T) { - if strategy.RemoveColumnNotNull == specs.MigrateModeForced && spec.MigrateMode == specs.MigrateModeSafe { - t.Skip("skipping as migrate mode is safe") - return - } - tableName := "remove_column_not_null_" + tableUUIDSuffix() - source := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, NotNull: true}, - }, - } - target := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - }} - - p := newPlugin() - if err := testMigration(ctx, t, p, logger, spec, target, source, strategy.RemoveColumnNotNull, testOpts); err != nil { - t.Fatalf("failed to migrate remove_column_not_null: %v", err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("change_column", func(t *testing.T) { - if strategy.ChangeColumn == specs.MigrateModeForced && spec.MigrateMode == specs.MigrateModeSafe { - t.Skip("skipping as migrate mode is safe") - return - } - tableName := "change_column_" + tableUUIDSuffix() - source := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, NotNull: true}, - }} - target := &schema.Table{ - Name: tableName, - Columns: schema.ColumnList{ - schema.CqSourceNameColumn, - schema.CqSyncTimeColumn, - schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID}, - {Name: "bool", Type: arrow.BinaryTypes.String, NotNull: true}, - }} - - p := newPlugin() - if err := testMigration(ctx, t, p, logger, spec, target, source, strategy.ChangeColumn, testOpts); err != nil { - t.Fatalf("failed to migrate change_column: %v", err) - } - if err := p.Close(ctx); err != nil { - t.Fatal(err) - } - }) - - t.Run("double_migration", func(t *testing.T) { - tableName := "double_migration_" + tableUUIDSuffix() - table := schema.TestTable(tableName, testOpts.TestSourceOptions) - - p := newPlugin() - require.NoError(t, p.Init(ctx, logger, spec)) - require.NoError(t, p.Migrate(ctx, schema.Tables{table})) - - nonForced := spec - nonForced.MigrateMode = specs.MigrateModeSafe - require.NoError(t, p.Init(ctx, logger, nonForced)) - require.NoError(t, p.Migrate(ctx, schema.Tables{table})) - }) -} diff --git a/plugins/destination/plugin_testing_overwrite.go b/plugins/destination/plugin_testing_overwrite.go deleted file mode 100644 index f77285ff63..0000000000 --- a/plugins/destination/plugin_testing_overwrite.go +++ /dev/null @@ -1,111 +0,0 @@ -package destination - -import ( - "context" - "fmt" - "time" - - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" - "github.com/google/uuid" - "github.com/rs/zerolog" -) - -func (*PluginTestSuite) destinationPluginTestWriteOverwrite(ctx context.Context, p *Plugin, logger zerolog.Logger, spec specs.Destination, testOpts PluginTestSuiteRunnerOptions) error { - spec.WriteMode = specs.WriteModeOverwrite - if err := p.Init(ctx, logger, spec); err != nil { - return fmt.Errorf("failed to init plugin: %w", err) - } - tableName := fmt.Sprintf("cq_%s_%d", spec.Name, time.Now().Unix()) - table := schema.TestTable(tableName, testOpts.TestSourceOptions) - syncTime := time.Now().UTC().Round(1 * time.Second) - tables := schema.Tables{ - table, - } - if err := p.Migrate(ctx, tables); err != nil { - return fmt.Errorf("failed to migrate tables: %w", err) - } - - sourceName := "testOverwriteSource" + uuid.NewString() - sourceSpec := specs.Source{ - Name: sourceName, - } - - opts := schema.GenTestDataOptions{ - SourceName: sourceName, - SyncTime: syncTime, - MaxRows: 2, - TimePrecision: testOpts.TimePrecision, - } - resources := schema.GenTestData(table, opts) - if err := p.writeAll(ctx, sourceSpec, syncTime, resources); err != nil { - return fmt.Errorf("failed to write all: %w", err) - } - sortRecordsBySyncTime(table, resources) - testOpts.AllowNull.replaceNullsByEmpty(resources) - if testOpts.IgnoreNullsInLists { - stripNullsFromLists(resources) - } - resourcesRead, err := p.readAll(ctx, table, sourceName) - if err != nil { - return fmt.Errorf("failed to read all: %w", err) - } - sortRecordsBySyncTime(table, resourcesRead) - - if len(resourcesRead) != 2 { - return fmt.Errorf("expected 2 resources, got %d", len(resourcesRead)) - } - - if !recordApproxEqual(resources[0], resourcesRead[0]) { - diff := RecordDiff(resources[0], resourcesRead[0]) - return fmt.Errorf("expected first resource to be equal. diff=%s", diff) - } - - if !recordApproxEqual(resources[1], resourcesRead[1]) { - diff := RecordDiff(resources[1], resourcesRead[1]) - return fmt.Errorf("expected second resource to be equal. diff=%s", diff) - } - - secondSyncTime := syncTime.Add(time.Second).UTC() - - // copy first resource but update the sync time - cqIDInds := resources[0].Schema().FieldIndices(schema.CqIDColumn.Name) - u := resources[0].Column(cqIDInds[0]).(*types.UUIDArray).Value(0) - opts = schema.GenTestDataOptions{ - SourceName: sourceName, - SyncTime: secondSyncTime, - MaxRows: 1, - StableUUID: u, - TimePrecision: testOpts.TimePrecision, - } - updatedResource := schema.GenTestData(table, opts) - // write second time - if err := p.writeAll(ctx, sourceSpec, secondSyncTime, updatedResource); err != nil { - return fmt.Errorf("failed to write one second time: %w", err) - } - - testOpts.AllowNull.replaceNullsByEmpty(updatedResource) - if testOpts.IgnoreNullsInLists { - stripNullsFromLists(updatedResource) - } - resourcesRead, err = p.readAll(ctx, table, sourceName) - if err != nil { - return fmt.Errorf("failed to read all second time: %w", err) - } - sortRecordsBySyncTime(table, resourcesRead) - if len(resourcesRead) != 2 { - return fmt.Errorf("after overwrite expected 2 resources, got %d", len(resourcesRead)) - } - - if !recordApproxEqual(resources[1], resourcesRead[0]) { - diff := RecordDiff(resources[1], resourcesRead[0]) - return fmt.Errorf("after overwrite expected first resource to be equal. diff=%s", diff) - } - if !recordApproxEqual(updatedResource[0], resourcesRead[1]) { - diff := RecordDiff(updatedResource[0], resourcesRead[1]) - return fmt.Errorf("after overwrite expected second resource to be equal. diff=%s", diff) - } - - return nil -} diff --git a/plugins/destination/plugin_testing_overwrite_delete_stale.go b/plugins/destination/plugin_testing_overwrite_delete_stale.go deleted file mode 100644 index 4339bb1d43..0000000000 --- a/plugins/destination/plugin_testing_overwrite_delete_stale.go +++ /dev/null @@ -1,152 +0,0 @@ -package destination - -import ( - "context" - "fmt" - "time" - - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" - "github.com/google/uuid" - "github.com/rs/zerolog" -) - -func (*PluginTestSuite) destinationPluginTestWriteOverwriteDeleteStale(ctx context.Context, p *Plugin, logger zerolog.Logger, spec specs.Destination, testOpts PluginTestSuiteRunnerOptions) error { - spec.WriteMode = specs.WriteModeOverwriteDeleteStale - if err := p.Init(ctx, logger, spec); err != nil { - return fmt.Errorf("failed to init plugin: %w", err) - } - tableName := fmt.Sprintf("cq_%s_%d", spec.Name, time.Now().Unix()) - table := schema.TestTable(tableName, testOpts.TestSourceOptions) - incTable := schema.TestTable(tableName+"_incremental", testOpts.TestSourceOptions) - incTable.IsIncremental = true - syncTime := time.Now().UTC().Round(1 * time.Second) - tables := schema.Tables{ - table, - incTable, - } - if err := p.Migrate(ctx, tables); err != nil { - return fmt.Errorf("failed to migrate tables: %w", err) - } - - sourceName := "testOverwriteSource" + uuid.NewString() - sourceSpec := specs.Source{ - Name: sourceName, - Backend: specs.BackendLocal, - } - - opts := schema.GenTestDataOptions{ - SourceName: sourceName, - SyncTime: syncTime, - MaxRows: 2, - TimePrecision: testOpts.TimePrecision, - } - resources := schema.GenTestData(table, opts) - incResources := schema.GenTestData(incTable, opts) - allResources := resources - allResources = append(allResources, incResources...) - if err := p.writeAll(ctx, sourceSpec, syncTime, allResources); err != nil { - return fmt.Errorf("failed to write all: %w", err) - } - sortRecordsBySyncTime(table, resources) - - resourcesRead, err := p.readAll(ctx, table, sourceName) - if err != nil { - return fmt.Errorf("failed to read all: %w", err) - } - sortRecordsBySyncTime(table, resourcesRead) - - if len(resourcesRead) != 2 { - return fmt.Errorf("expected 2 resources, got %d", len(resourcesRead)) - } - testOpts.AllowNull.replaceNullsByEmpty(resources) - if testOpts.IgnoreNullsInLists { - stripNullsFromLists(resources) - } - if !recordApproxEqual(resources[0], resourcesRead[0]) { - diff := RecordDiff(resources[0], resourcesRead[0]) - return fmt.Errorf("expected first resource to be equal. diff: %s", diff) - } - - if !recordApproxEqual(resources[1], resourcesRead[1]) { - diff := RecordDiff(resources[1], resourcesRead[1]) - return fmt.Errorf("expected second resource to be equal. diff: %s", diff) - } - - // read from incremental table - resourcesRead, err = p.readAll(ctx, incTable, sourceName) - if err != nil { - return fmt.Errorf("failed to read all: %w", err) - } - if len(resourcesRead) != 2 { - return fmt.Errorf("expected 2 resources in incremental table, got %d", len(resourcesRead)) - } - - secondSyncTime := syncTime.Add(time.Second).UTC() - // copy first resource but update the sync time - cqIDInds := resources[0].Schema().FieldIndices(schema.CqIDColumn.Name) - u := resources[0].Column(cqIDInds[0]).(*types.UUIDArray).Value(0) - opts = schema.GenTestDataOptions{ - SourceName: sourceName, - SyncTime: secondSyncTime, - StableUUID: u, - MaxRows: 1, - TimePrecision: testOpts.TimePrecision, - } - updatedResources := schema.GenTestData(table, opts) - updatedIncResources := schema.GenTestData(incTable, opts) - allUpdatedResources := updatedResources - allUpdatedResources = append(allUpdatedResources, updatedIncResources...) - - if err := p.writeAll(ctx, sourceSpec, secondSyncTime, allUpdatedResources); err != nil { - return fmt.Errorf("failed to write all second time: %w", err) - } - - resourcesRead, err = p.readAll(ctx, table, sourceName) - if err != nil { - return fmt.Errorf("failed to read all second time: %w", err) - } - sortRecordsBySyncTime(table, resourcesRead) - if len(resourcesRead) != 1 { - return fmt.Errorf("after overwrite expected 1 resource, got %d", len(resourcesRead)) - } - testOpts.AllowNull.replaceNullsByEmpty(resources) - if testOpts.IgnoreNullsInLists { - stripNullsFromLists(resources) - } - if recordApproxEqual(resources[0], resourcesRead[0]) { - diff := RecordDiff(resources[0], resourcesRead[0]) - return fmt.Errorf("after overwrite expected first resource to be different. diff: %s", diff) - } - - resourcesRead, err = p.readAll(ctx, table, sourceName) - if err != nil { - return fmt.Errorf("failed to read all second time: %w", err) - } - if len(resourcesRead) != 1 { - return fmt.Errorf("expected 1 resource after delete stale, got %d", len(resourcesRead)) - } - - // we expect the only resource returned to match the updated resource we wrote - testOpts.AllowNull.replaceNullsByEmpty(updatedResources) - if testOpts.IgnoreNullsInLists { - stripNullsFromLists(updatedResources) - } - if !recordApproxEqual(updatedResources[0], resourcesRead[0]) { - diff := RecordDiff(updatedResources[0], resourcesRead[0]) - return fmt.Errorf("after delete stale expected resource to be equal. diff: %s", diff) - } - - // we expect the incremental table to still have 3 resources, because delete-stale should - // not apply there - resourcesRead, err = p.readAll(ctx, incTable, sourceName) - if err != nil { - return fmt.Errorf("failed to read all from incremental table: %w", err) - } - if len(resourcesRead) != 3 { - return fmt.Errorf("expected 3 resources in incremental table after delete-stale, got %d", len(resourcesRead)) - } - - return nil -} diff --git a/plugins/destination/plugin_testing_write_append.go b/plugins/destination/plugin_testing_write_append.go deleted file mode 100644 index a3f0445c27..0000000000 --- a/plugins/destination/plugin_testing_write_append.go +++ /dev/null @@ -1,95 +0,0 @@ -package destination - -import ( - "context" - "fmt" - "time" - - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/google/uuid" - "github.com/rs/zerolog" -) - -func (s *PluginTestSuite) destinationPluginTestWriteAppend(ctx context.Context, p *Plugin, logger zerolog.Logger, spec specs.Destination, testOpts PluginTestSuiteRunnerOptions) error { - spec.WriteMode = specs.WriteModeAppend - if err := p.Init(ctx, logger, spec); err != nil { - return fmt.Errorf("failed to init plugin: %w", err) - } - tableName := fmt.Sprintf("cq_%s_%d", spec.Name, time.Now().Unix()) - table := schema.TestTable(tableName, testOpts.TestSourceOptions) - syncTime := time.Now().UTC().Round(1 * time.Second) - tables := schema.Tables{ - table, - } - if err := p.Migrate(ctx, tables); err != nil { - return fmt.Errorf("failed to migrate tables: %w", err) - } - - sourceName := "testAppendSource" + uuid.NewString() - specSource := specs.Source{ - Name: sourceName, - } - - opts := schema.GenTestDataOptions{ - SourceName: sourceName, - SyncTime: syncTime, - MaxRows: 2, - TimePrecision: testOpts.TimePrecision, - } - record1 := schema.GenTestData(table, opts) - if err := p.writeAll(ctx, specSource, syncTime, record1); err != nil { - return fmt.Errorf("failed to write record first time: %w", err) - } - - secondSyncTime := syncTime.Add(10 * time.Second).UTC() - opts.SyncTime = secondSyncTime - opts.MaxRows = 1 - record2 := schema.GenTestData(table, opts) - - if !s.tests.SkipSecondAppend { - // write second time - if err := p.writeAll(ctx, specSource, secondSyncTime, record2); err != nil { - return fmt.Errorf("failed to write one second time: %w", err) - } - } - - resourcesRead, err := p.readAll(ctx, tables[0], sourceName) - if err != nil { - return fmt.Errorf("failed to read all second time: %w", err) - } - sortRecordsBySyncTime(table, resourcesRead) - - expectedResource := 3 - if s.tests.SkipSecondAppend { - expectedResource = 2 - } - - if len(resourcesRead) != expectedResource { - return fmt.Errorf("expected %d resources, got %d", expectedResource, len(resourcesRead)) - } - - testOpts.AllowNull.replaceNullsByEmpty(record1) - testOpts.AllowNull.replaceNullsByEmpty(record2) - if testOpts.IgnoreNullsInLists { - stripNullsFromLists(record1) - stripNullsFromLists(record2) - } - if !recordApproxEqual(record1[0], resourcesRead[0]) { - diff := RecordDiff(record1[0], resourcesRead[0]) - return fmt.Errorf("first expected resource diff at row 0: %s", diff) - } - if !recordApproxEqual(record1[1], resourcesRead[1]) { - diff := RecordDiff(record1[1], resourcesRead[1]) - return fmt.Errorf("first expected resource diff at row 1: %s", diff) - } - - if !s.tests.SkipSecondAppend { - if !recordApproxEqual(record2[0], resourcesRead[2]) { - diff := RecordDiff(record2[0], resourcesRead[2]) - return fmt.Errorf("second expected resource diff: %s", diff) - } - } - - return nil -} diff --git a/plugins/destination/unmanaged_writer.go b/plugins/destination/unmanaged_writer.go deleted file mode 100644 index cdb3466b09..0000000000 --- a/plugins/destination/unmanaged_writer.go +++ /dev/null @@ -1,14 +0,0 @@ -package destination - -import ( - "context" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" -) - -func (p *Plugin) writeUnmanaged(ctx context.Context, _ specs.Source, tables schema.Tables, _ time.Time, res <-chan arrow.Record) error { - return p.client.Write(ctx, tables, res) -} diff --git a/plugins/docs.go b/plugins/docs.go deleted file mode 100644 index 2e21a01945..0000000000 --- a/plugins/docs.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package plugins defines APIs for source and destination plugins -package plugins diff --git a/plugins/source/docs.go b/plugins/source/docs.go deleted file mode 100644 index f21d926856..0000000000 --- a/plugins/source/docs.go +++ /dev/null @@ -1,241 +0,0 @@ -package source - -import ( - "bytes" - "embed" - "encoding/json" - "fmt" - "os" - "path/filepath" - "regexp" - "sort" - "text/template" - - "github.com/cloudquery/plugin-sdk/v3/caser" - "github.com/cloudquery/plugin-sdk/v3/schema" -) - -//go:embed templates/*.go.tpl -var templatesFS embed.FS - -var reMatchNewlines = regexp.MustCompile(`\n{3,}`) -var reMatchHeaders = regexp.MustCompile(`(#{1,6}.+)\n+`) - -var DefaultTitleExceptions = map[string]string{ - // common abbreviations - "acl": "ACL", - "acls": "ACLs", - "api": "API", - "apis": "APIs", - "ca": "CA", - "cidr": "CIDR", - "cidrs": "CIDRs", - "db": "DB", - "dbs": "DBs", - "dhcp": "DHCP", - "iam": "IAM", - "iot": "IOT", - "ip": "IP", - "ips": "IPs", - "ipv4": "IPv4", - "ipv6": "IPv6", - "mfa": "MFA", - "ml": "ML", - "oauth": "OAuth", - "vpc": "VPC", - "vpcs": "VPCs", - "vpn": "VPN", - "vpns": "VPNs", - "waf": "WAF", - "wafs": "WAFs", - - // cloud providers - "aws": "AWS", - "gcp": "GCP", -} - -func DefaultTitleTransformer(table *schema.Table) string { - if table.Title != "" { - return table.Title - } - csr := caser.New(caser.WithCustomExceptions(DefaultTitleExceptions)) - return csr.ToTitle(table.Name) -} - -func sortTables(tables schema.Tables) { - sort.SliceStable(tables, func(i, j int) bool { - return tables[i].Name < tables[j].Name - }) - - for _, table := range tables { - sortTables(table.Relations) - } -} - -type templateData struct { - PluginName string - Tables schema.Tables -} - -// GeneratePluginDocs creates table documentation for the source plugin based on its list of tables -func (p *Plugin) GeneratePluginDocs(dir, format string) error { - if err := os.MkdirAll(dir, os.ModePerm); err != nil { - return err - } - - setDestinationManagedCqColumns(p.Tables()) - - sortedTables := make(schema.Tables, 0, len(p.Tables())) - for _, t := range p.Tables() { - sortedTables = append(sortedTables, t.Copy(nil)) - } - sortTables(sortedTables) - - switch format { - case "markdown": - return p.renderTablesAsMarkdown(dir, p.name, sortedTables) - case "json": - return p.renderTablesAsJSON(dir, sortedTables) - default: - return fmt.Errorf("unsupported format: %v", format) - } -} - -// setDestinationManagedCqColumns overwrites or adds the CQ columns that are managed by the destination plugins (_cq_sync_time, _cq_source_name). -func setDestinationManagedCqColumns(tables []*schema.Table) { - for _, table := range tables { - table.OverwriteOrAddColumn(&schema.CqSyncTimeColumn) - table.OverwriteOrAddColumn(&schema.CqSourceNameColumn) - setDestinationManagedCqColumns(table.Relations) - } -} - -type jsonTable struct { - Name string `json:"name"` - Title string `json:"title"` - Description string `json:"description"` - Columns []jsonColumn `json:"columns"` - Relations []jsonTable `json:"relations"` -} - -type jsonColumn struct { - Name string `json:"name"` - Type string `json:"type"` - IsPrimaryKey bool `json:"is_primary_key,omitempty"` - IsIncrementalKey bool `json:"is_incremental_key,omitempty"` -} - -func (p *Plugin) renderTablesAsJSON(dir string, tables schema.Tables) error { - jsonTables := p.jsonifyTables(tables) - buffer := &bytes.Buffer{} - m := json.NewEncoder(buffer) - m.SetIndent("", " ") - m.SetEscapeHTML(false) - err := m.Encode(jsonTables) - if err != nil { - return err - } - outputPath := filepath.Join(dir, "__tables.json") - return os.WriteFile(outputPath, buffer.Bytes(), 0644) -} - -func (p *Plugin) jsonifyTables(tables schema.Tables) []jsonTable { - jsonTables := make([]jsonTable, len(tables)) - for i, table := range tables { - jsonColumns := make([]jsonColumn, len(table.Columns)) - for c, col := range table.Columns { - jsonColumns[c] = jsonColumn{ - Name: col.Name, - Type: col.Type.String(), - IsPrimaryKey: col.PrimaryKey, - IsIncrementalKey: col.IncrementalKey, - } - } - jsonTables[i] = jsonTable{ - Name: table.Name, - Title: p.titleTransformer(table), - Description: table.Description, - Columns: jsonColumns, - Relations: p.jsonifyTables(table.Relations), - } - } - return jsonTables -} - -func (p *Plugin) renderTablesAsMarkdown(dir string, pluginName string, tables schema.Tables) error { - for _, table := range tables { - if err := p.renderAllTables(table, dir); err != nil { - return err - } - } - t, err := template.New("all_tables.md.go.tpl").Funcs(template.FuncMap{ - "indentToDepth": indentToDepth, - }).ParseFS(templatesFS, "templates/all_tables*.md.go.tpl") - if err != nil { - return fmt.Errorf("failed to parse template for README.md: %v", err) - } - - var b bytes.Buffer - if err := t.Execute(&b, templateData{PluginName: pluginName, Tables: tables}); err != nil { - return fmt.Errorf("failed to execute template: %v", err) - } - content := formatMarkdown(b.String()) - outputPath := filepath.Join(dir, "README.md") - f, err := os.Create(outputPath) - if err != nil { - return fmt.Errorf("failed to create file %v: %v", outputPath, err) - } - f.WriteString(content) - return nil -} - -func (p *Plugin) renderAllTables(t *schema.Table, dir string) error { - if err := p.renderTable(t, dir); err != nil { - return err - } - for _, r := range t.Relations { - if err := p.renderAllTables(r, dir); err != nil { - return err - } - } - return nil -} - -func (p *Plugin) renderTable(table *schema.Table, dir string) error { - t := template.New("").Funcs(map[string]any{ - "title": p.titleTransformer, - }) - t, err := t.New("table.md.go.tpl").ParseFS(templatesFS, "templates/table.md.go.tpl") - if err != nil { - return fmt.Errorf("failed to parse template: %v", err) - } - - outputPath := filepath.Join(dir, fmt.Sprintf("%s.md", table.Name)) - - var b bytes.Buffer - if err := t.Execute(&b, table); err != nil { - return fmt.Errorf("failed to execute template: %v", err) - } - content := formatMarkdown(b.String()) - f, err := os.Create(outputPath) - if err != nil { - return fmt.Errorf("failed to create file %v: %v", outputPath, err) - } - f.WriteString(content) - return f.Close() -} - -func formatMarkdown(s string) string { - s = reMatchNewlines.ReplaceAllString(s, "\n\n") - return reMatchHeaders.ReplaceAllString(s, `$1`+"\n\n") -} - -func indentToDepth(table *schema.Table) string { - s := "" - t := table - for t.Parent != nil { - s += " " - t = t.Parent - } - return s -} diff --git a/plugins/source/metrics_test.go b/plugins/source/metrics_test.go deleted file mode 100644 index fb7488d47e..0000000000 --- a/plugins/source/metrics_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package source - -import ( - "testing" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/stretchr/testify/assert" -) - -func TestMetrics(t *testing.T) { - s := &Metrics{ - TableClient: make(map[string]map[string]*TableClientMetrics), - } - s.TableClient["test_table"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - } - if s.TotalResources() != 1 { - t.Fatal("expected 1 resource") - } - if s.TotalErrors() != 2 { - t.Fatal("expected 2 error") - } - if s.TotalPanics() != 3 { - t.Fatal("expected 3 panics") - } - - other := &Metrics{ - TableClient: make(map[string]map[string]*TableClientMetrics), - } - other.TableClient["test_table"] = make(map[string]*TableClientMetrics) - other.TableClient["test_table"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - } - if !s.Equal(other) { - t.Fatal("expected metrics to be equal") - } -} - -func TestInProgressTables(t *testing.T) { - s := &Metrics{ - TableClient: make(map[string]map[string]*TableClientMetrics), - } - s.TableClient["test_table_done"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_done"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - startTime: time.Now(), - endTime: time.Now().Add(time.Second), - } - - s.TableClient["test_table_running1"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_running1"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - startTime: time.Now(), - } - - s.TableClient["test_table_running2"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_running2"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - startTime: time.Now(), - } - s.TableClient["test_table_running3"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_running3"]["testExecutionClient"] = &TableClientMetrics{} - assert.ElementsMatch(t, []string{"test_table_running1", "test_table_running2"}, s.InProgressTables()) -} - -func TestQueuedTables(t *testing.T) { - s := &Metrics{ - TableClient: make(map[string]map[string]*TableClientMetrics), - } - s.TableClient["test_table_done"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_done"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - startTime: time.Now(), - endTime: time.Now().Add(time.Second), - } - - s.TableClient["test_table_running1"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_running1"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - startTime: time.Now(), - } - - s.TableClient["test_table_running2"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_running2"]["testExecutionClient"] = &TableClientMetrics{ - Resources: 1, - Errors: 2, - Panics: 3, - startTime: time.Now(), - } - s.TableClient["test_table_running3"] = make(map[string]*TableClientMetrics) - s.TableClient["test_table_running3"]["testExecutionClient"] = &TableClientMetrics{} - assert.ElementsMatch(t, []string{"test_table_running3"}, s.QueuedTables()) -} - -type MockClientMeta struct { -} - -func (*MockClientMeta) ID() string { - return "id" -} - -var exampleTableSchema = &schema.Table{ - Name: "toplevel", - Columns: schema.ColumnList{ - { - Name: "col1", - Type: &arrow.Int32Type{}, - }, - }, - Relations: []*schema.Table{ - { - Name: "child", - Columns: schema.ColumnList{ - { - Name: "col1", - Type: &arrow.Int32Type{}, - }, - }, - }, - }, -} - -// When a top-level table is marked as done, all child tables should be marked as done as well. -// For child-tables, only the specified table should be marked as done. -func TestMarkEndChildTableNotRecursive(t *testing.T) { - mockClientMeta := &MockClientMeta{} - - metrics := &Metrics{ - TableClient: make(map[string]map[string]*TableClientMetrics), - } - metrics.TableClient["toplevel"] = nil - metrics.TableClient["child"] = nil - - parentTable := exampleTableSchema - childTable := exampleTableSchema.Relations[0] - - metrics.initWithClients(parentTable, []schema.ClientMeta{mockClientMeta}) - metrics.MarkStart(parentTable, mockClientMeta.ID()) - metrics.MarkStart(childTable, mockClientMeta.ID()) - - assert.ElementsMatch(t, []string{"toplevel", "child"}, metrics.InProgressTables()) - - metrics.MarkEnd(childTable, mockClientMeta.ID()) - - assert.ElementsMatch(t, []string{"toplevel"}, metrics.InProgressTables()) -} - -func TestMarkEndTopLevelTableRecursive(t *testing.T) { - mockClientMeta := &MockClientMeta{} - - metrics := &Metrics{ - TableClient: make(map[string]map[string]*TableClientMetrics), - } - metrics.TableClient["toplevel"] = nil - metrics.TableClient["child"] = nil - - parentTable := exampleTableSchema - childTable := exampleTableSchema.Relations[0] - - metrics.initWithClients(parentTable, []schema.ClientMeta{mockClientMeta}) - metrics.MarkStart(parentTable, mockClientMeta.ID()) - metrics.MarkStart(childTable, mockClientMeta.ID()) - - assert.ElementsMatch(t, []string{"toplevel", "child"}, metrics.InProgressTables()) - - metrics.MarkEnd(parentTable, mockClientMeta.ID()) - - assert.Empty(t, metrics.InProgressTables()) -} diff --git a/plugins/source/options.go b/plugins/source/options.go deleted file mode 100644 index 72ddc5acc7..0000000000 --- a/plugins/source/options.go +++ /dev/null @@ -1,39 +0,0 @@ -package source - -import ( - "context" - - "github.com/cloudquery/plugin-sdk/v3/schema" -) - -type GetTables func(ctx context.Context, c schema.ClientMeta) (schema.Tables, error) - -type Option func(*Plugin) - -// WithDynamicTableOption allows the plugin to return list of tables after call to New -func WithDynamicTableOption(getDynamicTables GetTables) Option { - return func(p *Plugin) { - p.getDynamicTables = getDynamicTables - } -} - -// WithNoInternalColumns won't add internal columns (_cq_id, _cq_parent_cq_id) to the plugin tables -func WithNoInternalColumns() Option { - return func(p *Plugin) { - p.internalColumns = false - } -} - -func WithUnmanaged() Option { - return func(p *Plugin) { - p.unmanaged = true - } -} - -// WithTitleTransformer allows the plugin to control how table names get turned into titles for the -// generated documentation. -func WithTitleTransformer(t func(*schema.Table) string) Option { - return func(p *Plugin) { - p.titleTransformer = t - } -} diff --git a/plugins/source/plugin.go b/plugins/source/plugin.go deleted file mode 100644 index 5a0363af1e..0000000000 --- a/plugins/source/plugin.go +++ /dev/null @@ -1,345 +0,0 @@ -package source - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/backend" - "github.com/cloudquery/plugin-sdk/v3/caser" - "github.com/cloudquery/plugin-sdk/v3/internal/backends/local" - "github.com/cloudquery/plugin-sdk/v3/internal/backends/nop" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/rs/zerolog" - "golang.org/x/sync/semaphore" -) - -type Options struct { - Backend backend.Backend -} - -type NewExecutionClientFunc func(context.Context, zerolog.Logger, specs.Source, Options) (schema.ClientMeta, error) - -type UnmanagedClient interface { - schema.ClientMeta - Sync(ctx context.Context, metrics *Metrics, res chan<- *schema.Resource) error -} - -// Plugin is the base structure required to pass to sdk.serve -// We take a declarative approach to API here similar to Cobra -type Plugin struct { - // Name of plugin i.e aws,gcp, azure etc' - name string - // Version of the plugin - version string - // Called upon configure call to validate and init configuration - newExecutionClient NewExecutionClientFunc - // dynamic table function if specified - getDynamicTables GetTables - // Tables is all tables supported by this source plugin - tables schema.Tables - // status sync metrics - metrics *Metrics - // Logger to call, this logger is passed to the serve.Serve Client, if not defined Serve will create one instead. - logger zerolog.Logger - // resourceSem is a semaphore that limits the number of concurrent resources being fetched - resourceSem *semaphore.Weighted - // tableSem is a semaphore that limits the number of concurrent tables being fetched - tableSems []*semaphore.Weighted - // maxDepth is the max depth of tables - maxDepth uint64 - // caser - caser *caser.Caser - // mu is a mutex that limits the number of concurrent init/syncs (can only be one at a time) - mu sync.Mutex - - // client is the initialized session client - client schema.ClientMeta - // sessionTables are the - sessionTables schema.Tables - // backend is the backend used to store the cursor state - backend backend.Backend - // spec is the spec the client was initialized with - spec specs.Source - // NoInternalColumns if set to true will not add internal columns to tables such as _cq_id and _cq_parent_id - // useful for sources such as PostgreSQL and other databases - internalColumns bool - // unmanaged if set to true then the plugin will call Sync directly and not use the scheduler - unmanaged bool - // titleTransformer allows the plugin to control how table names get turned into titles for generated documentation - titleTransformer func(*schema.Table) string - syncTime time.Time -} - -const ( - maxAllowedDepth = 4 -) - -// Add internal columns -func (p *Plugin) addInternalColumns(tables []*schema.Table) error { - for _, table := range tables { - if c := table.Column("_cq_id"); c != nil { - return fmt.Errorf("table %s already has column _cq_id", table.Name) - } - cqID := schema.CqIDColumn - if len(table.PrimaryKeys()) == 0 { - cqID.PrimaryKey = true - } - cqSourceName := schema.CqSourceNameColumn - cqSyncTime := schema.CqSyncTimeColumn - cqSourceName.Resolver = func(_ context.Context, _ schema.ClientMeta, resource *schema.Resource, c schema.Column) error { - return resource.Set(c.Name, p.spec.Name) - } - cqSyncTime.Resolver = func(_ context.Context, _ schema.ClientMeta, resource *schema.Resource, c schema.Column) error { - return resource.Set(c.Name, p.syncTime) - } - - table.Columns = append([]schema.Column{cqSourceName, cqSyncTime, cqID, schema.CqParentIDColumn}, table.Columns...) - if err := p.addInternalColumns(table.Relations); err != nil { - return err - } - } - return nil -} - -// Set parent links on relational tables -func setParents(tables schema.Tables, parent *schema.Table) { - for _, table := range tables { - table.Parent = parent - setParents(table.Relations, table) - } -} - -// Apply transformations to tables -func transformTables(tables schema.Tables) error { - for _, table := range tables { - if table.Transform != nil { - if err := table.Transform(table); err != nil { - return fmt.Errorf("failed to transform table %s: %w", table.Name, err) - } - } - if err := transformTables(table.Relations); err != nil { - return err - } - } - return nil -} - -func maxDepth(tables schema.Tables) uint64 { - var depth uint64 - if len(tables) == 0 { - return 0 - } - for _, table := range tables { - newDepth := 1 + maxDepth(table.Relations) - if newDepth > depth { - depth = newDepth - } - } - return depth -} - -// NewPlugin returns a new plugin with a given name, version, tables, newExecutionClient -// and additional options. -func NewPlugin(name string, version string, tables []*schema.Table, newExecutionClient NewExecutionClientFunc, options ...Option) *Plugin { - p := Plugin{ - name: name, - version: version, - tables: tables, - newExecutionClient: newExecutionClient, - metrics: &Metrics{TableClient: make(map[string]map[string]*TableClientMetrics)}, - caser: caser.New(), - titleTransformer: DefaultTitleTransformer, - internalColumns: true, - } - for _, opt := range options { - opt(&p) - } - setParents(p.tables, nil) - if err := transformTables(p.tables); err != nil { - panic(err) - } - if p.internalColumns { - if err := p.addInternalColumns(p.tables); err != nil { - panic(err) - } - } - if err := p.validate(); err != nil { - panic(err) - } - p.maxDepth = maxDepth(p.tables) - if p.maxDepth > maxAllowedDepth { - panic(fmt.Errorf("max depth of tables is %d, max allowed is %d", p.maxDepth, maxAllowedDepth)) - } - return &p -} - -func (p *Plugin) SetLogger(logger zerolog.Logger) { - p.logger = logger.With().Str("module", p.name+"-src").Logger() -} - -// Tables returns all tables supported by this source plugin -func (p *Plugin) Tables() schema.Tables { - return p.tables -} - -func (p *Plugin) HasDynamicTables() bool { - return p.getDynamicTables != nil -} - -func (p *Plugin) GetDynamicTables() schema.Tables { - return p.sessionTables -} - -// TablesForSpec returns all tables supported by this source plugin that match the given spec. -// It validates the tables part of the spec and will return an error if it is found to be invalid. -// This is deprecated method -func (p *Plugin) TablesForSpec(spec specs.Source) (schema.Tables, error) { - spec.SetDefaults() - if err := spec.Validate(); err != nil { - return nil, fmt.Errorf("invalid spec: %w", err) - } - tables, err := p.tables.FilterDfs(spec.Tables, spec.SkipTables, spec.SkipDependentTables) - if err != nil { - return nil, fmt.Errorf("failed to filter tables: %w", err) - } - return tables, nil -} - -// Name return the name of this plugin -func (p *Plugin) Name() string { - return p.name -} - -// Version returns the version of this plugin -func (p *Plugin) Version() string { - return p.version -} - -func (p *Plugin) Metrics() *Metrics { - return p.metrics -} - -func (p *Plugin) Init(ctx context.Context, spec specs.Source) error { - if !p.mu.TryLock() { - return fmt.Errorf("plugin already in use") - } - defer p.mu.Unlock() - - var err error - spec.SetDefaults() - if err := spec.Validate(); err != nil { - return fmt.Errorf("invalid spec: %w", err) - } - p.spec = spec - - switch spec.Backend { - case specs.BackendNone: - p.backend = nop.New() - case specs.BackendLocal: - p.backend, err = local.New(spec) - if err != nil { - return fmt.Errorf("failed to initialize local backend: %w", err) - } - default: - return fmt.Errorf("unknown backend: %s", spec.Backend) - } - - tables := p.tables - if p.getDynamicTables != nil { - p.client, err = p.newExecutionClient(ctx, p.logger, spec, Options{Backend: p.backend}) - if err != nil { - return fmt.Errorf("failed to create execution client for source plugin %s: %w", p.name, err) - } - tables, err = p.getDynamicTables(ctx, p.client) - if err != nil { - return fmt.Errorf("failed to get dynamic tables: %w", err) - } - - tables, err = tables.FilterDfs(spec.Tables, spec.SkipTables, spec.SkipDependentTables) - if err != nil { - return fmt.Errorf("failed to filter tables: %w", err) - } - if len(tables) == 0 { - return fmt.Errorf("no tables to sync - please check your spec 'tables' and 'skip_tables' settings") - } - - setParents(tables, nil) - if err := transformTables(tables); err != nil { - return err - } - if p.internalColumns { - if err := p.addInternalColumns(tables); err != nil { - return err - } - } - if err := p.validate(); err != nil { - return err - } - p.maxDepth = maxDepth(tables) - if p.maxDepth > maxAllowedDepth { - return fmt.Errorf("max depth of tables is %d, max allowed is %d", p.maxDepth, maxAllowedDepth) - } - } else { - tables, err = tables.FilterDfs(spec.Tables, spec.SkipTables, spec.SkipDependentTables) - if err != nil { - return fmt.Errorf("failed to filter tables: %w", err) - } - } - - p.sessionTables = tables - return nil -} - -// Sync is syncing data from the requested tables in spec to the given channel -func (p *Plugin) Sync(ctx context.Context, syncTime time.Time, res chan<- *schema.Resource) error { - if !p.mu.TryLock() { - return fmt.Errorf("plugin already in use") - } - defer p.mu.Unlock() - p.syncTime = syncTime - if p.client == nil { - var err error - p.client, err = p.newExecutionClient(ctx, p.logger, p.spec, Options{Backend: p.backend}) - if err != nil { - return fmt.Errorf("failed to create execution client for source plugin %s: %w", p.name, err) - } - } - - startTime := time.Now() - if p.unmanaged { - unmanagedClient := p.client.(UnmanagedClient) - if err := unmanagedClient.Sync(ctx, p.metrics, res); err != nil { - return fmt.Errorf("failed to sync unmanaged client: %w", err) - } - } else { - switch p.spec.Scheduler { - case specs.SchedulerDFS: - p.syncDfs(ctx, p.spec, p.client, p.sessionTables, res) - case specs.SchedulerRoundRobin: - p.syncRoundRobin(ctx, p.spec, p.client, p.sessionTables, res) - default: - return fmt.Errorf("unknown scheduler %s. Options are: %v", p.spec.Scheduler, specs.AllSchedulers.String()) - } - } - - p.logger.Info().Uint64("resources", p.metrics.TotalResources()).Uint64("errors", p.metrics.TotalErrors()).Uint64("panics", p.metrics.TotalPanics()).TimeDiff("duration", time.Now(), startTime).Msg("sync finished") - return nil -} - -func (p *Plugin) Close(ctx context.Context) error { - if !p.mu.TryLock() { - return fmt.Errorf("plugin already in use") - } - defer p.mu.Unlock() - if p.backend != nil { - err := p.backend.Close(ctx) - if err != nil { - return fmt.Errorf("failed to close backend: %w", err) - } - p.backend = nil - } - return nil -} diff --git a/plugins/source/plugin_test.go b/plugins/source/plugin_test.go deleted file mode 100644 index 08b38da24d..0000000000 --- a/plugins/source/plugin_test.go +++ /dev/null @@ -1,470 +0,0 @@ -package source - -import ( - "context" - "testing" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/scalar" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/transformers" - "github.com/google/go-cmp/cmp" - "github.com/google/uuid" - "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" -) - -type testExecutionClient struct{} - -var _ schema.ClientMeta = &testExecutionClient{} - -var deterministicStableUUID = uuid.MustParse("c25355aab52c5b70a4e0c9991f5a3b87") -var randomStableUUID = uuid.MustParse("00000000000040008000000000000000") - -var testSyncTime = time.Now() - -func testResolverSuccess(_ context.Context, _ schema.ClientMeta, _ *schema.Resource, res chan<- any) error { - res <- map[string]any{ - "TestColumn": 3, - } - return nil -} - -func testResolverPanic(context.Context, schema.ClientMeta, *schema.Resource, chan<- any) error { - panic("Resolver") -} - -func testPreResourceResolverPanic(context.Context, schema.ClientMeta, *schema.Resource) error { - panic("PreResourceResolver") -} - -func testColumnResolverPanic(context.Context, schema.ClientMeta, *schema.Resource, schema.Column) error { - panic("ColumnResolver") -} - -func testTableSuccess() *schema.Table { - return &schema.Table{ - Name: "test_table_success", - Resolver: testResolverSuccess, - Columns: []schema.Column{ - { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - }, - }, - } -} - -func testTableSuccessWithPK() *schema.Table { - return &schema.Table{ - Name: "test_table_success", - Resolver: testResolverSuccess, - Columns: []schema.Column{ - { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - PrimaryKey: true, - }, - }, - } -} - -func testTableResolverPanic() *schema.Table { - return &schema.Table{ - Name: "test_table_resolver_panic", - Resolver: testResolverPanic, - Columns: []schema.Column{ - { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - }, - }, - } -} - -func testTablePreResourceResolverPanic() *schema.Table { - return &schema.Table{ - Name: "test_table_pre_resource_resolver_panic", - PreResourceResolver: testPreResourceResolverPanic, - Resolver: testResolverSuccess, - Columns: []schema.Column{ - { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - }, - }, - } -} - -func testTableColumnResolverPanic() *schema.Table { - return &schema.Table{ - Name: "test_table_column_resolver_panic", - Resolver: testResolverSuccess, - Columns: []schema.Column{ - { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - }, - { - Name: "test_column1", - Type: arrow.PrimitiveTypes.Int64, - Resolver: testColumnResolverPanic, - }, - }, - } -} - -func testTableRelationSuccess() *schema.Table { - return &schema.Table{ - Name: "test_table_relation_success", - Resolver: testResolverSuccess, - Columns: []schema.Column{ - { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - }, - }, - Relations: []*schema.Table{ - testTableSuccess(), - }, - } -} - -func (*testExecutionClient) ID() string { - return "testExecutionClient" -} - -func newTestExecutionClient(context.Context, zerolog.Logger, specs.Source, Options) (schema.ClientMeta, error) { - return &testExecutionClient{}, nil -} - -type syncTestCase struct { - table *schema.Table - stats Metrics - data []scalar.Vector - deterministicCQID bool -} - -var syncTestCases = []syncTestCase{ - { - table: testTableSuccess(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_success": { - "testExecutionClient": { - Resources: 1, - }, - }, - }, - }, - data: []scalar.Vector{ - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.UUID{}, - &scalar.Int{Value: 3, Valid: true}, - }, - }, - }, - { - table: testTableResolverPanic(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_resolver_panic": { - "testExecutionClient": { - Panics: 1, - }, - }, - }, - }, - data: nil, - }, - { - table: testTablePreResourceResolverPanic(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_pre_resource_resolver_panic": { - "testExecutionClient": { - Panics: 1, - }, - }, - }, - }, - data: nil, - }, - - { - table: testTableRelationSuccess(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_relation_success": { - "testExecutionClient": { - Resources: 1, - }, - }, - "test_table_success": { - "testExecutionClient": { - Resources: 1, - }, - }, - }, - }, - data: []scalar.Vector{ - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.UUID{}, - &scalar.Int{Value: 3, Valid: true}, - }, - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.Int{Value: 3, Valid: true}, - }, - }, - }, - { - table: testTableSuccess(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_success": { - "testExecutionClient": { - Resources: 1, - }, - }, - }, - }, - data: []scalar.Vector{ - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.UUID{}, - &scalar.Int{Value: 3, Valid: true}, - }, - }, - deterministicCQID: true, - }, - { - table: testTableColumnResolverPanic(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_column_resolver_panic": { - "testExecutionClient": { - Panics: 1, - Resources: 1, - }, - }, - }, - }, - data: []scalar.Vector{ - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.UUID{}, - &scalar.Int{Value: 3, Valid: true}, - &scalar.Int{}, - }, - }, - deterministicCQID: true, - }, - { - table: testTableRelationSuccess(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_relation_success": { - "testExecutionClient": { - Resources: 1, - }, - }, - "test_table_success": { - "testExecutionClient": { - Resources: 1, - }, - }, - }, - }, - data: []scalar.Vector{ - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.UUID{}, - &scalar.Int{Value: 3, Valid: true}, - }, - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.UUID{Value: randomStableUUID, Valid: true}, - &scalar.Int{Value: 3, Valid: true}, - }, - }, - deterministicCQID: true, - }, - { - table: testTableSuccessWithPK(), - stats: Metrics{ - TableClient: map[string]map[string]*TableClientMetrics{ - "test_table_success": { - "testExecutionClient": { - Resources: 1, - }, - }, - }, - }, - data: []scalar.Vector{ - { - &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{Value: testSyncTime, Valid: true}, - &scalar.UUID{Value: deterministicStableUUID, Valid: true}, - &scalar.UUID{}, - &scalar.Int{Value: 3, Valid: true}, - }, - }, - deterministicCQID: true, - }, -} - -type testRand struct{} - -func (testRand) Read(p []byte) (n int, err error) { - for i := range p { - p[i] = byte(0) - } - return len(p), nil -} - -func TestSync(t *testing.T) { - uuid.SetRand(testRand{}) - for _, scheduler := range specs.AllSchedulers { - for _, tc := range syncTestCases { - tc := tc - tc.table = tc.table.Copy(nil) - t.Run(tc.table.Name+"_"+scheduler.String(), func(t *testing.T) { - testSyncTable(t, tc, scheduler, tc.deterministicCQID) - }) - } - } -} - -func testSyncTable(t *testing.T, tc syncTestCase, scheduler specs.Scheduler, deterministicCQID bool) { - ctx := context.Background() - tables := []*schema.Table{ - tc.table, - } - - plugin := NewPlugin( - "testSourcePlugin", - "1.0.0", - tables, - newTestExecutionClient, - ) - plugin.SetLogger(zerolog.New(zerolog.NewTestWriter(t))) - spec := specs.Source{ - Name: "testSource", - Path: "cloudquery/testSource", - Tables: []string{"*"}, - Version: "v1.0.0", - Destinations: []string{"test"}, - Concurrency: 1, // choose a very low value to check that we don't run into deadlocks - Scheduler: scheduler, - DeterministicCQID: deterministicCQID, - } - if err := plugin.Init(ctx, spec); err != nil { - t.Fatal(err) - } - - resources := make(chan *schema.Resource) - g, ctx := errgroup.WithContext(ctx) - g.Go(func() error { - defer close(resources) - return plugin.Sync(ctx, - testSyncTime, - resources) - }) - - var i int - for resource := range resources { - if tc.data == nil { - t.Fatalf("Unexpected resource %v", resource) - } - if i >= len(tc.data) { - t.Fatalf("expected %d resources. got %d", len(tc.data), i) - } - if !resource.GetValues().Equal(tc.data[i]) { - t.Fatalf("expected at i=%d: %v. got %v", i, tc.data[i], resource.GetValues()) - } - i++ - } - if len(tc.data) != i { - t.Fatalf("expected %d resources. got %d", len(tc.data), i) - } - - stats := plugin.Metrics() - if !tc.stats.Equal(stats) { - t.Fatalf("unexpected stats: %v", cmp.Diff(tc.stats, stats)) - } - if err := g.Wait(); err != nil { - t.Fatal(err) - } -} - -func TestIgnoredColumns(t *testing.T) { - validateResources(t, schema.Resources{{ - Item: struct{ A *string }{}, - Table: &schema.Table{ - Columns: schema.ColumnList{ - { - Name: "a", - Type: arrow.BinaryTypes.String, - IgnoreInTests: true, - }, - }, - }, - }}) -} - -var testTable struct { - PrimaryKey string - SecondaryKey string - TertiaryKey string - Quaternary string -} - -func TestNewPluginPrimaryKeys(t *testing.T) { - testTransforms := []struct { - transformerOptions []transformers.StructTransformerOption - resultKeys []string - }{ - { - transformerOptions: []transformers.StructTransformerOption{transformers.WithPrimaryKeys("PrimaryKey")}, - resultKeys: []string{"primary_key"}, - }, - { - transformerOptions: []transformers.StructTransformerOption{}, - resultKeys: []string{"_cq_id"}, - }, - } - for _, tc := range testTransforms { - tables := []*schema.Table{ - { - Name: "test_table", - Transform: transformers.TransformWithStruct( - &testTable, tc.transformerOptions..., - ), - }, - } - - plugin := NewPlugin("testSourcePlugin", "1.0.0", tables, newTestExecutionClient) - assert.Equal(t, tc.resultKeys, plugin.tables[0].PrimaryKeys()) - } -} diff --git a/plugins/source/scheduler.go b/plugins/source/scheduler.go deleted file mode 100644 index 1967f3cc1a..0000000000 --- a/plugins/source/scheduler.go +++ /dev/null @@ -1,177 +0,0 @@ -package source - -import ( - "context" - "errors" - "fmt" - "runtime/debug" - "sync" - "sync/atomic" - "time" - - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/getsentry/sentry-go" - "github.com/rs/zerolog" - "github.com/thoas/go-funk" -) - -const ( - minTableConcurrency = 1 - minResourceConcurrency = 100 -) - -const periodicMetricLoggerInterval = 30 * time.Second -const periodicMetricLoggerLogTablesLimit = 30 // The max number of in_progress_tables to log in the periodic metric logger - -func (p *Plugin) logTablesMetrics(tables schema.Tables, client schema.ClientMeta) { - clientName := client.ID() - for _, table := range tables { - metrics := p.metrics.TableClient[table.Name][clientName] - p.logger.Info().Str("table", table.Name).Str("client", clientName).Uint64("resources", metrics.Resources).Uint64("errors", metrics.Errors).Msg("table sync finished") - p.logTablesMetrics(table.Relations, client) - } -} - -func (p *Plugin) resolveResource(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, item any) *schema.Resource { - var validationErr *schema.ValidationError - ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) - defer cancel() - resource := schema.NewResourceData(table, parent, item) - objectStartTime := time.Now() - clientID := client.ID() - tableMetrics := p.metrics.TableClient[table.Name][clientID] - logger := p.logger.With().Str("table", table.Name).Str("client", clientID).Logger() - defer func() { - if err := recover(); err != nil { - stack := fmt.Sprintf("%s\n%s", err, string(debug.Stack())) - logger.Error().Interface("error", err).TimeDiff("duration", time.Now(), objectStartTime).Str("stack", stack).Msg("resource resolver finished with panic") - atomic.AddUint64(&tableMetrics.Panics, 1) - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetTag("table", table.Name) - sentry.CurrentHub().CaptureMessage(stack) - }) - } - }() - if table.PreResourceResolver != nil { - if err := table.PreResourceResolver(ctx, client, resource); err != nil { - logger.Error().Err(err).Msg("pre resource resolver failed") - atomic.AddUint64(&tableMetrics.Errors, 1) - if errors.As(err, &validationErr) { - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetTag("table", table.Name) - sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) - }) - } - return nil - } - } - - for _, c := range table.Columns { - p.resolveColumn(ctx, logger, tableMetrics, client, resource, c) - } - - if table.PostResourceResolver != nil { - if err := table.PostResourceResolver(ctx, client, resource); err != nil { - logger.Error().Stack().Err(err).Msg("post resource resolver finished with error") - atomic.AddUint64(&tableMetrics.Errors, 1) - if errors.As(err, &validationErr) { - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetTag("table", table.Name) - sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) - }) - } - } - } - atomic.AddUint64(&tableMetrics.Resources, 1) - return resource -} - -func (p *Plugin) resolveColumn(ctx context.Context, logger zerolog.Logger, tableMetrics *TableClientMetrics, client schema.ClientMeta, resource *schema.Resource, c schema.Column) { - var validationErr *schema.ValidationError - columnStartTime := time.Now() - defer func() { - if err := recover(); err != nil { - stack := fmt.Sprintf("%s\n%s", err, string(debug.Stack())) - logger.Error().Str("column", c.Name).Interface("error", err).TimeDiff("duration", time.Now(), columnStartTime).Str("stack", stack).Msg("column resolver finished with panic") - atomic.AddUint64(&tableMetrics.Panics, 1) - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetTag("table", resource.Table.Name) - scope.SetTag("column", c.Name) - sentry.CurrentHub().CaptureMessage(stack) - }) - } - }() - - if c.Resolver != nil { - if err := c.Resolver(ctx, client, resource, c); err != nil { - logger.Error().Err(err).Msg("column resolver finished with error") - atomic.AddUint64(&tableMetrics.Errors, 1) - if errors.As(err, &validationErr) { - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetTag("table", resource.Table.Name) - scope.SetTag("column", c.Name) - sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) - }) - } - } - } else { - // base use case: try to get column with CamelCase name - v := funk.Get(resource.GetItem(), p.caser.ToPascal(c.Name), funk.WithAllowZero()) - if v != nil { - err := resource.Set(c.Name, v) - if err != nil { - logger.Error().Err(err).Msg("column resolver finished with error") - atomic.AddUint64(&tableMetrics.Errors, 1) - if errors.As(err, &validationErr) { - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetTag("table", resource.Table.Name) - scope.SetTag("column", c.Name) - sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) - }) - } - } - } - } -} - -func (p *Plugin) periodicMetricLogger(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() - - ticker := time.NewTicker(periodicMetricLoggerInterval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - inProgressTables := p.metrics.InProgressTables() - queuedTables := p.metrics.QueuedTables() - logLine := p.logger.Info(). - Uint64("total_resources", p.metrics.TotalResourcesAtomic()). - Uint64("total_errors", p.metrics.TotalErrorsAtomic()). - Uint64("total_panics", p.metrics.TotalPanicsAtomic()). - Int("num_in_progress_tables", len(inProgressTables)). - Int("num_queued_tables", len(queuedTables)) - - if len(inProgressTables) <= periodicMetricLoggerLogTablesLimit { - logLine.Strs("in_progress_tables", inProgressTables) - } - - if len(queuedTables) <= periodicMetricLoggerLogTablesLimit { - logLine.Strs("queued_tables", queuedTables) - } - - logLine.Msg("Sync in progress") - } - } -} - -// unparam's suggestion to remove the second parameter is not good advice here. -// nolint:unparam -func max(a, b uint64) uint64 { - if a > b { - return a - } - return b -} diff --git a/plugins/source/testing.go b/plugins/source/testing.go deleted file mode 100644 index 161778bca9..0000000000 --- a/plugins/source/testing.go +++ /dev/null @@ -1,141 +0,0 @@ -package source - -import ( - "context" - "testing" - "time" - - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" -) - -type Validator func(t *testing.T, plugin *Plugin, resources []*schema.Resource) - -func TestPluginSync(t *testing.T, plugin *Plugin, spec specs.Source, opts ...TestPluginOption) { - t.Helper() - - o := &testPluginOptions{ - parallel: true, - validators: []Validator{validatePlugin}, - } - for _, opt := range opts { - opt(o) - } - if o.parallel { - t.Parallel() - } - - resourcesChannel := make(chan *schema.Resource) - var syncErr error - - if err := plugin.Init(context.Background(), spec); err != nil { - t.Fatal(err) - } - - go func() { - defer close(resourcesChannel) - syncErr = plugin.Sync(context.Background(), time.Now(), resourcesChannel) - }() - - syncedResources := make([]*schema.Resource, 0) - for resource := range resourcesChannel { - syncedResources = append(syncedResources, resource) - } - if syncErr != nil { - t.Fatal(syncErr) - } - for _, validator := range o.validators { - validator(t, plugin, syncedResources) - } -} - -type TestPluginOption func(*testPluginOptions) - -func WithTestPluginNoParallel() TestPluginOption { - return func(f *testPluginOptions) { - f.parallel = false - } -} - -func WithTestPluginAdditionalValidators(v Validator) TestPluginOption { - return func(f *testPluginOptions) { - f.validators = append(f.validators, v) - } -} - -type testPluginOptions struct { - parallel bool - validators []Validator -} - -func getTableResources(t *testing.T, table *schema.Table, resources []*schema.Resource) []*schema.Resource { - t.Helper() - - tableResources := make([]*schema.Resource, 0) - - for _, resource := range resources { - if resource.Table.Name == table.Name { - tableResources = append(tableResources, resource) - } - } - - return tableResources -} - -func validateTable(t *testing.T, table *schema.Table, resources []*schema.Resource) { - t.Helper() - tableResources := getTableResources(t, table, resources) - if len(tableResources) == 0 { - t.Errorf("Expected table %s to be synced but it was not found", table.Name) - return - } - validateResources(t, tableResources) -} - -func validatePlugin(t *testing.T, plugin *Plugin, resources []*schema.Resource) { - t.Helper() - tables := extractTables(plugin.tables) - for _, table := range tables { - validateTable(t, table, resources) - } -} - -func extractTables(tables schema.Tables) []*schema.Table { - result := make([]*schema.Table, 0) - for _, table := range tables { - result = append(result, table) - result = append(result, extractTables(table.Relations)...) - } - return result -} - -// Validates that every column has at least one non-nil value. -// Also does some additional validations. -func validateResources(t *testing.T, resources []*schema.Resource) { - t.Helper() - - table := resources[0].Table - - // A set of column-names that have values in at least one of the resources. - columnsWithValues := make([]bool, len(table.Columns)) - - for _, resource := range resources { - for i, value := range resource.GetValues() { - if value == nil { - continue - } - if value.IsValid() { - columnsWithValues[i] = true - } - } - } - - // Make sure every column has at least one value. - for i, hasValue := range columnsWithValues { - col := table.Columns[i] - emptyExpected := col.Name == "_cq_parent_id" && table.Parent == nil - if !hasValue && !emptyExpected && !col.IgnoreInTests { - t.Errorf("table: %s column %s has no values", table.Name, table.Columns[i].Name) - } - } -} diff --git a/plugins/source/validate.go b/plugins/source/validate.go deleted file mode 100644 index 835b798c7e..0000000000 --- a/plugins/source/validate.go +++ /dev/null @@ -1,25 +0,0 @@ -package source - -import ( - "fmt" -) - -func (p *Plugin) validate() error { - if err := p.tables.ValidateDuplicateColumns(); err != nil { - return fmt.Errorf("found duplicate columns in source plugin: %s: %w", p.name, err) - } - - if err := p.tables.ValidateDuplicateTables(); err != nil { - return fmt.Errorf("found duplicate tables in source plugin: %s: %w", p.name, err) - } - - if err := p.tables.ValidateTableNames(); err != nil { - return fmt.Errorf("found table with invalid name in source plugin: %s: %w", p.name, err) - } - - if err := p.tables.ValidateColumnNames(); err != nil { - return fmt.Errorf("found column with invalid name in source plugin: %s: %w", p.name, err) - } - - return nil -} diff --git a/scalar/inet.go b/scalar/inet.go index f693a479e0..3d6163cfc7 100644 --- a/scalar/inet.go +++ b/scalar/inet.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/types" ) type Inet struct { diff --git a/scalar/json.go b/scalar/json.go index ed6761351b..c0c5fceea3 100644 --- a/scalar/json.go +++ b/scalar/json.go @@ -6,7 +6,7 @@ import ( "reflect" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/types" ) type JSON struct { diff --git a/scalar/mac.go b/scalar/mac.go index cef4ac27f6..5350a64bee 100644 --- a/scalar/mac.go +++ b/scalar/mac.go @@ -4,7 +4,7 @@ import ( "net" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/types" ) type Mac struct { diff --git a/scalar/scalar.go b/scalar/scalar.go index 5f471e0258..7236cd7109 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -6,7 +6,8 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/float16" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/types" "golang.org/x/exp/maps" ) @@ -33,7 +34,12 @@ type Scalar interface { type Vector []Scalar -const nullValueStr = array.NullValueStr +func (v Vector) ToArrowRecord(sc *arrow.Schema) arrow.Record { + bldr := array.NewRecordBuilder(memory.DefaultAllocator, sc) + AppendToRecordBuilder(bldr, v) + rec := bldr.NewRecord() + return rec +} func (v Vector) Equal(r Vector) bool { if len(v) != len(r) { diff --git a/scalar/string.go b/scalar/string.go index 0d191d844e..7997aded97 100644 --- a/scalar/string.go +++ b/scalar/string.go @@ -4,8 +4,11 @@ import ( "fmt" "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" ) +const nullValueStr = array.NullValueStr + type String struct { Valid bool Value string diff --git a/scalar/uuid.go b/scalar/uuid.go index f8a79c94b0..dfae523cbd 100644 --- a/scalar/uuid.go +++ b/scalar/uuid.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/types" "github.com/google/uuid" ) diff --git a/scheduler/benchmark_test.go b/scheduler/benchmark_test.go new file mode 100644 index 0000000000..6990da0fd7 --- /dev/null +++ b/scheduler/benchmark_test.go @@ -0,0 +1 @@ +package scheduler diff --git a/plugins/source/benchmark_test.go b/scheduler/benchmark_test.go.backup similarity index 99% rename from plugins/source/benchmark_test.go rename to scheduler/benchmark_test.go.backup index 71ccdc929d..a1bf87d5a8 100644 --- a/plugins/source/benchmark_test.go +++ b/scheduler/benchmark_test.go.backup @@ -1,4 +1,4 @@ -package source +package plugin import ( "context" @@ -11,7 +11,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" ) @@ -48,7 +48,7 @@ func (s *BenchmarkScenario) SetDefaults() { } } -type Client interface { +type ClientTest interface { Call(clientID, tableName string) error } diff --git a/plugins/source/metrics.go b/scheduler/metrics.go similarity index 52% rename from plugins/source/metrics.go rename to scheduler/metrics.go index 9975933779..f5b6c73ef6 100644 --- a/plugins/source/metrics.go +++ b/scheduler/metrics.go @@ -1,28 +1,23 @@ -package source +package scheduler import ( - "sync" "sync/atomic" "time" - "github.com/cloudquery/plugin-sdk/v3/schema" - "golang.org/x/exp/slices" + "github.com/cloudquery/plugin-sdk/v4/schema" ) +// Metrics is deprecated as we move toward open telemetry for tracing and metrics type Metrics struct { TableClient map[string]map[string]*TableClientMetrics } type TableClientMetrics struct { - // These should only be accessed with 'Atomic*' methods. Resources uint64 Errors uint64 Panics uint64 - - // These accesses must be protected by the mutex. - startTime time.Time - endTime time.Time - mutex sync.Mutex + StartTime time.Time + EndTime time.Time } func (s *TableClientMetrics) Equal(other *TableClientMetrics) bool { @@ -129,79 +124,3 @@ func (s *Metrics) TotalResourcesAtomic() uint64 { } return total } - -func (s *Metrics) MarkStart(table *schema.Table, clientID string) { - now := time.Now() - - s.TableClient[table.Name][clientID].mutex.Lock() - defer s.TableClient[table.Name][clientID].mutex.Unlock() - s.TableClient[table.Name][clientID].startTime = now -} - -// if the table is a top-level table, we need to mark all of its descendents as 'done' as well. -// This is because, when a top-level table is empty (no resources), its descendants are never actually -// synced. -func (s *Metrics) MarkEnd(table *schema.Table, clientID string) { - now := time.Now() - - if table.Parent == nil { - s.markEndRecursive(table, clientID, now) - return - } - - s.TableClient[table.Name][clientID].mutex.Lock() - defer s.TableClient[table.Name][clientID].mutex.Unlock() - s.TableClient[table.Name][clientID].endTime = now -} - -func (s *Metrics) markEndRecursive(table *schema.Table, clientID string, now time.Time) { - // We don't use defer with Unlock(), because we want to unlock the mutex as soon as possible. - s.TableClient[table.Name][clientID].mutex.Lock() - s.TableClient[table.Name][clientID].endTime = now - s.TableClient[table.Name][clientID].mutex.Unlock() - - for _, relation := range table.Relations { - s.markEndRecursive(relation, clientID, now) - } -} - -func (s *Metrics) InProgressTables() []string { - var inProgressTables []string - - for table, tableMetrics := range s.TableClient { - for _, clientMetrics := range tableMetrics { - clientMetrics.mutex.Lock() - endTime := clientMetrics.endTime - startTime := clientMetrics.startTime - clientMetrics.mutex.Unlock() - if endTime.IsZero() && !startTime.IsZero() { - inProgressTables = append(inProgressTables, table) - break - } - } - } - - slices.Sort(inProgressTables) - - return inProgressTables -} - -func (s *Metrics) QueuedTables() []string { - var queuedTables []string - - for table, tableMetrics := range s.TableClient { - for _, clientMetrics := range tableMetrics { - clientMetrics.mutex.Lock() - startTime := clientMetrics.startTime - endTime := clientMetrics.endTime - clientMetrics.mutex.Unlock() - if startTime.IsZero() && endTime.IsZero() { - queuedTables = append(queuedTables, table) - break - } - } - } - - slices.Sort(queuedTables) - return queuedTables -} diff --git a/scheduler/metrics_test.go b/scheduler/metrics_test.go new file mode 100644 index 0000000000..1bc11daa58 --- /dev/null +++ b/scheduler/metrics_test.go @@ -0,0 +1,37 @@ +package scheduler + +import "testing" + +func TestMetrics(t *testing.T) { + s := &Metrics{ + TableClient: make(map[string]map[string]*TableClientMetrics), + } + s.TableClient["test_table"] = make(map[string]*TableClientMetrics) + s.TableClient["test_table"]["testExecutionClient"] = &TableClientMetrics{ + Resources: 1, + Errors: 2, + Panics: 3, + } + if s.TotalResources() != 1 { + t.Fatal("expected 1 resource") + } + if s.TotalErrors() != 2 { + t.Fatal("expected 2 error") + } + if s.TotalPanics() != 3 { + t.Fatal("expected 3 panics") + } + + other := &Metrics{ + TableClient: make(map[string]map[string]*TableClientMetrics), + } + other.TableClient["test_table"] = make(map[string]*TableClientMetrics) + other.TableClient["test_table"]["testExecutionClient"] = &TableClientMetrics{ + Resources: 1, + Errors: 2, + Panics: 3, + } + if !s.Equal(other) { + t.Fatal("expected metrics to be equal") + } +} diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go new file mode 100644 index 0000000000..89ea17ed14 --- /dev/null +++ b/scheduler/scheduler.go @@ -0,0 +1,317 @@ +package scheduler + +import ( + "bytes" + "context" + "errors" + "fmt" + "runtime/debug" + "sync/atomic" + "time" + + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/caser" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/scalar" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/getsentry/sentry-go" + "github.com/rs/zerolog" + "github.com/thoas/go-funk" + "golang.org/x/sync/semaphore" +) + +const ( + minTableConcurrency = 1 + minResourceConcurrency = 100 + defaultConcurrency = 200000 + defaultMaxDepth = 4 +) + +type Strategy int + +const ( + StrategyDFS Strategy = iota + StrategyRoundRobin +) + +var AllSchedulers = Strategies{StrategyDFS, StrategyRoundRobin} +var AllSchedulerNames = [...]string{ + StrategyDFS: "dfs", + StrategyRoundRobin: "round-robin", +} + +type Strategies []Strategy + +func (s Strategies) String() string { + var buffer bytes.Buffer + for i, strategy := range s { + if i > 0 { + buffer.WriteString(", ") + } + buffer.WriteString(strategy.String()) + } + return buffer.String() +} + +func (s Strategy) String() string { + return AllSchedulerNames[s] +} + +type Option func(*Scheduler) + +func WithLogger(logger zerolog.Logger) Option { + return func(s *Scheduler) { + s.logger = logger + } +} + +func WithDeterministicCQId(deterministicCQId bool) Option { + return func(s *Scheduler) { + s.deterministicCQId = deterministicCQId + } +} + +func WithConcurrency(concurrency uint64) Option { + return func(s *Scheduler) { + s.concurrency = concurrency + } +} + +func WithMaxDepth(maxDepth uint64) Option { + return func(s *Scheduler) { + s.maxDepth = maxDepth + } +} + +func WithSchedulerStrategy(strategy Strategy) Option { + return func(s *Scheduler) { + s.strategy = strategy + } +} + +type Client interface { + ID() string +} + +type Scheduler struct { + tables schema.Tables + client schema.ClientMeta + caser *caser.Caser + strategy Strategy + // status sync metrics + metrics *Metrics + maxDepth uint64 + // resourceSem is a semaphore that limits the number of concurrent resources being fetched + resourceSem *semaphore.Weighted + // tableSem is a semaphore that limits the number of concurrent tables being fetched + tableSems []*semaphore.Weighted + // Logger to call, this logger is passed to the serve.Serve Client, if not defined Serve will create one instead. + logger zerolog.Logger + deterministicCQId bool + concurrency uint64 +} + +func NewScheduler(client schema.ClientMeta, opts ...Option) *Scheduler { + s := Scheduler{ + client: client, + metrics: &Metrics{TableClient: make(map[string]map[string]*TableClientMetrics)}, + caser: caser.New(), + concurrency: defaultConcurrency, + maxDepth: defaultMaxDepth, + } + for _, opt := range opts { + opt(&s) + } + return &s +} + +// SyncAll is mostly used for testing as it will sync all tables and can run out of memory +// in the real world. Should use Sync for production. +func (s *Scheduler) SyncAll(ctx context.Context, tables schema.Tables) (message.Messages, error) { + res := make(chan message.Message) + var err error + go func() { + defer close(res) + err = s.Sync(ctx, tables, res) + }() + // nolint:prealloc + var messages []message.Message + for msg := range res { + messages = append(messages, msg) + } + return messages, err +} + +func (s *Scheduler) Sync(ctx context.Context, tables schema.Tables, res chan<- message.Message) error { + if len(tables) == 0 { + return nil + } + + if maxDepth(tables) > s.maxDepth { + return fmt.Errorf("max depth exceeded, max depth is %d", s.maxDepth) + } + s.tables = tables + + // send migrate messages first + for _, table := range tables.FlattenTables() { + res <- &message.MigrateTable{ + Table: table, + } + } + + resources := make(chan *schema.Resource) + go func() { + defer close(resources) + switch s.strategy { + case StrategyDFS: + s.syncDfs(ctx, resources) + case StrategyRoundRobin: + s.syncRoundRobin(ctx, resources) + default: + panic(fmt.Errorf("unknown scheduler %s", s.strategy)) + } + }() + for resource := range resources { + vector := resource.GetValues() + bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema()) + scalar.AppendToRecordBuilder(bldr, vector) + rec := bldr.NewRecord() + res <- &message.Insert{Record: rec} + } + return nil +} + +func (s *Scheduler) logTablesMetrics(tables schema.Tables, client Client) { + clientName := client.ID() + for _, table := range tables { + metrics := s.metrics.TableClient[table.Name][clientName] + s.logger.Info().Str("table", table.Name).Str("client", clientName).Uint64("resources", metrics.Resources).Uint64("errors", metrics.Errors).Msg("table sync finished") + s.logTablesMetrics(table.Relations, client) + } +} + +func (s *Scheduler) resolveResource(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, item any) *schema.Resource { + var validationErr *schema.ValidationError + ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) + defer cancel() + resource := schema.NewResourceData(table, parent, item) + objectStartTime := time.Now() + clientID := client.ID() + tableMetrics := s.metrics.TableClient[table.Name][clientID] + logger := s.logger.With().Str("table", table.Name).Str("client", clientID).Logger() + defer func() { + if err := recover(); err != nil { + stack := fmt.Sprintf("%s\n%s", err, string(debug.Stack())) + logger.Error().Interface("error", err).TimeDiff("duration", time.Now(), objectStartTime).Str("stack", stack).Msg("resource resolver finished with panic") + atomic.AddUint64(&tableMetrics.Panics, 1) + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(stack) + }) + } + }() + if table.PreResourceResolver != nil { + if err := table.PreResourceResolver(ctx, client, resource); err != nil { + logger.Error().Err(err).Msg("pre resource resolver failed") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + return nil + } + } + + for _, c := range table.Columns { + s.resolveColumn(ctx, logger, tableMetrics, client, resource, c) + } + + if table.PostResourceResolver != nil { + if err := table.PostResourceResolver(ctx, client, resource); err != nil { + logger.Error().Stack().Err(err).Msg("post resource resolver finished with error") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + } + } + atomic.AddUint64(&tableMetrics.Resources, 1) + return resource +} + +func (s *Scheduler) resolveColumn(ctx context.Context, logger zerolog.Logger, tableMetrics *TableClientMetrics, client schema.ClientMeta, resource *schema.Resource, c schema.Column) { + var validationErr *schema.ValidationError + columnStartTime := time.Now() + defer func() { + if err := recover(); err != nil { + stack := fmt.Sprintf("%s\n%s", err, string(debug.Stack())) + logger.Error().Str("column", c.Name).Interface("error", err).TimeDiff("duration", time.Now(), columnStartTime).Str("stack", stack).Msg("column resolver finished with panic") + atomic.AddUint64(&tableMetrics.Panics, 1) + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", resource.Table.Name) + scope.SetTag("column", c.Name) + sentry.CurrentHub().CaptureMessage(stack) + }) + } + }() + + if c.Resolver != nil { + if err := c.Resolver(ctx, client, resource, c); err != nil { + logger.Error().Err(err).Msg("column resolver finished with error") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", resource.Table.Name) + scope.SetTag("column", c.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + } + } else { + // base use case: try to get column with CamelCase name + v := funk.Get(resource.GetItem(), s.caser.ToPascal(c.Name), funk.WithAllowZero()) + if v != nil { + err := resource.Set(c.Name, v) + if err != nil { + logger.Error().Err(err).Msg("column resolver finished with error") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", resource.Table.Name) + scope.SetTag("column", c.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + } + } + } +} + +func maxDepth(tables schema.Tables) uint64 { + var depth uint64 + if len(tables) == 0 { + return 0 + } + for _, table := range tables { + newDepth := 1 + maxDepth(table.Relations) + if newDepth > depth { + depth = newDepth + } + } + return depth +} + +// unparam's suggestion to remove the second parameter is not good advice here. +// nolint:unparam +func max(a, b uint64) uint64 { + if a > b { + return a + } + return b +} diff --git a/plugins/source/scheduler_dfs.go b/scheduler/scheduler_dfs.go similarity index 65% rename from plugins/source/scheduler_dfs.go rename to scheduler/scheduler_dfs.go index 1cd5142624..86f2874ec6 100644 --- a/plugins/source/scheduler_dfs.go +++ b/scheduler/scheduler_dfs.go @@ -1,4 +1,4 @@ -package source +package scheduler import ( "context" @@ -8,34 +8,33 @@ import ( "sync" "sync/atomic" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/helpers" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/helpers" + "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/getsentry/sentry-go" "golang.org/x/sync/semaphore" ) -func (p *Plugin) syncDfs(ctx context.Context, spec specs.Source, client schema.ClientMeta, tables schema.Tables, resolvedResources chan<- *schema.Resource) { +func (s *Scheduler) syncDfs(ctx context.Context, resolvedResources chan<- *schema.Resource) { // This is very similar to the concurrent web crawler problem with some minor changes. // We are using DFS to make sure memory usage is capped at O(h) where h is the height of the tree. - tableConcurrency := max(spec.Concurrency/minResourceConcurrency, minTableConcurrency) + tableConcurrency := max(s.concurrency/minResourceConcurrency, minTableConcurrency) resourceConcurrency := tableConcurrency * minResourceConcurrency - p.tableSems = make([]*semaphore.Weighted, p.maxDepth) - for i := uint64(0); i < p.maxDepth; i++ { - p.tableSems[i] = semaphore.NewWeighted(int64(tableConcurrency)) + s.tableSems = make([]*semaphore.Weighted, s.maxDepth) + for i := uint64(0); i < s.maxDepth; i++ { + s.tableSems[i] = semaphore.NewWeighted(int64(tableConcurrency)) // reduce table concurrency logarithmically for every depth level tableConcurrency = max(tableConcurrency/2, minTableConcurrency) } - p.resourceSem = semaphore.NewWeighted(int64(resourceConcurrency)) + s.resourceSem = semaphore.NewWeighted(int64(resourceConcurrency)) // we have this because plugins can return sometimes clients in a random way which will cause // differences between this run and the next one. - preInitialisedClients := make([][]schema.ClientMeta, len(tables)) - for i, table := range tables { - clients := []schema.ClientMeta{client} + preInitialisedClients := make([][]schema.ClientMeta, len(s.tables)) + for i, table := range s.tables { + clients := []schema.ClientMeta{s.client} if table.Multiplex != nil { - clients = table.Multiplex(client) + clients = table.Multiplex(s.client) } // Detect duplicate clients while multiplexing seenClients := make(map[string]bool) @@ -47,69 +46,50 @@ func (p *Plugin) syncDfs(ctx context.Context, spec specs.Source, client schema.C scope.SetTag("table", table.Name) sentry.CurrentHub().CaptureMessage("duplicate client ID in " + table.Name) }) - p.logger.Warn().Str("client", c.ID()).Str("table", table.Name).Msg("multiplex returned duplicate client") + s.logger.Warn().Str("client", c.ID()).Str("table", table.Name).Msg("multiplex returned duplicate client") } } preInitialisedClients[i] = clients // we do this here to avoid locks so we initial the metrics structure once in the main goroutines // and then we can just read from it in the other goroutines concurrently given we are not writing to it. - p.metrics.initWithClients(table, clients) + s.metrics.initWithClients(table, clients) } - // We start a goroutine that logs the metrics periodically. - // It needs its own waitgroup - var logWg sync.WaitGroup - logWg.Add(1) - - logCtx, logCancel := context.WithCancel(ctx) - go p.periodicMetricLogger(logCtx, &logWg) - var wg sync.WaitGroup - for i, table := range tables { + for i, table := range s.tables { table := table clients := preInitialisedClients[i] for _, client := range clients { client := client - if err := p.tableSems[0].Acquire(ctx, 1); err != nil { + if err := s.tableSems[0].Acquire(ctx, 1); err != nil { // This means context was cancelled wg.Wait() - // gracefully shut down the logger goroutine - logCancel() - logWg.Wait() return } wg.Add(1) go func() { defer wg.Done() - defer p.tableSems[0].Release(1) + defer s.tableSems[0].Release(1) // not checking for error here as nothing much todo. // the error is logged and this happens when context is cancelled - p.resolveTableDfs(ctx, table, client, nil, resolvedResources, 1) + s.resolveTableDfs(ctx, table, client, nil, resolvedResources, 1) }() } } // Wait for all the worker goroutines to finish wg.Wait() - - // gracefully shut down the logger goroutine - logCancel() - logWg.Wait() } -func (p *Plugin) resolveTableDfs(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, resolvedResources chan<- *schema.Resource, depth int) { - clientName := client.ID() - - p.metrics.MarkStart(table, clientName) - defer p.Metrics().MarkEnd(table, clientName) - +func (s *Scheduler) resolveTableDfs(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, resolvedResources chan<- *schema.Resource, depth int) { var validationErr *schema.ValidationError - logger := p.logger.With().Str("table", table.Name).Str("client", clientName).Logger() + clientName := client.ID() + logger := s.logger.With().Str("table", table.Name).Str("client", clientName).Logger() if parent == nil { // Log only for root tables, otherwise we spam too much. logger.Info().Msg("top level table resolver started") } - tableMetrics := p.metrics.TableClient[table.Name][clientName] + tableMetrics := s.metrics.TableClient[table.Name][clientName] res := make(chan any) go func() { @@ -139,17 +119,17 @@ func (p *Plugin) resolveTableDfs(ctx context.Context, table *schema.Table, clien }() for r := range res { - p.resolveResourcesDfs(ctx, table, client, parent, r, resolvedResources, depth) + s.resolveResourcesDfs(ctx, table, client, parent, r, resolvedResources, depth) } // we don't need any waitgroups here because we are waiting for the channel to close if parent == nil { // Log only for root tables and relations only after resolving is done, otherwise we spam per object instead of per table. logger.Info().Uint64("resources", tableMetrics.Resources).Uint64("errors", tableMetrics.Errors).Msg("table sync finished") - p.logTablesMetrics(table.Relations, client) + s.logTablesMetrics(table.Relations, client) } } -func (p *Plugin) resolveResourcesDfs(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, resources any, resolvedResources chan<- *schema.Resource, depth int) { +func (s *Scheduler) resolveResourcesDfs(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, resources any, resolvedResources chan<- *schema.Resource, depth int) { resourcesSlice := helpers.InterfaceSlice(resources) if len(resourcesSlice) == 0 { return @@ -161,25 +141,25 @@ func (p *Plugin) resolveResourcesDfs(ctx context.Context, table *schema.Table, c sentValidationErrors := sync.Map{} for i := range resourcesSlice { i := i - if err := p.resourceSem.Acquire(ctx, 1); err != nil { - p.logger.Warn().Err(err).Msg("failed to acquire semaphore. context cancelled") + if err := s.resourceSem.Acquire(ctx, 1); err != nil { + s.logger.Warn().Err(err).Msg("failed to acquire semaphore. context cancelled") wg.Wait() // we have to continue emptying the channel to exit gracefully return } wg.Add(1) go func() { - defer p.resourceSem.Release(1) + defer s.resourceSem.Release(1) defer wg.Done() //nolint:all - resolvedResource := p.resolveResource(ctx, table, client, parent, resourcesSlice[i]) + resolvedResource := s.resolveResource(ctx, table, client, parent, resourcesSlice[i]) if resolvedResource == nil { return } - if err := resolvedResource.CalculateCQID(p.spec.DeterministicCQID); err != nil { - tableMetrics := p.metrics.TableClient[table.Name][client.ID()] - p.logger.Error().Err(err).Str("table", table.Name).Str("client", client.ID()).Msg("resource resolver finished with primary key calculation error") + if err := resolvedResource.CalculateCQID(s.deterministicCQId); err != nil { + tableMetrics := s.metrics.TableClient[table.Name][client.ID()] + s.logger.Error().Err(err).Str("table", table.Name).Str("client", client.ID()).Msg("resource resolver finished with primary key calculation error") if _, found := sentValidationErrors.LoadOrStore(table.Name, struct{}{}); !found { // send resource validation errors to Sentry only once per table, // to avoid sending too many duplicate messages @@ -192,8 +172,8 @@ func (p *Plugin) resolveResourcesDfs(ctx context.Context, table *schema.Table, c return } if err := resolvedResource.Validate(); err != nil { - tableMetrics := p.metrics.TableClient[table.Name][client.ID()] - p.logger.Error().Err(err).Str("table", table.Name).Str("client", client.ID()).Msg("resource resolver finished with validation error") + tableMetrics := s.metrics.TableClient[table.Name][client.ID()] + s.logger.Error().Err(err).Str("table", table.Name).Str("client", client.ID()).Msg("resource resolver finished with validation error") if _, found := sentValidationErrors.LoadOrStore(table.Name, struct{}{}); !found { // send resource validation errors to Sentry only once per table, // to avoid sending too many duplicate messages @@ -217,7 +197,7 @@ func (p *Plugin) resolveResourcesDfs(ctx context.Context, table *schema.Table, c resolvedResources <- resource for _, relation := range resource.Table.Relations { relation := relation - if err := p.tableSems[depth].Acquire(ctx, 1); err != nil { + if err := s.tableSems[depth].Acquire(ctx, 1); err != nil { // This means context was cancelled wg.Wait() return @@ -225,8 +205,8 @@ func (p *Plugin) resolveResourcesDfs(ctx context.Context, table *schema.Table, c wg.Add(1) go func() { defer wg.Done() - defer p.tableSems[depth].Release(1) - p.resolveTableDfs(ctx, relation, client, resource, resolvedResources, depth+1) + defer s.tableSems[depth].Release(1) + s.resolveTableDfs(ctx, relation, client, resource, resolvedResources, depth+1) }() } } diff --git a/plugins/source/scheduler_round_robin.go b/scheduler/scheduler_round_robin.go similarity index 58% rename from plugins/source/scheduler_round_robin.go rename to scheduler/scheduler_round_robin.go index 00b1030f68..f800caebc6 100644 --- a/plugins/source/scheduler_round_robin.go +++ b/scheduler/scheduler_round_robin.go @@ -1,11 +1,10 @@ -package source +package scheduler import ( "context" "sync" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/schema" "golang.org/x/sync/semaphore" ) @@ -14,72 +13,57 @@ type tableClient struct { client schema.ClientMeta } -func (p *Plugin) syncRoundRobin(ctx context.Context, spec specs.Source, client schema.ClientMeta, tables schema.Tables, resolvedResources chan<- *schema.Resource) { - tableConcurrency := max(spec.Concurrency/minResourceConcurrency, minTableConcurrency) +func (s *Scheduler) syncRoundRobin(ctx context.Context, resolvedResources chan<- *schema.Resource) { + tableConcurrency := max(s.concurrency/minResourceConcurrency, minTableConcurrency) resourceConcurrency := tableConcurrency * minResourceConcurrency - p.tableSems = make([]*semaphore.Weighted, p.maxDepth) - for i := uint64(0); i < p.maxDepth; i++ { - p.tableSems[i] = semaphore.NewWeighted(int64(tableConcurrency)) + s.tableSems = make([]*semaphore.Weighted, s.maxDepth) + for i := uint64(0); i < s.maxDepth; i++ { + s.tableSems[i] = semaphore.NewWeighted(int64(tableConcurrency)) // reduce table concurrency logarithmically for every depth level tableConcurrency = max(tableConcurrency/2, minTableConcurrency) } - p.resourceSem = semaphore.NewWeighted(int64(resourceConcurrency)) + s.resourceSem = semaphore.NewWeighted(int64(resourceConcurrency)) // we have this because plugins can return sometimes clients in a random way which will cause // differences between this run and the next one. - preInitialisedClients := make([][]schema.ClientMeta, len(tables)) - for i, table := range tables { - clients := []schema.ClientMeta{client} + preInitialisedClients := make([][]schema.ClientMeta, len(s.tables)) + for i, table := range s.tables { + clients := []schema.ClientMeta{s.client} if table.Multiplex != nil { - clients = table.Multiplex(client) + clients = table.Multiplex(s.client) } preInitialisedClients[i] = clients // we do this here to avoid locks so we initial the metrics structure once in the main goroutines // and then we can just read from it in the other goroutines concurrently given we are not writing to it. - p.metrics.initWithClients(table, clients) + s.metrics.initWithClients(table, clients) } - // We start a goroutine that logs the metrics periodically. - // It needs its own waitgroup - var logWg sync.WaitGroup - logWg.Add(1) - - logCtx, logCancel := context.WithCancel(ctx) - go p.periodicMetricLogger(logCtx, &logWg) - - tableClients := roundRobinInterleave(tables, preInitialisedClients) + tableClients := roundRobinInterleave(s.tables, preInitialisedClients) var wg sync.WaitGroup for _, tc := range tableClients { table := tc.table cl := tc.client - if err := p.tableSems[0].Acquire(ctx, 1); err != nil { + if err := s.tableSems[0].Acquire(ctx, 1); err != nil { // This means context was cancelled wg.Wait() - // gracefully shut down the logger goroutine - logCancel() - logWg.Wait() return } wg.Add(1) go func() { defer wg.Done() - defer p.tableSems[0].Release(1) + defer s.tableSems[0].Release(1) // not checking for error here as nothing much to do. // the error is logged and this happens when context is cancelled // Round Robin currently uses the DFS algorithm to resolve the tables, but this // may change in the future. - p.resolveTableDfs(ctx, table, cl, nil, resolvedResources, 1) + s.resolveTableDfs(ctx, table, cl, nil, resolvedResources, 1) }() } // Wait for all the worker goroutines to finish wg.Wait() - - // gracefully shut down the logger goroutine - logCancel() - logWg.Wait() } // interleave table-clients so that we get: diff --git a/plugins/source/scheduler_round_robin_test.go b/scheduler/scheduler_round_robin_test.go similarity index 96% rename from plugins/source/scheduler_round_robin_test.go rename to scheduler/scheduler_round_robin_test.go index 8f7e3425f5..5e60765063 100644 --- a/plugins/source/scheduler_round_robin_test.go +++ b/scheduler/scheduler_round_robin_test.go @@ -1,9 +1,9 @@ -package source +package scheduler import ( "testing" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/schema" ) func TestRoundRobinInterleave(t *testing.T) { diff --git a/scheduler/scheduler_test.go b/scheduler/scheduler_test.go new file mode 100644 index 0000000000..1fe5bc57ea --- /dev/null +++ b/scheduler/scheduler_test.go @@ -0,0 +1,258 @@ +package scheduler + +import ( + "context" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/scalar" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/rs/zerolog" +) + +type testExecutionClient struct { +} + +func (*testExecutionClient) ID() string { + return "test" +} + +var _ schema.ClientMeta = &testExecutionClient{} + +func testResolverSuccess(_ context.Context, _ schema.ClientMeta, _ *schema.Resource, res chan<- any) error { + res <- map[string]any{ + "TestColumn": 3, + } + return nil +} + +func testResolverPanic(context.Context, schema.ClientMeta, *schema.Resource, chan<- any) error { + panic("Resolver") +} + +func testPreResourceResolverPanic(context.Context, schema.ClientMeta, *schema.Resource) error { + panic("PreResourceResolver") +} + +func testColumnResolverPanic(context.Context, schema.ClientMeta, *schema.Resource, schema.Column) error { + panic("ColumnResolver") +} + +func testTableSuccess() *schema.Table { + return &schema.Table{ + Name: "test_table_success", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + +func testTableSuccessWithPK() *schema.Table { + return &schema.Table{ + Name: "test_table_success", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + PrimaryKey: true, + }, + }, + } +} + +func testTableResolverPanic() *schema.Table { + return &schema.Table{ + Name: "test_table_resolver_panic", + Resolver: testResolverPanic, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + +func testTablePreResourceResolverPanic() *schema.Table { + return &schema.Table{ + Name: "test_table_pre_resource_resolver_panic", + PreResourceResolver: testPreResourceResolverPanic, + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + +func testTableColumnResolverPanic() *schema.Table { + return &schema.Table{ + Name: "test_table_column_resolver_panic", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + { + Name: "test_column1", + Type: arrow.PrimitiveTypes.Int64, + Resolver: testColumnResolverPanic, + }, + }, + } +} + +func testTableRelationSuccess() *schema.Table { + return &schema.Table{ + Name: "test_table_relation_success", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + Relations: []*schema.Table{ + testTableSuccess(), + }, + } +} + +type syncTestCase struct { + table *schema.Table + data []scalar.Vector + deterministicCQID bool +} + +var syncTestCases = []syncTestCase{ + { + table: testTableSuccess(), + data: []scalar.Vector{ + { + &scalar.Int{Value: 3, Valid: true}, + }, + }, + }, + { + table: testTableResolverPanic(), + data: nil, + }, + { + table: testTablePreResourceResolverPanic(), + data: nil, + }, + + { + table: testTableRelationSuccess(), + data: []scalar.Vector{ + { + &scalar.Int{Value: 3, Valid: true}, + }, + { + &scalar.Int{Value: 3, Valid: true}, + }, + }, + }, + { + table: testTableSuccess(), + data: []scalar.Vector{ + { + &scalar.Int{Value: 3, Valid: true}, + }, + }, + deterministicCQID: true, + }, + { + table: testTableColumnResolverPanic(), + data: []scalar.Vector{ + { + &scalar.Int{Value: 3, Valid: true}, + &scalar.Int{}, + }, + }, + // deterministicCQID: true, + }, + { + table: testTableRelationSuccess(), + data: []scalar.Vector{ + { + &scalar.Int{Value: 3, Valid: true}, + }, + { + &scalar.Int{Value: 3, Valid: true}, + }, + }, + // deterministicCQID: true, + }, + { + table: testTableSuccessWithPK(), + data: []scalar.Vector{ + { + &scalar.Int{Value: 3, Valid: true}, + }, + }, + // deterministicCQID: true, + }, +} + +func TestScheduler(t *testing.T) { + // uuid.SetRand(testRand{}) + for _, scheduler := range AllSchedulers { + for _, tc := range syncTestCases { + tc := tc + tc.table = tc.table.Copy(nil) + t.Run(tc.table.Name+"_"+scheduler.String(), func(t *testing.T) { + testSyncTable(t, tc, scheduler, tc.deterministicCQID) + }) + } + } +} + +func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, deterministicCQID bool) { + ctx := context.Background() + tables := []*schema.Table{ + tc.table, + } + c := testExecutionClient{} + opts := []Option{ + WithLogger(zerolog.New(zerolog.NewTestWriter(t))), + WithSchedulerStrategy(strategy), + WithDeterministicCQId(deterministicCQID), + } + sc := NewScheduler(&c, opts...) + msgs := make(chan message.Message, 10) + if err := sc.Sync(ctx, tables, msgs); err != nil { + t.Fatal(err) + } + close(msgs) + + var i int + for msg := range msgs { + switch v := msg.(type) { + case *message.Insert: + record := v.Record + rec := tc.data[i].ToArrowRecord(record.Schema()) + if !array.RecordEqual(rec, record) { + t.Fatalf("expected at i=%d: %v. got %v", i, tc.data[i], record) + } + i++ + case *message.MigrateTable: + // ignore + default: + t.Fatalf("expected insert message. got %T", msg) + } + } + if len(tc.data) != i { + t.Fatalf("expected %d resources. got %d", len(tc.data), i) + } +} diff --git a/schema/arrow.go b/schema/arrow.go index f7f61dbe61..56e51de354 100644 --- a/schema/arrow.go +++ b/schema/arrow.go @@ -1,11 +1,7 @@ package schema import ( - "bytes" - "fmt" - "github.com/apache/arrow/go/v13/arrow" - "github.com/apache/arrow/go/v13/arrow/ipc" ) const ( @@ -38,36 +34,3 @@ func (s Schemas) SchemaByName(name string) *arrow.Schema { } return nil } - -func (s Schemas) Encode() ([][]byte, error) { - ret := make([][]byte, len(s)) - for i, sc := range s { - var buf bytes.Buffer - wr := ipc.NewWriter(&buf, ipc.WithSchema(sc)) - if err := wr.Close(); err != nil { - return nil, err - } - ret[i] = buf.Bytes() - } - return ret, nil -} - -func NewSchemasFromBytes(b [][]byte) (Schemas, error) { - ret := make([]*arrow.Schema, len(b)) - for i, buf := range b { - rdr, err := ipc.NewReader(bytes.NewReader(buf)) - if err != nil { - return nil, err - } - ret[i] = rdr.Schema() - } - return ret, nil -} - -func NewTablesFromBytes(b [][]byte) (Tables, error) { - schemas, err := NewSchemasFromBytes(b) - if err != nil { - return nil, fmt.Errorf("failed to decode schemas: %w", err) - } - return NewTablesFromArrowSchemas(schemas) -} diff --git a/schema/arrow_test.go b/schema/arrow_test.go index 377cc5718f..8bcf6db8ae 100644 --- a/schema/arrow_test.go +++ b/schema/arrow_test.go @@ -1,44 +1,33 @@ package schema import ( - "testing" + "fmt" + "strings" "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" ) -func TestSchemaEncode(t *testing.T) { - md := arrow.NewMetadata([]string{"true"}, []string{"false"}) - md1 := arrow.NewMetadata([]string{"false"}, []string{"true"}) - schemas := Schemas{ - arrow.NewSchema( - []arrow.Field{ - {Name: "id", Type: arrow.PrimitiveTypes.Int64}, - {Name: "name", Type: arrow.BinaryTypes.String}, - }, - &md, - ), - arrow.NewSchema( - []arrow.Field{ - {Name: "id", Type: arrow.PrimitiveTypes.Int64}, - {Name: "name", Type: arrow.BinaryTypes.String}, - }, - &md1, - ), +func RecordDiff(l arrow.Record, r arrow.Record) string { + var sb strings.Builder + if l.NumCols() != r.NumCols() { + return fmt.Sprintf("different number of columns: %d vs %d", l.NumCols(), r.NumCols()) } - b, err := schemas.Encode() - if err != nil { - t.Fatal(err) + if l.NumRows() != r.NumRows() { + return fmt.Sprintf("different number of rows: %d vs %d", l.NumRows(), r.NumRows()) } - decodedSchemas, err := NewSchemasFromBytes(b) - if err != nil { - t.Fatal(err) - } - if len(decodedSchemas) != len(schemas) { - t.Fatalf("expected %d schemas, got %d", len(schemas), len(decodedSchemas)) - } - for i := range schemas { - if !schemas[i].Equal(decodedSchemas[i]) { - t.Fatalf("expected schema %d to be %v, got %v", i, schemas[i], decodedSchemas[i]) + for i := 0; i < int(l.NumCols()); i++ { + edits, err := array.Diff(l.Column(i), r.Column(i)) + if err != nil { + panic(fmt.Sprintf("left: %v, right: %v, error: %v", l.Column(i).DataType(), r.Column(i).DataType(), err)) + } + diff := edits.UnifiedDiff(l.Column(i), r.Column(i)) + if diff != "" { + sb.WriteString(l.Schema().Field(i).Name) + sb.WriteString(": ") + sb.WriteString(diff) + sb.WriteString("\n") } } + return sb.String() } diff --git a/schema/meta.go b/schema/meta.go index bd739bf80f..bd5ca2de7e 100644 --- a/schema/meta.go +++ b/schema/meta.go @@ -4,8 +4,8 @@ import ( "context" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/scalar" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/scalar" + "github.com/cloudquery/plugin-sdk/v4/types" ) type ClientMeta interface { diff --git a/schema/resource.go b/schema/resource.go index fbbaf6667b..e55c31c262 100644 --- a/schema/resource.go +++ b/schema/resource.go @@ -4,7 +4,7 @@ import ( "crypto/sha256" "fmt" - "github.com/cloudquery/plugin-sdk/v3/scalar" + "github.com/cloudquery/plugin-sdk/v4/scalar" "github.com/google/uuid" "golang.org/x/exp/slices" ) @@ -97,6 +97,11 @@ func (r *Resource) CalculateCQID(deterministicCQID bool) error { } func (r *Resource) storeCQID(value uuid.UUID) error { + // We skeep if _cq_id is not present. + // Mostly the problem here is because the transformaiton step is baked into the the resolving step + if r.Table.Columns.Get(CqIDColumn.Name) == nil { + return nil + } b, err := value.MarshalBinary() if err != nil { return err diff --git a/schema/table.go b/schema/table.go index ed774f3b39..c7680a32c0 100644 --- a/schema/table.go +++ b/schema/table.go @@ -6,7 +6,7 @@ import ( "regexp" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/internal/glob" + "github.com/cloudquery/plugin-sdk/v4/glob" "golang.org/x/exp/slices" ) diff --git a/schema/testdata.go b/schema/testdata.go index 5570c6a090..af79a95f5e 100644 --- a/schema/testdata.go +++ b/schema/testdata.go @@ -12,7 +12,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/memory" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/types" "github.com/google/uuid" "golang.org/x/exp/rand" "golang.org/x/exp/slices" @@ -21,7 +21,6 @@ import ( // TestSourceOptions controls which types are included by TestSourceColumns. type TestSourceOptions struct { SkipDates bool - SkipDecimals bool SkipDurations bool SkipIntervals bool SkipLargeTypes bool // e.g. large binary, large string @@ -31,6 +30,7 @@ type TestSourceOptions struct { SkipTimes bool // time of day types SkipTimestamps bool // timestamp types. Microsecond timestamp is always be included, regardless of this setting. TimePrecision time.Duration + SkipDecimals bool } // TestSourceColumns returns columns for all Arrow types and composites thereof. TestSourceOptions controls diff --git a/schema/validators.go b/schema/validators.go index b42f59e223..6116e861a1 100644 --- a/schema/validators.go +++ b/schema/validators.go @@ -3,6 +3,8 @@ package schema import ( "errors" "fmt" + + "github.com/apache/arrow/go/v13/arrow" ) type TableValidator interface { @@ -53,3 +55,28 @@ func validateTableAttributesNameLength(t *Table) error { func (LengthTableValidator) Validate(t *Table) error { return validateTableAttributesNameLength(t) } + +func FindEmptyColumns(table *Table, records []arrow.Record) []string { + columnsWithValues := make([]bool, len(table.Columns)) + emptyColumns := make([]string, 0) + + for _, resource := range records { + for colIndex, arr := range resource.Columns() { + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + columnsWithValues[colIndex] = true + } + } + } + } + + // Make sure every column has at least one value. + for i, hasValue := range columnsWithValues { + col := table.Columns[i] + emptyExpected := col.Name == "_cq_parent_id" && table.Parent == nil + if !hasValue && !emptyExpected && !col.IgnoreInTests { + emptyColumns = append(emptyColumns, col.Name) + } + } + return emptyColumns +} diff --git a/serve/destination.go b/serve/destination.go deleted file mode 100644 index cba93b90a5..0000000000 --- a/serve/destination.go +++ /dev/null @@ -1,209 +0,0 @@ -package serve - -import ( - "fmt" - "net" - "os" - "os/signal" - "strings" - "sync" - "syscall" - - pbv0 "github.com/cloudquery/plugin-pb-go/pb/destination/v0" - pbv1 "github.com/cloudquery/plugin-pb-go/pb/destination/v1" - pbdiscoveryv0 "github.com/cloudquery/plugin-pb-go/pb/discovery/v0" - servers "github.com/cloudquery/plugin-sdk/v3/internal/servers/destination/v0" - serversv1 "github.com/cloudquery/plugin-sdk/v3/internal/servers/destination/v1" - discoveryServerV0 "github.com/cloudquery/plugin-sdk/v3/internal/servers/discovery/v0" - "github.com/cloudquery/plugin-sdk/v3/plugins/destination" - "github.com/cloudquery/plugin-sdk/v3/types" - "github.com/getsentry/sentry-go" - grpczerolog "github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2" - "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "github.com/spf13/cobra" - "github.com/thoas/go-funk" - "google.golang.org/grpc" - "google.golang.org/grpc/test/bufconn" -) - -type destinationServe struct { - plugin *destination.Plugin - sentryDSN string -} - -type DestinationOption func(*destinationServe) - -func WithDestinationSentryDSN(dsn string) DestinationOption { - return func(s *destinationServe) { - s.sentryDSN = dsn - } -} - -var testDestinationListener *bufconn.Listener -var testDestinationListenerLock sync.Mutex - -const serveDestinationShort = `Start destination plugin server` - -func Destination(plugin *destination.Plugin, opts ...DestinationOption) { - s := &destinationServe{ - plugin: plugin, - } - for _, opt := range opts { - opt(s) - } - if err := newCmdDestinationRoot(s).Execute(); err != nil { - sentry.CaptureMessage(err.Error()) - fmt.Println(err) - os.Exit(1) - } -} - -// nolint:dupl -func newCmdDestinationServe(serve *destinationServe) *cobra.Command { - var address string - var network string - var noSentry bool - logLevel := newEnum([]string{"trace", "debug", "info", "warn", "error"}, "info") - logFormat := newEnum([]string{"text", "json"}, "text") - telemetryLevel := newEnum([]string{"none", "errors", "stats", "all"}, "all") - err := telemetryLevel.Set(getEnvOrDefault("CQ_TELEMETRY_LEVEL", telemetryLevel.Value)) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to set telemetry level: "+err.Error()) - os.Exit(1) - } - - cmd := &cobra.Command{ - Use: "serve", - Short: serveDestinationShort, - Long: serveDestinationShort, - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - zerologLevel, err := zerolog.ParseLevel(logLevel.String()) - if err != nil { - return err - } - var logger zerolog.Logger - if logFormat.String() == "json" { - logger = zerolog.New(os.Stdout).Level(zerologLevel) - } else { - logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout}).Level(zerologLevel) - } - - var listener net.Listener - if network == "test" { - testDestinationListenerLock.Lock() - listener = bufconn.Listen(testBufSize) - testDestinationListener = listener.(*bufconn.Listener) - testDestinationListenerLock.Unlock() - } else { - listener, err = net.Listen(network, address) - if err != nil { - return fmt.Errorf("failed to listen %s:%s: %w", network, address, err) - } - } - // See logging pattern https://github.com/grpc-ecosystem/go-grpc-middleware/blob/v2/providers/zerolog/examples_test.go - s := grpc.NewServer( - grpc.ChainUnaryInterceptor( - logging.UnaryServerInterceptor(grpczerolog.InterceptorLogger(logger)), - ), - grpc.ChainStreamInterceptor( - logging.StreamServerInterceptor(grpczerolog.InterceptorLogger(logger)), - ), - grpc.MaxRecvMsgSize(MaxMsgSize), - grpc.MaxSendMsgSize(MaxMsgSize), - ) - pbv0.RegisterDestinationServer(s, &servers.Server{ - Plugin: serve.plugin, - Logger: logger, - }) - pbv1.RegisterDestinationServer(s, &serversv1.Server{ - Plugin: serve.plugin, - Logger: logger, - }) - pbdiscoveryv0.RegisterDiscoveryServer(s, &discoveryServerV0.Server{ - Versions: []string{"v0", "v1"}, - }) - version := serve.plugin.Version() - - if serve.sentryDSN != "" && !strings.EqualFold(version, "development") && !noSentry { - err = sentry.Init(sentry.ClientOptions{ - Dsn: serve.sentryDSN, - Debug: false, - AttachStacktrace: false, - Release: version, - Transport: sentry.NewHTTPSyncTransport(), - ServerName: "oss", // set to "oss" on purpose to avoid sending any identifying information - // https://docs.sentry.io/platforms/go/configuration/options/#removing-default-integrations - Integrations: func(integrations []sentry.Integration) []sentry.Integration { - var filteredIntegrations []sentry.Integration - for _, integration := range integrations { - if integration.Name() == "Modules" { - continue - } - filteredIntegrations = append(filteredIntegrations, integration) - } - return filteredIntegrations - }, - }) - if err != nil { - log.Error().Err(err).Msg("Error initializing sentry") - } - } - - if err := types.RegisterAllExtensions(); err != nil { - return err - } - defer func() { - if err := types.UnregisterAllExtensions(); err != nil { - logger.Error().Err(err).Msg("Failed to unregister extensions") - } - }() - - ctx := cmd.Context() - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - defer func() { - signal.Stop(c) - }() - - go func() { - select { - case sig := <-c: - logger.Info().Str("address", listener.Addr().String()).Str("signal", sig.String()).Msg("Got stop signal. Destination plugin server shutting down") - s.Stop() - case <-ctx.Done(): - logger.Info().Str("address", listener.Addr().String()).Msg("Context cancelled. Destination plugin server shutting down") - s.Stop() - } - }() - - logger.Info().Str("address", listener.Addr().String()).Msg("Destination plugin server listening") - if err := s.Serve(listener); err != nil { - return fmt.Errorf("failed to serve: %w", err) - } - return nil - }, - } - cmd.Flags().StringVar(&address, "address", "localhost:7777", "address to serve on. can be tcp: `localhost:7777` or unix socket: `/tmp/plugin.rpc.sock`") - cmd.Flags().StringVar(&network, "network", "tcp", `the network must be "tcp", "tcp4", "tcp6", "unix" or "unixpacket"`) - cmd.Flags().Var(logLevel, "log-level", fmt.Sprintf("log level. one of: %s", strings.Join(logLevel.Allowed, ","))) - cmd.Flags().Var(logFormat, "log-format", fmt.Sprintf("log format. one of: %s", strings.Join(logFormat.Allowed, ","))) - cmd.Flags().BoolVar(&noSentry, "no-sentry", false, "disable sentry") - sendErrors := funk.ContainsString([]string{"all", "errors"}, telemetryLevel.String()) - if !sendErrors { - noSentry = true - } - return cmd -} - -func newCmdDestinationRoot(serve *destinationServe) *cobra.Command { - cmd := &cobra.Command{ - Use: fmt.Sprintf("%s ", serve.plugin.Name()), - } - cmd.AddCommand(newCmdDestinationServe(serve)) - cmd.CompletionOptions.DisableDefaultCmd = true - cmd.Version = serve.plugin.Version() - return cmd -} diff --git a/serve/destination_v0_test.go b/serve/destination_v0_test.go index 84c4b0e272..7f3c9fe21a 100644 --- a/serve/destination_v0_test.go +++ b/serve/destination_v0_test.go @@ -3,12 +3,10 @@ package serve import ( "context" "encoding/json" - "net" "sync" "testing" "time" - "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/memory" pbBase "github.com/cloudquery/plugin-pb-go/pb/base/v0" @@ -16,28 +14,19 @@ import ( "github.com/cloudquery/plugin-pb-go/specs" schemav2 "github.com/cloudquery/plugin-sdk/v2/schema" "github.com/cloudquery/plugin-sdk/v2/testdata" - "github.com/cloudquery/plugin-sdk/v3/internal/deprecated" - "github.com/cloudquery/plugin-sdk/v3/internal/memdb" - serversDestination "github.com/cloudquery/plugin-sdk/v3/internal/servers/destination/v0" - "github.com/cloudquery/plugin-sdk/v3/plugins/destination" + "github.com/cloudquery/plugin-sdk/v4/internal/deprecated" + "github.com/cloudquery/plugin-sdk/v4/internal/memdb" + serversDestination "github.com/cloudquery/plugin-sdk/v4/internal/servers/destination/v0" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/types/known/timestamppb" ) -func bufDestinationDialer(context.Context, string) (net.Conn, error) { - testDestinationListenerLock.Lock() - defer testDestinationListenerLock.Unlock() - return testDestinationListener.Dial() -} - func TestDestination(t *testing.T) { - plugin := destination.NewPlugin("testDestinationPlugin", "development", memdb.NewClient) - s := &destinationServe{ - plugin: plugin, - } - cmd := newCmdDestinationRoot(s) - cmd.SetArgs([]string{"serve", "--network", "test"}) + p := plugin.NewPlugin("testDestinationPlugin", "development", memdb.NewMemDBClient) + srv := Plugin(p, WithArgs("serve"), WithDestinationV0V1Server(), WithTestListener()) ctx := context.Background() ctx, cancel := context.WithCancel(ctx) var wg sync.WaitGroup @@ -45,27 +34,15 @@ func TestDestination(t *testing.T) { var serverErr error go func() { defer wg.Done() - serverErr = cmd.ExecuteContext(ctx) + serverErr = srv.Serve(ctx) }() defer func() { cancel() wg.Wait() }() - // wait for the server to start - for { - testDestinationListenerLock.Lock() - if testDestinationListener != nil { - testDestinationListenerLock.Unlock() - break - } - testDestinationListenerLock.Unlock() - t.Log("waiting for grpc server to start") - time.Sleep(time.Millisecond * 200) - } - // https://stackoverflow.com/questions/42102496/testing-a-grpc-service - conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDestinationDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + conn, err := grpc.DialContext(ctx, "bufnet1", grpc.WithContextDialer(srv.bufPluginDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) if err != nil { t.Fatalf("Failed to dial bufnet: %v", err) } @@ -77,10 +54,10 @@ func TestDestination(t *testing.T) { if err != nil { t.Fatal(err) } + if _, err := c.Configure(ctx, &pbBase.Configure_Request{Config: specBytes}); err != nil { t.Fatal(err) } - getNameRes, err := c.GetName(ctx, &pbBase.GetName_Request{}) if err != nil { t.Fatal(err) @@ -141,7 +118,6 @@ func TestDestination(t *testing.T) { }); err != nil { t.Fatal(err) } - if err := writeClient.Send(&pb.Write2_Request{ Resource: destResourceBytes, }); err != nil { @@ -151,25 +127,29 @@ func TestDestination(t *testing.T) { if _, err := writeClient.CloseAndRecv(); err != nil { t.Fatal(err) } + // serversDestination table := serversDestination.TableV2ToV3(tableV2) - readCh := make(chan arrow.Record, 1) - if err := plugin.Read(ctx, table, sourceName, readCh); err != nil { + msgs, err := p.SyncAll(ctx, plugin.SyncOptions{ + Tables: []string{tableName}, + }) + if err != nil { t.Fatal(err) } - close(readCh) totalResources := 0 destRecord := serversDestination.CQTypesOneToRecord(memory.DefaultAllocator, destResource.Data, table.ToArrowSchema()) - for resource := range readCh { + for _, msg := range msgs { totalResources++ - if !array.RecordEqual(destRecord, resource) { - diff := destination.RecordDiff(destRecord, resource) - t.Fatalf("expected %v but got %v. Diff: %v", destRecord, resource, diff) + m := msg.(*message.Insert) + if !array.RecordEqual(destRecord, m.Record) { + // diff := destination.RecordDiff(destRecord, resource) + t.Fatalf("expected %v but got %v", destRecord, m.Record) } } if totalResources != 1 { t.Fatalf("expected 1 resource but got %d", totalResources) } + if _, err := c.DeleteStale(ctx, &pb.DeleteStale_Request{ Source: "testSource", Timestamp: timestamppb.New(time.Now().Truncate(time.Microsecond)), @@ -178,15 +158,9 @@ func TestDestination(t *testing.T) { t.Fatal(err) } - _, err = c.GetMetrics(ctx, &pb.GetDestinationMetrics_Request{}) - if err != nil { - t.Fatal(err) - } - if _, err := c.Close(ctx, &pb.Close_Request{}); err != nil { t.Fatalf("failed to call Close: %v", err) } - cancel() wg.Wait() if serverErr != nil { diff --git a/serve/destination_v1_test.go b/serve/destination_v1_test.go index e5172106ad..c13b56232c 100644 --- a/serve/destination_v1_test.go +++ b/serve/destination_v1_test.go @@ -8,26 +8,23 @@ import ( "testing" "time" - "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" "github.com/apache/arrow/go/v13/arrow/ipc" pb "github.com/cloudquery/plugin-pb-go/pb/destination/v1" + pbSource "github.com/cloudquery/plugin-pb-go/pb/source/v2" "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/internal/memdb" - "github.com/cloudquery/plugin-sdk/v3/plugins/destination" - "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v4/internal/memdb" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/types/known/timestamppb" ) func TestDestinationV1(t *testing.T) { - plugin := destination.NewPlugin("testDestinationPlugin", "development", memdb.NewClient) - s := &destinationServe{ - plugin: plugin, - } - cmd := newCmdDestinationRoot(s) - cmd.SetArgs([]string{"serve", "--network", "test"}) + p := plugin.NewPlugin("testDestinationPlugin", "development", memdb.NewMemDBClient) + srv := Plugin(p, WithArgs("serve"), WithDestinationV0V1Server(), WithTestListener()) ctx := context.Background() ctx, cancel := context.WithCancel(ctx) var wg sync.WaitGroup @@ -35,27 +32,15 @@ func TestDestinationV1(t *testing.T) { var serverErr error go func() { defer wg.Done() - serverErr = cmd.ExecuteContext(ctx) + serverErr = srv.Serve(ctx) }() defer func() { cancel() wg.Wait() }() - // wait for the server to start - for { - testDestinationListenerLock.Lock() - if testDestinationListener != nil { - testDestinationListenerLock.Unlock() - break - } - testDestinationListenerLock.Unlock() - t.Log("waiting for grpc server to start") - time.Sleep(time.Millisecond * 200) - } - // https://stackoverflow.com/questions/42102496/testing-a-grpc-service - conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDestinationDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(srv.bufPluginDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) if err != nil { t.Fatalf("Failed to dial bufnet: %v", err) } @@ -95,7 +80,8 @@ func TestDestinationV1(t *testing.T) { sourceSpec := specs.Source{ Name: sourceName, } - encodedTables, err := tables.ToArrowSchemas().Encode() + schemas := tables.ToArrowSchemas() + encodedTables, err := pbSource.SchemasToBytes(schemas) if err != nil { t.Fatal(err) } @@ -146,17 +132,20 @@ func TestDestinationV1(t *testing.T) { t.Fatal(err) } // serversDestination - readCh := make(chan arrow.Record, 1) - if err := plugin.Read(ctx, table, sourceName, readCh); err != nil { + msgs, err := p.SyncAll(ctx, plugin.SyncOptions{ + Tables: []string{tableName}, + }) + if err != nil { t.Fatal(err) } - close(readCh) totalResources := 0 - for resource := range readCh { + for _, msg := range msgs { totalResources++ - if !array.RecordEqual(rec, resource) { - diff := destination.RecordDiff(rec, resource) - t.Fatalf("expected %v but got %v. Diff: %v", rec, resource, diff) + m := msg.(*message.Insert) + if !array.RecordEqual(rec, m.Record) { + // diff := plugin.RecordDiff(rec, resource) + // t.Fatalf("diff at %d: %s", totalResources, diff) + t.Fatalf("expected %v but got %v", rec, m.Record) } } if totalResources != 1 { @@ -170,11 +159,6 @@ func TestDestinationV1(t *testing.T) { t.Fatal(err) } - _, err = c.GetMetrics(ctx, &pb.GetDestinationMetrics_Request{}) - if err != nil { - t.Fatal(err) - } - if _, err := c.Close(ctx, &pb.Close_Request{}); err != nil { t.Fatalf("failed to call Close: %v", err) } diff --git a/serve/docs.go b/serve/docs.go new file mode 100644 index 0000000000..442b6308f1 --- /dev/null +++ b/serve/docs.go @@ -0,0 +1,47 @@ +package serve + +import ( + "fmt" + "strings" + + "github.com/cloudquery/plugin-sdk/v4/docs" + "github.com/spf13/cobra" +) + +const ( + pluginDocShort = "Generate documentation for tables" + pluginDocLong = `Generate documentation for tables + +If format is markdown, a destination directory will be created (if necessary) containing markdown files. +Example: +doc ./output + +If format is JSON, a destination directory will be created (if necessary) with a single json file called __tables.json. +Example: +doc --format json . +` +) + +func (s *PluginServe) newCmdPluginDoc() *cobra.Command { + format := newEnum([]string{"json", "markdown"}, "markdown") + cmd := &cobra.Command{ + Use: "doc ", + Short: pluginDocShort, + Long: pluginDocLong, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + tables, err := s.plugin.Tables(cmd.Context()) + if err != nil { + return err + } + g := docs.NewGenerator(s.plugin.Name(), tables) + f := docs.FormatMarkdown + if format.Value == "json" { + f = docs.FormatJSON + } + return g.Generate(args[0], f) + }, + } + cmd.Flags().Var(format, "format", fmt.Sprintf("output format. one of: %s", strings.Join(format.Allowed, ","))) + return cmd +} diff --git a/serve/docs_test.go b/serve/docs_test.go new file mode 100644 index 0000000000..1548e0b1c1 --- /dev/null +++ b/serve/docs_test.go @@ -0,0 +1,26 @@ +package serve + +import ( + "context" + "testing" + + "github.com/cloudquery/plugin-sdk/v4/internal/memdb" + "github.com/cloudquery/plugin-sdk/v4/plugin" +) + +func TestPluginDocs(t *testing.T) { + tmpDir := t.TempDir() + p := plugin.NewPlugin( + "testPlugin", + "v1.0.0", + memdb.NewMemDBClient) + if err := p.Init(context.Background(), nil); err != nil { + t.Fatal(err) + } + srv := Plugin(p) + cmd := srv.newCmdPluginRoot() + cmd.SetArgs([]string{"doc", tmpDir}) + if err := cmd.Execute(); err != nil { + t.Fatal(err) + } +} diff --git a/serve/source.go b/serve/plugin.go similarity index 54% rename from serve/source.go rename to serve/plugin.go index ae57c83d07..c787399820 100644 --- a/serve/source.go +++ b/serve/plugin.go @@ -1,20 +1,28 @@ package serve import ( + "context" "fmt" "net" "os" "os/signal" "strings" - "sync" "syscall" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/types" + + pbDestinationV0 "github.com/cloudquery/plugin-pb-go/pb/destination/v0" + pbDestinationV1 "github.com/cloudquery/plugin-pb-go/pb/destination/v1" pbdiscoveryv0 "github.com/cloudquery/plugin-pb-go/pb/discovery/v0" - pbv2 "github.com/cloudquery/plugin-pb-go/pb/source/v2" - discoveryServerV0 "github.com/cloudquery/plugin-sdk/v3/internal/servers/discovery/v0" + pbdiscoveryv1 "github.com/cloudquery/plugin-pb-go/pb/discovery/v1" + pbv3 "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + discoveryServerV0 "github.com/cloudquery/plugin-sdk/v4/internal/servers/discovery/v0" + discoveryServerV1 "github.com/cloudquery/plugin-sdk/v4/internal/servers/discovery/v1" - serversv2 "github.com/cloudquery/plugin-sdk/v3/internal/servers/source/v2" - "github.com/cloudquery/plugin-sdk/v3/plugins/source" + serverDestinationV0 "github.com/cloudquery/plugin-sdk/v4/internal/servers/destination/v0" + serverDestinationV1 "github.com/cloudquery/plugin-sdk/v4/internal/servers/destination/v1" + serversv3 "github.com/cloudquery/plugin-sdk/v4/internal/servers/plugin/v3" "github.com/getsentry/sentry-go" grpczerolog "github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" @@ -22,46 +30,84 @@ import ( "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/thoas/go-funk" - "golang.org/x/net/netutil" "google.golang.org/grpc" "google.golang.org/grpc/test/bufconn" ) -type sourceServe struct { - plugin *source.Plugin - sentryDSN string +type PluginServe struct { + plugin *plugin.Plugin + args []string + destinationV0V1Server bool + sentryDSN string + testListener bool + testListenerConn *bufconn.Listener } -type SourceOption func(*sourceServe) +type PluginOption func(*PluginServe) -func WithSourceSentryDSN(dsn string) SourceOption { - return func(s *sourceServe) { +func WithPluginSentryDSN(dsn string) PluginOption { + return func(s *PluginServe) { s.sentryDSN = dsn } } -// lis used for unit testing grpc server and client -var testSourceListener *bufconn.Listener -var testSourceListenerLock sync.Mutex +// WithDestinationV0V1Server is used to include destination v0 and v1 server to work +// with older sources +func WithDestinationV0V1Server() PluginOption { + return func(s *PluginServe) { + s.destinationV0V1Server = true + } +} + +// WithArgs used to serve the plugin with predefined args instead of os.Args +func WithArgs(args ...string) PluginOption { + return func(s *PluginServe) { + s.args = args + } +} + +// WithTestListener means that the plugin will be served with an in-memory listener +// available via testListener() method instead of a network listener. +func WithTestListener() PluginOption { + return func(s *PluginServe) { + s.testListener = true + s.testListenerConn = bufconn.Listen(testBufSize) + } +} -const serveSourceShort = `Start source plugin server` +const servePluginShort = `Start plugin server` -func Source(plugin *source.Plugin, opts ...SourceOption) { - s := &sourceServe{ - plugin: plugin, +func Plugin(p *plugin.Plugin, opts ...PluginOption) *PluginServe { + s := &PluginServe{ + plugin: p, } for _, opt := range opts { opt(s) } - if err := newCmdSourceRoot(s).Execute(); err != nil { - sentry.CaptureMessage(err.Error()) - fmt.Println(err) - os.Exit(1) + return s +} + +func (s *PluginServe) bufPluginDialer(context.Context, string) (net.Conn, error) { + return s.testListenerConn.Dial() +} + +func (s *PluginServe) Serve(ctx context.Context) error { + if err := types.RegisterAllExtensions(); err != nil { + return err } + defer func() { + if err := types.UnregisterAllExtensions(); err != nil { + log.Error().Err(err).Msg("failed to unregister all extensions") + } + }() + cmd := s.newCmdPluginRoot() + if s.args != nil { + cmd.SetArgs(s.args) + } + return cmd.ExecuteContext(ctx) } -// nolint:dupl -func newCmdSourceServe(serve *sourceServe) *cobra.Command { +func (s *PluginServe) newCmdPluginServe() *cobra.Command { var address string var network string var noSentry bool @@ -76,8 +122,8 @@ func newCmdSourceServe(serve *sourceServe) *cobra.Command { cmd := &cobra.Command{ Use: "serve", - Short: serveSourceShort, - Long: serveSourceShort, + Short: servePluginShort, + Long: servePluginShort, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { zerologLevel, err := zerolog.ParseLevel(logLevel.String()) @@ -90,25 +136,22 @@ func newCmdSourceServe(serve *sourceServe) *cobra.Command { } else { logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout}).Level(zerologLevel) } - // opts.Plugin.Logger = logger var listener net.Listener - if network == "test" { - testSourceListenerLock.Lock() - listener = bufconn.Listen(testBufSize) - testSourceListener = listener.(*bufconn.Listener) - testSourceListenerLock.Unlock() + if s.testListener { + listener = s.testListenerConn } else { listener, err = net.Listen(network, address) if err != nil { return fmt.Errorf("failed to listen %s:%s: %w", network, address, err) } } + defer listener.Close() // source plugins can only accept one connection at a time // unlike destination plugins that can accept multiple connections - limitListener := netutil.LimitListener(listener, 1) + // limitListener := netutil.LimitListener(listener, 1) // See logging pattern https://github.com/grpc-ecosystem/go-grpc-middleware/blob/v2/providers/zerolog/examples_test.go - s := grpc.NewServer( + grpcServer := grpc.NewServer( grpc.ChainUnaryInterceptor( logging.UnaryServerInterceptor(grpczerolog.InterceptorLogger(logger)), ), @@ -118,20 +161,34 @@ func newCmdSourceServe(serve *sourceServe) *cobra.Command { grpc.MaxRecvMsgSize(MaxMsgSize), grpc.MaxSendMsgSize(MaxMsgSize), ) - serve.plugin.SetLogger(logger) - pbv2.RegisterSourceServer(s, &serversv2.Server{ - Plugin: serve.plugin, - Logger: logger, + s.plugin.SetLogger(logger) + pbv3.RegisterPluginServer(grpcServer, &serversv3.Server{ + Plugin: s.plugin, + Logger: logger, + NoSentry: noSentry, }) - pbdiscoveryv0.RegisterDiscoveryServer(s, &discoveryServerV0.Server{ - Versions: []string{"v2"}, + if s.destinationV0V1Server { + pbDestinationV1.RegisterDestinationServer(grpcServer, &serverDestinationV1.Server{ + Plugin: s.plugin, + Logger: logger, + }) + pbDestinationV0.RegisterDestinationServer(grpcServer, &serverDestinationV0.Server{ + Plugin: s.plugin, + Logger: logger, + }) + } + pbdiscoveryv0.RegisterDiscoveryServer(grpcServer, &discoveryServerV0.Server{ + Versions: []string{"v0", "v1", "v2", "v3"}, + }) + pbdiscoveryv1.RegisterDiscoveryServer(grpcServer, &discoveryServerV1.Server{ + Versions: []int32{0, 1, 2, 3}, }) - version := serve.plugin.Version() + version := s.plugin.Version() - if serve.sentryDSN != "" && !strings.EqualFold(version, "development") && !noSentry { + if s.sentryDSN != "" && !strings.EqualFold(version, "development") && !noSentry { err = sentry.Init(sentry.ClientOptions{ - Dsn: serve.sentryDSN, + Dsn: s.sentryDSN, Debug: false, AttachStacktrace: false, Release: version, @@ -165,15 +222,15 @@ func newCmdSourceServe(serve *sourceServe) *cobra.Command { select { case sig := <-c: logger.Info().Str("address", listener.Addr().String()).Str("signal", sig.String()).Msg("Got stop signal. Source plugin server shutting down") - s.Stop() + grpcServer.Stop() case <-ctx.Done(): logger.Info().Str("address", listener.Addr().String()).Msg("Context cancelled. Source plugin server shutting down") - s.Stop() + grpcServer.Stop() } }() logger.Info().Str("address", listener.Addr().String()).Msg("Source plugin server listening") - if err := s.Serve(limitListener); err != nil { + if err := grpcServer.Serve(listener); err != nil { return fmt.Errorf("failed to serve: %w", err) } return nil @@ -192,42 +249,13 @@ func newCmdSourceServe(serve *sourceServe) *cobra.Command { return cmd } -const ( - sourceDocShort = "Generate documentation for tables" - sourceDocLong = `Generate documentation for tables - -If format is markdown, a destination directory will be created (if necessary) containing markdown files. -Example: -doc ./output - -If format is JSON, a destination directory will be created (if necessary) with a single json file called __tables.json. -Example: -doc --format json . -` -) - -func newCmdSourceDoc(serve *sourceServe) *cobra.Command { - format := newEnum([]string{"json", "markdown"}, "markdown") - cmd := &cobra.Command{ - Use: "doc ", - Short: sourceDocShort, - Long: sourceDocLong, - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - return serve.plugin.GeneratePluginDocs(args[0], format.Value) - }, - } - cmd.Flags().Var(format, "format", fmt.Sprintf("output format. one of: %s", strings.Join(format.Allowed, ","))) - return cmd -} - -func newCmdSourceRoot(serve *sourceServe) *cobra.Command { +func (s *PluginServe) newCmdPluginRoot() *cobra.Command { cmd := &cobra.Command{ - Use: fmt.Sprintf("%s ", serve.plugin.Name()), + Use: fmt.Sprintf("%s ", s.plugin.Name()), } - cmd.AddCommand(newCmdSourceServe(serve)) - cmd.AddCommand(newCmdSourceDoc(serve)) + cmd.AddCommand(s.newCmdPluginServe()) + cmd.AddCommand(s.newCmdPluginDoc()) cmd.CompletionOptions.DisableDefaultCmd = true - cmd.Version = serve.plugin.Version() + cmd.Version = s.plugin.Version() return cmd } diff --git a/serve/plugin_test.go b/serve/plugin_test.go new file mode 100644 index 0000000000..c648d53976 --- /dev/null +++ b/serve/plugin_test.go @@ -0,0 +1,196 @@ +package serve + +import ( + "bytes" + "context" + "io" + "sync" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/ipc" + "github.com/apache/arrow/go/v13/arrow/memory" + pb "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "github.com/cloudquery/plugin-sdk/v4/internal/memdb" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +func TestPluginServe(t *testing.T) { + p := plugin.NewPlugin( + "testPluginV3", + "v1.0.0", + memdb.NewMemDBClient) + srv := Plugin(p, WithArgs("serve"), WithTestListener()) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + wg.Add(1) + var serverErr error + go func() { + defer wg.Done() + serverErr = srv.Serve(ctx) + }() + defer func() { + cancel() + wg.Wait() + }() + + // https://stackoverflow.com/questions/42102496/testing-a-grpc-service + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(srv.bufPluginDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + + c := pb.NewPluginClient(conn) + + getNameRes, err := c.GetName(ctx, &pb.GetName_Request{}) + if err != nil { + t.Fatal(err) + } + if getNameRes.Name != "testPluginV3" { + t.Fatalf("expected name to be testPluginV3 but got %s", getNameRes.Name) + } + + getVersionResponse, err := c.GetVersion(ctx, &pb.GetVersion_Request{}) + if err != nil { + t.Fatal(err) + } + if getVersionResponse.Version != "v1.0.0" { + t.Fatalf("Expected version to be v1.0.0 but got %s", getVersionResponse.Version) + } + + if _, err := c.Init(ctx, &pb.Init_Request{}); err != nil { + t.Fatal(err) + } + + getTablesRes, err := c.GetTables(ctx, &pb.GetTables_Request{}) + if err != nil { + t.Fatal(err) + } + schemas, err := pb.NewSchemasFromBytes(getTablesRes.Tables) + if err != nil { + t.Fatal(err) + } + tables, err := schema.NewTablesFromArrowSchemas(schemas) + if err != nil { + t.Fatal(err) + } + + if len(tables) != 0 { + t.Fatalf("Expected 0 tables but got %d", len(tables)) + } + testTable := schema.Table{ + Name: "test_table", + Columns: []schema.Column{ + { + Name: "col1", + Type: arrow.BinaryTypes.String, + }, + }, + } + bldr := array.NewRecordBuilder(memory.DefaultAllocator, testTable.ToArrowSchema()) + bldr.Field(0).(*array.StringBuilder).Append("test") + record := bldr.NewRecord() + + recordBytes, err := pb.RecordToBytes(record) + if err != nil { + t.Fatal(err) + } + sc := testTable.ToArrowSchema() + tableBytes, err := pb.SchemaToBytes(sc) + if err != nil { + t.Fatal(err) + } + writeClient, err := c.Write(ctx) + if err != nil { + t.Fatal(err) + } + + if err := writeClient.Send(&pb.Write_Request{ + Message: &pb.Write_Request_Options{ + Options: &pb.WriteOptions{ + MigrateForce: true, + }, + }, + }); err != nil { + t.Fatal(err) + } + + if err := writeClient.Send(&pb.Write_Request{ + Message: &pb.Write_Request_MigrateTable{ + MigrateTable: &pb.MessageMigrateTable{ + Table: tableBytes, + }, + }, + }); err != nil { + t.Fatal(err) + } + if err := writeClient.Send(&pb.Write_Request{ + Message: &pb.Write_Request_Insert{ + Insert: &pb.MessageInsert{ + Record: recordBytes, + }, + }, + }); err != nil { + t.Fatal(err) + } + if _, err := writeClient.CloseAndRecv(); err != nil { + t.Fatal(err) + } + + syncClient, err := c.Sync(ctx, &pb.Sync_Request{ + Tables: []string{"test_table"}, + }) + if err != nil { + t.Fatal(err) + } + var resources []arrow.Record + for { + r, err := syncClient.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + m := r.Message.(*pb.Sync_Response_Insert) + rdr, err := ipc.NewReader(bytes.NewReader(m.Insert.Record)) + if err != nil { + t.Fatal(err) + } + for rdr.Next() { + rec := rdr.Record() + rec.Retain() + resources = append(resources, rec) + } + } + + totalResources := 0 + for _, resource := range resources { + sc := resource.Schema() + tableName, ok := sc.Metadata().GetValue(schema.MetadataTableName) + if !ok { + t.Fatal("Expected table name metadata to be set") + } + if tableName != "test_table" { + t.Fatalf("Expected resource with table name test_table. got: %s", tableName) + } + if len(resource.Columns()) != 1 { + t.Fatalf("Expected resource with data length 1 but got %d", len(resource.Columns())) + } + totalResources++ + } + if totalResources != 1 { + t.Fatalf("Expected 1 resource on channel but got %d", totalResources) + } + + cancel() + wg.Wait() + if serverErr != nil { + t.Fatal(serverErr) + } +} diff --git a/serve/source_v2_test.go b/serve/source_v2_test.go deleted file mode 100644 index 8a541611e9..0000000000 --- a/serve/source_v2_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package serve - -import ( - "bytes" - "context" - "encoding/json" - "io" - "net" - "sync" - "testing" - "time" - - "github.com/apache/arrow/go/v13/arrow" - "github.com/apache/arrow/go/v13/arrow/ipc" - pb "github.com/cloudquery/plugin-pb-go/pb/source/v2" - "github.com/cloudquery/plugin-pb-go/specs" - "github.com/cloudquery/plugin-sdk/v3/plugins/source" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/rs/zerolog" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" -) - -type TestSourcePluginSpec struct { - Accounts []string `json:"accounts,omitempty" yaml:"accounts,omitempty"` -} - -type testExecutionClient struct{} - -var _ schema.ClientMeta = &testExecutionClient{} - -// var errTestExecutionClientErr = fmt.Errorf("error in newTestExecutionClientErr") - -func testTable(name string) *schema.Table { - return &schema.Table{ - Name: name, - Resolver: func(ctx context.Context, meta schema.ClientMeta, parent *schema.Resource, res chan<- any) error { - res <- map[string]any{ - "TestColumn": 3, - } - return nil - }, - Columns: []schema.Column{ - { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - }, - }, - } -} - -func (*testExecutionClient) ID() string { - return "testExecutionClient" -} - -func newTestExecutionClient(context.Context, zerolog.Logger, specs.Source, source.Options) (schema.ClientMeta, error) { - return &testExecutionClient{}, nil -} - -func bufSourceDialer(context.Context, string) (net.Conn, error) { - testSourceListenerLock.Lock() - defer testSourceListenerLock.Unlock() - return testSourceListener.Dial() -} - -func TestSourceSuccess(t *testing.T) { - plugin := source.NewPlugin( - "testPlugin", - "v1.0.0", - []*schema.Table{testTable("test_table"), testTable("test_table2")}, - newTestExecutionClient) - - cmd := newCmdSourceRoot(&sourceServe{ - plugin: plugin, - }) - cmd.SetArgs([]string{"serve", "--network", "test"}) - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - var wg sync.WaitGroup - wg.Add(1) - var serverErr error - go func() { - defer wg.Done() - serverErr = cmd.ExecuteContext(ctx) - }() - defer func() { - cancel() - wg.Wait() - }() - for { - testSourceListenerLock.Lock() - if testSourceListener != nil { - testSourceListenerLock.Unlock() - break - } - testSourceListenerLock.Unlock() - t.Log("waiting for grpc server to start") - time.Sleep(time.Millisecond * 200) - } - - // https://stackoverflow.com/questions/42102496/testing-a-grpc-service - conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufSourceDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) - if err != nil { - t.Fatalf("Failed to dial bufnet: %v", err) - } - c := pb.NewSourceClient(conn) - - getNameRes, err := c.GetName(ctx, &pb.GetName_Request{}) - if err != nil { - t.Fatal(err) - } - if getNameRes.Name != "testPlugin" { - t.Fatalf("expected name to be testPlugin but got %s", getNameRes.Name) - } - - getVersionResponse, err := c.GetVersion(ctx, &pb.GetVersion_Request{}) - if err != nil { - t.Fatal(err) - } - if getVersionResponse.Version != "v1.0.0" { - t.Fatalf("Expected version to be v1.0.0 but got %s", getVersionResponse.Version) - } - - spec := specs.Source{ - Name: "testSourcePlugin", - Version: "v1.0.0", - Path: "cloudquery/testSourcePlugin", - Registry: specs.RegistryGithub, - Tables: []string{"test_table"}, - Spec: TestSourcePluginSpec{Accounts: []string{"cloudquery/plugin-sdk"}}, - Destinations: []string{"test"}, - } - specMarshaled, err := json.Marshal(spec) - if err != nil { - t.Fatalf("Failed to marshal spec: %v", err) - } - - getTablesRes, err := c.GetTables(ctx, &pb.GetTables_Request{}) - if err != nil { - t.Fatal(err) - } - - tables, err := schema.NewTablesFromBytes(getTablesRes.Tables) - if err != nil { - t.Fatal(err) - } - - if len(tables) != 2 { - t.Fatalf("Expected 2 tables but got %d", len(tables)) - } - if _, err := c.Init(ctx, &pb.Init_Request{Spec: specMarshaled}); err != nil { - t.Fatal(err) - } - - getTablesForSpecRes, err := c.GetDynamicTables(ctx, &pb.GetDynamicTables_Request{}) - if err != nil { - t.Fatal(err) - } - tables, err = schema.NewTablesFromBytes(getTablesForSpecRes.Tables) - if err != nil { - t.Fatal(err) - } - - if len(tables) != 1 { - t.Fatalf("Expected 1 table but got %d", len(tables)) - } - - syncClient, err := c.Sync(ctx, &pb.Sync_Request{}) - if err != nil { - t.Fatal(err) - } - var resources []arrow.Record - for { - r, err := syncClient.Recv() - if err == io.EOF { - break - } - if err != nil { - t.Fatal(err) - } - rdr, err := ipc.NewReader(bytes.NewReader(r.Resource)) - if err != nil { - t.Fatal(err) - } - for rdr.Next() { - rec := rdr.Record() - rec.Retain() - resources = append(resources, rec) - } - } - - totalResources := 0 - for _, resource := range resources { - sc := resource.Schema() - tableName, ok := sc.Metadata().GetValue(schema.MetadataTableName) - if !ok { - t.Fatal("Expected table name metadata to be set") - } - if tableName != "test_table" { - t.Fatalf("Expected resource with table name test_table. got: %s", tableName) - } - if len(resource.Columns()) != 5 { - t.Fatalf("Expected resource with data length 3 but got %d", len(resource.Columns())) - } - totalResources++ - } - if totalResources != 1 { - t.Fatalf("Expected 1 resource on channel but got %d", totalResources) - } - - getMetricsRes, err := c.GetMetrics(ctx, &pb.GetMetrics_Request{}) - if err != nil { - t.Fatal(err) - } - var stats source.Metrics - if err := json.Unmarshal(getMetricsRes.Metrics, &stats); err != nil { - t.Fatal(err) - } - - clientStats := stats.TableClient[""][""] - if clientStats.Resources != 1 { - t.Fatalf("Expected 1 resource but got %d", clientStats.Resources) - } - - if clientStats.Errors != 0 { - t.Fatalf("Expected 0 errors but got %d", clientStats.Errors) - } - - if clientStats.Panics != 0 { - t.Fatalf("Expected 0 panics but got %d", clientStats.Panics) - } - - cancel() - wg.Wait() - if serverErr != nil { - t.Fatal(serverErr) - } -} diff --git a/serve/state_test.go b/serve/state_test.go new file mode 100644 index 0000000000..14cf2aa90d --- /dev/null +++ b/serve/state_test.go @@ -0,0 +1,83 @@ +package serve + +import ( + "context" + "sync" + "testing" + + pb "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "github.com/cloudquery/plugin-sdk/v4/internal/clients/state/v3" + "github.com/cloudquery/plugin-sdk/v4/internal/memdb" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +func TestState(t *testing.T) { + p := plugin.NewPlugin( + "testPluginV3", + "v1.0.0", + memdb.NewMemDBClient) + srv := Plugin(p, WithArgs("serve"), WithTestListener()) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + wg.Add(1) + var serverErr error + go func() { + defer wg.Done() + serverErr = srv.Serve(ctx) + }() + defer func() { + cancel() + wg.Wait() + }() + + // https://stackoverflow.com/questions/42102496/testing-a-grpc-service + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(srv.bufPluginDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + + c := pb.NewPluginClient(conn) + if _, err := c.Init(ctx, &pb.Init_Request{}); err != nil { + t.Fatal(err) + } + stateClient, err := state.NewClient(ctx, c, "test") + if err != nil { + t.Fatal(err) + } + + if err := stateClient.SetKey(ctx, "key", "value"); err != nil { + t.Fatal(err) + } + + val, err := stateClient.GetKey(ctx, "key") + if err != nil { + t.Fatal(err) + } + if val != "value" { + t.Fatalf("expected value to be value but got %s", val) + } + + if err := stateClient.Flush(ctx); err != nil { + t.Fatal(err) + } + stateClient, err = state.NewClient(ctx, c, "test") + if err != nil { + t.Fatal(err) + } + val, err = stateClient.GetKey(ctx, "key") + if err != nil { + t.Fatal(err) + } + if val != "value" { + t.Fatalf("expected value to be value but got %s", val) + } + + cancel() + wg.Wait() + if serverErr != nil { + t.Fatal(serverErr) + } +} diff --git a/state/state.go b/state/state.go new file mode 100644 index 0000000000..d90b595ef2 --- /dev/null +++ b/state/state.go @@ -0,0 +1,29 @@ +package state + +import ( + "context" + "fmt" + + pbDiscovery "github.com/cloudquery/plugin-pb-go/pb/discovery/v1" + pbPluginV3 "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + stateV3 "github.com/cloudquery/plugin-sdk/v4/internal/clients/state/v3" + "golang.org/x/exp/slices" + "google.golang.org/grpc" +) + +type Client interface { + SetKey(ctx context.Context, key string, value string) error + GetKey(ctx context.Context, key string) (string, error) +} + +func NewClient(ctx context.Context, conn *grpc.ClientConn, tableName string) (Client, error) { + discoveryClient := pbDiscovery.NewDiscoveryClient(conn) + versions, err := discoveryClient.GetVersions(ctx, &pbDiscovery.GetVersions_Request{}) + if err != nil { + return nil, err + } + if slices.Contains(versions.Versions, 3) { + return stateV3.NewClient(ctx, pbPluginV3.NewPluginClient(conn), tableName) + } + return nil, fmt.Errorf("please upgrade your state backend plugin. state supporting version 3 plugin has %v", versions.Versions) +} diff --git a/transformers/struct.go b/transformers/struct.go index 2296af865e..b6c97842c5 100644 --- a/transformers/struct.go +++ b/transformers/struct.go @@ -8,9 +8,9 @@ import ( "time" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/caser" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/caser" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/cloudquery/plugin-sdk/v4/types" "github.com/thoas/go-funk" "golang.org/x/exp/slices" ) diff --git a/transformers/struct_test.go b/transformers/struct_test.go index 55acfbef16..d59cc6588b 100644 --- a/transformers/struct_test.go +++ b/transformers/struct_test.go @@ -7,8 +7,8 @@ import ( "time" "github.com/apache/arrow/go/v13/arrow" - "github.com/cloudquery/plugin-sdk/v3/schema" - "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/cloudquery/plugin-sdk/v4/types" "github.com/google/go-cmp/cmp" "golang.org/x/exp/slices" ) diff --git a/transformers/tables.go b/transformers/tables.go new file mode 100644 index 0000000000..f8e7c5b46f --- /dev/null +++ b/transformers/tables.go @@ -0,0 +1,30 @@ +package transformers + +import ( + "fmt" + + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +// Set parent links on relational tables +func SetParents(tables schema.Tables, parent *schema.Table) { + for _, table := range tables { + table.Parent = parent + SetParents(table.Relations, table) + } +} + +// Apply transformations to tables +func TransformTables(tables schema.Tables) error { + for _, table := range tables { + if table.Transform != nil { + if err := table.Transform(table); err != nil { + return fmt.Errorf("failed to transform table %s: %w", table.Name, err) + } + } + if err := TransformTables(table.Relations); err != nil { + return err + } + } + return nil +} diff --git a/writers/batch.go b/writers/batch.go new file mode 100644 index 0000000000..510418c9bb --- /dev/null +++ b/writers/batch.go @@ -0,0 +1,327 @@ +package writers + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/apache/arrow/go/v13/arrow/util" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/rs/zerolog" +) + +type Writer interface { + Write(ctx context.Context, writeOptions plugin.WriteOptions, res <-chan message.Message) error +} + +const ( + defaultBatchTimeoutSeconds = 20 + defaultBatchSize = 10000 + defaultBatchSizeBytes = 5 * 1024 * 1024 // 5 MiB +) + +type BatchWriterClient interface { + MigrateTables(context.Context, []*message.MigrateTable) error + WriteTableBatch(ctx context.Context, name string, msgs []*message.Insert) error + DeleteStale(context.Context, []*message.DeleteStale) error +} + +type BatchWriter struct { + client BatchWriterClient + workers map[string]*worker + workersLock sync.RWMutex + workersWaitGroup sync.WaitGroup + + migrateTableLock sync.Mutex + migrateTableMessages []*message.MigrateTable + deleteStaleLock sync.Mutex + deleteStaleMessages []*message.DeleteStale + + logger zerolog.Logger + batchTimeout time.Duration + batchSize int + batchSizeBytes int +} + +type Option func(*BatchWriter) + +func WithLogger(logger zerolog.Logger) Option { + return func(p *BatchWriter) { + p.logger = logger + } +} + +func WithBatchTimeout(timeout time.Duration) Option { + return func(p *BatchWriter) { + p.batchTimeout = timeout + } +} + +func WithBatchSize(size int) Option { + return func(p *BatchWriter) { + p.batchSize = size + } +} + +func WithBatchSizeBytes(size int) Option { + return func(p *BatchWriter) { + p.batchSizeBytes = size + } +} + +type worker struct { + count int + ch chan *message.Insert + flush chan chan bool +} + +func NewBatchWriter(client BatchWriterClient, opts ...Option) (*BatchWriter, error) { + c := &BatchWriter{ + client: client, + workers: make(map[string]*worker), + logger: zerolog.Nop(), + batchTimeout: defaultBatchTimeoutSeconds * time.Second, + batchSize: defaultBatchSize, + batchSizeBytes: defaultBatchSizeBytes, + } + for _, opt := range opts { + opt(c) + } + c.migrateTableMessages = make([]*message.MigrateTable, 0, c.batchSize) + c.deleteStaleMessages = make([]*message.DeleteStale, 0, c.batchSize) + return c, nil +} + +func (w *BatchWriter) Flush(ctx context.Context) error { + w.workersLock.RLock() + for _, worker := range w.workers { + done := make(chan bool) + worker.flush <- done + <-done + } + w.workersLock.RUnlock() + if err := w.flushMigrateTables(ctx); err != nil { + return err + } + return w.flushDeleteStaleTables(ctx) +} + +func (w *BatchWriter) Close(context.Context) error { + w.workersLock.Lock() + defer w.workersLock.Unlock() + for _, w := range w.workers { + close(w.ch) + } + w.workersWaitGroup.Wait() + + return nil +} + +func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.Insert, flush <-chan chan bool) { + sizeBytes := int64(0) + resources := make([]*message.Insert, 0) + for { + select { + case r, ok := <-ch: + if !ok { + if len(resources) > 0 { + w.flushTable(ctx, tableName, resources) + } + return + } + resources = append(resources, r) + sizeBytes += util.TotalRecordSize(r.Record) + + if len(resources) >= w.batchSize || sizeBytes >= int64(w.batchSizeBytes) { + w.flushTable(ctx, tableName, resources) + resources = make([]*message.Insert, 0) + sizeBytes = 0 + } + case <-time.After(w.batchTimeout): + if len(resources) > 0 { + w.flushTable(ctx, tableName, resources) + resources = make([]*message.Insert, 0) + sizeBytes = 0 + } + case done := <-flush: + if len(resources) > 0 { + w.flushTable(ctx, tableName, resources) + resources = make([]*message.Insert, 0) + sizeBytes = 0 + } + done <- true + case <-ctx.Done(): + // this means the request was cancelled + return // after this NO other call will succeed + } + } +} + +func (w *BatchWriter) flushTable(ctx context.Context, tableName string, resources []*message.Insert) { + // resources = w.removeDuplicatesByPK(table, resources) + start := time.Now() + batchSize := len(resources) + if err := w.client.WriteTableBatch(ctx, tableName, resources); err != nil { + w.logger.Err(err).Str("table", tableName).Int("len", batchSize).Dur("duration", time.Since(start)).Msg("failed to write batch") + } else { + w.logger.Info().Str("table", tableName).Int("len", batchSize).Dur("duration", time.Since(start)).Msg("batch written successfully") + } +} + +// func (*BatchWriter) removeDuplicatesByPK(table *schema.Table, resources []*message.Insert) []*message.Insert { +// pkIndices := table.PrimaryKeysIndexes() +// // special case where there's no PK at all +// if len(pkIndices) == 0 { +// return resources +// } + +// pks := make(map[string]struct{}, len(resources)) +// res := make([]*message.Insert, 0, len(resources)) +// for _, r := range resources { +// if r.Record.NumRows() > 1 { +// panic(fmt.Sprintf("record with more than 1 row: %d", r.Record.NumRows())) +// } +// key := pk.String(r.Record) +// _, ok := pks[key] +// if !ok { +// pks[key] = struct{}{} +// res = append(res, r) +// continue +// } +// // duplicate, release +// r.Release() +// } + +// return res +// } + +func (w *BatchWriter) flushMigrateTables(ctx context.Context) error { + w.migrateTableLock.Lock() + defer w.migrateTableLock.Unlock() + if len(w.migrateTableMessages) == 0 { + return nil + } + if err := w.client.MigrateTables(ctx, w.migrateTableMessages); err != nil { + return err + } + w.migrateTableMessages = w.migrateTableMessages[:0] + return nil +} + +func (w *BatchWriter) flushDeleteStaleTables(ctx context.Context) error { + w.deleteStaleLock.Lock() + defer w.deleteStaleLock.Unlock() + if len(w.deleteStaleMessages) == 0 { + return nil + } + if err := w.client.DeleteStale(ctx, w.deleteStaleMessages); err != nil { + return err + } + w.deleteStaleMessages = w.deleteStaleMessages[:0] + return nil +} + +func (w *BatchWriter) flushInsert(tableName string) { + w.workersLock.RLock() + worker, ok := w.workers[tableName] + if !ok { + w.workersLock.RUnlock() + // no tables to flush + return + } + w.workersLock.RUnlock() + ch := make(chan bool) + worker.flush <- ch + <-ch +} + +func (w *BatchWriter) writeAll(ctx context.Context, msgs []message.Message) error { + ch := make(chan message.Message, len(msgs)) + for _, msg := range msgs { + ch <- msg + } + close(ch) + return w.Write(ctx, ch) +} + +func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.Message) error { + for msg := range msgs { + switch m := msg.(type) { + case *message.DeleteStale: + if err := w.flushMigrateTables(ctx); err != nil { + return err + } + w.flushInsert(m.Table.Name) + w.deleteStaleLock.Lock() + w.deleteStaleMessages = append(w.deleteStaleMessages, m) + l := len(w.deleteStaleMessages) + w.deleteStaleLock.Unlock() + if l > w.batchSize { + if err := w.flushDeleteStaleTables(ctx); err != nil { + return err + } + } + case *message.Insert: + if err := w.flushMigrateTables(ctx); err != nil { + return err + } + if err := w.flushDeleteStaleTables(ctx); err != nil { + return err + } + if err := w.startWorker(ctx, m); err != nil { + return err + } + case *message.MigrateTable: + w.flushInsert(m.Table.Name) + if err := w.flushDeleteStaleTables(ctx); err != nil { + return err + } + w.migrateTableLock.Lock() + w.migrateTableMessages = append(w.migrateTableMessages, m) + l := len(w.migrateTableMessages) + w.migrateTableLock.Unlock() + if l > w.batchSize { + if err := w.flushMigrateTables(ctx); err != nil { + return err + } + } + } + } + return nil +} + +func (w *BatchWriter) startWorker(ctx context.Context, msg *message.Insert) error { + w.workersLock.RLock() + md := msg.Record.Schema().Metadata() + tableName, ok := md.GetValue(schema.MetadataTableName) + if !ok { + w.workersLock.RUnlock() + return fmt.Errorf("table name not found in metadata") + } + wr, ok := w.workers[tableName] + w.workersLock.RUnlock() + if ok { + wr.ch <- msg + return nil + } + w.workersLock.Lock() + ch := make(chan *message.Insert) + flush := make(chan chan bool) + wr = &worker{ + count: 1, + ch: ch, + flush: flush, + } + w.workers[tableName] = wr + w.workersLock.Unlock() + w.workersWaitGroup.Add(1) + go func() { + defer w.workersWaitGroup.Done() + w.worker(ctx, tableName, ch, flush) + }() + ch <- msg + return nil +} diff --git a/writers/batch_test.go b/writers/batch_test.go new file mode 100644 index 0000000000..a6940181d1 --- /dev/null +++ b/writers/batch_test.go @@ -0,0 +1,228 @@ +package writers + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +type testBatchClient struct { + mutex sync.Mutex + migrateTables []*message.MigrateTable + inserts []*message.Insert + deleteStales []*message.DeleteStale +} + +func (c *testBatchClient) MigrateTablesLen() int { + c.mutex.Lock() + defer c.mutex.Unlock() + return len(c.migrateTables) +} + +func (c *testBatchClient) InsertsLen() int { + c.mutex.Lock() + defer c.mutex.Unlock() + return len(c.inserts) +} + +func (c *testBatchClient) DeleteStalesLen() int { + c.mutex.Lock() + defer c.mutex.Unlock() + return len(c.deleteStales) +} + +func (c *testBatchClient) MigrateTables(_ context.Context, msgs []*message.MigrateTable) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.migrateTables = append(c.migrateTables, msgs...) + return nil +} + +func (c *testBatchClient) WriteTableBatch(_ context.Context, _ string, msgs []*message.Insert) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.inserts = append(c.inserts, msgs...) + return nil +} +func (c *testBatchClient) DeleteStale(_ context.Context, msgs []*message.DeleteStale) error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.deleteStales = append(c.deleteStales, msgs...) + return nil +} + +var batchTestTables = schema.Tables{ + { + Name: "table1", + Columns: []schema.Column{ + { + Name: "id", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + }, + { + Name: "table2", + Columns: []schema.Column{ + { + Name: "id", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + }, +} + +// TestBatchFlushDifferentMessages tests that if writer receives a message of a new type all other pending +// batches are flushed. +func TestBatchFlushDifferentMessages(t *testing.T) { + ctx := context.Background() + + testClient := &testBatchClient{} + wr, err := NewBatchWriter(testClient) + if err != nil { + t.Fatal(err) + } + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, batchTestTables[0].ToArrowSchema()) + bldr.Field(0).(*array.Int64Builder).Append(1) + record := bldr.NewRecord() + if err := wr.writeAll(ctx, []message.Message{&message.MigrateTable{Table: batchTestTables[0]}}); err != nil { + t.Fatal(err) + } + + if testClient.MigrateTablesLen() != 0 { + t.Fatalf("expected 0 create table messages, got %d", testClient.MigrateTablesLen()) + } + + if err := wr.writeAll(ctx, []message.Message{&message.Insert{Record: record}}); err != nil { + t.Fatal(err) + } + + if testClient.MigrateTablesLen() != 1 { + t.Fatalf("expected 1 migrate table messages, got %d", testClient.MigrateTablesLen()) + } + + if testClient.InsertsLen() != 0 { + t.Fatalf("expected 0 insert messages, got %d", testClient.InsertsLen()) + } + + if err := wr.writeAll(ctx, []message.Message{&message.MigrateTable{Table: batchTestTables[0]}}); err != nil { + t.Fatal(err) + } + + if testClient.InsertsLen() != 1 { + t.Fatalf("expected 1 insert messages, got %d", testClient.InsertsLen()) + } +} + +func TestBatchSize(t *testing.T) { + ctx := context.Background() + + testClient := &testBatchClient{} + wr, err := NewBatchWriter(testClient, WithBatchSize(2)) + if err != nil { + t.Fatal(err) + } + table := schema.Table{Name: "table1", Columns: []schema.Column{{Name: "id", Type: arrow.PrimitiveTypes.Int64}}} + record := array.NewRecord(table.ToArrowSchema(), nil, 0) + if err := wr.writeAll(ctx, []message.Message{&message.Insert{ + Record: record, + }}); err != nil { + t.Fatal(err) + } + + if testClient.InsertsLen() != 0 { + t.Fatalf("expected 0 insert messages, got %d", testClient.InsertsLen()) + } + + if err := wr.writeAll(ctx, []message.Message{&message.Insert{ + Record: record, + }}); err != nil { + t.Fatal(err) + } + // we need to wait for the batch to be flushed + time.Sleep(time.Second * 2) + + if testClient.InsertsLen() != 2 { + t.Fatalf("expected 2 insert messages, got %d", testClient.InsertsLen()) + } +} + +func TestBatchTimeout(t *testing.T) { + ctx := context.Background() + + testClient := &testBatchClient{} + wr, err := NewBatchWriter(testClient, WithBatchTimeout(time.Second)) + if err != nil { + t.Fatal(err) + } + table := schema.Table{Name: "table1", Columns: []schema.Column{{Name: "id", Type: arrow.PrimitiveTypes.Int64}}} + record := array.NewRecord(table.ToArrowSchema(), nil, 0) + if err := wr.writeAll(ctx, []message.Message{&message.Insert{ + Record: record, + }}); err != nil { + t.Fatal(err) + } + + if testClient.InsertsLen() != 0 { + t.Fatalf("expected 0 insert messages, got %d", testClient.InsertsLen()) + } + + // we need to wait for the batch to be flushed + time.Sleep(time.Millisecond * 250) + + if testClient.InsertsLen() != 0 { + t.Fatalf("expected 0 insert messages, got %d", testClient.InsertsLen()) + } + + // we need to wait for the batch to be flushed + time.Sleep(time.Second * 1) + + if testClient.InsertsLen() != 1 { + t.Fatalf("expected 1 insert messages, got %d", testClient.InsertsLen()) + } +} + +func TestBatchUpserts(t *testing.T) { + ctx := context.Background() + + testClient := &testBatchClient{} + wr, err := NewBatchWriter(testClient, WithBatchSize(2), WithBatchTimeout(time.Second)) + if err != nil { + t.Fatal(err) + } + table := schema.Table{Name: "table1", Columns: []schema.Column{{Name: "id", Type: arrow.PrimitiveTypes.Int64, PrimaryKey: true}}} + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.Int64Builder).Append(1) + record := bldr.NewRecord() + + if err := wr.writeAll(ctx, []message.Message{&message.Insert{ + Record: record, + }}); err != nil { + t.Fatal(err) + } + + if testClient.InsertsLen() != 0 { + t.Fatalf("expected 0 insert messages, got %d", testClient.InsertsLen()) + } + + if err := wr.writeAll(ctx, []message.Message{&message.Insert{ + Record: record, + }}); err != nil { + t.Fatal(err) + } + // we need to wait for the batch to be flushed + time.Sleep(time.Second * 2) + + if testClient.InsertsLen() != 2 { + t.Fatalf("expected 2 insert messages, got %d", testClient.InsertsLen()) + } +} diff --git a/writers/mixed_batch.go b/writers/mixed_batch.go new file mode 100644 index 0000000000..ec9a3e48e3 --- /dev/null +++ b/writers/mixed_batch.go @@ -0,0 +1,211 @@ +package writers + +import ( + "context" + "reflect" + "time" + + "github.com/apache/arrow/go/v13/arrow/util" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/rs/zerolog" +) + +const ( + msgTypeMigrateTable = iota + msgTypeInsert + msgTypeDeleteStale +) + +// MixedBatchClient is a client that will receive batches of messages with a mixture of tables. +type MixedBatchClient interface { + MigrateTableBatch(ctx context.Context, messages []*message.MigrateTable, options plugin.WriteOptions) error + InsertBatch(ctx context.Context, messages []*message.Insert, options plugin.WriteOptions) error + DeleteStaleBatch(ctx context.Context, messages []*message.DeleteStale, options plugin.WriteOptions) error +} + +type MixedBatchWriter struct { + client MixedBatchClient + logger zerolog.Logger + batchTimeout time.Duration + batchSize int + batchSizeBytes int +} + +// Assert at compile-time that MixedBatchWriter implements the Writer interface +var _ Writer = (*MixedBatchWriter)(nil) + +type MixedBatchWriterOption func(writer *MixedBatchWriter) + +func WithMixedBatchWriterLogger(logger zerolog.Logger) MixedBatchWriterOption { + return func(p *MixedBatchWriter) { + p.logger = logger + } +} + +func WithMixedBatchWriterBatchTimeout(timeout time.Duration) MixedBatchWriterOption { + return func(p *MixedBatchWriter) { + p.batchTimeout = timeout + } +} + +func WithMixedBatchWriterBatchSize(size int) MixedBatchWriterOption { + return func(p *MixedBatchWriter) { + p.batchSize = size + } +} + +func WithMixedBatchWriterSizeBytes(size int) MixedBatchWriterOption { + return func(p *MixedBatchWriter) { + p.batchSizeBytes = size + } +} + +func NewMixedBatchWriter(client MixedBatchClient, opts ...MixedBatchWriterOption) (*MixedBatchWriter, error) { + c := &MixedBatchWriter{ + client: client, + logger: zerolog.Nop(), + batchTimeout: defaultBatchTimeoutSeconds * time.Second, + batchSize: defaultBatchSize, + batchSizeBytes: defaultBatchSizeBytes, + } + for _, opt := range opts { + opt(c) + } + return c, nil +} + +func msgID(msg message.Message) int { + switch msg.(type) { + case *message.MigrateTable: + return msgTypeMigrateTable + case *message.Insert: + return msgTypeInsert + case *message.DeleteStale: + return msgTypeDeleteStale + } + panic("unknown message type: " + reflect.TypeOf(msg).Name()) +} + +// Write starts listening for messages on the msgChan channel and writes them to the client in batches. +func (w *MixedBatchWriter) Write(ctx context.Context, options plugin.WriteOptions, msgChan <-chan message.Message) error { + migrateTable := &batchManager[*message.MigrateTable]{ + batch: make([]*message.MigrateTable, 0, w.batchSize), + writeFunc: w.client.MigrateTableBatch, + writeOptions: options, + } + insert := &insertBatchManager{ + batch: make([]*message.Insert, 0, w.batchSize), + writeFunc: w.client.InsertBatch, + maxBatchSizeBytes: int64(w.batchSizeBytes), + writeOptions: options, + } + deleteStale := &batchManager[*message.DeleteStale]{ + batch: make([]*message.DeleteStale, 0, w.batchSize), + writeFunc: w.client.DeleteStaleBatch, + writeOptions: options, + } + flush := func(msgType int) error { + switch msgType { + case msgTypeMigrateTable: + return migrateTable.flush(ctx) + case msgTypeInsert: + return insert.flush(ctx) + case msgTypeDeleteStale: + return deleteStale.flush(ctx) + default: + panic("unknown message type") + } + } + prevMsgType := -1 + var err error + for msg := range msgChan { + msgType := msgID(msg) + if prevMsgType != -1 && prevMsgType != msgType { + if err := flush(prevMsgType); err != nil { + return err + } + } + prevMsgType = msgType + switch v := msg.(type) { + case *message.MigrateTable: + err = migrateTable.append(ctx, v) + case *message.Insert: + err = insert.append(ctx, v) + case *message.DeleteStale: + err = deleteStale.append(ctx, v) + default: + panic("unknown message type") + } + if err != nil { + return err + } + } + if prevMsgType == -1 { + return nil + } + return flush(prevMsgType) +} + +// generic batch manager for most message types +type batchManager[T message.Message] struct { + batch []T + writeFunc func(ctx context.Context, messages []T, options plugin.WriteOptions) error + writeOptions plugin.WriteOptions +} + +func (m *batchManager[T]) append(ctx context.Context, msg T) error { + if len(m.batch) == cap(m.batch) { + if err := m.flush(ctx); err != nil { + return err + } + } + m.batch = append(m.batch, msg) + return nil +} + +func (m *batchManager[T]) flush(ctx context.Context) error { + if len(m.batch) == 0 { + return nil + } + + err := m.writeFunc(ctx, m.batch, m.writeOptions) + if err != nil { + return err + } + m.batch = m.batch[:0] + return nil +} + +// special batch manager for insert messages that also keeps track of the total size of the batch +type insertBatchManager struct { + batch []*message.Insert + writeFunc func(ctx context.Context, messages []*message.Insert, writeOptions plugin.WriteOptions) error + curBatchSizeBytes int64 + maxBatchSizeBytes int64 + writeOptions plugin.WriteOptions +} + +func (m *insertBatchManager) append(ctx context.Context, msg *message.Insert) error { + if len(m.batch) == cap(m.batch) || m.curBatchSizeBytes+util.TotalRecordSize(msg.Record) > m.maxBatchSizeBytes { + if err := m.flush(ctx); err != nil { + return err + } + } + m.batch = append(m.batch, msg) + m.curBatchSizeBytes += util.TotalRecordSize(msg.Record) + return nil +} + +func (m *insertBatchManager) flush(ctx context.Context) error { + if len(m.batch) == 0 { + return nil + } + + err := m.writeFunc(ctx, m.batch, m.writeOptions) + if err != nil { + return err + } + m.batch = m.batch[:0] + return nil +} diff --git a/writers/mixed_batch_test.go b/writers/mixed_batch_test.go new file mode 100644 index 0000000000..ee8c9bbc94 --- /dev/null +++ b/writers/mixed_batch_test.go @@ -0,0 +1,195 @@ +package writers + +import ( + "context" + "testing" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cloudquery/plugin-sdk/v4/message" + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/schema" +) + +type testMixedBatchClient struct { + receivedBatches [][]message.Message +} + +func (c *testMixedBatchClient) MigrateTableBatch(_ context.Context, msgs []*message.MigrateTable, _ plugin.WriteOptions) error { + m := make([]message.Message, len(msgs)) + for i, msg := range msgs { + m[i] = msg + } + c.receivedBatches = append(c.receivedBatches, m) + return nil +} + +func (c *testMixedBatchClient) InsertBatch(_ context.Context, msgs []*message.Insert, _ plugin.WriteOptions) error { + m := make([]message.Message, len(msgs)) + for i, msg := range msgs { + m[i] = msg + } + c.receivedBatches = append(c.receivedBatches, m) + return nil +} + +func (c *testMixedBatchClient) DeleteStaleBatch(_ context.Context, msgs []*message.DeleteStale, _ plugin.WriteOptions) error { + m := make([]message.Message, len(msgs)) + for i, msg := range msgs { + m[i] = msg + } + c.receivedBatches = append(c.receivedBatches, m) + return nil +} + +var _ MixedBatchClient = (*testMixedBatchClient)(nil) + +func TestMixedBatchWriter(t *testing.T) { + ctx := context.Background() + + // message to create table1 + table1 := &schema.Table{ + Name: "table1", + Columns: []schema.Column{ + { + Name: "id", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } + msgMigrateTable1 := &message.MigrateTable{ + Table: table1, + } + + // message to create table2 + table2 := &schema.Table{ + Name: "table2", + Columns: []schema.Column{ + { + Name: "id", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } + msgMigrateTable2 := &message.MigrateTable{ + Table: table2, + } + + // message to insert into table1 + bldr1 := array.NewRecordBuilder(memory.DefaultAllocator, table1.ToArrowSchema()) + bldr1.Field(0).(*array.Int64Builder).Append(1) + rec1 := bldr1.NewRecord() + msgInsertTable1 := &message.Insert{ + Record: rec1, + } + + // message to insert into table2 + bldr2 := array.NewRecordBuilder(memory.DefaultAllocator, table1.ToArrowSchema()) + bldr2.Field(0).(*array.Int64Builder).Append(1) + rec2 := bldr2.NewRecord() + msgInsertTable2 := &message.Insert{ + Record: rec2, + } + + // message to delete stale from table1 + msgDeleteStale1 := &message.DeleteStale{ + Table: table1, + SourceName: "my-source", + SyncTime: time.Now(), + } + msgDeleteStale2 := &message.DeleteStale{ + Table: table1, + SourceName: "my-source", + SyncTime: time.Now(), + } + + testCases := []struct { + name string + messages []message.Message + wantBatches [][]message.Message + }{ + { + name: "create table, insert, delete stale", + messages: []message.Message{ + msgMigrateTable1, + msgMigrateTable2, + msgInsertTable1, + msgInsertTable2, + msgDeleteStale1, + msgDeleteStale2, + }, + wantBatches: [][]message.Message{ + {msgMigrateTable1, msgMigrateTable2}, + {msgInsertTable1, msgInsertTable2}, + {msgDeleteStale1, msgDeleteStale2}, + }, + }, + { + name: "interleaved messages", + messages: []message.Message{ + msgMigrateTable1, + msgInsertTable1, + msgDeleteStale1, + msgMigrateTable2, + msgInsertTable2, + msgDeleteStale2, + }, + wantBatches: [][]message.Message{ + {msgMigrateTable1}, + {msgInsertTable1}, + {msgDeleteStale1}, + {msgMigrateTable2}, + {msgInsertTable2}, + {msgDeleteStale2}, + }, + }, + { + name: "interleaved messages", + messages: []message.Message{ + msgMigrateTable1, + msgMigrateTable2, + msgInsertTable1, + msgDeleteStale2, + msgInsertTable2, + msgDeleteStale1, + }, + wantBatches: [][]message.Message{ + {msgMigrateTable1, msgMigrateTable2}, + {msgInsertTable1}, + {msgDeleteStale2}, + {msgInsertTable2}, + {msgDeleteStale1}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := &testMixedBatchClient{ + receivedBatches: make([][]message.Message, 0), + } + wr, err := NewMixedBatchWriter(client) + if err != nil { + t.Fatal(err) + } + ch := make(chan message.Message, len(tc.messages)) + for _, msg := range tc.messages { + ch <- msg + } + close(ch) + if err := wr.Write(ctx, plugin.WriteOptions{}, ch); err != nil { + t.Fatal(err) + } + if len(client.receivedBatches) != len(tc.wantBatches) { + t.Fatalf("got %d batches, want %d", len(client.receivedBatches), len(tc.wantBatches)) + } + for i, wantBatch := range tc.wantBatches { + if len(client.receivedBatches[i]) != len(wantBatch) { + t.Fatalf("got %d messages in batch %d, want %d", len(client.receivedBatches[i]), i, len(wantBatch)) + } + } + }) + } +}