From e5026817efa0a463efc80735325bcdbfe064734a Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Sun, 14 May 2023 16:06:13 +0300 Subject: [PATCH 01/14] feat: Add SDKV3 support for sources and update protocol --- go.mod | 6 +- go.sum | 7 +- .../servers/destination/v1/destinations.go | 191 ++++++++ internal/servers/source/v2/source.go | 173 +++++++ plugins/destination/plugin_testing_migrate.go | 9 +- plugins/source/benchmark_test.go | 428 +++++++++++++++++ plugins/source/docs.go | 243 ++++++++++ plugins/source/docs_test.go | 146 ++++++ plugins/source/metrics.go | 125 +++++ plugins/source/metrics_test.go | 37 ++ plugins/source/options.go | 39 ++ plugins/source/plugin.go | 335 +++++++++++++ plugins/source/plugin_test.go | 452 ++++++++++++++++++ plugins/source/scheduler.go | 163 +++++++ plugins/source/scheduler_dfs.go | 230 +++++++++ plugins/source/scheduler_round_robin.go | 104 ++++ plugins/source/scheduler_round_robin_test.go | 65 +++ plugins/source/templates/all_tables.md.go.tpl | 5 + .../templates/all_tables_entry.md.go.tpl | 5 + plugins/source/templates/table.md.go.tpl | 44 ++ .../TestGeneratePluginDocs-JSON-__tables.json | 197 ++++++++ .../TestGeneratePluginDocs-Markdown-README.md | 10 + ...tePluginDocs-Markdown-incremental_table.md | 20 + ...Docs-Markdown-relation_relation_table_a.md | 21 + ...Docs-Markdown-relation_relation_table_b.md | 21 + ...eratePluginDocs-Markdown-relation_table.md | 25 + ...tGeneratePluginDocs-Markdown-test_table.md | 25 + plugins/source/testing.go | 140 ++++++ plugins/source/validate.go | 25 + scalar/LICENSE | 23 + scalar/README.md | 4 + scalar/binary.go | 74 +++ scalar/binary_test.go | 28 ++ scalar/bool.go | 66 +++ scalar/bool_test.go | 33 ++ scalar/convert.go | 161 +++++++ scalar/errors.go | 37 ++ scalar/float.go | 219 +++++++++ scalar/float_test.go | 39 ++ scalar/inet.go | 127 +++++ scalar/inet_test.go | 116 +++++ scalar/int.go | 157 ++++++ scalar/int_test.go | 40 ++ scalar/json.go | 160 +++++++ scalar/json_test.go | 59 +++ scalar/list.go | 92 ++++ scalar/list_test.go | 33 ++ scalar/mac.go | 79 +++ scalar/mac_test.go | 43 ++ scalar/scalar.go | 130 +++++ scalar/string.go | 78 +++ scalar/string_test.go | 27 ++ scalar/type_test.go | 8 + scalar/uint.go | 166 +++++++ scalar/uint_test.go | 36 ++ scalar/uuid.go | 105 ++++ scalar/uuid_test.go | 66 +++ schema/arrow.go | 38 ++ schema/arrow_test.go | 44 ++ schema/meta.go | 4 +- schema/resource.go | 74 ++- schema/table.go | 14 +- serve/destination.go | 14 +- serve/destination_v1_test.go | 187 ++++++++ serve/source.go | 233 +++++++++ serve/source_v2_test.go | 244 ++++++++++ types/mac.go | 8 +- types/register.go | 20 + types/uuid.go | 2 +- 69 files changed, 6351 insertions(+), 28 deletions(-) create mode 100644 internal/servers/destination/v1/destinations.go create mode 100644 internal/servers/source/v2/source.go create mode 100644 plugins/source/benchmark_test.go create mode 100644 plugins/source/docs.go create mode 100644 plugins/source/docs_test.go create mode 100644 plugins/source/metrics.go create mode 100644 plugins/source/metrics_test.go create mode 100644 plugins/source/options.go create mode 100644 plugins/source/plugin.go create mode 100644 plugins/source/plugin_test.go create mode 100644 plugins/source/scheduler.go create mode 100644 plugins/source/scheduler_dfs.go create mode 100644 plugins/source/scheduler_round_robin.go create mode 100644 plugins/source/scheduler_round_robin_test.go create mode 100644 plugins/source/templates/all_tables.md.go.tpl create mode 100644 plugins/source/templates/all_tables_entry.md.go.tpl create mode 100644 plugins/source/templates/table.md.go.tpl create mode 100644 plugins/source/testdata/TestGeneratePluginDocs-JSON-__tables.json create mode 100644 plugins/source/testdata/TestGeneratePluginDocs-Markdown-README.md create mode 100644 plugins/source/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md create mode 100644 plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md create mode 100644 plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md create mode 100644 plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_table.md create mode 100644 plugins/source/testdata/TestGeneratePluginDocs-Markdown-test_table.md create mode 100644 plugins/source/testing.go create mode 100644 plugins/source/validate.go create mode 100644 scalar/LICENSE create mode 100644 scalar/README.md create mode 100644 scalar/binary.go create mode 100644 scalar/binary_test.go create mode 100644 scalar/bool.go create mode 100644 scalar/bool_test.go create mode 100644 scalar/convert.go create mode 100644 scalar/errors.go create mode 100644 scalar/float.go create mode 100644 scalar/float_test.go create mode 100644 scalar/inet.go create mode 100644 scalar/inet_test.go create mode 100644 scalar/int.go create mode 100644 scalar/int_test.go create mode 100644 scalar/json.go create mode 100644 scalar/json_test.go create mode 100644 scalar/list.go create mode 100644 scalar/list_test.go create mode 100644 scalar/mac.go create mode 100644 scalar/mac_test.go create mode 100644 scalar/scalar.go create mode 100644 scalar/string.go create mode 100644 scalar/string_test.go create mode 100644 scalar/type_test.go create mode 100644 scalar/uint.go create mode 100644 scalar/uint_test.go create mode 100644 scalar/uuid.go create mode 100644 scalar/uuid_test.go create mode 100644 schema/arrow_test.go create mode 100644 serve/destination_v1_test.go create mode 100644 serve/source.go create mode 100644 serve/source_v2_test.go create mode 100644 types/register.go diff --git a/go.mod b/go.mod index ce00954aac..2632d4666f 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( github.com/apache/arrow/go/v13 v13.0.0-20230509040948-de6c3cd2b604 + github.com/bradleyjkemp/cupaloy/v2 v2.8.0 github.com/cloudquery/plugin-pb-go v1.0.8 github.com/cloudquery/plugin-sdk/v2 v2.7.0 github.com/getsentry/sentry-go v0.20.0 @@ -18,6 +19,7 @@ require ( github.com/stretchr/testify v1.8.2 github.com/thoas/go-funk v0.9.3 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 + golang.org/x/net v0.9.0 golang.org/x/sync v0.1.0 golang.org/x/text v0.9.0 google.golang.org/grpc v1.54.0 @@ -26,6 +28,8 @@ require ( replace github.com/apache/arrow/go/v13 => github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8 +replace github.com/cloudquery/plugin-pb-go => ../plugin-pb-go + require ( github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/thrift v0.16.0 // indirect @@ -42,11 +46,11 @@ require ( github.com/mattn/go-isatty v0.0.18 // indirect github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect + github.com/pierrec/lz4/v4 v4.1.15 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/mod v0.8.0 // indirect - golang.org/x/net v0.9.0 // indirect golang.org/x/sys v0.7.0 // indirect golang.org/x/tools v0.6.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/go.sum b/go.sum index 3b776d8875..48f39ba21e 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/ github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apache/thrift v0.16.0 h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY= github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= @@ -45,8 +47,6 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8 h1:CmgLSEGQNLHpUQ5cU4L4aF7cuJZRnc1toIIWqC1gmPg= github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8/go.mod h1:/XatdE3kDIBqZKhZ7OBUHwP2jaASDFZHqF4puOWM8po= -github.com/cloudquery/plugin-pb-go v1.0.8 h1:wn3GXhcNItcP+6wUUZuzUFbvdL59liKBO37/izMi+FQ= -github.com/cloudquery/plugin-pb-go v1.0.8/go.mod h1:vAGA27psem7ZZNAY4a3S9TKuA/JDQWstjKcHPJX91Mc= github.com/cloudquery/plugin-sdk/v2 v2.7.0 h1:hRXsdEiaOxJtsn/wZMFQC9/jPfU1MeMK3KF+gPGqm7U= github.com/cloudquery/plugin-sdk/v2 v2.7.0/go.mod h1:pAX6ojIW99b/Vg4CkhnsGkRIzNaVEceYMR+Bdit73ug= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -174,6 +174,7 @@ github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8D github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= +github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -196,10 +197,12 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= diff --git a/internal/servers/destination/v1/destinations.go b/internal/servers/destination/v1/destinations.go new file mode 100644 index 0000000000..c3595a942b --- /dev/null +++ b/internal/servers/destination/v1/destinations.go @@ -0,0 +1,191 @@ +package destination + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/ipc" + pb "github.com/cloudquery/plugin-pb-go/pb/destination/v1" + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/plugins/destination" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type Server struct { + pb.UnimplementedDestinationServer + Plugin *destination.Plugin + Logger zerolog.Logger + spec specs.Destination +} + +func (s *Server) Configure(ctx context.Context, req *pb.Configure_Request) (*pb.Configure_Response, error) { + var spec specs.Destination + if err := json.Unmarshal(req.Config, &spec); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to unmarshal spec: %v", err) + } + s.spec = spec + return &pb.Configure_Response{}, s.Plugin.Init(ctx, s.Logger, spec) +} + +func (s *Server) GetName(context.Context, *pb.GetName_Request) (*pb.GetName_Response, error) { + return &pb.GetName_Response{ + Name: s.Plugin.Name(), + }, nil +} + +func (s *Server) GetVersion(context.Context, *pb.GetVersion_Request) (*pb.GetVersion_Response, error) { + return &pb.GetVersion_Response{ + Version: s.Plugin.Version(), + }, nil +} + +func (s *Server) Migrate(ctx context.Context, req *pb.Migrate_Request) (*pb.Migrate_Response, error) { + schemas, err := schema.NewSchemasFromBytes(req.Tables) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to create schemas: %v", err) + } + tables, err := schema.NewTablesFromArrowSchemas(schemas) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to create tables: %v", err) + } + s.setPKsForTables(tables) + + return &pb.Migrate_Response{}, s.Plugin.Migrate(ctx, tables) +} + +// Note the order of operations in this method is important! +// Trying to insert into the `resources` channel before starting the reader goroutine will cause a deadlock. +func (s *Server) Write(msg pb.Destination_WriteServer) error { + resources := make(chan arrow.Record) + + r, err := msg.Recv() + if err != nil { + if err == io.EOF { + return msg.SendAndClose(&pb.Write_Response{}) + } + return status.Errorf(codes.Internal, "failed to receive msg: %v", err) + } + + schemas, err := schema.NewSchemasFromBytes(r.Tables) + if err != nil { + return status.Errorf(codes.InvalidArgument, "failed to create schemas: %v", err) + } + tables, err := schema.NewTablesFromArrowSchemas(schemas) + if err != nil { + return status.Errorf(codes.InvalidArgument, "failed to create tables: %v", err) + } + var sourceSpec specs.Source + if r.SourceSpec == nil { + // this is for backward compatibility + sourceSpec = specs.Source{ + Name: r.Source, + } + } else { + if err := json.Unmarshal(r.SourceSpec, &sourceSpec); err != nil { + return status.Errorf(codes.InvalidArgument, "failed to unmarshal source spec: %v", err) + } + } + syncTime := r.Timestamp.AsTime() + s.setPKsForTables(tables) + eg, ctx := errgroup.WithContext(msg.Context()) + eg.Go(func() error { + return s.Plugin.Write(ctx, sourceSpec, tables, syncTime, resources) + }) + + for { + r, err := msg.Recv() + if err == io.EOF { + close(resources) + if err := eg.Wait(); err != nil { + return status.Errorf(codes.Internal, "write failed: %v", err) + } + return msg.SendAndClose(&pb.Write_Response{}) + } + if err != nil { + close(resources) + if wgErr := eg.Wait(); wgErr != nil { + return status.Errorf(codes.Internal, "failed to receive msg: %v and write failed: %v", err, wgErr) + } + return status.Errorf(codes.Internal, "failed to receive msg: %v", err) + } + rdr, err := ipc.NewReader(bytes.NewReader(r.Resource)) + if err != nil { + close(resources) + if wgErr := eg.Wait(); wgErr != nil { + return status.Errorf(codes.InvalidArgument, "failed to create reader: %v and write failed: %v", err, wgErr) + } + return status.Errorf(codes.InvalidArgument, "failed to create reader: %v", err) + } + for rdr.Next() { + rec := rdr.Record() + rec.Retain() + select { + case resources <- rec: + case <-ctx.Done(): + close(resources) + if err := eg.Wait(); err != nil { + return status.Errorf(codes.Internal, "Context done: %v and failed to wait for plugin: %v", ctx.Err(), err) + } + return status.Errorf(codes.Internal, "Context done: %v", ctx.Err()) + } + } + if err := rdr.Err(); err != nil { + return status.Errorf(codes.InvalidArgument, "failed to read resource: %v", err) + } + } +} + +func setCQIDAsPrimaryKeysForTables(tables schema.Tables) { + for _, table := range tables { + for i, col := range table.Columns { + table.Columns[i].CreationOptions.PrimaryKey = col.Name == schema.CqIDColumn.Name + } + setCQIDAsPrimaryKeysForTables(table.Relations) + } +} + +func (s *Server) GetMetrics(context.Context, *pb.GetDestinationMetrics_Request) (*pb.GetDestinationMetrics_Response, error) { + stats := s.Plugin.Metrics() + b, err := json.Marshal(stats) + if err != nil { + return nil, fmt.Errorf("failed to marshal stats: %w", err) + } + return &pb.GetDestinationMetrics_Response{ + Metrics: b, + }, nil +} + +func (s *Server) DeleteStale(ctx context.Context, req *pb.DeleteStale_Request) (*pb.DeleteStale_Response, error) { + schemas, err := schema.NewSchemasFromBytes(req.Tables) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to create schemas: %v", err) + } + tables, err := schema.NewTablesFromArrowSchemas(schemas) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to create tables: %v", err) + } + + if err := s.Plugin.DeleteStale(ctx, tables, req.Source, req.Timestamp.AsTime()); err != nil { + return nil, err + } + + return &pb.DeleteStale_Response{}, nil +} + +func (s *Server) setPKsForTables(tables schema.Tables) { + if s.spec.PKMode == specs.PKModeCQID { + setCQIDAsPrimaryKeysForTables(tables) + } +} + +func (s *Server) Close(ctx context.Context, _ *pb.Close_Request) (*pb.Close_Response, error) { + return &pb.Close_Response{}, s.Plugin.Close(ctx) +} diff --git a/internal/servers/source/v2/source.go b/internal/servers/source/v2/source.go new file mode 100644 index 0000000000..1909b85b78 --- /dev/null +++ b/internal/servers/source/v2/source.go @@ -0,0 +1,173 @@ +package source + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/ipc" + "github.com/apache/arrow/go/v13/arrow/memory" + pb "github.com/cloudquery/plugin-pb-go/pb/source/v2" + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/plugins/source" + "github.com/cloudquery/plugin-sdk/v3/scalar" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/getsentry/sentry-go" + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +const MaxMsgSize = 100 * 1024 * 1024 // 100 MiB + +type Server struct { + pb.UnimplementedSourceServer + Plugin *source.Plugin + Logger zerolog.Logger +} + +func (s *Server) GetTables(context.Context, *pb.GetTables_Request) (*pb.GetTables_Response, error) { + tables := s.Plugin.Tables().ToArrowSchemas() + encoded, err := tables.Encode() + if err != nil { + return nil, fmt.Errorf("failed to encode tables: %w", err) + } + return &pb.GetTables_Response{ + Tables: encoded, + }, nil +} + +func (s *Server) GetDynamicTables(context.Context, *pb.GetDynamicTables_Request) (*pb.GetDynamicTables_Response, error) { + tables := s.Plugin.GetDynamicTables().ToArrowSchemas() + encoded, err := tables.Encode() + if err != nil { + return nil, fmt.Errorf("failed to encode tables: %w", err) + } + return &pb.GetDynamicTables_Response{ + Tables: encoded, + }, nil +} + +func (s *Server) GetName(context.Context, *pb.GetName_Request) (*pb.GetName_Response, error) { + return &pb.GetName_Response{ + Name: s.Plugin.Name(), + }, nil +} + +func (s *Server) GetVersion(context.Context, *pb.GetVersion_Request) (*pb.GetVersion_Response, error) { + return &pb.GetVersion_Response{ + Version: s.Plugin.Version(), + }, nil +} + +func (s *Server) Init(ctx context.Context, req *pb.Init_Request) (*pb.Init_Response, error) { + var spec specs.Source + dec := json.NewDecoder(bytes.NewReader(req.Spec)) + dec.UseNumber() + // TODO: warn about unknown fields + if err := dec.Decode(&spec); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to decode spec: %v", err) + } + + if err := s.Plugin.Init(ctx, spec); err != nil { + return nil, status.Errorf(codes.Internal, "failed to init plugin: %v", err) + } + return &pb.Init_Response{}, nil +} + +func (s *Server) Sync(_ *pb.Sync_Request, stream pb.Source_SyncServer) error { + resources := make(chan *schema.Resource) + var syncErr error + ctx := stream.Context() + + go func() { + defer close(resources) + err := s.Plugin.Sync(ctx, resources) + if err != nil { + syncErr = fmt.Errorf("failed to sync resources: %w", err) + } + }() + + for resource := range resources { + vector := resource.GetValues() + bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema()) + scalar.AppendToRecordBuilder(bldr, vector) + rec := bldr.NewRecord() + + var buf bytes.Buffer + w := ipc.NewWriter(&buf, ipc.WithSchema(rec.Schema())) + if err := w.Write(rec); err != nil { + return status.Errorf(codes.Internal, "failed to write record: %v", err) + } + if err := w.Close(); err != nil { + return status.Errorf(codes.Internal, "failed to close writer: %v", err) + } + + msg := &pb.Sync_Response{ + Resource: buf.Bytes(), + } + err := checkMessageSize(msg, resource) + if err != nil { + s.Logger.Warn().Str("table", resource.Table.Name). + Int("bytes", len(msg.String())). + Msg("Row exceeding max bytes ignored") + continue + } + if err := stream.Send(msg); err != nil { + return status.Errorf(codes.Internal, "failed to send resource: %v", err) + } + } + + return syncErr +} + +func (s *Server) GetMetrics(context.Context, *pb.GetMetrics_Request) (*pb.GetMetrics_Response, error) { + // Aggregate metrics before sending to keep response size small. + // Temporary fix for https://github.com/cloudquery/cloudquery/issues/3962 + m := s.Plugin.Metrics() + agg := &source.TableClientMetrics{} + for _, table := range m.TableClient { + for _, tableClient := range table { + agg.Resources += tableClient.Resources + agg.Errors += tableClient.Errors + agg.Panics += tableClient.Panics + } + } + b, err := json.Marshal(&source.Metrics{ + TableClient: map[string]map[string]*source.TableClientMetrics{"": {"": agg}}, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal source metrics: %w", err) + } + return &pb.GetMetrics_Response{ + Metrics: b, + }, nil +} + +func (s *Server) GenDocs(_ context.Context, req *pb.GenDocs_Request) (*pb.GenDocs_Response, error) { + err := s.Plugin.GeneratePluginDocs(req.Path, req.Format.String()) + if err != nil { + return nil, fmt.Errorf("failed to generate docs: %w", err) + } + return &pb.GenDocs_Response{}, nil +} + +func checkMessageSize(msg proto.Message, resource *schema.Resource) error { + size := proto.Size(msg) + // log error to Sentry if row exceeds half of the max size + if size > MaxMsgSize/2 { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", resource.Table.Name) + scope.SetExtra("bytes", size) + sentry.CurrentHub().CaptureMessage("Large message detected") + }) + } + if size > MaxMsgSize { + return errors.New("message exceeds max size") + } + return nil +} diff --git a/plugins/destination/plugin_testing_migrate.go b/plugins/destination/plugin_testing_migrate.go index 5c859583c5..70ac7a0cc5 100644 --- a/plugins/destination/plugin_testing_migrate.go +++ b/plugins/destination/plugin_testing_migrate.go @@ -98,6 +98,7 @@ func (*PluginTestSuite) destinationPluginTestMigrate( } tableName := "add_column_" + tableUUIDSuffix() source := &schema.Table{ + Name: tableName, Columns: schema.ColumnList{ schema.CqSourceNameColumn, schema.CqSyncTimeColumn, @@ -143,6 +144,7 @@ func (*PluginTestSuite) destinationPluginTestMigrate( } target := &schema.Table{ + Name: tableName, Columns: schema.ColumnList{ schema.CqSourceNameColumn, schema.CqSyncTimeColumn, @@ -171,10 +173,11 @@ func (*PluginTestSuite) destinationPluginTestMigrate( schema.CqSourceNameColumn, schema.CqSyncTimeColumn, schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID, NotNull: true}, - {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, NotNull: true}, + {Name: "id", Type: types.ExtensionTypes.UUID}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, }} target := &schema.Table{ + Name: tableName, Columns: schema.ColumnList{ schema.CqSourceNameColumn, schema.CqSyncTimeColumn, @@ -213,7 +216,7 @@ func (*PluginTestSuite) destinationPluginTestMigrate( schema.CqSourceNameColumn, schema.CqSyncTimeColumn, schema.CqIDColumn, - {Name: "id", Type: types.ExtensionTypes.UUID, NotNull: true}, + {Name: "id", Type: types.ExtensionTypes.UUID}, }} p := newPlugin() diff --git a/plugins/source/benchmark_test.go b/plugins/source/benchmark_test.go new file mode 100644 index 0000000000..eb81e31b9f --- /dev/null +++ b/plugins/source/benchmark_test.go @@ -0,0 +1,428 @@ +package source + +import ( + "context" + "fmt" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" +) + +type BenchmarkScenario struct { + Client Client + Scheduler specs.Scheduler + Clients int + Tables int + ChildrenPerTable int + Columns int + ColumnResolvers int // number of columns with custom resolvers + ResourcesPerTable int + ResourcesPerPage int + NoPreResourceResolver bool + Concurrency uint64 +} + +func (s *BenchmarkScenario) SetDefaults() { + if s.Clients == 0 { + s.Clients = 1 + } + if s.Tables == 0 { + s.Tables = 1 + } + if s.Columns == 0 { + s.Columns = 10 + } + if s.ResourcesPerTable == 0 { + s.ResourcesPerTable = 100 + } + if s.ResourcesPerPage == 0 { + s.ResourcesPerPage = 10 + } +} + +type Client interface { + Call(clientID, tableName string) error +} + +type Benchmark struct { + *BenchmarkScenario + + b *testing.B + tables []*schema.Table + plugin *Plugin + + apiCalls atomic.Int64 +} + +func NewBenchmark(b *testing.B, scenario BenchmarkScenario) *Benchmark { + scenario.SetDefaults() + sb := &Benchmark{ + BenchmarkScenario: &scenario, + b: b, + tables: nil, + plugin: nil, + } + sb.setup(b) + return sb +} + +func (s *Benchmark) setup(b *testing.B) { + createResolvers := func(tableName string) (schema.TableResolver, schema.RowResolver, schema.ColumnResolver) { + tableResolver := func(ctx context.Context, meta schema.ClientMeta, parent *schema.Resource, res chan<- any) error { + total := 0 + for total < s.ResourcesPerTable { + s.simulateAPICall(meta.ID(), tableName) + num := min(s.ResourcesPerPage, s.ResourcesPerTable-total) + resources := make([]struct { + Column1 string + }, num) + for i := 0; i < num; i++ { + resources[i] = struct { + Column1 string + }{ + Column1: "test-column", + } + } + res <- resources + total += num + } + return nil + } + preResourceResolver := func(ctx context.Context, meta schema.ClientMeta, resource *schema.Resource) error { + s.simulateAPICall(meta.ID(), tableName) + resource.Item = struct { + Column1 string + }{ + Column1: "test-pre", + } + return nil + } + columnResolver := func(ctx context.Context, meta schema.ClientMeta, resource *schema.Resource, c schema.Column) error { + s.simulateAPICall(meta.ID(), tableName) + return resource.Set(c.Name, "test") + } + return tableResolver, preResourceResolver, columnResolver + } + + s.tables = make([]*schema.Table, s.Tables) + for i := 0; i < s.Tables; i++ { + tableResolver, preResourceResolver, columnResolver := createResolvers(fmt.Sprintf("table%d", i)) + columns := make([]schema.Column, s.Columns) + for u := 0; u < s.Columns; u++ { + columns[u] = schema.Column{ + Name: fmt.Sprintf("column%d", u), + Type: arrow.BinaryTypes.String, + } + if u < s.ColumnResolvers { + columns[u].Resolver = columnResolver + } + } + relations := make([]*schema.Table, s.ChildrenPerTable) + for u := 0; u < s.ChildrenPerTable; u++ { + relations[u] = &schema.Table{ + Name: fmt.Sprintf("table%d_child%d", i, u), + Columns: columns, + Resolver: tableResolver, + } + if !s.NoPreResourceResolver { + relations[u].PreResourceResolver = preResourceResolver + } + } + s.tables[i] = &schema.Table{ + Name: fmt.Sprintf("table%d", i), + Columns: columns, + Relations: relations, + Resolver: tableResolver, + Multiplex: nMultiplexer(s.Clients), + } + if !s.NoPreResourceResolver { + s.tables[i].PreResourceResolver = preResourceResolver + } + for u := range relations { + relations[u].Parent = s.tables[i] + } + } + + plugin := NewPlugin( + "testPlugin", + "1.0.0", + s.tables, + newTestExecutionClient, + ) + plugin.SetLogger(zerolog.New(zerolog.NewTestWriter(b)).Level(zerolog.WarnLevel)) + s.plugin = plugin + s.b = b +} + +func (s *Benchmark) simulateAPICall(clientID, tableName string) { + for { + s.apiCalls.Add(1) + err := s.Client.Call(clientID, tableName) + if err == nil { + // if no error, we are done + break + } + // if error, we have to retry + // we simulate a random backoff + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (s *Benchmark) Run() { + for n := 0; n < s.b.N; n++ { + s.b.StopTimer() + ctx := context.Background() + spec := specs.Source{ + Name: "testSource", + Path: "cloudquery/testSource", + Tables: []string{"*"}, + Version: "v1.0.0", + Destinations: []string{"test"}, + Concurrency: s.Concurrency, + Scheduler: s.Scheduler, + } + if err := s.plugin.Init(ctx, spec); err != nil { + s.b.Fatal(err) + } + resources := make(chan *schema.Resource) + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + defer close(resources) + return s.plugin.Sync(ctx, + resources) + }) + s.b.StartTimer() + start := time.Now() + + totalResources := 0 + for range resources { + // read resources channel until empty + totalResources++ + } + if err := g.Wait(); err != nil { + s.b.Fatal(err) + } + + end := time.Now() + s.b.ReportMetric(0, "ns/op") // drop default ns/op output + s.b.ReportMetric(float64(totalResources)/(end.Sub(start).Seconds()), "resources/s") + + // Enable the below metrics for more verbose information about the scenario: + // s.b.ReportMetric(float64(s.apiCalls.Load())/(end.Sub(start).Seconds()), "api-calls/s") + // s.b.ReportMetric(float64(totalResources), "resources") + // s.b.ReportMetric(float64(s.apiCalls.Load()), "apiCalls") + } +} + +type benchmarkClient struct { + num int +} + +func (b benchmarkClient) ID() string { + return fmt.Sprintf("client%d", b.num) +} + +func nMultiplexer(n int) schema.Multiplexer { + return func(meta schema.ClientMeta) []schema.ClientMeta { + clients := make([]schema.ClientMeta, n) + for i := 0; i < n; i++ { + clients[i] = benchmarkClient{ + num: i, + } + } + return clients + } +} + +func BenchmarkDefaultConcurrencyDFS(b *testing.B) { + benchmarkWithScheduler(b, specs.SchedulerDFS) +} + +func BenchmarkDefaultConcurrencyRoundRobin(b *testing.B) { + benchmarkWithScheduler(b, specs.SchedulerRoundRobin) +} + +func benchmarkWithScheduler(b *testing.B, scheduler specs.Scheduler) { + b.ReportAllocs() + minTime := 1 * time.Millisecond + mean := 10 * time.Millisecond + stdDev := 100 * time.Millisecond + client := NewDefaultClient(minTime, mean, stdDev) + bs := BenchmarkScenario{ + Client: client, + Clients: 25, + Tables: 5, + Columns: 10, + ColumnResolvers: 1, + ResourcesPerTable: 100, + ResourcesPerPage: 50, + Scheduler: scheduler, + } + sb := NewBenchmark(b, bs) + sb.Run() +} + +func BenchmarkTablesWithChildrenDFS(b *testing.B) { + benchmarkTablesWithChildrenScheduler(b, specs.SchedulerDFS) +} + +func BenchmarkTablesWithChildrenRoundRobin(b *testing.B) { + benchmarkTablesWithChildrenScheduler(b, specs.SchedulerRoundRobin) +} + +func benchmarkTablesWithChildrenScheduler(b *testing.B, scheduler specs.Scheduler) { + b.ReportAllocs() + minTime := 1 * time.Millisecond + mean := 10 * time.Millisecond + stdDev := 100 * time.Millisecond + client := NewDefaultClient(minTime, mean, stdDev) + bs := BenchmarkScenario{ + Client: client, + Clients: 2, + Tables: 2, + ChildrenPerTable: 2, + Columns: 10, + ColumnResolvers: 1, + ResourcesPerTable: 100, + ResourcesPerPage: 50, + Scheduler: scheduler, + } + sb := NewBenchmark(b, bs) + sb.Run() +} + +type DefaultClient struct { + min, stdDev, mean time.Duration +} + +func NewDefaultClient(min, mean, stdDev time.Duration) *DefaultClient { + if min == 0 { + min = time.Millisecond + } + if mean == 0 { + mean = 10 * time.Millisecond + } + if stdDev == 0 { + stdDev = 100 * time.Millisecond + } + return &DefaultClient{ + min: min, + mean: mean, + stdDev: stdDev, + } +} + +func (c *DefaultClient) Call(_, _ string) error { + sample := int(rand.NormFloat64()*float64(c.stdDev) + float64(c.mean)) + duration := time.Duration(sample) + if duration < c.min { + duration = c.min + } + time.Sleep(duration) + return nil +} + +type RateLimitClient struct { + *DefaultClient + calls map[string][]time.Time + callsLock sync.Mutex + window time.Duration + maxCallsPerWindow int +} + +func NewRateLimitClient(min, mean, stdDev time.Duration, maxCallsPerWindow int, window time.Duration) *RateLimitClient { + return &RateLimitClient{ + DefaultClient: NewDefaultClient(min, mean, stdDev), + calls: map[string][]time.Time{}, + window: window, + maxCallsPerWindow: maxCallsPerWindow, + } +} + +func (r *RateLimitClient) Call(clientID, table string) error { + // this will sleep for the appropriate amount of time before responding + err := r.DefaultClient.Call(clientID, table) + if err != nil { + return err + } + + r.callsLock.Lock() + defer r.callsLock.Unlock() + + // limit the number of calls per window by table + key := table + + // remove calls from outside the call window + updated := make([]time.Time, 0, len(r.calls[key])) + for i := range r.calls[key] { + if time.Since(r.calls[key][i]) < r.window { + updated = append(updated, r.calls[key][i]) + } + } + + // return error if we've exceeded the max calls in the time window + if len(updated) >= r.maxCallsPerWindow { + return fmt.Errorf("rate limit exceeded") + } + + r.calls[key] = append(r.calls[key], time.Now()) + return nil +} + +// BenchmarkDefaultConcurrency represents a benchmark scenario where rate limiting is applied +// by the cloud provider. In this rate limiter, the limit is applied globally per table. +// This mirrors the behavior of GCP, where rate limiting is applied per project *token*, not +// per project. A good scheduler should spread the load across tables so that other tables can make +// progress while waiting for the rate limit to reset. +func BenchmarkTablesWithRateLimitingDFS(b *testing.B) { + benchmarkTablesWithRateLimitingScheduler(b, specs.SchedulerDFS) +} + +func BenchmarkTablesWithRateLimitingRoundRobin(b *testing.B) { + benchmarkTablesWithRateLimitingScheduler(b, specs.SchedulerRoundRobin) +} + +// In this benchmark, we set up a scenario where each table has a global rate limit of 1 call per 100ms. +// Every table requires 1 call to resolve, and has 10 clients. This means, at best, each table can resolve in 1 second. +// We have 100 such tables and a concurrency that allows 1000 calls at a time. A good scheduler for this scenario +// should be able to resolve all tables in a bit more than 1 second. +func benchmarkTablesWithRateLimitingScheduler(b *testing.B, scheduler specs.Scheduler) { + b.ReportAllocs() + minTime := 1 * time.Millisecond + mean := 1 * time.Millisecond + stdDev := 1 * time.Millisecond + maxCallsPerWindow := 1 + window := 100 * time.Millisecond + c := NewRateLimitClient(minTime, mean, stdDev, maxCallsPerWindow, window) + + bs := BenchmarkScenario{ + Client: c, + Scheduler: scheduler, + Clients: 10, + Tables: 100, + ChildrenPerTable: 0, + Columns: 10, + ColumnResolvers: 0, + ResourcesPerTable: 1, + ResourcesPerPage: 1, + Concurrency: 1000, + NoPreResourceResolver: true, + } + sb := NewBenchmark(b, bs) + sb.Run() +} diff --git a/plugins/source/docs.go b/plugins/source/docs.go new file mode 100644 index 0000000000..b9b10d11f4 --- /dev/null +++ b/plugins/source/docs.go @@ -0,0 +1,243 @@ +package source + +import ( + "bytes" + "embed" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "text/template" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v3/caser" + "github.com/cloudquery/plugin-sdk/v3/schema" +) + +//go:embed templates/*.go.tpl +var templatesFS embed.FS + +var reMatchNewlines = regexp.MustCompile(`\n{3,}`) +var reMatchHeaders = regexp.MustCompile(`(#{1,6}.+)\n+`) + +var DefaultTitleExceptions = map[string]string{ + // common abbreviations + "acl": "ACL", + "acls": "ACLs", + "api": "API", + "apis": "APIs", + "ca": "CA", + "cidr": "CIDR", + "cidrs": "CIDRs", + "db": "DB", + "dbs": "DBs", + "dhcp": "DHCP", + "iam": "IAM", + "iot": "IOT", + "ip": "IP", + "ips": "IPs", + "ipv4": "IPv4", + "ipv6": "IPv6", + "mfa": "MFA", + "ml": "ML", + "oauth": "OAuth", + "vpc": "VPC", + "vpcs": "VPCs", + "vpn": "VPN", + "vpns": "VPNs", + "waf": "WAF", + "wafs": "WAFs", + + // cloud providers + "aws": "AWS", + "gcp": "GCP", +} + +func DefaultTitleTransformer(table *schema.Table) string { + if table.Title != "" { + return table.Title + } + csr := caser.New(caser.WithCustomExceptions(DefaultTitleExceptions)) + return csr.ToTitle(table.Name) +} + +func sortTables(tables schema.Tables) { + sort.SliceStable(tables, func(i, j int) bool { + return tables[i].Name < tables[j].Name + }) + + for _, table := range tables { + sortTables(table.Relations) + } +} + +type templateData struct { + PluginName string + Tables schema.Tables +} + +// GeneratePluginDocs creates table documentation for the source plugin based on its list of tables +func (p *Plugin) GeneratePluginDocs(dir, format string) error { + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return err + } + + setDestinationManagedCqColumns(p.Tables()) + + sortedTables := make(schema.Tables, 0, len(p.Tables())) + for _, t := range p.Tables() { + sortedTables = append(sortedTables, t.Copy(nil)) + } + sortTables(sortedTables) + + switch format { + case "markdown": + return p.renderTablesAsMarkdown(dir, p.name, sortedTables) + case "json": + return p.renderTablesAsJSON(dir, sortedTables) + default: + return fmt.Errorf("unsupported format: %v", format) + } +} + +// setDestinationManagedCqColumns overwrites or adds the CQ columns that are managed by the destination plugins (_cq_sync_time, _cq_source_name). +func setDestinationManagedCqColumns(tables []*schema.Table) { + for _, table := range tables { + table.OverwriteOrAddColumn(&schema.CqSyncTimeColumn) + table.OverwriteOrAddColumn(&schema.CqSourceNameColumn) + setDestinationManagedCqColumns(table.Relations) + } +} + +type jsonTable struct { + Name string `json:"name"` + Title string `json:"title"` + Description string `json:"description"` + Columns []jsonColumn `json:"columns"` + Relations []jsonTable `json:"relations"` +} + +type jsonColumn struct { + Name string `json:"name"` + Type string `json:"type"` + IsPrimaryKey bool `json:"is_primary_key,omitempty"` + IsIncrementalKey bool `json:"is_incremental_key,omitempty"` +} + +func (p *Plugin) renderTablesAsJSON(dir string, tables schema.Tables) error { + jsonTables := p.jsonifyTables(tables) + b, err := json.MarshalIndent(jsonTables, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal tables as json: %v", err) + } + outputPath := filepath.Join(dir, "__tables.json") + return os.WriteFile(outputPath, b, 0644) +} + +func (p *Plugin) jsonifyTables(tables schema.Tables) []jsonTable { + jsonTables := make([]jsonTable, len(tables)) + for i, table := range tables { + jsonColumns := make([]jsonColumn, len(table.Columns)) + for c, col := range table.Columns { + jsonColumns[c] = jsonColumn{ + Name: col.Name, + Type: col.Type.String(), + IsPrimaryKey: col.CreationOptions.PrimaryKey, + IsIncrementalKey: col.CreationOptions.IncrementalKey, + } + } + jsonTables[i] = jsonTable{ + Name: table.Name, + Title: p.titleTransformer(table), + Description: table.Description, + Columns: jsonColumns, + Relations: p.jsonifyTables(table.Relations), + } + } + return jsonTables +} + +func (p *Plugin) renderTablesAsMarkdown(dir string, pluginName string, tables schema.Tables) error { + for _, table := range tables { + if err := p.renderAllTables(table, dir); err != nil { + return err + } + } + t, err := template.New("all_tables.md.go.tpl").Funcs(template.FuncMap{ + "indentToDepth": indentToDepth, + }).ParseFS(templatesFS, "templates/all_tables*.md.go.tpl") + if err != nil { + return fmt.Errorf("failed to parse template for README.md: %v", err) + } + + var b bytes.Buffer + if err := t.Execute(&b, templateData{PluginName: pluginName, Tables: tables}); err != nil { + return fmt.Errorf("failed to execute template: %v", err) + } + content := formatMarkdown(b.String()) + outputPath := filepath.Join(dir, "README.md") + f, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create file %v: %v", outputPath, err) + } + f.WriteString(content) + return nil +} + +func (p *Plugin) renderAllTables(t *schema.Table, dir string) error { + if err := p.renderTable(t, dir); err != nil { + return err + } + for _, r := range t.Relations { + if err := p.renderAllTables(r, dir); err != nil { + return err + } + } + return nil +} + +func (p *Plugin) renderTable(table *schema.Table, dir string) error { + t := template.New("").Funcs(map[string]any{ + "formatType": formatType, + "title": p.titleTransformer, + }) + t, err := t.New("table.md.go.tpl").ParseFS(templatesFS, "templates/table.md.go.tpl") + if err != nil { + return fmt.Errorf("failed to parse template: %v", err) + } + + outputPath := filepath.Join(dir, fmt.Sprintf("%s.md", table.Name)) + + var b bytes.Buffer + if err := t.Execute(&b, table); err != nil { + return fmt.Errorf("failed to execute template: %v", err) + } + content := formatMarkdown(b.String()) + f, err := os.Create(outputPath) + if err != nil { + return fmt.Errorf("failed to create file %v: %v", outputPath, err) + } + f.WriteString(content) + return f.Close() +} + +func formatMarkdown(s string) string { + s = reMatchNewlines.ReplaceAllString(s, "\n\n") + return reMatchHeaders.ReplaceAllString(s, `$1`+"\n\n") +} + +func formatType(v arrow.DataType) string { + return v.String() +} + +func indentToDepth(table *schema.Table) string { + s := "" + t := table + for t.Parent != nil { + s += " " + t = t.Parent + } + return s +} diff --git a/plugins/source/docs_test.go b/plugins/source/docs_test.go new file mode 100644 index 0000000000..7097668dfe --- /dev/null +++ b/plugins/source/docs_test.go @@ -0,0 +1,146 @@ +//go:build !windows + +package source + +import ( + "os" + "path" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/bradleyjkemp/cupaloy/v2" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/stretchr/testify/require" +) + +var testTables = []*schema.Table{ + { + Name: "test_table", + Description: "Description for test table", + Columns: []schema.Column{ + { + Name: "int_col", + Type: arrow.PrimitiveTypes.Int64, + }, + { + Name: "id_col", + Type: arrow.PrimitiveTypes.Int64, + CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true}, + }, + { + Name: "id_col2", + Type: arrow.PrimitiveTypes.Int64, + CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true}, + }, + }, + Relations: []*schema.Table{ + { + Name: "relation_table", + Description: "Description for relational table", + Columns: []schema.Column{ + { + Name: "string_col", + Type: arrow.BinaryTypes.String, + }, + }, + Relations: []*schema.Table{ + { + Name: "relation_relation_table_b", + Description: "Description for relational table's relation", + Columns: []schema.Column{ + { + Name: "string_col", + Type: arrow.BinaryTypes.String, + }, + }, + }, + { + Name: "relation_relation_table_a", + Description: "Description for relational table's relation", + Columns: []schema.Column{ + { + Name: "string_col", + Type: arrow.BinaryTypes.String, + }, + }, + }, + }, + }, + { + Name: "relation_table2", + Description: "Description for second relational table", + Columns: []schema.Column{ + { + Name: "string_col", + Type: arrow.BinaryTypes.String, + }, + }, + }, + }, + }, + { + Name: "incremental_table", + Description: "Description for incremental table", + IsIncremental: true, + Columns: []schema.Column{ + { + Name: "int_col", + Type: arrow.PrimitiveTypes.Int64, + }, + { + Name: "id_col", + Type: arrow.PrimitiveTypes.Int64, + CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true, IncrementalKey: true}, + }, + { + Name: "id_col2", + Type: arrow.PrimitiveTypes.Int64, + CreationOptions: schema.ColumnCreationOptions{IncrementalKey: true}, + }, + }, + }, +} + +func TestGeneratePluginDocs(t *testing.T) { + p := NewPlugin("test", "v1.0.0", testTables, newTestExecutionClient) + + cup := cupaloy.New(cupaloy.SnapshotSubdirectory("testdata")) + + t.Run("Markdown", func(t *testing.T) { + tmpdir := t.TempDir() + + err := p.GeneratePluginDocs(tmpdir, "markdown") + if err != nil { + t.Fatalf("unexpected error calling GeneratePluginDocs: %v", err) + } + + expectFiles := []string{"test_table.md", "relation_table.md", "relation_relation_table_a.md", "relation_relation_table_b.md", "incremental_table.md", "README.md"} + for _, exp := range expectFiles { + t.Run(exp, func(t *testing.T) { + output := path.Join(tmpdir, exp) + got, err := os.ReadFile(output) + require.NoError(t, err) + cup.SnapshotT(t, got) + }) + } + }) + + t.Run("JSON", func(t *testing.T) { + tmpdir := t.TempDir() + + err := p.GeneratePluginDocs(tmpdir, "json") + if err != nil { + t.Fatalf("unexpected error calling GeneratePluginDocs: %v", err) + } + + expectFiles := []string{"__tables.json"} + for _, exp := range expectFiles { + t.Run(exp, func(t *testing.T) { + output := path.Join(tmpdir, exp) + got, err := os.ReadFile(output) + require.NoError(t, err) + cup.SnapshotT(t, got) + }) + } + }) +} diff --git a/plugins/source/metrics.go b/plugins/source/metrics.go new file mode 100644 index 0000000000..a4924a664c --- /dev/null +++ b/plugins/source/metrics.go @@ -0,0 +1,125 @@ +package source + +import ( + "sync/atomic" + "time" + + "github.com/cloudquery/plugin-sdk/v3/schema" +) + +type Metrics struct { + TableClient map[string]map[string]*TableClientMetrics +} + +type TableClientMetrics struct { + Resources uint64 + Errors uint64 + Panics uint64 + StartTime time.Time + EndTime time.Time +} + +func (s *TableClientMetrics) Equal(other *TableClientMetrics) bool { + return s.Resources == other.Resources && s.Errors == other.Errors && s.Panics == other.Panics +} + +// Equal compares to stats. Mostly useful in testing +func (s *Metrics) Equal(other *Metrics) bool { + for table, clientStats := range s.TableClient { + for client, stats := range clientStats { + if _, ok := other.TableClient[table]; !ok { + return false + } + if _, ok := other.TableClient[table][client]; !ok { + return false + } + if !stats.Equal(other.TableClient[table][client]) { + return false + } + } + } + for table, clientStats := range other.TableClient { + for client, stats := range clientStats { + if _, ok := s.TableClient[table]; !ok { + return false + } + if _, ok := s.TableClient[table][client]; !ok { + return false + } + if !stats.Equal(s.TableClient[table][client]) { + return false + } + } + } + return true +} + +func (s *Metrics) initWithClients(table *schema.Table, clients []schema.ClientMeta) { + s.TableClient[table.Name] = make(map[string]*TableClientMetrics, len(clients)) + for _, client := range clients { + s.TableClient[table.Name][client.ID()] = &TableClientMetrics{} + } + for _, relation := range table.Relations { + s.initWithClients(relation, clients) + } +} + +func (s *Metrics) TotalErrors() uint64 { + var total uint64 + for _, clientMetrics := range s.TableClient { + for _, metrics := range clientMetrics { + total += metrics.Errors + } + } + return total +} + +func (s *Metrics) TotalErrorsAtomic() uint64 { + var total uint64 + for _, clientMetrics := range s.TableClient { + for _, metrics := range clientMetrics { + total += atomic.LoadUint64(&metrics.Errors) + } + } + return total +} + +func (s *Metrics) TotalPanics() uint64 { + var total uint64 + for _, clientMetrics := range s.TableClient { + for _, metrics := range clientMetrics { + total += metrics.Panics + } + } + return total +} + +func (s *Metrics) TotalPanicsAtomic() uint64 { + var total uint64 + for _, clientMetrics := range s.TableClient { + for _, metrics := range clientMetrics { + total += atomic.LoadUint64(&metrics.Panics) + } + } + return total +} + +func (s *Metrics) TotalResources() uint64 { + var total uint64 + for _, clientMetrics := range s.TableClient { + for _, metrics := range clientMetrics { + total += metrics.Resources + } + } + return total +} + +func (s *Metrics) TotalResourcesAtomic() uint64 { + var total uint64 + for _, clientMetrics := range s.TableClient { + for _, metrics := range clientMetrics { + total += atomic.LoadUint64(&metrics.Resources) + } + } + return total +} diff --git a/plugins/source/metrics_test.go b/plugins/source/metrics_test.go new file mode 100644 index 0000000000..b18cdb387a --- /dev/null +++ b/plugins/source/metrics_test.go @@ -0,0 +1,37 @@ +package source + +import "testing" + +func TestMetrics(t *testing.T) { + s := &Metrics{ + TableClient: make(map[string]map[string]*TableClientMetrics), + } + s.TableClient["test_table"] = make(map[string]*TableClientMetrics) + s.TableClient["test_table"]["testExecutionClient"] = &TableClientMetrics{ + Resources: 1, + Errors: 2, + Panics: 3, + } + if s.TotalResources() != 1 { + t.Fatal("expected 1 resource") + } + if s.TotalErrors() != 2 { + t.Fatal("expected 2 error") + } + if s.TotalPanics() != 3 { + t.Fatal("expected 3 panics") + } + + other := &Metrics{ + TableClient: make(map[string]map[string]*TableClientMetrics), + } + other.TableClient["test_table"] = make(map[string]*TableClientMetrics) + other.TableClient["test_table"]["testExecutionClient"] = &TableClientMetrics{ + Resources: 1, + Errors: 2, + Panics: 3, + } + if !s.Equal(other) { + t.Fatal("expected metrics to be equal") + } +} diff --git a/plugins/source/options.go b/plugins/source/options.go new file mode 100644 index 0000000000..72ddc5acc7 --- /dev/null +++ b/plugins/source/options.go @@ -0,0 +1,39 @@ +package source + +import ( + "context" + + "github.com/cloudquery/plugin-sdk/v3/schema" +) + +type GetTables func(ctx context.Context, c schema.ClientMeta) (schema.Tables, error) + +type Option func(*Plugin) + +// WithDynamicTableOption allows the plugin to return list of tables after call to New +func WithDynamicTableOption(getDynamicTables GetTables) Option { + return func(p *Plugin) { + p.getDynamicTables = getDynamicTables + } +} + +// WithNoInternalColumns won't add internal columns (_cq_id, _cq_parent_cq_id) to the plugin tables +func WithNoInternalColumns() Option { + return func(p *Plugin) { + p.internalColumns = false + } +} + +func WithUnmanaged() Option { + return func(p *Plugin) { + p.unmanaged = true + } +} + +// WithTitleTransformer allows the plugin to control how table names get turned into titles for the +// generated documentation. +func WithTitleTransformer(t func(*schema.Table) string) Option { + return func(p *Plugin) { + p.titleTransformer = t + } +} diff --git a/plugins/source/plugin.go b/plugins/source/plugin.go new file mode 100644 index 0000000000..4aefcd0ec7 --- /dev/null +++ b/plugins/source/plugin.go @@ -0,0 +1,335 @@ +package source + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/backend" + "github.com/cloudquery/plugin-sdk/v3/caser" + "github.com/cloudquery/plugin-sdk/v3/internal/backends/local" + "github.com/cloudquery/plugin-sdk/v3/internal/backends/nop" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/rs/zerolog" + "golang.org/x/sync/semaphore" +) + +type Options struct { + Backend backend.Backend +} + +type NewExecutionClientFunc func(context.Context, zerolog.Logger, specs.Source, Options) (schema.ClientMeta, error) + +type UnmanagedClient interface { + schema.ClientMeta + Sync(ctx context.Context, metrics *Metrics, res chan<- *schema.Resource) error +} + +// Plugin is the base structure required to pass to sdk.serve +// We take a declarative approach to API here similar to Cobra +type Plugin struct { + // Name of plugin i.e aws,gcp, azure etc' + name string + // Version of the plugin + version string + // Called upon configure call to validate and init configuration + newExecutionClient NewExecutionClientFunc + // dynamic table function if specified + getDynamicTables GetTables + // Tables is all tables supported by this source plugin + tables schema.Tables + // status sync metrics + metrics *Metrics + // Logger to call, this logger is passed to the serve.Serve Client, if not defined Serve will create one instead. + logger zerolog.Logger + // resourceSem is a semaphore that limits the number of concurrent resources being fetched + resourceSem *semaphore.Weighted + // tableSem is a semaphore that limits the number of concurrent tables being fetched + tableSems []*semaphore.Weighted + // maxDepth is the max depth of tables + maxDepth uint64 + // caser + caser *caser.Caser + // mu is a mutex that limits the number of concurrent init/syncs (can only be one at a time) + mu sync.Mutex + + // client is the initialized session client + client schema.ClientMeta + // sessionTables are the + sessionTables schema.Tables + // backend is the backend used to store the cursor state + backend backend.Backend + // spec is the spec the client was initialized with + spec specs.Source + // NoInternalColumns if set to true will not add internal columns to tables such as _cq_id and _cq_parent_id + // useful for sources such as PostgreSQL and other databases + internalColumns bool + // unmanaged if set to true then the plugin will call Sync directly and not use the scheduler + unmanaged bool + // titleTransformer allows the plugin to control how table names get turned into titles for generated documentation + titleTransformer func(*schema.Table) string +} + +const ( + maxAllowedDepth = 4 +) + +// Add internal columns +func addInternalColumns(tables []*schema.Table) error { + for _, table := range tables { + if c := table.Column("_cq_id"); c != nil { + return fmt.Errorf("table %s already has column _cq_id", table.Name) + } + cqID := schema.CqIDColumn + if len(table.PrimaryKeys()) == 0 { + cqID.CreationOptions.PrimaryKey = true + } + table.Columns = append([]schema.Column{cqID, schema.CqParentIDColumn}, table.Columns...) + if err := addInternalColumns(table.Relations); err != nil { + return err + } + } + return nil +} + +// Set parent links on relational tables +func setParents(tables schema.Tables, parent *schema.Table) { + for _, table := range tables { + table.Parent = parent + setParents(table.Relations, table) + } +} + +// Apply transformations to tables +func transformTables(tables schema.Tables) error { + for _, table := range tables { + if table.Transform != nil { + if err := table.Transform(table); err != nil { + return fmt.Errorf("failed to transform table %s: %w", table.Name, err) + } + } + if err := transformTables(table.Relations); err != nil { + return err + } + } + return nil +} + +func maxDepth(tables schema.Tables) uint64 { + var depth uint64 + if len(tables) == 0 { + return 0 + } + for _, table := range tables { + newDepth := 1 + maxDepth(table.Relations) + if newDepth > depth { + depth = newDepth + } + } + return depth +} + +// NewPlugin returns a new plugin with a given name, version, tables, newExecutionClient +// and additional options. +func NewPlugin(name string, version string, tables []*schema.Table, newExecutionClient NewExecutionClientFunc, options ...Option) *Plugin { + p := Plugin{ + name: name, + version: version, + tables: tables, + newExecutionClient: newExecutionClient, + metrics: &Metrics{TableClient: make(map[string]map[string]*TableClientMetrics)}, + caser: caser.New(), + titleTransformer: DefaultTitleTransformer, + internalColumns: true, + } + for _, opt := range options { + opt(&p) + } + setParents(p.tables, nil) + if err := transformTables(p.tables); err != nil { + panic(err) + } + if p.internalColumns { + if err := addInternalColumns(p.tables); err != nil { + panic(err) + } + } + if err := p.validate(); err != nil { + panic(err) + } + p.maxDepth = maxDepth(p.tables) + if p.maxDepth > maxAllowedDepth { + panic(fmt.Errorf("max depth of tables is %d, max allowed is %d", p.maxDepth, maxAllowedDepth)) + } + return &p +} + +func (p *Plugin) SetLogger(logger zerolog.Logger) { + p.logger = logger.With().Str("module", p.name+"-src").Logger() +} + +// Tables returns all tables supported by this source plugin +func (p *Plugin) Tables() schema.Tables { + return p.tables +} + +func (p *Plugin) HasDynamicTables() bool { + return p.getDynamicTables != nil +} + +func (p *Plugin) GetDynamicTables() schema.Tables { + return p.sessionTables +} + +// TablesForSpec returns all tables supported by this source plugin that match the given spec. +// It validates the tables part of the spec and will return an error if it is found to be invalid. +// This is deprecated method +func (p *Plugin) TablesForSpec(spec specs.Source) (schema.Tables, error) { + spec.SetDefaults() + if err := spec.Validate(); err != nil { + return nil, fmt.Errorf("invalid spec: %w", err) + } + tables, err := p.tables.FilterDfs(spec.Tables, spec.SkipTables, spec.SkipDependentTables) + if err != nil { + return nil, fmt.Errorf("failed to filter tables: %w", err) + } + return tables, nil +} + +// Name return the name of this plugin +func (p *Plugin) Name() string { + return p.name +} + +// Version returns the version of this plugin +func (p *Plugin) Version() string { + return p.version +} + +func (p *Plugin) Metrics() *Metrics { + return p.metrics +} + +func (p *Plugin) Init(ctx context.Context, spec specs.Source) error { + if !p.mu.TryLock() { + return fmt.Errorf("plugin already in use") + } + defer p.mu.Unlock() + + var err error + spec.SetDefaults() + if err := spec.Validate(); err != nil { + return fmt.Errorf("invalid spec: %w", err) + } + p.spec = spec + + switch spec.Backend { + case specs.BackendNone: + p.backend = nop.New() + case specs.BackendLocal: + p.backend, err = local.New(spec) + if err != nil { + return fmt.Errorf("failed to initialize local backend: %w", err) + } + default: + return fmt.Errorf("unknown backend: %s", spec.Backend) + } + + tables := p.tables + if p.getDynamicTables != nil { + p.client, err = p.newExecutionClient(ctx, p.logger, spec, Options{Backend: p.backend}) + if err != nil { + return fmt.Errorf("failed to create execution client for source plugin %s: %w", p.name, err) + } + tables, err = p.getDynamicTables(ctx, p.client) + if err != nil { + return fmt.Errorf("failed to get dynamic tables: %w", err) + } + + tables, err = tables.FilterDfs(spec.Tables, spec.SkipTables, spec.SkipDependentTables) + if err != nil { + return fmt.Errorf("failed to filter tables: %w", err) + } + if len(tables) == 0 { + return fmt.Errorf("no tables to sync - please check your spec 'tables' and 'skip_tables' settings") + } + + setParents(tables, nil) + if err := transformTables(tables); err != nil { + return err + } + if p.internalColumns { + if err := addInternalColumns(tables); err != nil { + return err + } + } + if err := p.validate(); err != nil { + return err + } + p.maxDepth = maxDepth(tables) + if p.maxDepth > maxAllowedDepth { + return fmt.Errorf("max depth of tables is %d, max allowed is %d", p.maxDepth, maxAllowedDepth) + } + } else { + tables, err = tables.FilterDfs(spec.Tables, spec.SkipTables, spec.SkipDependentTables) + if err != nil { + return fmt.Errorf("failed to filter tables: %w", err) + } + } + + p.sessionTables = tables + return nil +} + +// Sync is syncing data from the requested tables in spec to the given channel +func (p *Plugin) Sync(ctx context.Context, res chan<- *schema.Resource) error { + if !p.mu.TryLock() { + return fmt.Errorf("plugin already in use") + } + defer p.mu.Unlock() + + if p.client == nil { + var err error + p.client, err = p.newExecutionClient(ctx, p.logger, p.spec, Options{Backend: p.backend}) + if err != nil { + return fmt.Errorf("failed to create execution client for source plugin %s: %w", p.name, err) + } + } + + startTime := time.Now() + if p.unmanaged { + unmanagedClient := p.client.(UnmanagedClient) + if err := unmanagedClient.Sync(ctx, p.metrics, res); err != nil { + return fmt.Errorf("failed to sync unmanaged client: %w", err) + } + } else { + switch p.spec.Scheduler { + case specs.SchedulerDFS: + p.syncDfs(ctx, p.spec, p.client, p.sessionTables, res) + case specs.SchedulerRoundRobin: + p.syncRoundRobin(ctx, p.spec, p.client, p.sessionTables, res) + default: + return fmt.Errorf("unknown scheduler %s. Options are: %v", p.spec.Scheduler, specs.AllSchedulers.String()) + } + } + + p.logger.Info().Uint64("resources", p.metrics.TotalResources()).Uint64("errors", p.metrics.TotalErrors()).Uint64("panics", p.metrics.TotalPanics()).TimeDiff("duration", time.Now(), startTime).Msg("sync finished") + return nil +} + +func (p *Plugin) Close(ctx context.Context) error { + if !p.mu.TryLock() { + return fmt.Errorf("plugin already in use") + } + defer p.mu.Unlock() + if p.backend != nil { + err := p.backend.Close(ctx) + if err != nil { + return fmt.Errorf("failed to close backend: %w", err) + } + p.backend = nil + } + return nil +} diff --git a/plugins/source/plugin_test.go b/plugins/source/plugin_test.go new file mode 100644 index 0000000000..8b3c84db82 --- /dev/null +++ b/plugins/source/plugin_test.go @@ -0,0 +1,452 @@ +package source + +import ( + "context" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/scalar" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/cloudquery/plugin-sdk/v3/transformers" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" +) + +type testExecutionClient struct{} + +var _ schema.ClientMeta = &testExecutionClient{} + +var deterministicStableUUID = uuid.MustParse("c25355aab52c5b70a4e0c9991f5a3b87") +var randomStableUUID = uuid.MustParse("00000000000040008000000000000000") + +func testResolverSuccess(_ context.Context, _ schema.ClientMeta, _ *schema.Resource, res chan<- any) error { + res <- map[string]any{ + "TestColumn": 3, + } + return nil +} + +func testResolverPanic(context.Context, schema.ClientMeta, *schema.Resource, chan<- any) error { + panic("Resolver") +} + +func testPreResourceResolverPanic(context.Context, schema.ClientMeta, *schema.Resource) error { + panic("PreResourceResolver") +} + +func testColumnResolverPanic(context.Context, schema.ClientMeta, *schema.Resource, schema.Column) error { + panic("ColumnResolver") +} + +func testTableSuccess() *schema.Table { + return &schema.Table{ + Name: "test_table_success", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + +func testTableSuccessWithPK() *schema.Table { + return &schema.Table{ + Name: "test_table_success", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + CreationOptions: schema.ColumnCreationOptions{ + PrimaryKey: true, + }, + }, + }, + } +} + +func testTableResolverPanic() *schema.Table { + return &schema.Table{ + Name: "test_table_resolver_panic", + Resolver: testResolverPanic, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + +func testTablePreResourceResolverPanic() *schema.Table { + return &schema.Table{ + Name: "test_table_pre_resource_resolver_panic", + PreResourceResolver: testPreResourceResolverPanic, + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + +func testTableColumnResolverPanic() *schema.Table { + return &schema.Table{ + Name: "test_table_column_resolver_panic", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + { + Name: "test_column1", + Type: arrow.PrimitiveTypes.Int64, + Resolver: testColumnResolverPanic, + }, + }, + } +} + +func testTableRelationSuccess() *schema.Table { + return &schema.Table{ + Name: "test_table_relation_success", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + Relations: []*schema.Table{ + testTableSuccess(), + }, + } +} + +func (*testExecutionClient) ID() string { + return "testExecutionClient" +} + +func newTestExecutionClient(context.Context, zerolog.Logger, specs.Source, Options) (schema.ClientMeta, error) { + return &testExecutionClient{}, nil +} + +type syncTestCase struct { + table *schema.Table + stats Metrics + data []scalar.Vector + deterministicCQID bool +} + +var syncTestCases = []syncTestCase{ + { + table: testTableSuccess(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_success": { + "testExecutionClient": { + Resources: 1, + }, + }, + }, + }, + data: []scalar.Vector{ + { + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.UUID{}, + &scalar.Int64{Value: 3, Valid: true}, + }, + }, + }, + { + table: testTableResolverPanic(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_resolver_panic": { + "testExecutionClient": { + Panics: 1, + }, + }, + }, + }, + data: nil, + }, + { + table: testTablePreResourceResolverPanic(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_pre_resource_resolver_panic": { + "testExecutionClient": { + Panics: 1, + }, + }, + }, + }, + data: nil, + }, + + { + table: testTableRelationSuccess(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_relation_success": { + "testExecutionClient": { + Resources: 1, + }, + }, + "test_table_success": { + "testExecutionClient": { + Resources: 1, + }, + }, + }, + }, + data: []scalar.Vector{ + { + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.UUID{}, + &scalar.Int64{Value: 3, Valid: true}, + }, + { + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.Int64{Value: 3, Valid: true}, + }, + }, + }, + { + table: testTableSuccess(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_success": { + "testExecutionClient": { + Resources: 1, + }, + }, + }, + }, + data: []scalar.Vector{ + { + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.UUID{}, + &scalar.Int64{Value: 3, Valid: true}, + }, + }, + deterministicCQID: true, + }, + { + table: testTableColumnResolverPanic(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_column_resolver_panic": { + "testExecutionClient": { + Panics: 1, + Resources: 1, + }, + }, + }, + }, + data: []scalar.Vector{ + { + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.UUID{}, + &scalar.Int64{Value: 3, Valid: true}, + &scalar.Int64{}, + }, + }, + deterministicCQID: true, + }, + { + table: testTableRelationSuccess(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_relation_success": { + "testExecutionClient": { + Resources: 1, + }, + }, + "test_table_success": { + "testExecutionClient": { + Resources: 1, + }, + }, + }, + }, + data: []scalar.Vector{ + { + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.UUID{}, + &scalar.Int64{Value: 3, Valid: true}, + }, + { + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.UUID{Value: randomStableUUID, Valid: true}, + &scalar.Int64{Value: 3, Valid: true}, + }, + }, + deterministicCQID: true, + }, + { + table: testTableSuccessWithPK(), + stats: Metrics{ + TableClient: map[string]map[string]*TableClientMetrics{ + "test_table_success": { + "testExecutionClient": { + Resources: 1, + }, + }, + }, + }, + data: []scalar.Vector{ + { + &scalar.UUID{Value: deterministicStableUUID, Valid: true}, + &scalar.UUID{}, + &scalar.Int64{Value: 3, Valid: true}, + }, + }, + deterministicCQID: true, + }, +} + +type testRand struct{} + +func (testRand) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(0) + } + return len(p), nil +} + +func TestSync(t *testing.T) { + uuid.SetRand(testRand{}) + for _, scheduler := range specs.AllSchedulers { + for _, tc := range syncTestCases { + tc := tc + tc.table = tc.table.Copy(nil) + t.Run(tc.table.Name+"_"+scheduler.String(), func(t *testing.T) { + testSyncTable(t, tc, scheduler, tc.deterministicCQID) + }) + } + } +} + +func testSyncTable(t *testing.T, tc syncTestCase, scheduler specs.Scheduler, deterministicCQID bool) { + ctx := context.Background() + tables := []*schema.Table{ + tc.table, + } + + plugin := NewPlugin( + "testSourcePlugin", + "1.0.0", + tables, + newTestExecutionClient, + ) + plugin.SetLogger(zerolog.New(zerolog.NewTestWriter(t))) + spec := specs.Source{ + Name: "testSource", + Path: "cloudquery/testSource", + Tables: []string{"*"}, + Version: "v1.0.0", + Destinations: []string{"test"}, + Concurrency: 1, // choose a very low value to check that we don't run into deadlocks + Scheduler: scheduler, + DeterministicCQID: deterministicCQID, + } + if err := plugin.Init(ctx, spec); err != nil { + t.Fatal(err) + } + + resources := make(chan *schema.Resource) + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + defer close(resources) + return plugin.Sync(ctx, + resources) + }) + + var i int + for resource := range resources { + if tc.data == nil { + t.Fatalf("Unexpected resource %v", resource) + } + if i >= len(tc.data) { + t.Fatalf("expected %d resources. got %d", len(tc.data), i) + } + if !resource.GetValues().Equal(tc.data[i]) { + t.Fatalf("expected at i=%d: %v. got %v", i, tc.data[i], resource.GetValues()) + } + i++ + } + if len(tc.data) != i { + t.Fatalf("expected %d resources. got %d", len(tc.data), i) + } + + stats := plugin.Metrics() + if !tc.stats.Equal(stats) { + t.Fatalf("unexpected stats: %v", cmp.Diff(tc.stats, stats)) + } + if err := g.Wait(); err != nil { + t.Fatal(err) + } +} + +func TestIgnoredColumns(t *testing.T) { + validateResources(t, schema.Resources{{ + Item: struct{ A *string }{}, + Table: &schema.Table{ + Columns: schema.ColumnList{ + { + Name: "a", + Type: arrow.BinaryTypes.String, + IgnoreInTests: true, + }, + }, + }, + }}) +} + +var testTable struct { + PrimaryKey string + SecondaryKey string + TertiaryKey string + Quaternary string +} + +func TestNewPluginPrimaryKeys(t *testing.T) { + testTransforms := []struct { + transformerOptions []transformers.StructTransformerOption + resultKeys []string + }{ + { + transformerOptions: []transformers.StructTransformerOption{transformers.WithPrimaryKeys("PrimaryKey")}, + resultKeys: []string{"primary_key"}, + }, + { + transformerOptions: []transformers.StructTransformerOption{}, + resultKeys: []string{"_cq_id"}, + }, + } + for _, tc := range testTransforms { + tables := []*schema.Table{ + { + Name: "test_table", + Transform: transformers.TransformWithStruct( + &testTable, tc.transformerOptions..., + ), + }, + } + + plugin := NewPlugin("testSourcePlugin", "1.0.0", tables, newTestExecutionClient) + assert.Equal(t, tc.resultKeys, plugin.tables[0].PrimaryKeys()) + } +} diff --git a/plugins/source/scheduler.go b/plugins/source/scheduler.go new file mode 100644 index 0000000000..e1ec4953d8 --- /dev/null +++ b/plugins/source/scheduler.go @@ -0,0 +1,163 @@ +package source + +import ( + "context" + "errors" + "fmt" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/getsentry/sentry-go" + "github.com/rs/zerolog" + "github.com/thoas/go-funk" +) + +const ( + minTableConcurrency = 1 + minResourceConcurrency = 100 +) + +const periodicMetricLoggerInterval = 30 * time.Second + +func (p *Plugin) logTablesMetrics(tables schema.Tables, client schema.ClientMeta) { + clientName := client.ID() + for _, table := range tables { + metrics := p.metrics.TableClient[table.Name][clientName] + p.logger.Info().Str("table", table.Name).Str("client", clientName).Uint64("resources", metrics.Resources).Uint64("errors", metrics.Errors).Msg("table sync finished") + p.logTablesMetrics(table.Relations, client) + } +} + +func (p *Plugin) resolveResource(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, item any) *schema.Resource { + var validationErr *schema.ValidationError + ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) + defer cancel() + resource := schema.NewResourceData(table, parent, item) + objectStartTime := time.Now() + clientID := client.ID() + tableMetrics := p.metrics.TableClient[table.Name][clientID] + logger := p.logger.With().Str("table", table.Name).Str("client", clientID).Logger() + defer func() { + if err := recover(); err != nil { + stack := fmt.Sprintf("%s\n%s", err, string(debug.Stack())) + logger.Error().Interface("error", err).TimeDiff("duration", time.Now(), objectStartTime).Str("stack", stack).Msg("resource resolver finished with panic") + atomic.AddUint64(&tableMetrics.Panics, 1) + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(stack) + }) + } + }() + if table.PreResourceResolver != nil { + if err := table.PreResourceResolver(ctx, client, resource); err != nil { + logger.Error().Err(err).Msg("pre resource resolver failed") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + return nil + } + } + + for _, c := range table.Columns { + p.resolveColumn(ctx, logger, tableMetrics, client, resource, c) + } + + if table.PostResourceResolver != nil { + if err := table.PostResourceResolver(ctx, client, resource); err != nil { + logger.Error().Stack().Err(err).Msg("post resource resolver finished with error") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + } + } + atomic.AddUint64(&tableMetrics.Resources, 1) + return resource +} + +func (p *Plugin) resolveColumn(ctx context.Context, logger zerolog.Logger, tableMetrics *TableClientMetrics, client schema.ClientMeta, resource *schema.Resource, c schema.Column) { + var validationErr *schema.ValidationError + columnStartTime := time.Now() + defer func() { + if err := recover(); err != nil { + stack := fmt.Sprintf("%s\n%s", err, string(debug.Stack())) + logger.Error().Str("column", c.Name).Interface("error", err).TimeDiff("duration", time.Now(), columnStartTime).Str("stack", stack).Msg("column resolver finished with panic") + atomic.AddUint64(&tableMetrics.Panics, 1) + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", resource.Table.Name) + scope.SetTag("column", c.Name) + sentry.CurrentHub().CaptureMessage(stack) + }) + } + }() + + if c.Resolver != nil { + if err := c.Resolver(ctx, client, resource, c); err != nil { + logger.Error().Err(err).Msg("column resolver finished with error") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", resource.Table.Name) + scope.SetTag("column", c.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + } + } else { + // base use case: try to get column with CamelCase name + v := funk.Get(resource.GetItem(), p.caser.ToPascal(c.Name), funk.WithAllowZero()) + if v != nil { + err := resource.Set(c.Name, v) + if err != nil { + logger.Error().Err(err).Msg("column resolver finished with error") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", resource.Table.Name) + scope.SetTag("column", c.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + } + } + } +} + +func (p *Plugin) periodicMetricLogger(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + + ticker := time.NewTicker(periodicMetricLoggerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.logger.Info(). + Uint64("total_resources", p.metrics.TotalResourcesAtomic()). + Uint64("total_errors", p.metrics.TotalErrorsAtomic()). + Uint64("total_panics", p.metrics.TotalPanicsAtomic()). + Msg("Sync in progress") + } + } +} + +// unparam's suggestion to remove the second parameter is not good advice here. +// nolint:unparam +func max(a, b uint64) uint64 { + if a > b { + return a + } + return b +} diff --git a/plugins/source/scheduler_dfs.go b/plugins/source/scheduler_dfs.go new file mode 100644 index 0000000000..17e7feb2fa --- /dev/null +++ b/plugins/source/scheduler_dfs.go @@ -0,0 +1,230 @@ +package source + +import ( + "context" + "errors" + "fmt" + "runtime/debug" + "sync" + "sync/atomic" + + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/helpers" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/getsentry/sentry-go" + "golang.org/x/sync/semaphore" +) + +func (p *Plugin) syncDfs(ctx context.Context, spec specs.Source, client schema.ClientMeta, tables schema.Tables, resolvedResources chan<- *schema.Resource) { + // This is very similar to the concurrent web crawler problem with some minor changes. + // We are using DFS to make sure memory usage is capped at O(h) where h is the height of the tree. + tableConcurrency := max(spec.Concurrency/minResourceConcurrency, minTableConcurrency) + resourceConcurrency := tableConcurrency * minResourceConcurrency + + p.tableSems = make([]*semaphore.Weighted, p.maxDepth) + for i := uint64(0); i < p.maxDepth; i++ { + p.tableSems[i] = semaphore.NewWeighted(int64(tableConcurrency)) + // reduce table concurrency logarithmically for every depth level + tableConcurrency = max(tableConcurrency/2, minTableConcurrency) + } + p.resourceSem = semaphore.NewWeighted(int64(resourceConcurrency)) + + // we have this because plugins can return sometimes clients in a random way which will cause + // differences between this run and the next one. + preInitialisedClients := make([][]schema.ClientMeta, len(tables)) + for i, table := range tables { + clients := []schema.ClientMeta{client} + if table.Multiplex != nil { + clients = table.Multiplex(client) + } + // Detect duplicate clients while multiplexing + seenClients := make(map[string]bool) + for _, c := range clients { + if _, ok := seenClients[c.ID()]; !ok { + seenClients[c.ID()] = true + } else { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage("duplicate client ID in " + table.Name) + }) + p.logger.Warn().Str("client", c.ID()).Str("table", table.Name).Msg("multiplex returned duplicate client") + } + } + preInitialisedClients[i] = clients + // we do this here to avoid locks so we initial the metrics structure once in the main goroutines + // and then we can just read from it in the other goroutines concurrently given we are not writing to it. + p.metrics.initWithClients(table, clients) + } + + // We start a goroutine that logs the metrics periodically. + // It needs its own waitgroup + var logWg sync.WaitGroup + logWg.Add(1) + + logCtx, logCancel := context.WithCancel(ctx) + go p.periodicMetricLogger(logCtx, &logWg) + + var wg sync.WaitGroup + for i, table := range tables { + table := table + clients := preInitialisedClients[i] + for _, client := range clients { + client := client + if err := p.tableSems[0].Acquire(ctx, 1); err != nil { + // This means context was cancelled + wg.Wait() + // gracefully shut down the logger goroutine + logCancel() + logWg.Wait() + return + } + wg.Add(1) + go func() { + defer wg.Done() + defer p.tableSems[0].Release(1) + // not checking for error here as nothing much todo. + // the error is logged and this happens when context is cancelled + p.resolveTableDfs(ctx, table, client, nil, resolvedResources, 1) + }() + } + } + + // Wait for all the worker goroutines to finish + wg.Wait() + + // gracefully shut down the logger goroutine + logCancel() + logWg.Wait() +} + +func (p *Plugin) resolveTableDfs(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, resolvedResources chan<- *schema.Resource, depth int) { + var validationErr *schema.ValidationError + clientName := client.ID() + logger := p.logger.With().Str("table", table.Name).Str("client", clientName).Logger() + + if parent == nil { // Log only for root tables, otherwise we spam too much. + logger.Info().Msg("top level table resolver started") + } + tableMetrics := p.metrics.TableClient[table.Name][clientName] + + res := make(chan any) + go func() { + defer func() { + if err := recover(); err != nil { + stack := fmt.Sprintf("%s\n%s", err, string(debug.Stack())) + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(stack) + }) + logger.Error().Interface("error", err).Str("stack", stack).Msg("table resolver finished with panic") + atomic.AddUint64(&tableMetrics.Panics, 1) + } + close(res) + }() + if err := table.Resolver(ctx, client, parent, res); err != nil { + logger.Error().Err(err).Msg("table resolver finished with error") + atomic.AddUint64(&tableMetrics.Errors, 1) + if errors.As(err, &validationErr) { + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(validationErr.MaskedError()) + }) + } + return + } + }() + + for r := range res { + p.resolveResourcesDfs(ctx, table, client, parent, r, resolvedResources, depth) + } + + // we don't need any waitgroups here because we are waiting for the channel to close + if parent == nil { // Log only for root tables and relations only after resolving is done, otherwise we spam per object instead of per table. + logger.Info().Uint64("resources", tableMetrics.Resources).Uint64("errors", tableMetrics.Errors).Msg("table sync finished") + p.logTablesMetrics(table.Relations, client) + } +} + +func (p *Plugin) resolveResourcesDfs(ctx context.Context, table *schema.Table, client schema.ClientMeta, parent *schema.Resource, resources any, resolvedResources chan<- *schema.Resource, depth int) { + resourcesSlice := helpers.InterfaceSlice(resources) + if len(resourcesSlice) == 0 { + return + } + resourcesChan := make(chan *schema.Resource, len(resourcesSlice)) + go func() { + defer close(resourcesChan) + var wg sync.WaitGroup + sentValidationErrors := sync.Map{} + for i := range resourcesSlice { + i := i + if err := p.resourceSem.Acquire(ctx, 1); err != nil { + p.logger.Warn().Err(err).Msg("failed to acquire semaphore. context cancelled") + wg.Wait() + // we have to continue emptying the channel to exit gracefully + return + } + wg.Add(1) + go func() { + defer p.resourceSem.Release(1) + defer wg.Done() + //nolint:all + resolvedResource := p.resolveResource(ctx, table, client, parent, resourcesSlice[i]) + if resolvedResource == nil { + return + } + + if err := resolvedResource.CalculateCQID(p.spec.DeterministicCQID); err != nil { + tableMetrics := p.metrics.TableClient[table.Name][client.ID()] + p.logger.Error().Err(err).Str("table", table.Name).Str("client", client.ID()).Msg("resource resolver finished with primary key calculation error") + if _, found := sentValidationErrors.LoadOrStore(table.Name, struct{}{}); !found { + // send resource validation errors to Sentry only once per table, + // to avoid sending too many duplicate messages + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(err.Error()) + }) + } + atomic.AddUint64(&tableMetrics.Errors, 1) + return + } + if err := resolvedResource.Validate(); err != nil { + tableMetrics := p.metrics.TableClient[table.Name][client.ID()] + p.logger.Error().Err(err).Str("table", table.Name).Str("client", client.ID()).Msg("resource resolver finished with validation error") + if _, found := sentValidationErrors.LoadOrStore(table.Name, struct{}{}); !found { + // send resource validation errors to Sentry only once per table, + // to avoid sending too many duplicate messages + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetTag("table", table.Name) + sentry.CurrentHub().CaptureMessage(err.Error()) + }) + } + atomic.AddUint64(&tableMetrics.Errors, 1) + return + } + resourcesChan <- resolvedResource + }() + } + wg.Wait() + }() + + var wg sync.WaitGroup + for resource := range resourcesChan { + resource := resource + resolvedResources <- resource + for _, relation := range resource.Table.Relations { + relation := relation + if err := p.tableSems[depth].Acquire(ctx, 1); err != nil { + // This means context was cancelled + wg.Wait() + return + } + wg.Add(1) + go func() { + defer wg.Done() + defer p.tableSems[depth].Release(1) + p.resolveTableDfs(ctx, relation, client, resource, resolvedResources, depth+1) + }() + } + } + wg.Wait() +} diff --git a/plugins/source/scheduler_round_robin.go b/plugins/source/scheduler_round_robin.go new file mode 100644 index 0000000000..00b1030f68 --- /dev/null +++ b/plugins/source/scheduler_round_robin.go @@ -0,0 +1,104 @@ +package source + +import ( + "context" + "sync" + + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/schema" + "golang.org/x/sync/semaphore" +) + +type tableClient struct { + table *schema.Table + client schema.ClientMeta +} + +func (p *Plugin) syncRoundRobin(ctx context.Context, spec specs.Source, client schema.ClientMeta, tables schema.Tables, resolvedResources chan<- *schema.Resource) { + tableConcurrency := max(spec.Concurrency/minResourceConcurrency, minTableConcurrency) + resourceConcurrency := tableConcurrency * minResourceConcurrency + + p.tableSems = make([]*semaphore.Weighted, p.maxDepth) + for i := uint64(0); i < p.maxDepth; i++ { + p.tableSems[i] = semaphore.NewWeighted(int64(tableConcurrency)) + // reduce table concurrency logarithmically for every depth level + tableConcurrency = max(tableConcurrency/2, minTableConcurrency) + } + p.resourceSem = semaphore.NewWeighted(int64(resourceConcurrency)) + + // we have this because plugins can return sometimes clients in a random way which will cause + // differences between this run and the next one. + preInitialisedClients := make([][]schema.ClientMeta, len(tables)) + for i, table := range tables { + clients := []schema.ClientMeta{client} + if table.Multiplex != nil { + clients = table.Multiplex(client) + } + preInitialisedClients[i] = clients + // we do this here to avoid locks so we initial the metrics structure once in the main goroutines + // and then we can just read from it in the other goroutines concurrently given we are not writing to it. + p.metrics.initWithClients(table, clients) + } + + // We start a goroutine that logs the metrics periodically. + // It needs its own waitgroup + var logWg sync.WaitGroup + logWg.Add(1) + + logCtx, logCancel := context.WithCancel(ctx) + go p.periodicMetricLogger(logCtx, &logWg) + + tableClients := roundRobinInterleave(tables, preInitialisedClients) + + var wg sync.WaitGroup + for _, tc := range tableClients { + table := tc.table + cl := tc.client + if err := p.tableSems[0].Acquire(ctx, 1); err != nil { + // This means context was cancelled + wg.Wait() + // gracefully shut down the logger goroutine + logCancel() + logWg.Wait() + return + } + wg.Add(1) + go func() { + defer wg.Done() + defer p.tableSems[0].Release(1) + // not checking for error here as nothing much to do. + // the error is logged and this happens when context is cancelled + // Round Robin currently uses the DFS algorithm to resolve the tables, but this + // may change in the future. + p.resolveTableDfs(ctx, table, cl, nil, resolvedResources, 1) + }() + } + + // Wait for all the worker goroutines to finish + wg.Wait() + + // gracefully shut down the logger goroutine + logCancel() + logWg.Wait() +} + +// interleave table-clients so that we get: +// table1-client1, table2-client1, table3-client1, table1-client2, table2-client2, table3-client2, ... +func roundRobinInterleave(tables schema.Tables, preInitialisedClients [][]schema.ClientMeta) []tableClient { + tableClients := make([]tableClient, 0) + c := 0 + for { + addedNew := false + for i, table := range tables { + if c < len(preInitialisedClients[i]) { + tableClients = append(tableClients, tableClient{table: table, client: preInitialisedClients[i][c]}) + addedNew = true + } + } + c++ + if !addedNew { + break + } + } + return tableClients +} diff --git a/plugins/source/scheduler_round_robin_test.go b/plugins/source/scheduler_round_robin_test.go new file mode 100644 index 0000000000..8f7e3425f5 --- /dev/null +++ b/plugins/source/scheduler_round_robin_test.go @@ -0,0 +1,65 @@ +package source + +import ( + "testing" + + "github.com/cloudquery/plugin-sdk/v3/schema" +) + +func TestRoundRobinInterleave(t *testing.T) { + table1 := &schema.Table{Name: "test_table"} + table2 := &schema.Table{Name: "test_table2"} + client1 := &testExecutionClient{} + client2 := &testExecutionClient{} + client3 := &testExecutionClient{} + cases := []struct { + name string + tables schema.Tables + preInitialisedClients [][]schema.ClientMeta + want []tableClient + }{ + { + name: "single table", + tables: schema.Tables{table1}, + preInitialisedClients: [][]schema.ClientMeta{{client1}}, + want: []tableClient{{table: table1, client: client1}}, + }, + { + name: "two tables with different clients", + tables: schema.Tables{table1, table2}, + preInitialisedClients: [][]schema.ClientMeta{{client1}, {client1, client2}}, + want: []tableClient{ + {table: table1, client: client1}, + {table: table2, client: client1}, + {table: table2, client: client2}, + }, + }, + { + name: "two tables with different clients", + tables: schema.Tables{table1, table2}, + preInitialisedClients: [][]schema.ClientMeta{{client1, client3}, {client1, client2}}, + want: []tableClient{ + {table: table1, client: client1}, + {table: table2, client: client1}, + {table: table1, client: client3}, + {table: table2, client: client2}, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := roundRobinInterleave(tc.tables, tc.preInitialisedClients) + if len(got) != len(tc.want) { + t.Fatalf("got %d tableClients, want %d", len(got), len(tc.want)) + } + for i := range got { + if got[i].table != tc.want[i].table { + t.Errorf("got table %v, want %v", got[i].table, tc.want[i].table) + } + if got[i].client != tc.want[i].client { + t.Errorf("got client %v, want %v", got[i].client, tc.want[i].client) + } + } + }) + } +} diff --git a/plugins/source/templates/all_tables.md.go.tpl b/plugins/source/templates/all_tables.md.go.tpl new file mode 100644 index 0000000000..008afb66fd --- /dev/null +++ b/plugins/source/templates/all_tables.md.go.tpl @@ -0,0 +1,5 @@ +# Source Plugin: {{.PluginName}} +## Tables +{{- range $table := $.Tables }} +{{- template "all_tables_entry.md.go.tpl" $table}} +{{- end }} \ No newline at end of file diff --git a/plugins/source/templates/all_tables_entry.md.go.tpl b/plugins/source/templates/all_tables_entry.md.go.tpl new file mode 100644 index 0000000000..6166b1983b --- /dev/null +++ b/plugins/source/templates/all_tables_entry.md.go.tpl @@ -0,0 +1,5 @@ + +{{. | indentToDepth}}- [{{.Name}}]({{.Name}}.md){{ if .IsIncremental}} (Incremental){{ end }} +{{- range $index, $rel := .Relations}} +{{- template "all_tables_entry.md.go.tpl" $rel}} +{{- end}} \ No newline at end of file diff --git a/plugins/source/templates/table.md.go.tpl b/plugins/source/templates/table.md.go.tpl new file mode 100644 index 0000000000..45bee702cd --- /dev/null +++ b/plugins/source/templates/table.md.go.tpl @@ -0,0 +1,44 @@ +# Table: {{$.Name}} + +This table shows data for {{.|title}}. + +{{ $.Description }} +{{ $length := len $.PrimaryKeys -}} +{{ if eq $length 1 }} +The primary key for this table is **{{ index $.PrimaryKeys 0 }}**. +{{ else }} +The composite primary key for this table is ({{ range $index, $pk := $.PrimaryKeys -}} + {{if $index }}, {{end -}} + **{{$pk}}** + {{- end -}}). +{{ end }} +{{- if $.IsIncremental -}} +It supports incremental syncs +{{- $ikLength := len $.IncrementalKeys -}} +{{- if eq $ikLength 1 }} based on the **{{ index $.IncrementalKeys 0 }}** column +{{- else if gt $ikLength 1 }} based on the ({{ range $index, $pk := $.IncrementalKeys -}} + {{- if $index -}}, {{end -}} + **{{$pk}}** + {{- end -}}) columns +{{- end -}}. +{{- end -}} + +{{- if or ($.Relations) ($.Parent) }} +## Relations +{{- end }} +{{- if $.Parent }} +This table depends on [{{ $.Parent.Name }}]({{ $.Parent.Name }}.md). +{{- end}} +{{ if $.Relations }} +The following tables depend on {{.Name}}: +{{- range $rel := $.Relations }} + - [{{ $rel.Name }}]({{ $rel.Name }}.md) +{{- end }} +{{- end }} + +## Columns +| Name | Type | +| ------------- | ------------- | +{{- range $column := $.Columns }} +|{{$column.Name}}{{if $column.CreationOptions.PrimaryKey}} (PK){{end}}{{if $column.CreationOptions.IncrementalKey}} (Incremental Key){{end}}|{{$column.Type | formatType}}| +{{- end }} \ No newline at end of file diff --git a/plugins/source/testdata/TestGeneratePluginDocs-JSON-__tables.json b/plugins/source/testdata/TestGeneratePluginDocs-JSON-__tables.json new file mode 100644 index 0000000000..e8bd9f7593 --- /dev/null +++ b/plugins/source/testdata/TestGeneratePluginDocs-JSON-__tables.json @@ -0,0 +1,197 @@ +[ + { + "name": "incremental_table", + "title": "Incremental Table", + "description": "Description for incremental table", + "columns": [ + { + "name": "_cq_source_name", + "type": "utf8" + }, + { + "name": "_cq_sync_time", + "type": "timestamp[us, tz=UTC]" + }, + { + "name": "_cq_id", + "type": "uuid" + }, + { + "name": "_cq_parent_id", + "type": "uuid" + }, + { + "name": "int_col", + "type": "int64" + }, + { + "name": "id_col", + "type": "int64", + "is_primary_key": true, + "is_incremental_key": true + }, + { + "name": "id_col2", + "type": "int64", + "is_incremental_key": true + } + ], + "relations": [] + }, + { + "name": "test_table", + "title": "Test Table", + "description": "Description for test table", + "columns": [ + { + "name": "_cq_source_name", + "type": "utf8" + }, + { + "name": "_cq_sync_time", + "type": "timestamp[us, tz=UTC]" + }, + { + "name": "_cq_id", + "type": "uuid" + }, + { + "name": "_cq_parent_id", + "type": "uuid" + }, + { + "name": "int_col", + "type": "int64" + }, + { + "name": "id_col", + "type": "int64", + "is_primary_key": true + }, + { + "name": "id_col2", + "type": "int64", + "is_primary_key": true + } + ], + "relations": [ + { + "name": "relation_table", + "title": "Relation Table", + "description": "Description for relational table", + "columns": [ + { + "name": "_cq_source_name", + "type": "utf8" + }, + { + "name": "_cq_sync_time", + "type": "timestamp[us, tz=UTC]" + }, + { + "name": "_cq_id", + "type": "uuid", + "is_primary_key": true + }, + { + "name": "_cq_parent_id", + "type": "uuid" + }, + { + "name": "string_col", + "type": "utf8" + } + ], + "relations": [ + { + "name": "relation_relation_table_a", + "title": "Relation Relation Table A", + "description": "Description for relational table's relation", + "columns": [ + { + "name": "_cq_source_name", + "type": "utf8" + }, + { + "name": "_cq_sync_time", + "type": "timestamp[us, tz=UTC]" + }, + { + "name": "_cq_id", + "type": "uuid", + "is_primary_key": true + }, + { + "name": "_cq_parent_id", + "type": "uuid" + }, + { + "name": "string_col", + "type": "utf8" + } + ], + "relations": [] + }, + { + "name": "relation_relation_table_b", + "title": "Relation Relation Table B", + "description": "Description for relational table's relation", + "columns": [ + { + "name": "_cq_source_name", + "type": "utf8" + }, + { + "name": "_cq_sync_time", + "type": "timestamp[us, tz=UTC]" + }, + { + "name": "_cq_id", + "type": "uuid", + "is_primary_key": true + }, + { + "name": "_cq_parent_id", + "type": "uuid" + }, + { + "name": "string_col", + "type": "utf8" + } + ], + "relations": [] + } + ] + }, + { + "name": "relation_table2", + "title": "Relation Table2", + "description": "Description for second relational table", + "columns": [ + { + "name": "_cq_source_name", + "type": "utf8" + }, + { + "name": "_cq_sync_time", + "type": "timestamp[us, tz=UTC]" + }, + { + "name": "_cq_id", + "type": "uuid", + "is_primary_key": true + }, + { + "name": "_cq_parent_id", + "type": "uuid" + }, + { + "name": "string_col", + "type": "utf8" + } + ], + "relations": [] + } + ] + } +] diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-README.md b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-README.md new file mode 100644 index 0000000000..9480a0598a --- /dev/null +++ b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-README.md @@ -0,0 +1,10 @@ +# Source Plugin: test + +## Tables + +- [incremental_table](incremental_table.md) (Incremental) +- [test_table](test_table.md) + - [relation_table](relation_table.md) + - [relation_relation_table_a](relation_relation_table_a.md) + - [relation_relation_table_b](relation_relation_table_b.md) + - [relation_table2](relation_table2.md) diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md new file mode 100644 index 0000000000..67ca4b8539 --- /dev/null +++ b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-incremental_table.md @@ -0,0 +1,20 @@ +# Table: incremental_table + +This table shows data for Incremental Table. + +Description for incremental table + +The primary key for this table is **id_col**. +It supports incremental syncs based on the (**id_col**, **id_col2**) columns. + +## Columns + +| Name | Type | +| ------------- | ------------- | +|_cq_source_name|utf8| +|_cq_sync_time|timestamp[us, tz=UTC]| +|_cq_id|uuid| +|_cq_parent_id|uuid| +|int_col|int64| +|id_col (PK) (Incremental Key)|int64| +|id_col2 (Incremental Key)|int64| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md new file mode 100644 index 0000000000..038791b13e --- /dev/null +++ b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_a.md @@ -0,0 +1,21 @@ +# Table: relation_relation_table_a + +This table shows data for Relation Relation Table A. + +Description for relational table's relation + +The primary key for this table is **_cq_id**. + +## Relations + +This table depends on [relation_table](relation_table.md). + +## Columns + +| Name | Type | +| ------------- | ------------- | +|_cq_source_name|utf8| +|_cq_sync_time|timestamp[us, tz=UTC]| +|_cq_id (PK)|uuid| +|_cq_parent_id|uuid| +|string_col|utf8| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md new file mode 100644 index 0000000000..432f6533f8 --- /dev/null +++ b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_relation_table_b.md @@ -0,0 +1,21 @@ +# Table: relation_relation_table_b + +This table shows data for Relation Relation Table B. + +Description for relational table's relation + +The primary key for this table is **_cq_id**. + +## Relations + +This table depends on [relation_table](relation_table.md). + +## Columns + +| Name | Type | +| ------------- | ------------- | +|_cq_source_name|utf8| +|_cq_sync_time|timestamp[us, tz=UTC]| +|_cq_id (PK)|uuid| +|_cq_parent_id|uuid| +|string_col|utf8| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_table.md b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_table.md new file mode 100644 index 0000000000..7db8baff7e --- /dev/null +++ b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-relation_table.md @@ -0,0 +1,25 @@ +# Table: relation_table + +This table shows data for Relation Table. + +Description for relational table + +The primary key for this table is **_cq_id**. + +## Relations + +This table depends on [test_table](test_table.md). + +The following tables depend on relation_table: + - [relation_relation_table_a](relation_relation_table_a.md) + - [relation_relation_table_b](relation_relation_table_b.md) + +## Columns + +| Name | Type | +| ------------- | ------------- | +|_cq_source_name|utf8| +|_cq_sync_time|timestamp[us, tz=UTC]| +|_cq_id (PK)|uuid| +|_cq_parent_id|uuid| +|string_col|utf8| diff --git a/plugins/source/testdata/TestGeneratePluginDocs-Markdown-test_table.md b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-test_table.md new file mode 100644 index 0000000000..02afdcbc1e --- /dev/null +++ b/plugins/source/testdata/TestGeneratePluginDocs-Markdown-test_table.md @@ -0,0 +1,25 @@ +# Table: test_table + +This table shows data for Test Table. + +Description for test table + +The composite primary key for this table is (**id_col**, **id_col2**). + +## Relations + +The following tables depend on test_table: + - [relation_table](relation_table.md) + - [relation_table2](relation_table2.md) + +## Columns + +| Name | Type | +| ------------- | ------------- | +|_cq_source_name|utf8| +|_cq_sync_time|timestamp[us, tz=UTC]| +|_cq_id|uuid| +|_cq_parent_id|uuid| +|int_col|int64| +|id_col (PK)|int64| +|id_col2 (PK)|int64| diff --git a/plugins/source/testing.go b/plugins/source/testing.go new file mode 100644 index 0000000000..0f86081ec8 --- /dev/null +++ b/plugins/source/testing.go @@ -0,0 +1,140 @@ +package source + +import ( + "context" + "testing" + + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/schema" +) + +type Validator func(t *testing.T, plugin *Plugin, resources []*schema.Resource) + +func TestPluginSync(t *testing.T, plugin *Plugin, spec specs.Source, opts ...TestPluginOption) { + t.Helper() + + o := &testPluginOptions{ + parallel: true, + validators: []Validator{validatePlugin}, + } + for _, opt := range opts { + opt(o) + } + if o.parallel { + t.Parallel() + } + + resourcesChannel := make(chan *schema.Resource) + var syncErr error + + if err := plugin.Init(context.Background(), spec); err != nil { + t.Fatal(err) + } + + go func() { + defer close(resourcesChannel) + syncErr = plugin.Sync(context.Background(), resourcesChannel) + }() + + syncedResources := make([]*schema.Resource, 0) + for resource := range resourcesChannel { + syncedResources = append(syncedResources, resource) + } + if syncErr != nil { + t.Fatal(syncErr) + } + for _, validator := range o.validators { + validator(t, plugin, syncedResources) + } +} + +type TestPluginOption func(*testPluginOptions) + +func WithTestPluginNoParallel() TestPluginOption { + return func(f *testPluginOptions) { + f.parallel = false + } +} + +func WithTestPluginAdditionalValidators(v Validator) TestPluginOption { + return func(f *testPluginOptions) { + f.validators = append(f.validators, v) + } +} + +type testPluginOptions struct { + parallel bool + validators []Validator +} + +func getTableResources(t *testing.T, table *schema.Table, resources []*schema.Resource) []*schema.Resource { + t.Helper() + + tableResources := make([]*schema.Resource, 0) + + for _, resource := range resources { + if resource.Table.Name == table.Name { + tableResources = append(tableResources, resource) + } + } + + return tableResources +} + +func validateTable(t *testing.T, table *schema.Table, resources []*schema.Resource) { + t.Helper() + tableResources := getTableResources(t, table, resources) + if len(tableResources) == 0 { + t.Errorf("Expected table %s to be synced but it was not found", table.Name) + return + } + validateResources(t, tableResources) +} + +func validatePlugin(t *testing.T, plugin *Plugin, resources []*schema.Resource) { + t.Helper() + tables := extractTables(plugin.tables) + for _, table := range tables { + validateTable(t, table, resources) + } +} + +func extractTables(tables schema.Tables) []*schema.Table { + result := make([]*schema.Table, 0) + for _, table := range tables { + result = append(result, table) + result = append(result, extractTables(table.Relations)...) + } + return result +} + +// Validates that every column has at least one non-nil value. +// Also does some additional validations. +func validateResources(t *testing.T, resources []*schema.Resource) { + t.Helper() + + table := resources[0].Table + + // A set of column-names that have values in at least one of the resources. + columnsWithValues := make([]bool, len(table.Columns)) + + for _, resource := range resources { + for i, value := range resource.GetValues() { + if value == nil { + continue + } + if value.IsValid() { + columnsWithValues[i] = true + } + } + } + + // Make sure every column has at least one value. + for i, hasValue := range columnsWithValues { + col := table.Columns[i] + emptyExpected := col.Name == "_cq_parent_id" && table.Parent == nil + if !hasValue && !emptyExpected && !col.IgnoreInTests { + t.Errorf("table: %s column %s has no values", table.Name, table.Columns[i].Name) + } + } +} diff --git a/plugins/source/validate.go b/plugins/source/validate.go new file mode 100644 index 0000000000..835b798c7e --- /dev/null +++ b/plugins/source/validate.go @@ -0,0 +1,25 @@ +package source + +import ( + "fmt" +) + +func (p *Plugin) validate() error { + if err := p.tables.ValidateDuplicateColumns(); err != nil { + return fmt.Errorf("found duplicate columns in source plugin: %s: %w", p.name, err) + } + + if err := p.tables.ValidateDuplicateTables(); err != nil { + return fmt.Errorf("found duplicate tables in source plugin: %s: %w", p.name, err) + } + + if err := p.tables.ValidateTableNames(); err != nil { + return fmt.Errorf("found table with invalid name in source plugin: %s: %w", p.name, err) + } + + if err := p.tables.ValidateColumnNames(); err != nil { + return fmt.Errorf("found column with invalid name in source plugin: %s: %w", p.name, err) + } + + return nil +} diff --git a/scalar/LICENSE b/scalar/LICENSE new file mode 100644 index 0000000000..530597e6f4 --- /dev/null +++ b/scalar/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2013-2021 Jack Christensen +Copyright (c) 2022 CloudQuery inc. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/scalar/README.md b/scalar/README.md new file mode 100644 index 0000000000..e4090ea350 --- /dev/null +++ b/scalar/README.md @@ -0,0 +1,4 @@ +# CloudQuery Type System + +This directory is heavily based on [jackc/pgtype](https://github.com/jackc/pgtype) and modified per CQ needs +and thus fall under the original [MIT license and copyright](./LICENSE). diff --git a/scalar/binary.go b/scalar/binary.go new file mode 100644 index 0000000000..1085780b11 --- /dev/null +++ b/scalar/binary.go @@ -0,0 +1,74 @@ +package scalar + +import ( + "bytes" + + "github.com/apache/arrow/go/v13/arrow" +) + +type Binary struct { + Valid bool + Value []byte +} + + +func (s *Binary) IsValid() bool { + return s.Valid +} + +func (s *Binary) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Binary) + if !ok { + return false + } + return s.Valid == r.Valid && bytes.Equal(s.Value, r.Value) +} + +func (s *Binary) String() string { + if !s.Valid { + return "(null)" + } + return string(s.Value) +} + +func (s *Binary) Set(val any) error { + if val == nil { + s.Valid = false + return nil + } + + switch value := val.(type) { + case []byte: + if value == nil { + return nil + } + s.Value = value + case string: + s.Value = []byte(value) + case *string: + return s.Set(*value) + default: + if originalSrc, ok := underlyingBytesType(value); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: arrow.BinaryTypes.Binary, Msg: noConversion, Value: value} + } + + s.Valid = true + return nil +} + +func (s *Binary) DataType() arrow.DataType { + return arrow.BinaryTypes.Binary +} + +type LargeBinary struct { + Binary +} + +func (s *LargeBinary) DataType() arrow.DataType { + return arrow.BinaryTypes.LargeBinary +} \ No newline at end of file diff --git a/scalar/binary_test.go b/scalar/binary_test.go new file mode 100644 index 0000000000..a3cf0b54f0 --- /dev/null +++ b/scalar/binary_test.go @@ -0,0 +1,28 @@ +package scalar + +import "testing" + +func TestBinarySet(t *testing.T) { + successfulTests := []struct { + source any + result Binary + }{ + {source: []byte{1, 2, 3}, result: Binary{Value: []byte{1, 2, 3}, Valid: true}}, + {source: []byte{}, result: Binary{Value: []byte{}, Valid: true}}, + {source: []byte(nil), result: Binary{}}, + {source: _byteSlice{1, 2, 3}, result: Binary{Value: []byte{1, 2, 3}, Valid: true}}, + {source: _byteSlice(nil), result: Binary{}}, + } + + for i, tt := range successfulTests { + var r Binary + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} \ No newline at end of file diff --git a/scalar/bool.go b/scalar/bool.go new file mode 100644 index 0000000000..71bf4675d5 --- /dev/null +++ b/scalar/bool.go @@ -0,0 +1,66 @@ +package scalar + +import ( + "strconv" + + "github.com/apache/arrow/go/v13/arrow" +) + +type Bool struct { + Valid bool + Value bool +} + +func (s *Bool) IsValid() bool { + return s.Valid +} + +func (s *Bool) DataType() arrow.DataType { + return arrow.FixedWidthTypes.Boolean +} + +func (s *Bool) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Bool) + if !ok { + return false + } + return s.Value == r.Value && s.Valid == r.Valid +} + +func (s *Bool) String() string { + if !s.Valid { + return "(null)" + } + return strconv.FormatBool(s.Value) +} + +func (s *Bool) Set(val any) error { + if val == nil { + s.Valid = false + return nil + } + switch value := val.(type) { + case bool: + s.Value = value + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return &ValidationError{Type: arrow.FixedWidthTypes.Boolean, Msg: "failed to ParseBool", Value: value, Err: err} + } + s.Value = bb + case *bool: + return s.Set(*value) + case *string: + return s.Set(*value) + default: + if originalSrc, ok := underlyingBoolType(value); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: arrow.FixedWidthTypes.Boolean, Msg: noConversion, Value: val} + } + s.Valid = true + return nil +} \ No newline at end of file diff --git a/scalar/bool_test.go b/scalar/bool_test.go new file mode 100644 index 0000000000..01a4ea5a07 --- /dev/null +++ b/scalar/bool_test.go @@ -0,0 +1,33 @@ +package scalar + +import ( + "testing" +) + +func TestBoolSet(t *testing.T) { + successfulTests := []struct { + source any + result Bool + }{ + {source: true, result: Bool{Value: true, Valid: true}}, + {source: false, result: Bool{Value: false, Valid: true}}, + {source: "true", result: Bool{Value: true, Valid: true}}, + {source: "false", result: Bool{Value: false, Valid: true}}, + {source: "t", result: Bool{Value: true, Valid: true}}, + {source: "f", result: Bool{Value: false, Valid: true}}, + {source: _bool(true), result: Bool{Value: true, Valid: true}}, + {source: _bool(false), result: Bool{Value: false, Valid: true}}, + {source: nil, result: Bool{}}, + } + + for i, tt := range successfulTests { + var r Bool + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} \ No newline at end of file diff --git a/scalar/convert.go b/scalar/convert.go new file mode 100644 index 0000000000..b6c3cc4e10 --- /dev/null +++ b/scalar/convert.go @@ -0,0 +1,161 @@ +package scalar + +import ( + "reflect" +) + +// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 +func underlyingNumberType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + a := reflect.TypeOf(convVal) + b := refVal.Type() + return convVal, a != b + case reflect.Int8: + convVal := int8(refVal.Int()) + a := reflect.TypeOf(convVal) + b := refVal.Type() + return convVal, a != b + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := refVal.Int() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := refVal.Uint() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float32: + convVal := float32(refVal.Float()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float64: + convVal := refVal.Float() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false +} + +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + //nolint:gocritic + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} \ No newline at end of file diff --git a/scalar/errors.go b/scalar/errors.go new file mode 100644 index 0000000000..3fa772ba26 --- /dev/null +++ b/scalar/errors.go @@ -0,0 +1,37 @@ +package scalar + +import ( + "fmt" + + "github.com/apache/arrow/go/v13/arrow" +) + +const ( + noConversion = "no conversion available" +) + +type ValidationError struct { + Err error + Msg string + Type arrow.DataType + Value any +} + +func (e *ValidationError) Error() string { + if e.Err == nil { + return fmt.Sprintf("cannot set `%s` with value `%v`: %s", e.Type, e.Value, e.Msg) + } + return fmt.Sprintf("cannot set `%s` with value `%v`: %s (%s)", e.Type, e.Value, e.Msg, e.Err) +} + +// this prints the error without the value +func (e *ValidationError) MaskedError() string { + if e.Err == nil { + return fmt.Sprintf("cannot set `%s`: %s", e.Type, e.Msg) + } + return fmt.Sprintf("cannot set `%s`: %s (%s)", e.Type, e.Msg, e.Err) +} + +func (e *ValidationError) Unwrap() error { + return e.Err +} \ No newline at end of file diff --git a/scalar/float.go b/scalar/float.go new file mode 100644 index 0000000000..7967a43cc6 --- /dev/null +++ b/scalar/float.go @@ -0,0 +1,219 @@ +package scalar + +import ( + "math" + "strconv" + + "github.com/apache/arrow/go/v13/arrow" +) + +type Float32 struct { + Valid bool + Value float32 +} + +func (s *Float32) IsValid() bool { + return s.Valid +} + +func (s *Float32) DataType() arrow.DataType { + return arrow.PrimitiveTypes.Float32 +} + +func (s *Float32) String() string { + if !s.Valid { + return "(null)" + } + return strconv.FormatFloat(float64(s.Value), 'f', -1, 32) +} + +func (s *Float32) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Float32) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value == r.Value +} + +func (s *Float32) Set(val any) error { + if val == nil { + s.Valid = false + return nil + } + + switch value := val.(type) { + case int8: + s.Value = float32(value) + case int16: + if value > math.MaxInt8 { + return &ValidationError{Type: arrow.PrimitiveTypes.Float32, Msg: "int16 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case int32: + if value > math.MaxInt8 { + return &ValidationError{Type: arrow.PrimitiveTypes.Float32, Msg: "int32 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case int64: + if value > math.MaxInt32 { + return &ValidationError{Type: arrow.PrimitiveTypes.Float32, Msg: "int64 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case uint8: + if value > math.MaxInt8 { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "uint8 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case uint16: + if value > math.MaxInt8 { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "uint16 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case uint32: + if value > math.MaxInt32 { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "uint32 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case uint64: + if value > math.MaxInt32 { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "uint64 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case float32: + s.Value = value + case float64: + if value > math.MaxInt32 { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "float64 bigger than MaxInt8", Value: value} + } + s.Value = float32(value) + case string: + num, err := strconv.ParseFloat(value, 32) + if err != nil { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "invalid string", Value: value} + } + s.Value = float32(num) + case *int8: + return s.Set(*value) + case *int16: + return s.Set(*value) + case *int32: + return s.Set(*value) + case *int64: + return s.Set(*value) + case *uint8: + return s.Set(*value) + case *uint16: + return s.Set(*value) + case *uint32: + return s.Set(*value) + case *uint64: + return s.Set(*value) + case *float32: + return s.Set(*value) + case *float64: + return s.Set(*value) + default: + if originalSrc, ok := underlyingNumberType(value); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: noConversion, Value: value} + } + s.Valid = true + return nil +} + +type Float64 struct { + Valid bool + Value float64 +} +func (s *Float64) IsValid() bool { + return s.Valid +} + +func (s *Float64) DataType() arrow.DataType { + return arrow.PrimitiveTypes.Float64 +} + +func (s *Float64) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Float64) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value == r.Value +} + +func (s *Float64) String() string { + if !s.Valid { + return "(null)" + } + return strconv.FormatFloat(s.Value, 'f', -1, 64) +} + +func (s *Float64) Set(val any) error { + if val == nil { + s.Valid = false + return nil + } + + switch value := val.(type) { + case int8: + s.Value = float64(value) + case int16: + s.Value = float64(value) + case int32: + s.Value = float64(value) + case int64: + s.Value = float64(value) + case uint8: + s.Value = float64(value) + case uint16: + s.Value = float64(value) + case uint32: + s.Value = float64(value) + case uint64: + s.Value = float64(value) + case float32: + s.Value = float64(value) + case float64: + s.Value = value + case string: + num, err := strconv.ParseFloat(value, 64) + if err != nil { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "invalid string", Value: value} + } + s.Value = num + case *int8: + return s.Set(*value) + case *int16: + return s.Set(*value) + case *int32: + return s.Set(*value) + case *int64: + return s.Set(*value) + case *uint8: + return s.Set(*value) + case *uint16: + return s.Set(*value) + case *uint32: + return s.Set(*value) + case *uint64: + return s.Set(*value) + case *float32: + return s.Set(*value) + case *float64: + return s.Set(*value) + default: + if originalSrc, ok := underlyingNumberType(value); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: noConversion, Value: value} + } + s.Valid = true + return nil +} \ No newline at end of file diff --git a/scalar/float_test.go b/scalar/float_test.go new file mode 100644 index 0000000000..bb5df96990 --- /dev/null +++ b/scalar/float_test.go @@ -0,0 +1,39 @@ +package scalar + +import "testing" + +func TestFloat64Set(t *testing.T) { + successfulTests := []struct { + source any + result Float64 + }{ + {source: float32(1), result: Float64{Value: 1, Valid: true}}, + {source: float64(1), result: Float64{Value: 1, Valid: true}}, + {source: int8(1), result: Float64{Value: 1, Valid: true}}, + {source: int16(1), result: Float64{Value: 1, Valid: true}}, + {source: int32(1), result: Float64{Value: 1, Valid: true}}, + {source: int64(1), result: Float64{Value: 1, Valid: true}}, + {source: int8(-1), result: Float64{Value: -1, Valid: true}}, + {source: int16(-1), result: Float64{Value: -1, Valid: true}}, + {source: int32(-1), result: Float64{Value: -1, Valid: true}}, + {source: int64(-1), result: Float64{Value: -1, Valid: true}}, + {source: uint8(1), result: Float64{Value: 1, Valid: true}}, + {source: uint16(1), result: Float64{Value: 1, Valid: true}}, + {source: uint32(1), result: Float64{Value: 1, Valid: true}}, + {source: uint64(1), result: Float64{Value: 1, Valid: true}}, + {source: "1", result: Float64{Value: 1, Valid: true}}, + {source: _int8(1), result: Float64{Value: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r Float64 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} \ No newline at end of file diff --git a/scalar/inet.go b/scalar/inet.go new file mode 100644 index 0000000000..fbfc59ee3e --- /dev/null +++ b/scalar/inet.go @@ -0,0 +1,127 @@ +package scalar + +import ( + "encoding" + "fmt" + "net" + "strings" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v3/types" +) + +type Inet struct { + Valid bool + Value *net.IPNet +} + +func (s *Inet) IsValid() bool { + return s.Valid +} + +func (s *Inet) DataType() arrow.DataType { + return types.ExtensionTypes.Inet +} + +func (s *Inet) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Inet) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value.String() == r.Value.String() +} + +func (s *Inet) String() string { + if !s.Valid { + return "(null)" + } + return s.Value.String() +} + +func (s *Inet) Set(val any) error { + if val == nil { + return nil + } + + switch value := val.(type) { + case net.IPNet: + s.Value = &value + case net.IP: + if len(value) == 0 { + return nil + } else { + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + s.Value = &net.IPNet{Mask: mask, IP: value} + } + case string: + ip, ipnet, err := net.ParseCIDR(value) + if err != nil { + ip := net.ParseIP(value) + if ip == nil { + return &ValidationError{Type: types.ExtensionTypes.Inet, Msg: "cannot parse string as IP", Value: value} + } + + if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil { + ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)} + } else { + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + } + } else { + ipnet.IP = ip + if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil { + ipnet.IP = ipv4 + if len(ipnet.Mask) == 16 { + ipnet.Mask = ipnet.Mask[12:] // Not sure this is ever needed. + } + } + } + s.Value = ipnet + case *net.IPNet: + s.Set(*value) + case *net.IP: + s.Set(*value) + case *string: + s.Set(*value) + default: + if tv, ok := value.(encoding.TextMarshaler); ok { + text, err := tv.MarshalText() + if err != nil { + return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "cannot marshal text", Err: err, Value: val} + } + return s.Set(string(text)) + } + if sv, ok := value.(fmt.Stringer); ok { + return s.Set(sv.String()) + } + if originalSrc, ok := underlyingPtrType(val); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: noConversion, Value: value} + } + s.Valid = true + return nil +} + + +// Convert the net.IP to IPv4, if appropriate. +// +// When parsing a string to a net.IP using net.ParseIP() and the like, we get a +// 16 byte slice for IPv4 addresses as well as IPv6 addresses. This function +// calls To4() to convert them to a 4 byte slice. This is useful as it allows +// users of the net.IP check for IPv4 addresses based on the length and makes +// it clear we are handling IPv4 as opposed to IPv6 or IPv4-mapped IPv6 +// addresses. +func maybeGetIPv4(input string, ip net.IP) net.IP { + // Do not do this if the provided input looks like IPv6. This is because + // To4() on IPv4-mapped IPv6 addresses converts them to IPv4, which behave + // different in some cases. + if strings.Contains(input, ":") { + return nil + } + + return ip.To4() +} \ No newline at end of file diff --git a/scalar/inet_test.go b/scalar/inet_test.go new file mode 100644 index 0000000000..d8fdb92132 --- /dev/null +++ b/scalar/inet_test.go @@ -0,0 +1,116 @@ +package scalar + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net" + "strings" + "testing" +) + + +type textMarshaler struct { + Text string +} + +func (t textMarshaler) MarshalText() (text []byte, err error) { + return []byte(t.Text), err +} + +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + +func mustParseInet(t testing.TB, s string) *net.IPNet { + ip, ipnet, err := net.ParseCIDR(s) + if err == nil { + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + return ipnet + } + + // May be bare IP address. + // + ip = net.ParseIP(s) + if ip == nil { + t.Fatal(errors.New("unable to parse inet address")) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) + } + return ipnet +} + +func TestInetSet(t *testing.T) { + successfulTests := []struct { + source any + result Inet + }{ + {source: mustParseCIDR(t, "127.0.0.1/32"), result: Inet{Value: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: Inet{Value: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: "127.0.0.1/32", result: Inet{Value: mustParseCIDR(t, "127.0.0.1/32"), Valid: true}}, + {source: "1.2.3.4/24", result: Inet{Value: &net.IPNet{IP: net.ParseIP("1.2.3.4").To4(), Mask: net.CIDRMask(24, 32)}, Valid: true}}, + {source: "10.0.0.1", result: Inet{Value: mustParseInet(t, "10.0.0.1"), Valid: true}}, + {source: "2607:f8b0:4009:80b::200e", result: Inet{Value: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Valid: true}}, + {source: net.ParseIP(""), result: Inet{}}, + {source: "0.0.0.0/8", result: Inet{Value: mustParseInet(t, "0.0.0.0/8"), Valid: true}}, + {source: "::ffff:0.0.0.0/104", result: Inet{Value: &net.IPNet{IP: net.ParseIP("::ffff:0.0.0.0"), Mask: net.CIDRMask(104, 128)}, Valid: true}}, + {source: textMarshaler{"127.0.0.1"}, result: Inet{Value: mustParseInet(t, "127.0.0.1"), Valid: true}}, + {source: func(s string) fmt.Stringer { + var b strings.Builder + b.WriteString(s) + return &b + }("127.0.0.1"), result: Inet{Value: mustParseInet(t, "127.0.0.1"), Valid: true}}, + } + + for i, tt := range successfulTests { + var r Inet + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} + +func TestInetMarshalUnmarshal(t *testing.T) { + var r Inet + err := r.Set("10.244.0.0/24") + if err != nil { + t.Fatal(err) + } + b, err := json.Marshal(r) + if err != nil { + t.Fatal(err) + } + var r2 Inet + err = json.Unmarshal(b, &r2) + if err != nil { + t.Fatal(err) + } + if !r.Equal(&r2) { + t.Errorf("%v != %v", r, r2) + } + + // workaround this Golang bug: https://github.com/golang/go/issues/35727 + if !bytes.Equal(r.Value.Mask, r2.Value.Mask) { + t.Errorf("%v != %v", r.Value.Mask, r2.Value.Mask) + } + if !net.IP.Equal(r.Value.IP, r2.Value.IP) { + t.Errorf("%v != %v", r.Value.IP, r2.Value.IP) + } +} diff --git a/scalar/int.go b/scalar/int.go new file mode 100644 index 0000000000..cf080f4f13 --- /dev/null +++ b/scalar/int.go @@ -0,0 +1,157 @@ +package scalar + +import ( + "math" + "strconv" + + "github.com/apache/arrow/go/v13/arrow" +) + +type Int64 struct { + Valid bool + Value int64 +} + +func (s *Int64) IsValid() bool { + return s.Valid +} + +func (s *Int64) DataType() arrow.DataType { + return arrow.PrimitiveTypes.Int64 +} + +func (s *Int64) String() string { + if !s.Valid { + return "(null)" + } + return strconv.FormatInt(int64(s.Value), 10) +} + +func (s *Int64) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Int64) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value == r.Value +} + +func (s *Int64) Set(val any) error { + if val == nil { + s.Valid = false + return nil + } + + switch value := val.(type) { + case int8: + s.Value = int64(value) + case int16: + s.Value = int64(value) + case int32: + s.Value = int64(value) + case int64: + s.Value = value + case int: + s.Value = int64(value) + case uint8: + s.Value = int64(value) + case uint16: + s.Value = int64(value) + case uint32: + s.Value = int64(value) + case uint64: + if value > math.MaxInt64 { + return &ValidationError{Type: arrow.PrimitiveTypes.Int64, Msg: "uint64 bigger than MaxInt64", Value: value} + } + s.Value = int64(value) + case uint: + if value > math.MaxInt64 { + return &ValidationError{Type: arrow.PrimitiveTypes.Int64, Msg: "uint bigger than MaxInt64", Value: value} + } + s.Value = int64(value) + case float32: + s.Value = int64(value) + case float64: + s.Value = int64(value) + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return &ValidationError{Type: arrow.PrimitiveTypes.Int64, Msg: "invalid string", Value: value} + } + s.Value = num + case *int8: + if value == nil { + return nil + } + return s.Set(*value) + case *int16: + if value == nil { + return nil + } + return s.Set(*value) + case *int32: + if value == nil { + return nil + } + return s.Set(*value) + case *int64: + if value == nil { + return nil + } + return s.Set(*value) + case *int: + if value == nil { + return nil + } + return s.Set(*value) + case *uint8: + if value == nil { + return nil + } + return s.Set(*value) + case *uint16: + if value == nil { + return nil + } + return s.Set(*value) + case *uint32: + if value == nil { + return nil + } + return s.Set(*value) + case *uint64: + if value == nil { + return nil + } + return s.Set(*value) + case *uint: + if value == nil { + return nil + } + return s.Set(*value) + case *float32: + if value == nil { + return nil + } + return s.Set(*value) + case *float64: + if value == nil { + return nil + } + return s.Set(*value) + case *string: + if value == nil { + return nil + } + return s.Set(*value) + default: + if originalSrc, ok := underlyingNumberType(value); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: arrow.PrimitiveTypes.Int64, Msg: noConversion, Value: value} + } + s.Valid = true + return nil +} \ No newline at end of file diff --git a/scalar/int_test.go b/scalar/int_test.go new file mode 100644 index 0000000000..28f7ca42e1 --- /dev/null +++ b/scalar/int_test.go @@ -0,0 +1,40 @@ +package scalar + +import "testing" + + +func TestInt8Set(t *testing.T) { + successfulTests := []struct { + source any + result Int64 + }{ + {source: int8(1), result: Int64{Value: 1, Valid: true}}, + {source: int16(1), result: Int64{Value: 1, Valid: true}}, + {source: int32(1), result: Int64{Value: 1, Valid: true}}, + {source: int64(1), result: Int64{Value: 1, Valid: true}}, + {source: int8(-1), result: Int64{Value: -1, Valid: true}}, + {source: int16(-1), result: Int64{Value: -1, Valid: true}}, + {source: int32(-1), result: Int64{Value: -1, Valid: true}}, + {source: int64(-1), result: Int64{Value: -1, Valid: true}}, + {source: uint8(1), result: Int64{Value: 1, Valid: true}}, + {source: uint16(1), result: Int64{Value: 1, Valid: true}}, + {source: uint32(1), result: Int64{Value: 1, Valid: true}}, + {source: uint64(1), result: Int64{Value: 1, Valid: true}}, + {source: float32(1), result: Int64{Value: 1, Valid: true}}, + {source: float64(1), result: Int64{Value: 1, Valid: true}}, + {source: "1", result: Int64{Value: 1, Valid: true}}, + {source: _int8(1), result: Int64{Value: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r Int64 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} \ No newline at end of file diff --git a/scalar/json.go b/scalar/json.go new file mode 100644 index 0000000000..86ac65ec30 --- /dev/null +++ b/scalar/json.go @@ -0,0 +1,160 @@ +package scalar + +import ( + "bytes" + "encoding/json" + "reflect" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v3/types" +) + +type JSON struct { + Valid bool + Value []byte +} + +func (s *JSON) IsValid() bool { + return s.Valid +} + +func (s *JSON) DataType() arrow.DataType { + return types.ExtensionTypes.JSON +} + +func (s *JSON) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*JSON) + if !ok { + return false + } + if !s.Valid && !r.Valid{ + return true + } + + if s.Valid != r.Valid { + return false + } + + equal, err := jsonBytesEqual(s.Value, r.Value) + if err != nil { + return false + } + return equal +} + +func (s *JSON) String() string { + if !s.Valid { + return "(null)" + } + return string(s.Value) +} + +func (s *JSON) Set(val any) error { + if val == nil { + return nil + } + + switch value := val.(type) { + case string: + if value == "" { + return nil + } + if !json.Valid([]byte(value)) { + return &ValidationError{Type: types.ExtensionTypes.JSON, Msg: "invalid json string", Value: value} + } + s.Value = []byte(value) + case *string: + if value == nil { + return nil + } + return s.Set(*value) + case []byte: + if value == nil { + return nil + } else { + if string(value) == "" { + return nil + } + + if !json.Valid(value) { + return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "invalid json byte array", Value: value} + } + s.Value = value + } + // Encode* methods are defined on *JSON. If JSON is passed directly then the + // struct itself would be encoded instead of Bytes. This is clearly a footgun + // so detect and return an error. See https://github.com/jackc/pgx/issues/350. + case JSON: + return &ValidationError{Type: types.ExtensionTypes.JSON, Msg: "use pointer to JSON instead of value", Value: value} + default: + buffer := &bytes.Buffer{} + encoder := json.NewEncoder(buffer) + encoder.SetEscapeHTML(false) + err := encoder.Encode(value) + if err != nil { + return err + } + + // JSON encoder adds a newline to the end of the output that we don't want. + buf := bytes.TrimSuffix(buffer.Bytes(), []byte("\n")) + // For map and slice jsons, it is easier for users to work with '[]' or '{}' instead of JSON's 'null'. + if bytes.Equal(buf, []byte(`null`)) { + if isEmptyStringMap(value) { + s.Value = []byte("{}") + s.Valid = true + return nil + } + + if isEmptySlice(value) { + s.Value = []byte("[]") + s.Valid = true + return nil + } + } + s.Value = buf + } + s.Valid = true + return nil +} + +// isEmptyStringMap returns true if the value is a map from string to any (i.e. map[string]any). +// We need to use reflection for this, because it impossible to type-assert a map[string]string into a +// map[string]any. See https://go.dev/doc/faq#convert_slice_of_interface. +func isEmptyStringMap(value any) bool { + if reflect.TypeOf(value).Kind() != reflect.Map { + return false + } + + if reflect.TypeOf(value).Key().Kind() != reflect.String { + return false + } + + return reflect.ValueOf(value).Len() == 0 +} + +// isEmptySlice returns true if the value is a slice (i.e. []any). +// We need to use reflection for this, because it impossible to type-assert a map[string]string into a +// map[string]any. See https://go.dev/doc/faq#convert_slice_of_interface. +func isEmptySlice(value any) bool { + if reflect.TypeOf(value).Kind() != reflect.Slice { + return false + } + + return reflect.ValueOf(value).Len() == 0 +} + + +// JSONBytesEqual compares the JSON in two byte slices. +func jsonBytesEqual(a, b []byte) (bool, error) { + var j, j2 any + if err := json.Unmarshal(a, &j); err != nil { + return false, err + } + if err := json.Unmarshal(b, &j2); err != nil { + return false, err + } + return reflect.DeepEqual(j2, j), nil +} \ No newline at end of file diff --git a/scalar/json_test.go b/scalar/json_test.go new file mode 100644 index 0000000000..42417e8be3 --- /dev/null +++ b/scalar/json_test.go @@ -0,0 +1,59 @@ +package scalar + +import "testing" + +type Foo struct { + Num int +} + +func TestJSONSet(t *testing.T) { + successfulTests := []struct { + source any + result JSON + }{ + {source: "", result: JSON{Value: []byte(""), }}, + {source: "{}", result: JSON{Value: []byte("{}"), Valid: true}}, + {source: `"test"`, result: JSON{Value: []byte(`"test"`), Valid: true}}, + {source: "1", result: JSON{Value: []byte("1"), Valid: true}}, + {source: "[1, 2, 3]", result: JSON{Value: []byte("[1, 2, 3]"), Valid: true}}, + {source: []byte("{}"), result: JSON{Value: []byte("{}"), Valid: true}}, + {source: []byte(`"test"`), result: JSON{Value: []byte(`"test"`), Valid: true}}, + {source: []byte("1"), result: JSON{Value: []byte("1"), Valid: true}}, + {source: []byte("[1, 2, 3]"), result: JSON{Value: []byte("[1, 2, 3]"), Valid: true}}, + {source: ([]byte)(nil), result: JSON{}}, + {source: (*string)(nil), result: JSON{}}, + + {source: []int{1, 2, 3}, result: JSON{Value: []byte("[1,2,3]"), Valid: true}}, + {source: []int(nil), result: JSON{Value: []byte(`[]`), Valid: true}}, + {source: []int{}, result: JSON{Value: []byte(`[]`), Valid: true}}, + {source: []Foo(nil), result: JSON{Value: []byte(`[]`), Valid: true}}, + {source: []Foo{}, result: JSON{Value: []byte(`[]`), Valid: true}}, + {source: []Foo{{1}}, result: JSON{Value: []byte(`[{"Num":1}]`), Valid: true}}, + + {source: map[string]any{"foo": "bar"}, result: JSON{Value: []byte(`{"foo":"bar"}`), Valid: true}}, + {source: map[string]any(nil), result: JSON{Value: []byte(`{}`), Valid: true}}, + {source: map[string]any{}, result: JSON{Value: []byte(`{}`), Valid: true}}, + {source: map[string]string{"foo": "bar"}, result: JSON{Value: []byte(`{"foo":"bar"}`), Valid: true}}, + {source: map[string]string(nil), result: JSON{Value: []byte(`{}`), Valid: true}}, + {source: map[string]string{}, result: JSON{Value: []byte(`{}`), Valid: true}}, + {source: map[string]Foo{"foo": {1}}, result: JSON{Value: []byte(`{"foo":{"Num":1}}`), Valid: true}}, + {source: map[string]Foo(nil), result: JSON{Value: []byte(`{}`), Valid: true}}, + {source: map[string]Foo{}, result: JSON{Value: []byte(`{}`), Valid: true}}, + + {source: nil, result: JSON{}}, + + {source: map[string]any{"test1": "a&b", "test2": "😀"}, result: JSON{Value: []byte(`{"test1": "a&b", "test2": "😀"}`), Valid: true}}, + } + + for i, tt := range successfulTests { + var d JSON + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !d.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, d, tt.result) + } + } +} \ No newline at end of file diff --git a/scalar/list.go b/scalar/list.go new file mode 100644 index 0000000000..50ce3b0b4d --- /dev/null +++ b/scalar/list.go @@ -0,0 +1,92 @@ +package scalar + +import ( + "reflect" + "strings" + + "github.com/apache/arrow/go/v13/arrow" +) + +type List struct { + Valid bool + Value Vector + Type arrow.DataType +} + +func (s *List) IsValid() bool { + return s.Valid +} + +func (s *List) DataType() arrow.DataType { + return s.Type +} + +func (s *List) String() string { + if !s.Valid { + return "(null)" + } + var sb strings.Builder + sb.WriteString("[") + for i, v := range s.Value { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(v.String()) + } + sb.WriteString("]") + return sb.String() +} + +func (s *List) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*List) + if !ok { + return false + } + if s.Valid != r.Valid { + return false + } + if len(s.Value) != len(r.Value) { + return false + } + for i := range s.Value { + if !s.Value[i].Equal(r.Value[i]) { + return false + } + } + return true +} + +func (s *List) Set(val any) error { + if val == nil { + s.Valid = false + return nil + } + if s.Type == nil { + panic("List type is nil") + } + + reflectedValue := reflect.ValueOf(val) + if !reflectedValue.IsValid() || reflectedValue.IsZero() { + return nil + } + + switch reflectedValue.Kind() { + case reflect.Array: + fallthrough + case reflect.Slice: + length := reflectedValue.Len() + s.Value = make(Vector, length) + for i := 0; i < length; i++ { + s.Value[i] = NewScalar(s.Type.(*arrow.ListType).Elem()) + if err := s.Value[i].Set(reflectedValue.Index(i).Interface()); err != nil { + return err + } + } + } + + s.Valid = true + return nil +} \ No newline at end of file diff --git a/scalar/list_test.go b/scalar/list_test.go new file mode 100644 index 0000000000..f21c1da683 --- /dev/null +++ b/scalar/list_test.go @@ -0,0 +1,33 @@ +package scalar + +import ( + "testing" + + "github.com/apache/arrow/go/v13/arrow" +) + +func TestListSet(t *testing.T) { + successfulTests := []struct { + source any + result List + }{ + {source: []int{1,2}, result: List{Value: []Scalar{ + &Int64{Value: 1, Valid: true}, + &Int64{Value: 2, Valid: true}, + }, Valid: true, Type: arrow.ListOf(arrow.PrimitiveTypes.Int64)}}, + } + + for i, tt := range successfulTests { + r := List{ + Type: tt.result.Type, + } + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} \ No newline at end of file diff --git a/scalar/mac.go b/scalar/mac.go new file mode 100644 index 0000000000..5e8c625180 --- /dev/null +++ b/scalar/mac.go @@ -0,0 +1,79 @@ +package scalar + +import ( + "net" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v3/types" +) + + +type Mac struct { + Valid bool + Value net.HardwareAddr +} + +func (s *Mac) IsValid() bool { + return s.Valid +} + +func (s *Mac) DataType() arrow.DataType { + return types.ExtensionTypes.Mac +} + +func (s *Mac) String() string { + if !s.Valid { + return "(null)" + } + return s.Value.String() +} + +func (s *Mac) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Mac) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value.String() == r.Value.String() +} + +func (s *Mac) Set(val any) error { + if val == nil { + return nil + } + + switch value := val.(type) { + case net.HardwareAddr: + addr := make(net.HardwareAddr, len(value)) + copy(addr, value) + s.Value = addr + case string: + addr, err := net.ParseMAC(value) + if err != nil { + return err + } + s.Value = addr + case *net.HardwareAddr: + if value == nil { + return nil + } else { + return s.Set(*value) + } + case *string: + if value == nil { + return nil + } else { + return s.Set(*value) + } + default: + if originalSrc, ok := underlyingPtrType(value); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: types.ExtensionTypes.Mac, Msg: noConversion, Value: value} + } + + s.Valid = true + return nil +} \ No newline at end of file diff --git a/scalar/mac_test.go b/scalar/mac_test.go new file mode 100644 index 0000000000..86099b1af9 --- /dev/null +++ b/scalar/mac_test.go @@ -0,0 +1,43 @@ +package scalar + +import ( + "net" + "testing" +) + +func TestMacaddrSet(t *testing.T) { + successfulTests := []struct { + source any + result Mac + }{ + { + source: mustParseMacaddr(t, "01:23:45:67:89:ab"), + result: Mac{Value: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + }, + { + source: "01:23:45:67:89:ab", + result: Mac{Value: mustParseMacaddr(t, "01:23:45:67:89:ab"), Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r Mac + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} + +func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { + addr, err := net.ParseMAC(s) + if err != nil { + t.Fatal(err) + } + + return addr +} \ No newline at end of file diff --git a/scalar/scalar.go b/scalar/scalar.go new file mode 100644 index 0000000000..1c4316c176 --- /dev/null +++ b/scalar/scalar.go @@ -0,0 +1,130 @@ +package scalar + +import ( + "fmt" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/cloudquery/plugin-sdk/v3/types" +) + +// Scalar represents a single value of a specific DataType as opposed to +// an array. +// +// Scalars are useful for passing single value inputs to compute functions +// (not yet implemented) or for representing individual array elements, +// (with a non-trivial cost though). +type Scalar interface { + fmt.Stringer + // IsValid returns true if the value is non-null, otherwise false. + IsValid() bool + // The datatype of the value in this scalar + DataType() arrow.DataType + // Performs cheap validation checks, returns nil if successful + // Validate() error + // tries to set the value of the scalar to the given value + Set(val any) error + Equal(other Scalar) bool +} + +type Vector []Scalar + +func (v Vector) Equal(r Vector) bool { + if len(v) != len(r) { + return false + } + for i := range v { + if !v[i].Equal(r[i]) { + return false + } + } + return true +} + +func NewScalar(dt arrow.DataType) Scalar { + switch dt.ID() { + case arrow.BINARY: + return &Binary{} + case arrow.STRING: + return &String{} + case arrow.INT64: + return &Int64{} + case arrow.UINT64: + return &Uint64{} + case arrow.FLOAT32: + return &Float32{} + case arrow.FLOAT64: + return &Float64{} + case arrow.BOOL: + return &Bool{} + case arrow.EXTENSION: + if arrow.TypeEqual(dt, types.ExtensionTypes.UUID) { + return &UUID{} + } else if arrow.TypeEqual(dt, types.ExtensionTypes.JSON) { + return &JSON{} + } else if arrow.TypeEqual(dt, types.ExtensionTypes.Mac) { + return &Mac{} + } else if arrow.TypeEqual(dt, types.ExtensionTypes.Inet) { + return &Inet{} + } else { + panic("not implemented extension: " + dt.Name()) + } + case arrow.LIST: + return &List{ + Type: dt, + } + default: + panic("not implemented: " + dt.Name()) + } +} + +func AppendToBuilder(bldr array.Builder, s Scalar) { + switch s.DataType().ID() { + case arrow.BINARY: + bldr.(*array.BinaryBuilder).Append(s.(*Binary).Value) + case arrow.LARGE_BINARY: + bldr.(*array.BinaryBuilder).Append(s.(*LargeBinary).Value) + case arrow.STRING: + bldr.(*array.StringBuilder).Append(s.(*String).Value) + case arrow.INT64: + bldr.(*array.Int64Builder).Append(s.(*Int64).Value) + case arrow.UINT64: + bldr.(*array.Uint64Builder).Append(s.(*Uint64).Value) + case arrow.FLOAT32: + bldr.(*array.Float32Builder).Append(s.(*Float32).Value) + case arrow.FLOAT64: + bldr.(*array.Float64Builder).Append(s.(*Float64).Value) + case arrow.BOOL: + bldr.(*array.BooleanBuilder).Append(s.(*Bool).Value) + case arrow.LIST: + lb := bldr.(*array.ListBuilder) + if s.IsValid() { + lb.Append(true) + for _, v := range s.(*List).Value { + AppendToBuilder(lb.ValueBuilder(), v) + } + } else { + lb.AppendNull() + } + case arrow.EXTENSION: + if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.UUID) { + bldr.(*types.UUIDBuilder).Append(s.(*UUID).Value) + } else if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.JSON) { + bldr.(*types.JSONBuilder).Append(s.(*JSON).Value) + } else if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Mac) { + bldr.(*types.MacBuilder).Append(s.(*Mac).Value) + } else if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Inet) { + bldr.(*types.InetBuilder).Append(s.(*Inet).Value) + } else { + panic("not implemented extension: " + s.DataType().Name()) + } + default: + panic("not implemented: " + s.DataType().String()) + } +} + +func AppendToRecordBuilder(bldr *array.RecordBuilder, vector Vector) { + for i, scalar := range vector { + AppendToBuilder(bldr.Field(i), scalar) + } +} \ No newline at end of file diff --git a/scalar/string.go b/scalar/string.go new file mode 100644 index 0000000000..64c7c749fc --- /dev/null +++ b/scalar/string.go @@ -0,0 +1,78 @@ +package scalar + +import ( + "fmt" + + "github.com/apache/arrow/go/v13/arrow" +) + +type String struct { + Valid bool + Value string +} + +func (s *String) IsValid() bool { + return s.Valid +} + +func (s *String) DataType() arrow.DataType { + return arrow.BinaryTypes.String +} + +func (s *String) String() string { + if !s.Valid { + return "(null)" + } + return s.Value +} + +func (s *String) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*String) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value == r.Value +} + +func (s *String) Set(val any) error { + if val == nil { + s.Valid = false + return nil + } + + switch value := val.(type) { + case []byte: + s.Value = string(value) + case string: + s.Value = (value) + case fmt.Stringer: + s.Value = value.String() + case *string: + if value == nil { + return nil + } + return s.Set(*value) + default: + if originalSrc, ok := underlyingStringType(value); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: arrow.BinaryTypes.String, Msg: noConversion, Value: value} + } + + s.Valid = true + return nil +} + +type _smallString String + +type LargeString struct { + _smallString +} + +func (s *LargeString) DataType() arrow.DataType { + return arrow.BinaryTypes.LargeString +} + diff --git a/scalar/string_test.go b/scalar/string_test.go new file mode 100644 index 0000000000..813aa0050a --- /dev/null +++ b/scalar/string_test.go @@ -0,0 +1,27 @@ +package scalar + +import "testing" + + +func TestStringSet(t *testing.T) { + successfulTests := []struct { + source any + result String + }{ + {source: "foo", result: String{Value: "foo", Valid: true}}, + {source: _string("bar"), result: String{Value: "bar", Valid: true}}, + {source: (*string)(nil), result: String{}}, + } + + for i, tt := range successfulTests { + var d String + err := d.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} \ No newline at end of file diff --git a/scalar/type_test.go b/scalar/type_test.go new file mode 100644 index 0000000000..6a9e7e44c4 --- /dev/null +++ b/scalar/type_test.go @@ -0,0 +1,8 @@ +package scalar + +// Test for renamed types +type _string string +type _bool bool +type _int8 int8 + +type _byteSlice []byte \ No newline at end of file diff --git a/scalar/uint.go b/scalar/uint.go new file mode 100644 index 0000000000..8785ddc316 --- /dev/null +++ b/scalar/uint.go @@ -0,0 +1,166 @@ +package scalar + +import ( + "strconv" + + "github.com/apache/arrow/go/v13/arrow" +) + +type Uint64 struct { + Valid bool + Value uint64 +} + +func (n *Uint64) IsValid() bool { + return n.Valid +} + +func (n *Uint64) DataType() arrow.DataType { + return arrow.PrimitiveTypes.Uint64 +} + +func (s *Uint64) String() string { + if !s.Valid { + return "(null)" + } + return strconv.FormatUint(s.Value, 10) +} + +func (s *Uint64) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Uint64) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value == r.Value +} + +func (n *Uint64) Set(val any) error { + if val == nil { + n.Valid = false + return nil + } + + switch value := val.(type) { + case int8: + if value < 0 { + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int8 less than 0", Value: value} + } + n.Value = uint64(value) + case int16: + if value < 0 { + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int16 less than 0", Value: value} + } + n.Value = uint64(value) + case int32: + if value < 0 { + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int32 less than 0", Value: value} + } + n.Value = uint64(value) + case int64: + if value < 0 { + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int64 less than 0", Value: value} + } + n.Value = uint64(value) + case int: + if value < 0 { + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int less than 0", Value: value} + } + n.Value = uint64(value) + case uint8: + n.Value = uint64(value) + case uint16: + n.Value = uint64(value) + case uint32: + n.Value = uint64(value) + case uint64: + n.Value = value + case uint: + n.Value = uint64(value) + case float32: + if value < 0 { + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "float32 less than 0", Value: value} + } + n.Value = uint64(value) + case float64: + if value < 0 { + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "float64 less than 0", Value: value} + } + n.Value = uint64(value) + case string: + num, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "invalid string", Value: value} + } + n.Value = num + case *int8: + if value == nil { + return nil + } + return n.Set(*value) + case *int16: + if value == nil { + return nil + } + return n.Set(*value) + case *int32: + if value == nil { + return nil + } + return n.Set(*value) + case *int64: + if value == nil { + return nil + } + return n.Set(*value) + case *int: + if value == nil { + return nil + } + return n.Set(*value) + case *uint8: + if value == nil { + return nil + } + return n.Set(*value) + case *uint16: + if value == nil { + return nil + } + return n.Set(*value) + case *uint32: + if value == nil { + return nil + } + return n.Set(*value) + case *uint64: + if value == nil { + return nil + } + return n.Set(*value) + case *uint: + if value == nil { + return nil + } + return n.Set(*value) + case *float32: + if value == nil { + return nil + } + return n.Set(*value) + case *float64: + if value == nil { + return nil + } + return n.Set(*value) + default: + if originalSrc, ok := underlyingNumberType(value); ok { + return n.Set(originalSrc) + } + return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: noConversion, Value: value} + } + n.Valid = true + return nil +} \ No newline at end of file diff --git a/scalar/uint_test.go b/scalar/uint_test.go new file mode 100644 index 0000000000..2281c9ca72 --- /dev/null +++ b/scalar/uint_test.go @@ -0,0 +1,36 @@ +package scalar + +import "testing" + + +func TestUint64Set(t *testing.T) { + successfulTests := []struct { + source any + result Uint64 + }{ + {source: int8(1), result: Uint64{Value: 1, Valid: true}}, + {source: int16(1), result: Uint64{Value: 1, Valid: true}}, + {source: int32(1), result: Uint64{Value: 1, Valid: true}}, + {source: int64(1), result: Uint64{Value: 1, Valid: true}}, + {source: uint8(1), result: Uint64{Value: 1, Valid: true}}, + {source: uint16(1), result: Uint64{Value: 1, Valid: true}}, + {source: uint32(1), result: Uint64{Value: 1, Valid: true}}, + {source: uint64(1), result: Uint64{Value: 1, Valid: true}}, + {source: float32(1), result: Uint64{Value: 1, Valid: true}}, + {source: float64(1), result: Uint64{Value: 1, Valid: true}}, + {source: "1", result: Uint64{Value: 1, Valid: true}}, + {source: _int8(1), result: Uint64{Value: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r Uint64 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} \ No newline at end of file diff --git a/scalar/uuid.go b/scalar/uuid.go new file mode 100644 index 0000000000..5ec5a160c3 --- /dev/null +++ b/scalar/uuid.go @@ -0,0 +1,105 @@ +package scalar + +import ( + "encoding/hex" + "fmt" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v3/types" + "github.com/google/uuid" +) + +type UUID struct { + Valid bool + Value uuid.UUID +} + +func (s *UUID) IsValid() bool { + return s.Valid +} + +func (s *UUID) DataType() arrow.DataType { + return types.ExtensionTypes.UUID +} + +func (s *UUID) String() string { + if !s.Valid { + return "(null)" + } + return s.Value.String() +} + +func (s *UUID) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*UUID) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value == r.Value +} + +func (s *UUID) Set(src any) error { + if src == nil { + return nil + } + + switch value := src.(type) { + case fmt.Stringer: + value2 := value.String() + return s.Set(value2) + case [16]byte: + s.Value = uuid.UUID(value) + case []byte: + if value != nil { + if len(value) != 16 { + return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "[]byte must be 16 bytes to convert to UUID", Value: value} + } + copy(s.Value[:], value) + } else { + return nil + } + case string: + uuid, err := parseUUID(value) + if err != nil { + return err + } + s.Value = uuid + case *string: + if value == nil { + return nil + } else { + return s.Set(*value) + } + default: + if originalSrc, ok := underlyingUUIDType(src); ok { + return s.Set(originalSrc) + } + return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: noConversion, Value: value} + } + s.Valid = true + return nil +} + +// parseUUID converts a string UUID in standard form to a byte array. +func parseUUID(src string) (dst [16]byte, err error) { + switch len(src) { + case 36: + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + case 32: + // dashes already stripped, assume valid + default: + // assume invalid. + return dst, &ValidationError{Type: types.ExtensionTypes.UUID, Msg: fmt.Sprintf("invalid %d UUID length", len(src)), Value: src} + } + + buf, err := hex.DecodeString(src) + if err != nil { + return dst, err + } + + copy(dst[:], buf) + return dst, err +} + diff --git a/scalar/uuid_test.go b/scalar/uuid_test.go new file mode 100644 index 0000000000..6301b4f5a0 --- /dev/null +++ b/scalar/uuid_test.go @@ -0,0 +1,66 @@ +package scalar + +import "testing" + +type SomeUUIDWrapper struct { + SomeUUIDType +} + +type SomeUUIDType [16]byte +type StringUUIDType string + +func (s StringUUIDType) String() string { + return string(s) +} + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source any + result UUID + }{ + { + source: nil, + result: UUID{}, + }, + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: UUID{Value: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: UUID{Value: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: UUID{Value: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: ([]byte)(nil), + result: UUID{}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: UUID{Value: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: "000102030405060708090a0b0c0d0e0f", + result: UUID{Value: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + { + source: StringUUIDType("00010203-0405-0607-0809-0a0b0c0d0e0f"), + result: UUID{Value: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, + }, + } + + for i, tt := range successfulTests { + var r UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} diff --git a/schema/arrow.go b/schema/arrow.go index 56e51de354..b24d49d835 100644 --- a/schema/arrow.go +++ b/schema/arrow.go @@ -1,7 +1,11 @@ package schema import ( + "bytes" + "fmt" + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/ipc" ) const ( @@ -34,3 +38,37 @@ func (s Schemas) SchemaByName(name string) *arrow.Schema { } return nil } + + +func (s Schemas) Encode() ([][]byte, error) { + ret := make([][]byte, len(s)) + for i, sc := range s { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(sc)) + if err := wr.Close(); err != nil { + return nil, err + } + ret[i] = buf.Bytes() + } + return ret, nil +} + +func NewSchemasFromBytes(b [][]byte) (Schemas, error) { + ret := make([]*arrow.Schema, len(b)) + for i, buf := range b { + rdr, err := ipc.NewReader(bytes.NewReader(buf)) + if err != nil { + return nil, err + } + ret[i] = rdr.Schema() + } + return ret, nil +} + +func NewTablesFromBytes(b [][]byte) (Tables, error) { + schemas, err := NewSchemasFromBytes(b) + if err != nil { + return nil, fmt.Errorf("failed to decode schemas: %w", err) + } + return NewTablesFromArrowSchemas(schemas) +} \ No newline at end of file diff --git a/schema/arrow_test.go b/schema/arrow_test.go new file mode 100644 index 0000000000..505e3925c7 --- /dev/null +++ b/schema/arrow_test.go @@ -0,0 +1,44 @@ +package schema + +import ( + "testing" + + "github.com/apache/arrow/go/v13/arrow" +) + +func TestSchemaEncode(t *testing.T) { + md := arrow.NewMetadata([]string{"true"}, []string{"false"}) + md1 := arrow.NewMetadata([]string{"false"}, []string{"true"}) + schemas := Schemas{ + arrow.NewSchema( + []arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64}, + {Name: "name", Type: arrow.BinaryTypes.String}, + }, + &md, + ), + arrow.NewSchema( + []arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64}, + {Name: "name", Type: arrow.BinaryTypes.String}, + }, + &md1, + ), + } + b, err := schemas.Encode() + if err != nil { + t.Fatal(err) + } + decodedSchemas, err := NewSchemasFromBytes(b) + if err != nil { + t.Fatal(err) + } + if len(decodedSchemas) != len(schemas) { + t.Fatalf("expected %d schemas, got %d", len(schemas), len(decodedSchemas)) + } + for i := range schemas { + if !schemas[i].Equal(decodedSchemas[i]) { + t.Fatalf("expected schema %d to be %v, got %v", i, schemas[i], decodedSchemas[i]) + } + } +} \ No newline at end of file diff --git a/schema/meta.go b/schema/meta.go index b6b3c16cd7..ec8be655ad 100644 --- a/schema/meta.go +++ b/schema/meta.go @@ -4,8 +4,8 @@ import ( "context" "github.com/apache/arrow/go/v13/arrow" + "github.com/cloudquery/plugin-sdk/v3/scalar" "github.com/cloudquery/plugin-sdk/v3/types" - "github.com/google/uuid" ) type ClientMeta interface { @@ -49,7 +49,7 @@ func parentCqUUIDResolver() ColumnResolver { if parentCqID == nil { return r.Set(c.Name, nil) } - pUUID, ok := parentCqID.(*uuid.UUID) + pUUID, ok := parentCqID.(*scalar.UUID) if !ok { return r.Set(c.Name, nil) } diff --git a/schema/resource.go b/schema/resource.go index 320d87d6e4..cb4592114e 100644 --- a/schema/resource.go +++ b/schema/resource.go @@ -1,7 +1,12 @@ package schema import ( + "crypto/sha256" + "fmt" + + "github.com/cloudquery/plugin-sdk/v3/scalar" "github.com/google/uuid" + "golang.org/x/exp/slices" ) type Resources []*Resource @@ -16,7 +21,7 @@ type Resource struct { // internal fields Table *Table // This is sorted result data by column name - data map[string]any + data scalar.Vector // bldr array.RecordBuilder } @@ -25,16 +30,22 @@ func NewResourceData(t *Table, parent *Resource, item any) *Resource { Item: item, Parent: parent, Table: t, - data: make(map[string]any, len(t.Columns)), + data: make(scalar.Vector, len(t.Columns)), } - for _, c := range t.Columns { - r.data[c.Name] = nil + for i := range r.data { + r.data[i] = scalar.NewScalar(t.Columns[i].Type) } return &r } -func (r *Resource) Get(columnName string) any { - return r.data[columnName] +func (r *Resource) Get(columnName string) scalar.Scalar { + index := r.Table.Columns.Index(columnName) + if index == -1 { + // we panic because we want to distinguish between code error and api error + // this also saves additional checks in our testing code + panic(columnName + " column not found") + } + return r.data[index] } // Set sets a column with value. This does validation and conversion to @@ -47,7 +58,9 @@ func (r *Resource) Set(columnName string, value any) error { // this also saves additional checks in our testing code panic(columnName + " column not found") } - r.data[columnName] = value + if err := r.data[index].Set(value); err != nil { + panic(fmt.Errorf("failed to set column %s: %w", columnName, err)) + } return nil } @@ -60,19 +73,50 @@ func (r *Resource) GetItem() any { return r.Item } -//nolint:revive -func (*Resource) CalculateCQID(deterministicCQID bool) error { - panic("not implemented") +func (r *Resource) GetValues() scalar.Vector { + return r.data } -//nolint:unused,revive -func (*Resource) storeCQID(value uuid.UUID) error { - panic("not implemented") +func (r *Resource) CalculateCQID(deterministicCQID bool) error { + if !deterministicCQID { + return r.storeCQID(uuid.New()) + } + names := r.Table.PrimaryKeys() + if len(names) == 0 || (len(names) == 1 && names[0] == CqIDColumn.Name) { + return r.storeCQID(uuid.New()) + } + slices.Sort(names) + h := sha256.New() + for _, name := range names { + // We need to include the column name in the hash because the same value can be present in multiple columns and therefore lead to the same hash + h.Write([]byte(name)) + h.Write([]byte(r.Get(name).String())) + } + return r.storeCQID(uuid.NewSHA1(uuid.UUID{}, h.Sum(nil))) +} + +func (r *Resource) storeCQID(value uuid.UUID) error { + b, err := value.MarshalBinary() + if err != nil { + return err + } + return r.Set(CqIDColumn.Name, b) } // Validates that all primary keys have values. -func (*Resource) Validate() error { - panic("not implemented") +func (r *Resource) Validate() error { + var missingPks []string + for i, c := range r.Table.Columns { + if c.CreationOptions.PrimaryKey { + if !r.data[i].IsValid() { + missingPks = append(missingPks, c.Name) + } + } + } + if len(missingPks) > 0 { + return fmt.Errorf("missing primary key on columns: %v", missingPks) + } + return nil } func (rr Resources) TableName() string { diff --git a/schema/table.go b/schema/table.go index c70e21e767..9e4f572833 100644 --- a/schema/table.go +++ b/schema/table.go @@ -93,6 +93,18 @@ var ( reValidColumnName = regexp.MustCompile(`^[a-z_][a-z\d_]*$`) ) +func NewTablesFromArrowSchemas(schemas []*arrow.Schema) (Tables, error) { + tables := make(Tables, len(schemas)) + for i, schema := range schemas { + table, err := NewTableFromArrowSchema(schema) + if err != nil { + return nil, err + } + tables[i] = table + } + return tables, nil +} + // Create a CloudQuery Table abstraction from an arrow schema // arrow schema is a low level representation of a table that can be sent // over the wire in a cross-language way @@ -367,7 +379,7 @@ func (t *Table) GetChanges(old *Table) []TableColumnChange { continue } // Column type or options (e.g. PK, Not Null) changed in the new table definition - if c.Type != otherColumn.Type || c.NotNull != otherColumn.NotNull || c.PrimaryKey != otherColumn.PrimaryKey { + if !arrow.TypeEqual(c.Type, otherColumn.Type) || c.NotNull != otherColumn.NotNull || c.PrimaryKey != otherColumn.PrimaryKey { changes = append(changes, TableColumnChange{ Type: TableColumnChangeTypeUpdate, ColumnName: c.Name, diff --git a/serve/destination.go b/serve/destination.go index 4dfbf482e6..72a3e4770a 100644 --- a/serve/destination.go +++ b/serve/destination.go @@ -10,10 +10,13 @@ import ( "syscall" pbv0 "github.com/cloudquery/plugin-pb-go/pb/destination/v0" + pbv1 "github.com/cloudquery/plugin-pb-go/pb/destination/v1" pbdiscoveryv0 "github.com/cloudquery/plugin-pb-go/pb/discovery/v0" servers "github.com/cloudquery/plugin-sdk/v3/internal/servers/destination/v0" + serversv1 "github.com/cloudquery/plugin-sdk/v3/internal/servers/destination/v1" discoveryServerV0 "github.com/cloudquery/plugin-sdk/v3/internal/servers/discovery/v0" "github.com/cloudquery/plugin-sdk/v3/plugins/destination" + "github.com/cloudquery/plugin-sdk/v3/types" "github.com/getsentry/sentry-go" grpczerolog "github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" @@ -115,8 +118,12 @@ func newCmdDestinationServe(serve *destinationServe) *cobra.Command { Plugin: serve.plugin, Logger: logger, }) + pbv1.RegisterDestinationServer(s, &serversv1.Server{ + Plugin: serve.plugin, + Logger: logger, + }) pbdiscoveryv0.RegisterDiscoveryServer(s, &discoveryServerV0.Server{ - Versions: []string{"v0"}, + Versions: []string{"v0", "v1"}, }) version := serve.plugin.Version() @@ -144,6 +151,11 @@ func newCmdDestinationServe(serve *destinationServe) *cobra.Command { log.Error().Err(err).Msg("Error initializing sentry") } } + + if err := types.RegisterAllExtensions(); err != nil { + return err + } + ctx := cmd.Context() c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) diff --git a/serve/destination_v1_test.go b/serve/destination_v1_test.go new file mode 100644 index 0000000000..4a66aba688 --- /dev/null +++ b/serve/destination_v1_test.go @@ -0,0 +1,187 @@ +package serve + +import ( + "bytes" + "context" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/ipc" + pb "github.com/cloudquery/plugin-pb-go/pb/destination/v1" + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/internal/memdb" + "github.com/cloudquery/plugin-sdk/v3/plugins/destination" + "github.com/cloudquery/plugin-sdk/v3/schema" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestDestinationV1(t *testing.T) { + plugin := destination.NewPlugin("testDestinationPlugin", "development", memdb.NewClient) + s := &destinationServe{ + plugin: plugin, + } + cmd := newCmdDestinationRoot(s) + cmd.SetArgs([]string{"serve", "--network", "test"}) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + wg.Add(1) + var serverErr error + go func() { + defer wg.Done() + serverErr = cmd.ExecuteContext(ctx) + }() + defer func() { + cancel() + wg.Wait() + }() + + // wait for the server to start + for { + testDestinationListenerLock.Lock() + if testDestinationListener != nil { + testDestinationListenerLock.Unlock() + break + } + testDestinationListenerLock.Unlock() + t.Log("waiting for grpc server to start") + time.Sleep(time.Millisecond * 200) + } + + // https://stackoverflow.com/questions/42102496/testing-a-grpc-service + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDestinationDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + c := pb.NewDestinationClient(conn) + spec := specs.Destination{ + WriteMode: specs.WriteModeAppend, + } + specBytes, err := json.Marshal(spec) + if err != nil { + t.Fatal(err) + } + if _, err := c.Configure(ctx, &pb.Configure_Request{Config: specBytes}); err != nil { + t.Fatal(err) + } + + getNameRes, err := c.GetName(ctx, &pb.GetName_Request{}) + if err != nil { + t.Fatal(err) + } + if getNameRes.Name != "testDestinationPlugin" { + t.Fatalf("expected name to be testDestinationPlugin but got %s", getNameRes.Name) + } + + getVersionRes, err := c.GetVersion(ctx, &pb.GetVersion_Request{}) + if err != nil { + t.Fatal(err) + } + if getVersionRes.Version != "development" { + t.Fatalf("expected version to be development but got %s", getVersionRes.Version) + } + + tableName := "test_destination_serve" + sourceName := "test_destination_serve_source" + syncTime := time.Now() + table := schema.TestTable(tableName) + tables := schema.Tables{table} + sourceSpec := specs.Source{ + Name: sourceName, + } + encodedTables, err := tables.ToArrowSchemas().Encode() + if err != nil { + t.Fatal(err) + } + + if _, err := c.Migrate(ctx, &pb.Migrate_Request{ + Tables: encodedTables, + }); err != nil { + t.Fatal(err) + } + + rec := schema.GenTestData(table, schema.GenTestDataOptions{ + SourceName: sourceName, + SyncTime: syncTime, + MaxRows: 1, + })[0] + + sourceSpecBytes, err := json.Marshal(sourceSpec) + if err != nil { + t.Fatal(err) + } + writeClient, err := c.Write(ctx) + if err != nil { + t.Fatal(err) + } + if err := writeClient.Send(&pb.Write_Request{ + SourceSpec: sourceSpecBytes, + Source: sourceSpec.Name, + Timestamp: timestamppb.New(syncTime.Truncate(time.Microsecond)), + Tables: encodedTables, + }); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(rec.Schema())) + if err := wr.Write(rec); err != nil { + t.Fatal(err) + } + if err := wr.Close(); err != nil { + t.Fatal(err) + } + if err := writeClient.Send(&pb.Write_Request{ + Resource: buf.Bytes(), + }); err != nil { + t.Fatal(err) + } + + if _, err := writeClient.CloseAndRecv(); err != nil { + t.Fatal(err) + } + // serversDestination + readCh := make(chan arrow.Record, 1) + if err := plugin.Read(ctx, table, sourceName, readCh); err != nil { + t.Fatal(err) + } + close(readCh) + totalResources := 0 + for resource := range readCh { + totalResources++ + if !array.RecordEqual(rec, resource) { + diff := destination.RecordDiff(rec, resource) + t.Fatalf("expected %v but got %v. Diff: %v", rec, resource, diff) + } + } + if totalResources != 1 { + t.Fatalf("expected 1 resource but got %d", totalResources) + } + if _, err := c.DeleteStale(ctx, &pb.DeleteStale_Request{ + Source: "testSource", + Timestamp: timestamppb.New(time.Now().Truncate(time.Microsecond)), + Tables: encodedTables, + }); err != nil { + t.Fatal(err) + } + + _, err = c.GetMetrics(ctx, &pb.GetDestinationMetrics_Request{}) + if err != nil { + t.Fatal(err) + } + + if _, err := c.Close(ctx, &pb.Close_Request{}); err != nil { + t.Fatalf("failed to call Close: %v", err) + } + + cancel() + wg.Wait() + if serverErr != nil { + t.Fatal(serverErr) + } +} diff --git a/serve/source.go b/serve/source.go new file mode 100644 index 0000000000..ae57c83d07 --- /dev/null +++ b/serve/source.go @@ -0,0 +1,233 @@ +package serve + +import ( + "fmt" + "net" + "os" + "os/signal" + "strings" + "sync" + "syscall" + + pbdiscoveryv0 "github.com/cloudquery/plugin-pb-go/pb/discovery/v0" + pbv2 "github.com/cloudquery/plugin-pb-go/pb/source/v2" + discoveryServerV0 "github.com/cloudquery/plugin-sdk/v3/internal/servers/discovery/v0" + + serversv2 "github.com/cloudquery/plugin-sdk/v3/internal/servers/source/v2" + "github.com/cloudquery/plugin-sdk/v3/plugins/source" + "github.com/getsentry/sentry-go" + grpczerolog "github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "github.com/thoas/go-funk" + "golang.org/x/net/netutil" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" +) + +type sourceServe struct { + plugin *source.Plugin + sentryDSN string +} + +type SourceOption func(*sourceServe) + +func WithSourceSentryDSN(dsn string) SourceOption { + return func(s *sourceServe) { + s.sentryDSN = dsn + } +} + +// lis used for unit testing grpc server and client +var testSourceListener *bufconn.Listener +var testSourceListenerLock sync.Mutex + +const serveSourceShort = `Start source plugin server` + +func Source(plugin *source.Plugin, opts ...SourceOption) { + s := &sourceServe{ + plugin: plugin, + } + for _, opt := range opts { + opt(s) + } + if err := newCmdSourceRoot(s).Execute(); err != nil { + sentry.CaptureMessage(err.Error()) + fmt.Println(err) + os.Exit(1) + } +} + +// nolint:dupl +func newCmdSourceServe(serve *sourceServe) *cobra.Command { + var address string + var network string + var noSentry bool + logLevel := newEnum([]string{"trace", "debug", "info", "warn", "error"}, "info") + logFormat := newEnum([]string{"text", "json"}, "text") + telemetryLevel := newEnum([]string{"none", "errors", "stats", "all"}, "all") + err := telemetryLevel.Set(getEnvOrDefault("CQ_TELEMETRY_LEVEL", telemetryLevel.Value)) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to set telemetry level: "+err.Error()) + os.Exit(1) + } + + cmd := &cobra.Command{ + Use: "serve", + Short: serveSourceShort, + Long: serveSourceShort, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + zerologLevel, err := zerolog.ParseLevel(logLevel.String()) + if err != nil { + return err + } + var logger zerolog.Logger + if logFormat.String() == "json" { + logger = zerolog.New(os.Stdout).Level(zerologLevel) + } else { + logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout}).Level(zerologLevel) + } + + // opts.Plugin.Logger = logger + var listener net.Listener + if network == "test" { + testSourceListenerLock.Lock() + listener = bufconn.Listen(testBufSize) + testSourceListener = listener.(*bufconn.Listener) + testSourceListenerLock.Unlock() + } else { + listener, err = net.Listen(network, address) + if err != nil { + return fmt.Errorf("failed to listen %s:%s: %w", network, address, err) + } + } + // source plugins can only accept one connection at a time + // unlike destination plugins that can accept multiple connections + limitListener := netutil.LimitListener(listener, 1) + // See logging pattern https://github.com/grpc-ecosystem/go-grpc-middleware/blob/v2/providers/zerolog/examples_test.go + s := grpc.NewServer( + grpc.ChainUnaryInterceptor( + logging.UnaryServerInterceptor(grpczerolog.InterceptorLogger(logger)), + ), + grpc.ChainStreamInterceptor( + logging.StreamServerInterceptor(grpczerolog.InterceptorLogger(logger)), + ), + grpc.MaxRecvMsgSize(MaxMsgSize), + grpc.MaxSendMsgSize(MaxMsgSize), + ) + serve.plugin.SetLogger(logger) + pbv2.RegisterSourceServer(s, &serversv2.Server{ + Plugin: serve.plugin, + Logger: logger, + }) + pbdiscoveryv0.RegisterDiscoveryServer(s, &discoveryServerV0.Server{ + Versions: []string{"v2"}, + }) + + version := serve.plugin.Version() + + if serve.sentryDSN != "" && !strings.EqualFold(version, "development") && !noSentry { + err = sentry.Init(sentry.ClientOptions{ + Dsn: serve.sentryDSN, + Debug: false, + AttachStacktrace: false, + Release: version, + Transport: sentry.NewHTTPSyncTransport(), + ServerName: "oss", // set to "oss" on purpose to avoid sending any identifying information + // https://docs.sentry.io/platforms/go/configuration/options/#removing-default-integrations + Integrations: func(integrations []sentry.Integration) []sentry.Integration { + var filteredIntegrations []sentry.Integration + for _, integration := range integrations { + if integration.Name() == "Modules" { + continue + } + filteredIntegrations = append(filteredIntegrations, integration) + } + return filteredIntegrations + }, + }) + if err != nil { + log.Error().Err(err).Msg("Error initializing sentry") + } + } + + ctx := cmd.Context() + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + defer func() { + signal.Stop(c) + }() + + go func() { + select { + case sig := <-c: + logger.Info().Str("address", listener.Addr().String()).Str("signal", sig.String()).Msg("Got stop signal. Source plugin server shutting down") + s.Stop() + case <-ctx.Done(): + logger.Info().Str("address", listener.Addr().String()).Msg("Context cancelled. Source plugin server shutting down") + s.Stop() + } + }() + + logger.Info().Str("address", listener.Addr().String()).Msg("Source plugin server listening") + if err := s.Serve(limitListener); err != nil { + return fmt.Errorf("failed to serve: %w", err) + } + return nil + }, + } + cmd.Flags().StringVar(&address, "address", "localhost:7777", "address to serve on. can be tcp: `localhost:7777` or unix socket: `/tmp/plugin.rpc.sock`") + cmd.Flags().StringVar(&network, "network", "tcp", `the network must be "tcp", "tcp4", "tcp6", "unix" or "unixpacket"`) + cmd.Flags().Var(logLevel, "log-level", fmt.Sprintf("log level. one of: %s", strings.Join(logLevel.Allowed, ","))) + cmd.Flags().Var(logFormat, "log-format", fmt.Sprintf("log format. one of: %s", strings.Join(logFormat.Allowed, ","))) + cmd.Flags().BoolVar(&noSentry, "no-sentry", false, "disable sentry") + sendErrors := funk.ContainsString([]string{"all", "errors"}, telemetryLevel.String()) + if !sendErrors { + noSentry = true + } + + return cmd +} + +const ( + sourceDocShort = "Generate documentation for tables" + sourceDocLong = `Generate documentation for tables + +If format is markdown, a destination directory will be created (if necessary) containing markdown files. +Example: +doc ./output + +If format is JSON, a destination directory will be created (if necessary) with a single json file called __tables.json. +Example: +doc --format json . +` +) + +func newCmdSourceDoc(serve *sourceServe) *cobra.Command { + format := newEnum([]string{"json", "markdown"}, "markdown") + cmd := &cobra.Command{ + Use: "doc ", + Short: sourceDocShort, + Long: sourceDocLong, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return serve.plugin.GeneratePluginDocs(args[0], format.Value) + }, + } + cmd.Flags().Var(format, "format", fmt.Sprintf("output format. one of: %s", strings.Join(format.Allowed, ","))) + return cmd +} + +func newCmdSourceRoot(serve *sourceServe) *cobra.Command { + cmd := &cobra.Command{ + Use: fmt.Sprintf("%s ", serve.plugin.Name()), + } + cmd.AddCommand(newCmdSourceServe(serve)) + cmd.AddCommand(newCmdSourceDoc(serve)) + cmd.CompletionOptions.DisableDefaultCmd = true + cmd.Version = serve.plugin.Version() + return cmd +} diff --git a/serve/source_v2_test.go b/serve/source_v2_test.go new file mode 100644 index 0000000000..8f32014a21 --- /dev/null +++ b/serve/source_v2_test.go @@ -0,0 +1,244 @@ +package serve + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/ipc" + pb "github.com/cloudquery/plugin-pb-go/pb/source/v2" + "github.com/cloudquery/plugin-pb-go/specs" + "github.com/cloudquery/plugin-sdk/v3/plugins/source" + "github.com/cloudquery/plugin-sdk/v3/schema" + "github.com/rs/zerolog" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type TestSourcePluginSpec struct { + Accounts []string `json:"accounts,omitempty" yaml:"accounts,omitempty"` +} + +type testExecutionClient struct{} + +var _ schema.ClientMeta = &testExecutionClient{} + +var errTestExecutionClientErr = fmt.Errorf("error in newTestExecutionClientErr") + +func testTable(name string) *schema.Table { + return &schema.Table{ + Name: name, + Resolver: func(ctx context.Context, meta schema.ClientMeta, parent *schema.Resource, res chan<- any) error { + res <- map[string]any{ + "TestColumn": 3, + } + return nil + }, + Columns: []schema.Column{ + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + }, + }, + } +} + +func (*testExecutionClient) ID() string { + return "testExecutionClient" +} + +func newTestExecutionClient(context.Context, zerolog.Logger, specs.Source, source.Options) (schema.ClientMeta, error) { + return &testExecutionClient{}, nil +} + +func newTestExecutionClientErr(context.Context, zerolog.Logger, specs.Source, source.Options) (schema.ClientMeta, error) { + return nil, errTestExecutionClientErr +} + +func bufSourceDialer(context.Context, string) (net.Conn, error) { + testSourceListenerLock.Lock() + defer testSourceListenerLock.Unlock() + return testSourceListener.Dial() +} + +func TestSourceSuccess(t *testing.T) { + plugin := source.NewPlugin( + "testPlugin", + "v1.0.0", + []*schema.Table{testTable("test_table"), testTable("test_table2")}, + newTestExecutionClient) + + cmd := newCmdSourceRoot(&sourceServe{ + plugin: plugin, + }) + cmd.SetArgs([]string{"serve", "--network", "test"}) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + wg.Add(1) + var serverErr error + go func() { + defer wg.Done() + serverErr = cmd.ExecuteContext(ctx) + }() + defer func() { + cancel() + wg.Wait() + }() + for { + testSourceListenerLock.Lock() + if testSourceListener != nil { + testSourceListenerLock.Unlock() + break + } + testSourceListenerLock.Unlock() + t.Log("waiting for grpc server to start") + time.Sleep(time.Millisecond * 200) + } + + // https://stackoverflow.com/questions/42102496/testing-a-grpc-service + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufSourceDialer), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + c := pb.NewSourceClient(conn) + + getNameRes, err := c.GetName(ctx, &pb.GetName_Request{}) + if err != nil { + t.Fatal(err) + } + if getNameRes.Name != "testPlugin" { + t.Fatalf("expected name to be testPlugin but got %s", getNameRes.Name) + } + + getVersionResponse, err := c.GetVersion(ctx, &pb.GetVersion_Request{}) + if err != nil { + t.Fatal(err) + } + if getVersionResponse.Version != "v1.0.0" { + t.Fatalf("Expected version to be v1.0.0 but got %s", getVersionResponse.Version) + } + + spec := specs.Source{ + Name: "testSourcePlugin", + Version: "v1.0.0", + Path: "cloudquery/testSourcePlugin", + Registry: specs.RegistryGithub, + Tables: []string{"test_table"}, + Spec: TestSourcePluginSpec{Accounts: []string{"cloudquery/plugin-sdk"}}, + Destinations: []string{"test"}, + } + specMarshaled, err := json.Marshal(spec) + if err != nil { + t.Fatalf("Failed to marshal spec: %v", err) + } + + + getTablesRes, err := c.GetTables(ctx, &pb.GetTables_Request{}) + if err != nil { + t.Fatal(err) + } + + tables, err := schema.NewTablesFromBytes(getTablesRes.Tables) + if err != nil { + t.Fatal(err) + } + + if len(tables) != 2 { + t.Fatalf("Expected 2 tables but got %d", len(tables)) + } + if _, err := c.Init(ctx, &pb.Init_Request{Spec: specMarshaled}); err != nil { + t.Fatal(err) + } + + getTablesForSpecRes, err := c.GetDynamicTables(ctx, &pb.GetDynamicTables_Request{}) + if err != nil { + t.Fatal(err) + } + tables, err = schema.NewTablesFromBytes(getTablesForSpecRes.Tables) + if err != nil { + t.Fatal(err) + } + + if len(tables) != 1 { + t.Fatalf("Expected 1 table but got %d", len(tables)) + } + + syncClient, err := c.Sync(ctx, &pb.Sync_Request{}) + if err != nil { + t.Fatal(err) + } + var resources []arrow.Record + for { + r, err := syncClient.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + rdr, err := ipc.NewReader(bytes.NewReader(r.Resource)) + if err != nil { + t.Fatal(err) + } + for rdr.Next() { + rec := rdr.Record() + rec.Retain() + resources = append(resources, rec) + } + } + + totalResources := 0 + for _, resource := range resources { + sc := resource.Schema() + tableName, ok := sc.Metadata().GetValue(schema.MetadataTableName) + if !ok { + t.Fatal("Expected table name metadata to be set") + } + if tableName != "test_table" { + t.Fatalf("Expected resource with table name test_table. got: %s", tableName) + } + if len(resource.Columns()) != 3 { + t.Fatalf("Expected resource with data length 3 but got %d", len(resource.Columns())) + } + totalResources++ + } + if totalResources != 1 { + t.Fatalf("Expected 1 resource on channel but got %d", totalResources) + } + + getMetricsRes, err := c.GetMetrics(ctx, &pb.GetMetrics_Request{}) + if err != nil { + t.Fatal(err) + } + var stats source.Metrics + if err := json.Unmarshal(getMetricsRes.Metrics, &stats); err != nil { + t.Fatal(err) + } + + clientStats := stats.TableClient[""][""] + if clientStats.Resources != 1 { + t.Fatalf("Expected 1 resource but got %d", clientStats.Resources) + } + + if clientStats.Errors != 0 { + t.Fatalf("Expected 0 errors but got %d", clientStats.Errors) + } + + if clientStats.Panics != 0 { + t.Fatalf("Expected 0 panics but got %d", clientStats.Panics) + } + + cancel() + wg.Wait() + if serverErr != nil { + t.Fatal(serverErr) + } +} \ No newline at end of file diff --git a/types/mac.go b/types/mac.go index b02a4e8861..d70a1841cc 100644 --- a/types/mac.go +++ b/types/mac.go @@ -197,7 +197,11 @@ func (*MACType) ExtensionName() string { return "mac" } -// Serialize returns "MAC-serialized" for testing proper metadata passing +func (*MACType) String() string { + return "mac" +} + +// Serialize returns "mac-serialized" for testing proper metadata passing func (*MACType) Serialize() string { return "mac-serialized" } @@ -211,7 +215,7 @@ func (*MACType) Deserialize(storageType arrow.DataType, data string) (arrow.Exte if !arrow.TypeEqual(storageType, &arrow.BinaryType{}) { return nil, fmt.Errorf("invalid storage type for MACType: %s", storageType.Name()) } - return NewInetType(), nil + return NewMACType(), nil } // ExtensionEquals returns true if both extensions have the same name diff --git a/types/register.go b/types/register.go new file mode 100644 index 0000000000..07695dc820 --- /dev/null +++ b/types/register.go @@ -0,0 +1,20 @@ +package types + +import "github.com/apache/arrow/go/v13/arrow" + + +func RegisterAllExtensions() error { + if err := arrow.RegisterExtensionType(&UUIDType{}); err != nil { + return err + } + if err := arrow.RegisterExtensionType(&JSONType{}); err != nil { + return err + } + if err := arrow.RegisterExtensionType(&InetType{}); err != nil { + return err + } + if err := arrow.RegisterExtensionType(&MacType{}); err != nil { + return err + } + return nil +} \ No newline at end of file diff --git a/types/uuid.go b/types/uuid.go index 4b43678702..5375dcb662 100644 --- a/types/uuid.go +++ b/types/uuid.go @@ -202,7 +202,7 @@ func (*UUIDType) ExtensionName() string { } func (e *UUIDType) String() string { - return fmt.Sprintf("extension_type", e.Storage) + return "uuid" } func (e *UUIDType) MarshalJSON() ([]byte, error) { From c9d4321dbd02b8be8e889847176a42e67b2da690 Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Sun, 14 May 2023 16:06:34 +0300 Subject: [PATCH 02/14] fmt --- scalar/binary.go | 3 +-- scalar/binary_test.go | 2 +- scalar/bool.go | 2 +- scalar/bool_test.go | 2 +- scalar/convert.go | 2 +- scalar/errors.go | 4 ++-- scalar/float.go | 3 ++- scalar/float_test.go | 2 +- scalar/inet.go | 5 ++--- scalar/inet_test.go | 1 - scalar/int.go | 2 +- scalar/int_test.go | 3 +-- scalar/json.go | 5 ++--- scalar/json_test.go | 4 ++-- scalar/list.go | 4 ++-- scalar/list_test.go | 4 ++-- scalar/mac.go | 3 +-- scalar/mac_test.go | 2 +- scalar/scalar.go | 2 +- scalar/string.go | 1 - scalar/string_test.go | 3 +-- scalar/type_test.go | 2 +- scalar/uint.go | 2 +- scalar/uint_test.go | 3 +-- scalar/uuid.go | 1 - schema/arrow.go | 3 +-- schema/arrow_test.go | 2 +- serve/destination_v1_test.go | 4 ++-- serve/source_v2_test.go | 3 +-- types/register.go | 3 +-- 30 files changed, 35 insertions(+), 47 deletions(-) diff --git a/scalar/binary.go b/scalar/binary.go index 1085780b11..68ed4d55b5 100644 --- a/scalar/binary.go +++ b/scalar/binary.go @@ -11,7 +11,6 @@ type Binary struct { Value []byte } - func (s *Binary) IsValid() bool { return s.Valid } @@ -71,4 +70,4 @@ type LargeBinary struct { func (s *LargeBinary) DataType() arrow.DataType { return arrow.BinaryTypes.LargeBinary -} \ No newline at end of file +} diff --git a/scalar/binary_test.go b/scalar/binary_test.go index a3cf0b54f0..4445ef7306 100644 --- a/scalar/binary_test.go +++ b/scalar/binary_test.go @@ -25,4 +25,4 @@ func TestBinarySet(t *testing.T) { t.Errorf("%d: %v != %v", i, r, tt.result) } } -} \ No newline at end of file +} diff --git a/scalar/bool.go b/scalar/bool.go index 71bf4675d5..2e87b7f7a3 100644 --- a/scalar/bool.go +++ b/scalar/bool.go @@ -63,4 +63,4 @@ func (s *Bool) Set(val any) error { } s.Valid = true return nil -} \ No newline at end of file +} diff --git a/scalar/bool_test.go b/scalar/bool_test.go index 01a4ea5a07..43a9ef725f 100644 --- a/scalar/bool_test.go +++ b/scalar/bool_test.go @@ -30,4 +30,4 @@ func TestBoolSet(t *testing.T) { t.Errorf("%d: %v != %v", i, r, tt.result) } } -} \ No newline at end of file +} diff --git a/scalar/convert.go b/scalar/convert.go index b6c3cc4e10..75ae88494b 100644 --- a/scalar/convert.go +++ b/scalar/convert.go @@ -158,4 +158,4 @@ func underlyingPtrType(val any) (any, bool) { } return nil, false -} \ No newline at end of file +} diff --git a/scalar/errors.go b/scalar/errors.go index 3fa772ba26..f64ad4b402 100644 --- a/scalar/errors.go +++ b/scalar/errors.go @@ -7,7 +7,7 @@ import ( ) const ( - noConversion = "no conversion available" + noConversion = "no conversion available" ) type ValidationError struct { @@ -34,4 +34,4 @@ func (e *ValidationError) MaskedError() string { func (e *ValidationError) Unwrap() error { return e.Err -} \ No newline at end of file +} diff --git a/scalar/float.go b/scalar/float.go index 7967a43cc6..2d9c69c7a2 100644 --- a/scalar/float.go +++ b/scalar/float.go @@ -129,6 +129,7 @@ type Float64 struct { Valid bool Value float64 } + func (s *Float64) IsValid() bool { return s.Valid } @@ -216,4 +217,4 @@ func (s *Float64) Set(val any) error { } s.Valid = true return nil -} \ No newline at end of file +} diff --git a/scalar/float_test.go b/scalar/float_test.go index bb5df96990..84c55c27f5 100644 --- a/scalar/float_test.go +++ b/scalar/float_test.go @@ -36,4 +36,4 @@ func TestFloat64Set(t *testing.T) { t.Errorf("%d: %v != %v", i, r, tt.result) } } -} \ No newline at end of file +} diff --git a/scalar/inet.go b/scalar/inet.go index fbfc59ee3e..6f7d6867fc 100644 --- a/scalar/inet.go +++ b/scalar/inet.go @@ -12,7 +12,7 @@ import ( type Inet struct { Valid bool - Value *net.IPNet + Value *net.IPNet } func (s *Inet) IsValid() bool { @@ -106,7 +106,6 @@ func (s *Inet) Set(val any) error { return nil } - // Convert the net.IP to IPv4, if appropriate. // // When parsing a string to a net.IP using net.ParseIP() and the like, we get a @@ -124,4 +123,4 @@ func maybeGetIPv4(input string, ip net.IP) net.IP { } return ip.To4() -} \ No newline at end of file +} diff --git a/scalar/inet_test.go b/scalar/inet_test.go index d8fdb92132..509e928719 100644 --- a/scalar/inet_test.go +++ b/scalar/inet_test.go @@ -10,7 +10,6 @@ import ( "testing" ) - type textMarshaler struct { Text string } diff --git a/scalar/int.go b/scalar/int.go index cf080f4f13..81513fb55d 100644 --- a/scalar/int.go +++ b/scalar/int.go @@ -154,4 +154,4 @@ func (s *Int64) Set(val any) error { } s.Valid = true return nil -} \ No newline at end of file +} diff --git a/scalar/int_test.go b/scalar/int_test.go index 28f7ca42e1..aaed6baa0b 100644 --- a/scalar/int_test.go +++ b/scalar/int_test.go @@ -2,7 +2,6 @@ package scalar import "testing" - func TestInt8Set(t *testing.T) { successfulTests := []struct { source any @@ -37,4 +36,4 @@ func TestInt8Set(t *testing.T) { t.Errorf("%d: %v != %v", i, r, tt.result) } } -} \ No newline at end of file +} diff --git a/scalar/json.go b/scalar/json.go index 86ac65ec30..1ab4a912cf 100644 --- a/scalar/json.go +++ b/scalar/json.go @@ -30,7 +30,7 @@ func (s *JSON) Equal(rhs Scalar) bool { if !ok { return false } - if !s.Valid && !r.Valid{ + if !s.Valid && !r.Valid { return true } @@ -146,7 +146,6 @@ func isEmptySlice(value any) bool { return reflect.ValueOf(value).Len() == 0 } - // JSONBytesEqual compares the JSON in two byte slices. func jsonBytesEqual(a, b []byte) (bool, error) { var j, j2 any @@ -157,4 +156,4 @@ func jsonBytesEqual(a, b []byte) (bool, error) { return false, err } return reflect.DeepEqual(j2, j), nil -} \ No newline at end of file +} diff --git a/scalar/json_test.go b/scalar/json_test.go index 42417e8be3..09a67cbf3a 100644 --- a/scalar/json_test.go +++ b/scalar/json_test.go @@ -11,7 +11,7 @@ func TestJSONSet(t *testing.T) { source any result JSON }{ - {source: "", result: JSON{Value: []byte(""), }}, + {source: "", result: JSON{Value: []byte("")}}, {source: "{}", result: JSON{Value: []byte("{}"), Valid: true}}, {source: `"test"`, result: JSON{Value: []byte(`"test"`), Valid: true}}, {source: "1", result: JSON{Value: []byte("1"), Valid: true}}, @@ -56,4 +56,4 @@ func TestJSONSet(t *testing.T) { t.Errorf("%d: %v != %v", i, d, tt.result) } } -} \ No newline at end of file +} diff --git a/scalar/list.go b/scalar/list.go index 50ce3b0b4d..2d3ac00417 100644 --- a/scalar/list.go +++ b/scalar/list.go @@ -10,7 +10,7 @@ import ( type List struct { Valid bool Value Vector - Type arrow.DataType + Type arrow.DataType } func (s *List) IsValid() bool { @@ -89,4 +89,4 @@ func (s *List) Set(val any) error { s.Valid = true return nil -} \ No newline at end of file +} diff --git a/scalar/list_test.go b/scalar/list_test.go index f21c1da683..82b4ad3f64 100644 --- a/scalar/list_test.go +++ b/scalar/list_test.go @@ -11,7 +11,7 @@ func TestListSet(t *testing.T) { source any result List }{ - {source: []int{1,2}, result: List{Value: []Scalar{ + {source: []int{1, 2}, result: List{Value: []Scalar{ &Int64{Value: 1, Valid: true}, &Int64{Value: 2, Valid: true}, }, Valid: true, Type: arrow.ListOf(arrow.PrimitiveTypes.Int64)}}, @@ -30,4 +30,4 @@ func TestListSet(t *testing.T) { t.Errorf("%d: %v != %v", i, r, tt.result) } } -} \ No newline at end of file +} diff --git a/scalar/mac.go b/scalar/mac.go index 5e8c625180..9d3711d4ce 100644 --- a/scalar/mac.go +++ b/scalar/mac.go @@ -7,7 +7,6 @@ import ( "github.com/cloudquery/plugin-sdk/v3/types" ) - type Mac struct { Valid bool Value net.HardwareAddr @@ -76,4 +75,4 @@ func (s *Mac) Set(val any) error { s.Valid = true return nil -} \ No newline at end of file +} diff --git a/scalar/mac_test.go b/scalar/mac_test.go index 86099b1af9..08863aba8c 100644 --- a/scalar/mac_test.go +++ b/scalar/mac_test.go @@ -40,4 +40,4 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { } return addr -} \ No newline at end of file +} diff --git a/scalar/scalar.go b/scalar/scalar.go index 1c4316c176..244c5d4d74 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -127,4 +127,4 @@ func AppendToRecordBuilder(bldr *array.RecordBuilder, vector Vector) { for i, scalar := range vector { AppendToBuilder(bldr.Field(i), scalar) } -} \ No newline at end of file +} diff --git a/scalar/string.go b/scalar/string.go index 64c7c749fc..c5b9e1c1cb 100644 --- a/scalar/string.go +++ b/scalar/string.go @@ -75,4 +75,3 @@ type LargeString struct { func (s *LargeString) DataType() arrow.DataType { return arrow.BinaryTypes.LargeString } - diff --git a/scalar/string_test.go b/scalar/string_test.go index 813aa0050a..97f7444d46 100644 --- a/scalar/string_test.go +++ b/scalar/string_test.go @@ -2,7 +2,6 @@ package scalar import "testing" - func TestStringSet(t *testing.T) { successfulTests := []struct { source any @@ -24,4 +23,4 @@ func TestStringSet(t *testing.T) { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) } } -} \ No newline at end of file +} diff --git a/scalar/type_test.go b/scalar/type_test.go index 6a9e7e44c4..a429148cfc 100644 --- a/scalar/type_test.go +++ b/scalar/type_test.go @@ -5,4 +5,4 @@ type _string string type _bool bool type _int8 int8 -type _byteSlice []byte \ No newline at end of file +type _byteSlice []byte diff --git a/scalar/uint.go b/scalar/uint.go index 8785ddc316..77ceaece05 100644 --- a/scalar/uint.go +++ b/scalar/uint.go @@ -163,4 +163,4 @@ func (n *Uint64) Set(val any) error { } n.Valid = true return nil -} \ No newline at end of file +} diff --git a/scalar/uint_test.go b/scalar/uint_test.go index 2281c9ca72..3d66e812d5 100644 --- a/scalar/uint_test.go +++ b/scalar/uint_test.go @@ -2,7 +2,6 @@ package scalar import "testing" - func TestUint64Set(t *testing.T) { successfulTests := []struct { source any @@ -33,4 +32,4 @@ func TestUint64Set(t *testing.T) { t.Errorf("%d: %v != %v", i, r, tt.result) } } -} \ No newline at end of file +} diff --git a/scalar/uuid.go b/scalar/uuid.go index 5ec5a160c3..adc6d45c56 100644 --- a/scalar/uuid.go +++ b/scalar/uuid.go @@ -102,4 +102,3 @@ func parseUUID(src string) (dst [16]byte, err error) { copy(dst[:], buf) return dst, err } - diff --git a/schema/arrow.go b/schema/arrow.go index b24d49d835..f7f61dbe61 100644 --- a/schema/arrow.go +++ b/schema/arrow.go @@ -39,7 +39,6 @@ func (s Schemas) SchemaByName(name string) *arrow.Schema { return nil } - func (s Schemas) Encode() ([][]byte, error) { ret := make([][]byte, len(s)) for i, sc := range s { @@ -71,4 +70,4 @@ func NewTablesFromBytes(b [][]byte) (Tables, error) { return nil, fmt.Errorf("failed to decode schemas: %w", err) } return NewTablesFromArrowSchemas(schemas) -} \ No newline at end of file +} diff --git a/schema/arrow_test.go b/schema/arrow_test.go index 505e3925c7..377cc5718f 100644 --- a/schema/arrow_test.go +++ b/schema/arrow_test.go @@ -41,4 +41,4 @@ func TestSchemaEncode(t *testing.T) { t.Fatalf("expected schema %d to be %v, got %v", i, schemas[i], decodedSchemas[i]) } } -} \ No newline at end of file +} diff --git a/serve/destination_v1_test.go b/serve/destination_v1_test.go index 4a66aba688..24d403785f 100644 --- a/serve/destination_v1_test.go +++ b/serve/destination_v1_test.go @@ -99,7 +99,7 @@ func TestDestinationV1(t *testing.T) { if err != nil { t.Fatal(err) } - + if _, err := c.Migrate(ctx, &pb.Migrate_Request{ Tables: encodedTables, }); err != nil { @@ -109,7 +109,7 @@ func TestDestinationV1(t *testing.T) { rec := schema.GenTestData(table, schema.GenTestDataOptions{ SourceName: sourceName, SyncTime: syncTime, - MaxRows: 1, + MaxRows: 1, })[0] sourceSpecBytes, err := json.Marshal(sourceSpec) diff --git a/serve/source_v2_test.go b/serve/source_v2_test.go index 8f32014a21..f2ce2cebe6 100644 --- a/serve/source_v2_test.go +++ b/serve/source_v2_test.go @@ -140,7 +140,6 @@ func TestSourceSuccess(t *testing.T) { t.Fatalf("Failed to marshal spec: %v", err) } - getTablesRes, err := c.GetTables(ctx, &pb.GetTables_Request{}) if err != nil { t.Fatal(err) @@ -241,4 +240,4 @@ func TestSourceSuccess(t *testing.T) { if serverErr != nil { t.Fatal(serverErr) } -} \ No newline at end of file +} diff --git a/types/register.go b/types/register.go index 07695dc820..148d64b45d 100644 --- a/types/register.go +++ b/types/register.go @@ -2,7 +2,6 @@ package types import "github.com/apache/arrow/go/v13/arrow" - func RegisterAllExtensions() error { if err := arrow.RegisterExtensionType(&UUIDType{}); err != nil { return err @@ -17,4 +16,4 @@ func RegisterAllExtensions() error { return err } return nil -} \ No newline at end of file +} From 84ea9d29a5d33ebf31600ced5eb3ed710b3d4756 Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Sun, 14 May 2023 16:20:35 +0300 Subject: [PATCH 03/14] fix lint --- go.mod | 2 -- go.sum | 2 ++ scalar/binary.go | 4 +-- scalar/bool.go | 2 +- scalar/convert.go | 3 +- scalar/float.go | 4 +-- scalar/inet.go | 21 ++++++++------ scalar/inet_test.go | 1 + scalar/int.go | 4 +-- scalar/json.go | 17 ++++++----- scalar/mac.go | 8 ++---- scalar/scalar.go | 22 +++++++------- scalar/string.go | 4 +-- scalar/uint.go | 64 ++++++++++++++++++++--------------------- scalar/uuid.go | 20 ++++++------- schema/resource.go | 1 + serve/source_v2_test.go | 7 +---- types/register.go | 5 +--- types/uuid.go | 2 +- 19 files changed, 95 insertions(+), 98 deletions(-) diff --git a/go.mod b/go.mod index 2632d4666f..cac28db99c 100644 --- a/go.mod +++ b/go.mod @@ -28,8 +28,6 @@ require ( replace github.com/apache/arrow/go/v13 => github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8 -replace github.com/cloudquery/plugin-pb-go => ../plugin-pb-go - require ( github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/thrift v0.16.0 // indirect diff --git a/go.sum b/go.sum index 48f39ba21e..a6bbc31786 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8 h1:CmgLSEGQNLHpUQ5cU4L4aF7cuJZRnc1toIIWqC1gmPg= github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8/go.mod h1:/XatdE3kDIBqZKhZ7OBUHwP2jaASDFZHqF4puOWM8po= +github.com/cloudquery/plugin-pb-go v1.0.6 h1:jWLXFUGgobO28rBERuipSUbr6Rqoth4nzGZ/XEQD86w= +github.com/cloudquery/plugin-pb-go v1.0.6/go.mod h1:vAGA27psem7ZZNAY4a3S9TKuA/JDQWstjKcHPJX91Mc= github.com/cloudquery/plugin-sdk/v2 v2.7.0 h1:hRXsdEiaOxJtsn/wZMFQC9/jPfU1MeMK3KF+gPGqm7U= github.com/cloudquery/plugin-sdk/v2 v2.7.0/go.mod h1:pAX6ojIW99b/Vg4CkhnsGkRIzNaVEceYMR+Bdit73ug= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= diff --git a/scalar/binary.go b/scalar/binary.go index 68ed4d55b5..2ffa5b3483 100644 --- a/scalar/binary.go +++ b/scalar/binary.go @@ -60,7 +60,7 @@ func (s *Binary) Set(val any) error { return nil } -func (s *Binary) DataType() arrow.DataType { +func (*Binary) DataType() arrow.DataType { return arrow.BinaryTypes.Binary } @@ -68,6 +68,6 @@ type LargeBinary struct { Binary } -func (s *LargeBinary) DataType() arrow.DataType { +func (*LargeBinary) DataType() arrow.DataType { return arrow.BinaryTypes.LargeBinary } diff --git a/scalar/bool.go b/scalar/bool.go index 2e87b7f7a3..626674ec0e 100644 --- a/scalar/bool.go +++ b/scalar/bool.go @@ -15,7 +15,7 @@ func (s *Bool) IsValid() bool { return s.Valid } -func (s *Bool) DataType() arrow.DataType { +func (*Bool) DataType() arrow.DataType { return arrow.FixedWidthTypes.Boolean } diff --git a/scalar/convert.go b/scalar/convert.go index 75ae88494b..fbd12531e2 100644 --- a/scalar/convert.go +++ b/scalar/convert.go @@ -126,6 +126,7 @@ func underlyingBytesType(val any) (any, bool) { func underlyingUUIDType(val any) (any, bool) { refVal := reflect.ValueOf(val) + //nolint:revive,gocritic switch refVal.Kind() { case reflect.Ptr: if refVal.IsNil() { @@ -147,7 +148,7 @@ func underlyingUUIDType(val any) (any, bool) { func underlyingPtrType(val any) (any, bool) { refVal := reflect.ValueOf(val) - //nolint:gocritic + //nolint:gocritic,revive switch refVal.Kind() { case reflect.Ptr: if refVal.IsNil() { diff --git a/scalar/float.go b/scalar/float.go index 2d9c69c7a2..b413e374ec 100644 --- a/scalar/float.go +++ b/scalar/float.go @@ -16,7 +16,7 @@ func (s *Float32) IsValid() bool { return s.Valid } -func (s *Float32) DataType() arrow.DataType { +func (*Float32) DataType() arrow.DataType { return arrow.PrimitiveTypes.Float32 } @@ -134,7 +134,7 @@ func (s *Float64) IsValid() bool { return s.Valid } -func (s *Float64) DataType() arrow.DataType { +func (*Float64) DataType() arrow.DataType { return arrow.PrimitiveTypes.Float64 } diff --git a/scalar/inet.go b/scalar/inet.go index 6f7d6867fc..e951275ecb 100644 --- a/scalar/inet.go +++ b/scalar/inet.go @@ -19,7 +19,7 @@ func (s *Inet) IsValid() bool { return s.Valid } -func (s *Inet) DataType() arrow.DataType { +func (*Inet) DataType() arrow.DataType { return types.ExtensionTypes.Inet } @@ -52,11 +52,10 @@ func (s *Inet) Set(val any) error { case net.IP: if len(value) == 0 { return nil - } else { - bitCount := len(value) * 8 - mask := net.CIDRMask(bitCount, bitCount) - s.Value = &net.IPNet{Mask: mask, IP: value} } + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + s.Value = &net.IPNet{Mask: mask, IP: value} case string: ip, ipnet, err := net.ParseCIDR(value) if err != nil { @@ -81,11 +80,17 @@ func (s *Inet) Set(val any) error { } s.Value = ipnet case *net.IPNet: - s.Set(*value) + if err := s.Set(*value); err != nil { + return err + } case *net.IP: - s.Set(*value) + if err := s.Set(*value); err != nil { + return err + } case *string: - s.Set(*value) + if err := s.Set(*value); err != nil { + return err + } default: if tv, ok := value.(encoding.TextMarshaler); ok { text, err := tv.MarshalText() diff --git a/scalar/inet_test.go b/scalar/inet_test.go index 509e928719..e0f8d7b890 100644 --- a/scalar/inet_test.go +++ b/scalar/inet_test.go @@ -18,6 +18,7 @@ func (t textMarshaler) MarshalText() (text []byte, err error) { return []byte(t.Text), err } +// nolint:unparam func mustParseCIDR(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { diff --git a/scalar/int.go b/scalar/int.go index 81513fb55d..12ffaeec44 100644 --- a/scalar/int.go +++ b/scalar/int.go @@ -16,7 +16,7 @@ func (s *Int64) IsValid() bool { return s.Valid } -func (s *Int64) DataType() arrow.DataType { +func (*Int64) DataType() arrow.DataType { return arrow.PrimitiveTypes.Int64 } @@ -24,7 +24,7 @@ func (s *Int64) String() string { if !s.Valid { return "(null)" } - return strconv.FormatInt(int64(s.Value), 10) + return strconv.FormatInt(s.Value, 10) } func (s *Int64) Equal(rhs Scalar) bool { diff --git a/scalar/json.go b/scalar/json.go index 1ab4a912cf..28d4362e84 100644 --- a/scalar/json.go +++ b/scalar/json.go @@ -18,7 +18,7 @@ func (s *JSON) IsValid() bool { return s.Valid } -func (s *JSON) DataType() arrow.DataType { +func (*JSON) DataType() arrow.DataType { return types.ExtensionTypes.JSON } @@ -74,16 +74,15 @@ func (s *JSON) Set(val any) error { case []byte: if value == nil { return nil - } else { - if string(value) == "" { - return nil - } + } + if string(value) == "" { + return nil + } - if !json.Valid(value) { - return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "invalid json byte array", Value: value} - } - s.Value = value + if !json.Valid(value) { + return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "invalid json byte array", Value: value} } + s.Value = value // Encode* methods are defined on *JSON. If JSON is passed directly then the // struct itself would be encoded instead of Bytes. This is clearly a footgun // so detect and return an error. See https://github.com/jackc/pgx/issues/350. diff --git a/scalar/mac.go b/scalar/mac.go index 9d3711d4ce..fb1e061fe8 100644 --- a/scalar/mac.go +++ b/scalar/mac.go @@ -16,7 +16,7 @@ func (s *Mac) IsValid() bool { return s.Valid } -func (s *Mac) DataType() arrow.DataType { +func (*Mac) DataType() arrow.DataType { return types.ExtensionTypes.Mac } @@ -57,15 +57,13 @@ func (s *Mac) Set(val any) error { case *net.HardwareAddr: if value == nil { return nil - } else { - return s.Set(*value) } + return s.Set(*value) case *string: if value == nil { return nil - } else { - return s.Set(*value) } + return s.Set(*value) default: if originalSrc, ok := underlyingPtrType(value); ok { return s.Set(originalSrc) diff --git a/scalar/scalar.go b/scalar/scalar.go index 244c5d4d74..4b665b66c7 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -58,15 +58,16 @@ func NewScalar(dt arrow.DataType) Scalar { case arrow.BOOL: return &Bool{} case arrow.EXTENSION: - if arrow.TypeEqual(dt, types.ExtensionTypes.UUID) { + switch { + case arrow.TypeEqual(dt, types.ExtensionTypes.UUID): return &UUID{} - } else if arrow.TypeEqual(dt, types.ExtensionTypes.JSON) { + case arrow.TypeEqual(dt, types.ExtensionTypes.JSON): return &JSON{} - } else if arrow.TypeEqual(dt, types.ExtensionTypes.Mac) { + case arrow.TypeEqual(dt, types.ExtensionTypes.Mac): return &Mac{} - } else if arrow.TypeEqual(dt, types.ExtensionTypes.Inet) { + case arrow.TypeEqual(dt, types.ExtensionTypes.Inet): return &Inet{} - } else { + default: panic("not implemented extension: " + dt.Name()) } case arrow.LIST: @@ -107,15 +108,16 @@ func AppendToBuilder(bldr array.Builder, s Scalar) { lb.AppendNull() } case arrow.EXTENSION: - if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.UUID) { + switch { + case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.UUID): bldr.(*types.UUIDBuilder).Append(s.(*UUID).Value) - } else if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.JSON) { + case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.JSON): bldr.(*types.JSONBuilder).Append(s.(*JSON).Value) - } else if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Mac) { + case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Mac): bldr.(*types.MacBuilder).Append(s.(*Mac).Value) - } else if arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Inet) { + case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Inet): bldr.(*types.InetBuilder).Append(s.(*Inet).Value) - } else { + default: panic("not implemented extension: " + s.DataType().Name()) } default: diff --git a/scalar/string.go b/scalar/string.go index c5b9e1c1cb..0dc807d2b8 100644 --- a/scalar/string.go +++ b/scalar/string.go @@ -15,7 +15,7 @@ func (s *String) IsValid() bool { return s.Valid } -func (s *String) DataType() arrow.DataType { +func (*String) DataType() arrow.DataType { return arrow.BinaryTypes.String } @@ -72,6 +72,6 @@ type LargeString struct { _smallString } -func (s *LargeString) DataType() arrow.DataType { +func (*LargeString) DataType() arrow.DataType { return arrow.BinaryTypes.LargeString } diff --git a/scalar/uint.go b/scalar/uint.go index 77ceaece05..8a98e21bde 100644 --- a/scalar/uint.go +++ b/scalar/uint.go @@ -11,11 +11,11 @@ type Uint64 struct { Value uint64 } -func (n *Uint64) IsValid() bool { - return n.Valid +func (s *Uint64) IsValid() bool { + return s.Valid } -func (n *Uint64) DataType() arrow.DataType { +func (*Uint64) DataType() arrow.DataType { return arrow.PrimitiveTypes.Uint64 } @@ -37,9 +37,9 @@ func (s *Uint64) Equal(rhs Scalar) bool { return s.Valid == r.Valid && s.Value == r.Value } -func (n *Uint64) Set(val any) error { +func (s *Uint64) Set(val any) error { if val == nil { - n.Valid = false + s.Valid = false return nil } @@ -48,119 +48,119 @@ func (n *Uint64) Set(val any) error { if value < 0 { return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int8 less than 0", Value: value} } - n.Value = uint64(value) + s.Value = uint64(value) case int16: if value < 0 { return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int16 less than 0", Value: value} } - n.Value = uint64(value) + s.Value = uint64(value) case int32: if value < 0 { return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int32 less than 0", Value: value} } - n.Value = uint64(value) + s.Value = uint64(value) case int64: if value < 0 { return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int64 less than 0", Value: value} } - n.Value = uint64(value) + s.Value = uint64(value) case int: if value < 0 { return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "int less than 0", Value: value} } - n.Value = uint64(value) + s.Value = uint64(value) case uint8: - n.Value = uint64(value) + s.Value = uint64(value) case uint16: - n.Value = uint64(value) + s.Value = uint64(value) case uint32: - n.Value = uint64(value) + s.Value = uint64(value) case uint64: - n.Value = value + s.Value = value case uint: - n.Value = uint64(value) + s.Value = uint64(value) case float32: if value < 0 { return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "float32 less than 0", Value: value} } - n.Value = uint64(value) + s.Value = uint64(value) case float64: if value < 0 { return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: "float64 less than 0", Value: value} } - n.Value = uint64(value) + s.Value = uint64(value) case string: num, err := strconv.ParseUint(value, 10, 64) if err != nil { return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "invalid string", Value: value} } - n.Value = num + s.Value = num case *int8: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *int16: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *int32: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *int64: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *int: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *uint8: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *uint16: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *uint32: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *uint64: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *uint: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *float32: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) case *float64: if value == nil { return nil } - return n.Set(*value) + return s.Set(*value) default: if originalSrc, ok := underlyingNumberType(value); ok { - return n.Set(originalSrc) + return s.Set(originalSrc) } return &ValidationError{Type: arrow.PrimitiveTypes.Uint64, Msg: noConversion, Value: value} } - n.Valid = true + s.Valid = true return nil } diff --git a/scalar/uuid.go b/scalar/uuid.go index adc6d45c56..64972c9dde 100644 --- a/scalar/uuid.go +++ b/scalar/uuid.go @@ -18,7 +18,7 @@ func (s *UUID) IsValid() bool { return s.Valid } -func (s *UUID) DataType() arrow.DataType { +func (*UUID) DataType() arrow.DataType { return types.ExtensionTypes.UUID } @@ -52,26 +52,24 @@ func (s *UUID) Set(src any) error { case [16]byte: s.Value = uuid.UUID(value) case []byte: - if value != nil { - if len(value) != 16 { - return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "[]byte must be 16 bytes to convert to UUID", Value: value} - } - copy(s.Value[:], value) - } else { + if value == nil { return nil } + if len(value) != 16 { + return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "[]byte must be 16 bytes to convert to UUID", Value: value} + } + copy(s.Value[:], value) case string: - uuid, err := parseUUID(value) + uuidVal, err := parseUUID(value) if err != nil { return err } - s.Value = uuid + s.Value = uuidVal case *string: if value == nil { return nil - } else { - return s.Set(*value) } + return s.Set(*value) default: if originalSrc, ok := underlyingUUIDType(src); ok { return s.Set(originalSrc) diff --git a/schema/resource.go b/schema/resource.go index cb4592114e..833e22fc45 100644 --- a/schema/resource.go +++ b/schema/resource.go @@ -77,6 +77,7 @@ func (r *Resource) GetValues() scalar.Vector { return r.data } +//nolint:revive func (r *Resource) CalculateCQID(deterministicCQID bool) error { if !deterministicCQID { return r.storeCQID(uuid.New()) diff --git a/serve/source_v2_test.go b/serve/source_v2_test.go index f2ce2cebe6..1fdc1a5ce0 100644 --- a/serve/source_v2_test.go +++ b/serve/source_v2_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" "net" "sync" @@ -30,7 +29,7 @@ type testExecutionClient struct{} var _ schema.ClientMeta = &testExecutionClient{} -var errTestExecutionClientErr = fmt.Errorf("error in newTestExecutionClientErr") +// var errTestExecutionClientErr = fmt.Errorf("error in newTestExecutionClientErr") func testTable(name string) *schema.Table { return &schema.Table{ @@ -58,10 +57,6 @@ func newTestExecutionClient(context.Context, zerolog.Logger, specs.Source, sourc return &testExecutionClient{}, nil } -func newTestExecutionClientErr(context.Context, zerolog.Logger, specs.Source, source.Options) (schema.ClientMeta, error) { - return nil, errTestExecutionClientErr -} - func bufSourceDialer(context.Context, string) (net.Conn, error) { testSourceListenerLock.Lock() defer testSourceListenerLock.Unlock() diff --git a/types/register.go b/types/register.go index 148d64b45d..ba44b34a4b 100644 --- a/types/register.go +++ b/types/register.go @@ -12,8 +12,5 @@ func RegisterAllExtensions() error { if err := arrow.RegisterExtensionType(&InetType{}); err != nil { return err } - if err := arrow.RegisterExtensionType(&MacType{}); err != nil { - return err - } - return nil + return arrow.RegisterExtensionType(&MacType{}) } diff --git a/types/uuid.go b/types/uuid.go index 5375dcb662..c85d55464f 100644 --- a/types/uuid.go +++ b/types/uuid.go @@ -201,7 +201,7 @@ func (*UUIDType) ExtensionName() string { return "uuid" } -func (e *UUIDType) String() string { +func (*UUIDType) String() string { return "uuid" } From 2734aefa46a07bfdc12e930d89bcd0ae9b411f75 Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Sun, 14 May 2023 19:42:07 +0300 Subject: [PATCH 04/14] minor fixes --- plugins/source/plugin.go | 20 +++-- plugins/source/plugin_test.go | 16 ++++ scalar/convert.go | 23 +++++ scalar/scalar.go | 4 + scalar/timestamp.go | 153 ++++++++++++++++++++++++++++++++++ scalar/timestamp_test.go | 51 ++++++++++++ serve/destination.go | 5 ++ serve/source_v2_test.go | 2 +- types/register.go | 13 +++ 9 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 scalar/timestamp.go create mode 100644 scalar/timestamp_test.go diff --git a/plugins/source/plugin.go b/plugins/source/plugin.go index 4aefcd0ec7..a426f3f74a 100644 --- a/plugins/source/plugin.go +++ b/plugins/source/plugin.go @@ -70,6 +70,7 @@ type Plugin struct { unmanaged bool // titleTransformer allows the plugin to control how table names get turned into titles for generated documentation titleTransformer func(*schema.Table) string + syncTime *time.Time } const ( @@ -77,7 +78,7 @@ const ( ) // Add internal columns -func addInternalColumns(tables []*schema.Table) error { +func (p *Plugin) addInternalColumns(tables []*schema.Table) error { for _, table := range tables { if c := table.Column("_cq_id"); c != nil { return fmt.Errorf("table %s already has column _cq_id", table.Name) @@ -86,8 +87,17 @@ func addInternalColumns(tables []*schema.Table) error { if len(table.PrimaryKeys()) == 0 { cqID.CreationOptions.PrimaryKey = true } - table.Columns = append([]schema.Column{cqID, schema.CqParentIDColumn}, table.Columns...) - if err := addInternalColumns(table.Relations); err != nil { + cqSourceName := schema.CqSourceNameColumn + cqSyncTime := schema.CqSyncTimeColumn + cqSourceName.Resolver = func(_ context.Context, _ schema.ClientMeta, resource *schema.Resource, c schema.Column) error { + return resource.Set(c.Name, p.spec.Name) + } + cqSyncTime.Resolver = func(_ context.Context, _ schema.ClientMeta, resource *schema.Resource, c schema.Column) error { + return resource.Set(c.Name, p.syncTime) + } + + table.Columns = append([]schema.Column{schema.CqSourceNameColumn, schema.CqSyncTimeColumn, cqID, schema.CqParentIDColumn}, table.Columns...) + if err := p.addInternalColumns(table.Relations); err != nil { return err } } @@ -152,7 +162,7 @@ func NewPlugin(name string, version string, tables []*schema.Table, newExecution panic(err) } if p.internalColumns { - if err := addInternalColumns(p.tables); err != nil { + if err := p.addInternalColumns(p.tables); err != nil { panic(err) } } @@ -261,7 +271,7 @@ func (p *Plugin) Init(ctx context.Context, spec specs.Source) error { return err } if p.internalColumns { - if err := addInternalColumns(tables); err != nil { + if err := p.addInternalColumns(tables); err != nil { return err } } diff --git a/plugins/source/plugin_test.go b/plugins/source/plugin_test.go index 8b3c84db82..acb0727de3 100644 --- a/plugins/source/plugin_test.go +++ b/plugins/source/plugin_test.go @@ -161,6 +161,8 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, @@ -212,11 +214,15 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.Int64{Value: 3, Valid: true}, @@ -236,6 +242,8 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, @@ -257,6 +265,8 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, @@ -283,11 +293,15 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.Int64{Value: 3, Valid: true}, @@ -308,6 +322,8 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { + &scalar.String{}, + &scalar.Timestamp{}, &scalar.UUID{Value: deterministicStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, diff --git a/scalar/convert.go b/scalar/convert.go index fbd12531e2..c40f0acdb2 100644 --- a/scalar/convert.go +++ b/scalar/convert.go @@ -2,6 +2,7 @@ package scalar import ( "reflect" + "time" ) // underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 @@ -122,6 +123,28 @@ func underlyingBytesType(val any) (any, bool) { return nil, false } +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val any) (any, bool) { + refVal := reflect.ValueOf(val) + + // nolint:gocritic,revive + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return nil, false +} + // underlyingUUIDType gets the underlying type that can be converted to [16]byte func underlyingUUIDType(val any) (any, bool) { refVal := reflect.ValueOf(val) diff --git a/scalar/scalar.go b/scalar/scalar.go index 4b665b66c7..b1f29bc56c 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -43,6 +43,8 @@ func (v Vector) Equal(r Vector) bool { func NewScalar(dt arrow.DataType) Scalar { switch dt.ID() { + case arrow.TIMESTAMP: + return &Timestamp{} case arrow.BINARY: return &Binary{} case arrow.STRING: @@ -97,6 +99,8 @@ func AppendToBuilder(bldr array.Builder, s Scalar) { bldr.(*array.Float64Builder).Append(s.(*Float64).Value) case arrow.BOOL: bldr.(*array.BooleanBuilder).Append(s.(*Bool).Value) + case arrow.TIMESTAMP: + bldr.(*array.TimestampBuilder).Append(arrow.Timestamp(s.(*Timestamp).Value.UnixMicro())) case arrow.LIST: lb := bldr.(*array.ListBuilder) if s.IsValid() { diff --git a/scalar/timestamp.go b/scalar/timestamp.go new file mode 100644 index 0000000000..fe93f7fbb1 --- /dev/null +++ b/scalar/timestamp.go @@ -0,0 +1,153 @@ +package scalar + +import ( + "encoding" + "fmt" + "math" + "time" + + "github.com/apache/arrow/go/v13/arrow" +) + +// const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +// const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +// const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" + +// this is the default format used by time.Time.String() +const defaultStringFormat = "2006-01-02 15:04:05.999999999 -0700 MST" + +// this is used by arrow string format (time is in UTC) +const arrowStringFormat = "2006-01-02 15:04:05.999999999" + +// const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( +// negativeInfinityMicrosecondOffset = -9223372036854775808 +// infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamp struct { + Valid bool + Value time.Time +} + +func (s *Timestamp) IsValid() bool { + return s.Valid +} + +func (*Timestamp) DataType() arrow.DataType { + return arrow.FixedWidthTypes.Timestamp_us +} + +func (s *Timestamp) Equal(rhs Scalar) bool { + if rhs == nil { + return false + } + r, ok := rhs.(*Timestamp) + if !ok { + return false + } + return s.Valid == r.Valid && s.Value.Equal(r.Value) +} + +func (s *Timestamp) String() string { + if !s.Valid { + return "(null)" + } + return s.Value.Format(time.RFC3339) +} + +func (s *Timestamp) Set(val any) error { + if val == nil { + return nil + } + + switch value := val.(type) { + case int: + if value < 0 { + return &ValidationError{Type: arrow.FixedWidthTypes.Timestamp_us, Msg: "negative timestamp"} + } + s.Value = time.Unix(int64(value), 0).UTC() + case int64: + if value < 0 { + return &ValidationError{Type: arrow.FixedWidthTypes.Timestamp_us, Msg: "negative timestamp"} + } + s.Value = time.Unix(value, 0).UTC() + case uint64: + if value > math.MaxInt64 { + return &ValidationError{Type: arrow.FixedWidthTypes.Timestamp_us, Msg: "uint64 bigger than MaxInt64", Value: value} + } + s.Value = time.Unix(int64(value), 0).UTC() + case time.Time: + s.Value = value.UTC() + case *time.Time: + if value == nil { + return nil + } + return s.Set(*value) + case string: + return s.DecodeText([]byte(value)) + case *string: + if value == nil { + return nil + } + return s.Set(*value) + default: + if originalSrc, ok := underlyingTimeType(val); ok { + return s.Set(originalSrc) + } + if value, ok := value.(encoding.TextMarshaler); ok { + text, err := value.MarshalText() + if err == nil { + return s.Set(string(text)) + } + // fall through to String() method + } + if value, ok := value.(fmt.Stringer); ok { + str := value.String() + return s.Set(str) + } + return &ValidationError{Type: arrow.FixedWidthTypes.Timestamp_us, Msg: noConversion, Value: value} + } + s.Valid = true + return nil +} + +func (s *Timestamp) DecodeText(src []byte) error { + if len(src) == 0 { + return nil + } + + sbuf := string(src) + // nolint:gocritic,revive + switch sbuf { + default: + var tim time.Time + var err error + + if len(sbuf) > len(defaultStringFormat)+1 && sbuf[len(defaultStringFormat)+1] == 'm' { + sbuf = sbuf[:len(defaultStringFormat)] + } + + // there is no good way of detecting format so we just try few of them + tim, err = time.Parse(time.RFC3339, sbuf) + if err == nil { + s.Value = tim.UTC() + s.Valid = true + return nil + } + tim, err = time.Parse(defaultStringFormat, sbuf) + if err == nil { + s.Value = tim.UTC() + s.Valid = true + return nil + } + tim, err = time.Parse(arrowStringFormat, sbuf) + if err == nil { + s.Value = tim.UTC() + s.Valid = true + return nil + } + return &ValidationError{Type: arrow.FixedWidthTypes.Timestamp_us, Msg: "cannot parse timestamp", Value: sbuf, Err: err} + } +} diff --git a/scalar/timestamp_test.go b/scalar/timestamp_test.go new file mode 100644 index 0000000000..79c0c98673 --- /dev/null +++ b/scalar/timestamp_test.go @@ -0,0 +1,51 @@ +package scalar + +import ( + "testing" + "time" +) + +type TimestampSt struct { + time.Time +} + +func TestTimestampSet(t *testing.T) { + type _time time.Time + + timeInstance := time.Date(2105, 7, 23, 22, 23, 37, 750076110, time.UTC) + timeRFC3339NanoBytes, _ := timeInstance.MarshalText() + + successfulTests := []struct { + source any + result Timestamp + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: Timestamp{Value: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: Timestamp{Value: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: Timestamp{Value: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: Timestamp{Value: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: Timestamp{Value: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Valid: true}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: Timestamp{Value: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: int(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local).Unix()), result: Timestamp{Value: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: uint64(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local).Unix()), result: Timestamp{Value: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local).Unix(), result: Timestamp{Value: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: Timestamp{Value: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, + {source: string(timeRFC3339NanoBytes), result: Timestamp{Value: time.Date(2105, 7, 23, 22, 23, 37, 750076110, time.UTC), Valid: true}}, + {source: "2150-10-15 07:25:09.75007611 +0000 UTC", result: Timestamp{Value: time.Date(2150, 10, 15, 7, 25, 9, 750076110, time.UTC), Valid: true}}, + {source: timeInstance.String(), result: Timestamp{Value: time.Date(2105, 7, 23, 22, 23, 37, 750076110, time.UTC), Valid: true}}, + {source: TimestampSt{timeInstance}, result: Timestamp{Value: time.Date(2105, 7, 23, 22, 23, 37, 750076110, time.UTC), Valid: true}}, + {source: "", result: Timestamp{}}, + } + + for i, tt := range successfulTests { + var r Timestamp + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} diff --git a/serve/destination.go b/serve/destination.go index 72a3e4770a..cba93b90a5 100644 --- a/serve/destination.go +++ b/serve/destination.go @@ -155,6 +155,11 @@ func newCmdDestinationServe(serve *destinationServe) *cobra.Command { if err := types.RegisterAllExtensions(); err != nil { return err } + defer func() { + if err := types.UnregisterAllExtensions(); err != nil { + logger.Error().Err(err).Msg("Failed to unregister extensions") + } + }() ctx := cmd.Context() c := make(chan os.Signal, 1) diff --git a/serve/source_v2_test.go b/serve/source_v2_test.go index 1fdc1a5ce0..8a541611e9 100644 --- a/serve/source_v2_test.go +++ b/serve/source_v2_test.go @@ -199,7 +199,7 @@ func TestSourceSuccess(t *testing.T) { if tableName != "test_table" { t.Fatalf("Expected resource with table name test_table. got: %s", tableName) } - if len(resource.Columns()) != 3 { + if len(resource.Columns()) != 5 { t.Fatalf("Expected resource with data length 3 but got %d", len(resource.Columns())) } totalResources++ diff --git a/types/register.go b/types/register.go index ba44b34a4b..709c955824 100644 --- a/types/register.go +++ b/types/register.go @@ -14,3 +14,16 @@ func RegisterAllExtensions() error { } return arrow.RegisterExtensionType(&MacType{}) } + +func UnregisterAllExtensions() error { + if err := arrow.UnregisterExtensionType(ExtensionTypes.Mac.ExtensionName()); err != nil { + return err + } + if err := arrow.UnregisterExtensionType(ExtensionTypes.Inet.ExtensionName()); err != nil { + return err + } + if err := arrow.UnregisterExtensionType(ExtensionTypes.JSON.ExtensionName()); err != nil { + return err + } + return arrow.UnregisterExtensionType(ExtensionTypes.UUID.ExtensionName()) +} From 093f04fbf54ad4f505998c2018f63e87a42fe70f Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 14:46:41 +0300 Subject: [PATCH 05/14] rebase --- internal/servers/destination/v1/destinations.go | 2 +- plugins/source/docs.go | 4 ++-- plugins/source/plugin.go | 2 +- scalar/mac.go | 4 ++-- scalar/scalar.go | 6 +++--- schema/resource.go | 2 +- types/register.go | 4 ++-- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/servers/destination/v1/destinations.go b/internal/servers/destination/v1/destinations.go index c3595a942b..447c03b596 100644 --- a/internal/servers/destination/v1/destinations.go +++ b/internal/servers/destination/v1/destinations.go @@ -146,7 +146,7 @@ func (s *Server) Write(msg pb.Destination_WriteServer) error { func setCQIDAsPrimaryKeysForTables(tables schema.Tables) { for _, table := range tables { for i, col := range table.Columns { - table.Columns[i].CreationOptions.PrimaryKey = col.Name == schema.CqIDColumn.Name + table.Columns[i].PrimaryKey = col.Name == schema.CqIDColumn.Name } setCQIDAsPrimaryKeysForTables(table.Relations) } diff --git a/plugins/source/docs.go b/plugins/source/docs.go index b9b10d11f4..64689260e5 100644 --- a/plugins/source/docs.go +++ b/plugins/source/docs.go @@ -144,8 +144,8 @@ func (p *Plugin) jsonifyTables(tables schema.Tables) []jsonTable { jsonColumns[c] = jsonColumn{ Name: col.Name, Type: col.Type.String(), - IsPrimaryKey: col.CreationOptions.PrimaryKey, - IsIncrementalKey: col.CreationOptions.IncrementalKey, + IsPrimaryKey: col.PrimaryKey, + IsIncrementalKey: col.IncrementalKey, } } jsonTables[i] = jsonTable{ diff --git a/plugins/source/plugin.go b/plugins/source/plugin.go index a426f3f74a..407621fc28 100644 --- a/plugins/source/plugin.go +++ b/plugins/source/plugin.go @@ -85,7 +85,7 @@ func (p *Plugin) addInternalColumns(tables []*schema.Table) error { } cqID := schema.CqIDColumn if len(table.PrimaryKeys()) == 0 { - cqID.CreationOptions.PrimaryKey = true + cqID.PrimaryKey = true } cqSourceName := schema.CqSourceNameColumn cqSyncTime := schema.CqSyncTimeColumn diff --git a/scalar/mac.go b/scalar/mac.go index fb1e061fe8..51821edbad 100644 --- a/scalar/mac.go +++ b/scalar/mac.go @@ -17,7 +17,7 @@ func (s *Mac) IsValid() bool { } func (*Mac) DataType() arrow.DataType { - return types.ExtensionTypes.Mac + return types.ExtensionTypes.MAC } func (s *Mac) String() string { @@ -68,7 +68,7 @@ func (s *Mac) Set(val any) error { if originalSrc, ok := underlyingPtrType(value); ok { return s.Set(originalSrc) } - return &ValidationError{Type: types.ExtensionTypes.Mac, Msg: noConversion, Value: value} + return &ValidationError{Type: types.ExtensionTypes.MAC, Msg: noConversion, Value: value} } s.Valid = true diff --git a/scalar/scalar.go b/scalar/scalar.go index b1f29bc56c..b1d8a80d78 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -65,7 +65,7 @@ func NewScalar(dt arrow.DataType) Scalar { return &UUID{} case arrow.TypeEqual(dt, types.ExtensionTypes.JSON): return &JSON{} - case arrow.TypeEqual(dt, types.ExtensionTypes.Mac): + case arrow.TypeEqual(dt, types.ExtensionTypes.MAC): return &Mac{} case arrow.TypeEqual(dt, types.ExtensionTypes.Inet): return &Inet{} @@ -117,8 +117,8 @@ func AppendToBuilder(bldr array.Builder, s Scalar) { bldr.(*types.UUIDBuilder).Append(s.(*UUID).Value) case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.JSON): bldr.(*types.JSONBuilder).Append(s.(*JSON).Value) - case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Mac): - bldr.(*types.MacBuilder).Append(s.(*Mac).Value) + case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.MAC): + bldr.(*types.MACBuilder).Append(s.(*Mac).Value) case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Inet): bldr.(*types.InetBuilder).Append(s.(*Inet).Value) default: diff --git a/schema/resource.go b/schema/resource.go index 833e22fc45..fbbaf6667b 100644 --- a/schema/resource.go +++ b/schema/resource.go @@ -108,7 +108,7 @@ func (r *Resource) storeCQID(value uuid.UUID) error { func (r *Resource) Validate() error { var missingPks []string for i, c := range r.Table.Columns { - if c.CreationOptions.PrimaryKey { + if c.PrimaryKey { if !r.data[i].IsValid() { missingPks = append(missingPks, c.Name) } diff --git a/types/register.go b/types/register.go index 709c955824..65d9822281 100644 --- a/types/register.go +++ b/types/register.go @@ -12,11 +12,11 @@ func RegisterAllExtensions() error { if err := arrow.RegisterExtensionType(&InetType{}); err != nil { return err } - return arrow.RegisterExtensionType(&MacType{}) + return arrow.RegisterExtensionType(&MACType{}) } func UnregisterAllExtensions() error { - if err := arrow.UnregisterExtensionType(ExtensionTypes.Mac.ExtensionName()); err != nil { + if err := arrow.UnregisterExtensionType(ExtensionTypes.MAC.ExtensionName()); err != nil { return err } if err := arrow.UnregisterExtensionType(ExtensionTypes.Inet.ExtensionName()); err != nil { From 09f9478d7598d04b1b3ef7842ba327373611f14e Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 14:48:32 +0300 Subject: [PATCH 06/14] fix lints --- plugins/source/docs_test.go | 25 ++++++++++++------------ plugins/source/plugin_test.go | 8 +++----- plugins/source/templates/table.md.go.tpl | 2 +- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/plugins/source/docs_test.go b/plugins/source/docs_test.go index 7097668dfe..f26e4399cf 100644 --- a/plugins/source/docs_test.go +++ b/plugins/source/docs_test.go @@ -23,14 +23,14 @@ var testTables = []*schema.Table{ Type: arrow.PrimitiveTypes.Int64, }, { - Name: "id_col", - Type: arrow.PrimitiveTypes.Int64, - CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true}, + Name: "id_col", + Type: arrow.PrimitiveTypes.Int64, + PrimaryKey: true, }, { - Name: "id_col2", - Type: arrow.PrimitiveTypes.Int64, - CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true}, + Name: "id_col2", + Type: arrow.PrimitiveTypes.Int64, + PrimaryKey: true, }, }, Relations: []*schema.Table{ @@ -88,14 +88,15 @@ var testTables = []*schema.Table{ Type: arrow.PrimitiveTypes.Int64, }, { - Name: "id_col", - Type: arrow.PrimitiveTypes.Int64, - CreationOptions: schema.ColumnCreationOptions{PrimaryKey: true, IncrementalKey: true}, + Name: "id_col", + Type: arrow.PrimitiveTypes.Int64, + PrimaryKey: true, + IncrementalKey: true, }, { - Name: "id_col2", - Type: arrow.PrimitiveTypes.Int64, - CreationOptions: schema.ColumnCreationOptions{IncrementalKey: true}, + Name: "id_col2", + Type: arrow.PrimitiveTypes.Int64, + IncrementalKey: true, }, }, }, diff --git a/plugins/source/plugin_test.go b/plugins/source/plugin_test.go index acb0727de3..68823ae449 100644 --- a/plugins/source/plugin_test.go +++ b/plugins/source/plugin_test.go @@ -61,11 +61,9 @@ func testTableSuccessWithPK() *schema.Table { Resolver: testResolverSuccess, Columns: []schema.Column{ { - Name: "test_column", - Type: arrow.PrimitiveTypes.Int64, - CreationOptions: schema.ColumnCreationOptions{ - PrimaryKey: true, - }, + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + PrimaryKey: true, }, }, } diff --git a/plugins/source/templates/table.md.go.tpl b/plugins/source/templates/table.md.go.tpl index 45bee702cd..45cc64d0b0 100644 --- a/plugins/source/templates/table.md.go.tpl +++ b/plugins/source/templates/table.md.go.tpl @@ -40,5 +40,5 @@ The following tables depend on {{.Name}}: | Name | Type | | ------------- | ------------- | {{- range $column := $.Columns }} -|{{$column.Name}}{{if $column.CreationOptions.PrimaryKey}} (PK){{end}}{{if $column.CreationOptions.IncrementalKey}} (Incremental Key){{end}}|{{$column.Type | formatType}}| +|{{$column.Name}}{{if $column.PrimaryKey}} (PK){{end}}{{if $column.IncrementalKey}} (Incremental Key){{end}}|{{$column.Type | formatType}}| {{- end }} \ No newline at end of file From c479525b4b0ed4c2ee2ebf5cd9236d72bb13700c Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 17:49:20 +0300 Subject: [PATCH 07/14] fix json on sdk side --- scalar/scalar.go | 2 +- types/json.go | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/scalar/scalar.go b/scalar/scalar.go index b1d8a80d78..57a933c6c2 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -116,7 +116,7 @@ func AppendToBuilder(bldr array.Builder, s Scalar) { case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.UUID): bldr.(*types.UUIDBuilder).Append(s.(*UUID).Value) case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.JSON): - bldr.(*types.JSONBuilder).Append(s.(*JSON).Value) + bldr.(*types.JSONBuilder).AppendBytes(s.(*JSON).Value) case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.MAC): bldr.(*types.MACBuilder).Append(s.(*Mac).Value) case arrow.TypeEqual(s.DataType(), types.ExtensionTypes.Inet): diff --git a/types/json.go b/types/json.go index 28f4cca87d..457ba273a7 100644 --- a/types/json.go +++ b/types/json.go @@ -20,6 +20,15 @@ func NewJSONBuilder(builder *array.ExtensionBuilder) *JSONBuilder { return &JSONBuilder{ExtensionBuilder: builder} } +func (b *JSONBuilder) AppendBytes(v []byte) { + if v == nil { + b.AppendNull() + return + } + + b.ExtensionBuilder.Builder.(*array.BinaryBuilder).Append(v) +} + func (b *JSONBuilder) Append(v any) { if v == nil { b.AppendNull() From a37d929f83fb67abbf852f37b0f223c9451e0381 Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 18:39:47 +0300 Subject: [PATCH 08/14] fix: source_name --- go.sum | 4 ++-- internal/servers/source/v2/source.go | 8 ++++---- plugins/source/benchmark_test.go | 1 + plugins/source/plugin.go | 4 ++-- plugins/source/plugin_test.go | 18 ++++++++++-------- plugins/source/testing.go | 3 ++- scalar/scalar.go | 4 ++++ 7 files changed, 25 insertions(+), 17 deletions(-) diff --git a/go.sum b/go.sum index a6bbc31786..8f1a67186c 100644 --- a/go.sum +++ b/go.sum @@ -47,8 +47,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8 h1:CmgLSEGQNLHpUQ5cU4L4aF7cuJZRnc1toIIWqC1gmPg= github.com/cloudquery/arrow/go/v13 v13.0.0-20230509053643-898a79b1d3c8/go.mod h1:/XatdE3kDIBqZKhZ7OBUHwP2jaASDFZHqF4puOWM8po= -github.com/cloudquery/plugin-pb-go v1.0.6 h1:jWLXFUGgobO28rBERuipSUbr6Rqoth4nzGZ/XEQD86w= -github.com/cloudquery/plugin-pb-go v1.0.6/go.mod h1:vAGA27psem7ZZNAY4a3S9TKuA/JDQWstjKcHPJX91Mc= +github.com/cloudquery/plugin-pb-go v1.0.8 h1:wn3GXhcNItcP+6wUUZuzUFbvdL59liKBO37/izMi+FQ= +github.com/cloudquery/plugin-pb-go v1.0.8/go.mod h1:vAGA27psem7ZZNAY4a3S9TKuA/JDQWstjKcHPJX91Mc= github.com/cloudquery/plugin-sdk/v2 v2.7.0 h1:hRXsdEiaOxJtsn/wZMFQC9/jPfU1MeMK3KF+gPGqm7U= github.com/cloudquery/plugin-sdk/v2 v2.7.0/go.mod h1:pAX6ojIW99b/Vg4CkhnsGkRIzNaVEceYMR+Bdit73ug= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= diff --git a/internal/servers/source/v2/source.go b/internal/servers/source/v2/source.go index 1909b85b78..a7a8f2dbba 100644 --- a/internal/servers/source/v2/source.go +++ b/internal/servers/source/v2/source.go @@ -72,21 +72,21 @@ func (s *Server) Init(ctx context.Context, req *pb.Init_Request) (*pb.Init_Respo if err := dec.Decode(&spec); err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to decode spec: %v", err) } - + if err := s.Plugin.Init(ctx, spec); err != nil { return nil, status.Errorf(codes.Internal, "failed to init plugin: %v", err) } return &pb.Init_Response{}, nil } -func (s *Server) Sync(_ *pb.Sync_Request, stream pb.Source_SyncServer) error { +func (s *Server) Sync(req *pb.Sync_Request, stream pb.Source_SyncServer) error { resources := make(chan *schema.Resource) var syncErr error ctx := stream.Context() - + go func() { defer close(resources) - err := s.Plugin.Sync(ctx, resources) + err := s.Plugin.Sync(ctx, req.SyncTime.AsTime(), resources) if err != nil { syncErr = fmt.Errorf("failed to sync resources: %w", err) } diff --git a/plugins/source/benchmark_test.go b/plugins/source/benchmark_test.go index eb81e31b9f..71ccdc929d 100644 --- a/plugins/source/benchmark_test.go +++ b/plugins/source/benchmark_test.go @@ -204,6 +204,7 @@ func (s *Benchmark) Run() { g.Go(func() error { defer close(resources) return s.plugin.Sync(ctx, + time.Now(), resources) }) s.b.StartTimer() diff --git a/plugins/source/plugin.go b/plugins/source/plugin.go index 407621fc28..8da91d3aa3 100644 --- a/plugins/source/plugin.go +++ b/plugins/source/plugin.go @@ -96,7 +96,7 @@ func (p *Plugin) addInternalColumns(tables []*schema.Table) error { return resource.Set(c.Name, p.syncTime) } - table.Columns = append([]schema.Column{schema.CqSourceNameColumn, schema.CqSyncTimeColumn, cqID, schema.CqParentIDColumn}, table.Columns...) + table.Columns = append([]schema.Column{cqSourceName, cqSyncTime, cqID, schema.CqParentIDColumn}, table.Columns...) if err := p.addInternalColumns(table.Relations); err != nil { return err } @@ -294,7 +294,7 @@ func (p *Plugin) Init(ctx context.Context, spec specs.Source) error { } // Sync is syncing data from the requested tables in spec to the given channel -func (p *Plugin) Sync(ctx context.Context, res chan<- *schema.Resource) error { +func (p *Plugin) Sync(ctx context.Context, syncTime time.Time, res chan<- *schema.Resource) error { if !p.mu.TryLock() { return fmt.Errorf("plugin already in use") } diff --git a/plugins/source/plugin_test.go b/plugins/source/plugin_test.go index 68823ae449..0f3547813a 100644 --- a/plugins/source/plugin_test.go +++ b/plugins/source/plugin_test.go @@ -3,6 +3,7 @@ package source import ( "context" "testing" + "time" "github.com/apache/arrow/go/v13/arrow" "github.com/cloudquery/plugin-pb-go/specs" @@ -159,7 +160,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, @@ -212,14 +213,14 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, @@ -240,7 +241,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, @@ -263,7 +264,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, @@ -291,14 +292,14 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, @@ -320,7 +321,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{}, + &scalar.String{Value:"testSource" ,Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: deterministicStableUUID, Valid: true}, &scalar.UUID{}, @@ -385,6 +386,7 @@ func testSyncTable(t *testing.T, tc syncTestCase, scheduler specs.Scheduler, det g.Go(func() error { defer close(resources) return plugin.Sync(ctx, + time.Now(), resources) }) diff --git a/plugins/source/testing.go b/plugins/source/testing.go index 0f86081ec8..161778bca9 100644 --- a/plugins/source/testing.go +++ b/plugins/source/testing.go @@ -3,6 +3,7 @@ package source import ( "context" "testing" + "time" "github.com/cloudquery/plugin-pb-go/specs" "github.com/cloudquery/plugin-sdk/v3/schema" @@ -33,7 +34,7 @@ func TestPluginSync(t *testing.T, plugin *Plugin, spec specs.Source, opts ...Tes go func() { defer close(resourcesChannel) - syncErr = plugin.Sync(context.Background(), resourcesChannel) + syncErr = plugin.Sync(context.Background(), time.Now(), resourcesChannel) }() syncedResources := make([]*schema.Resource, 0) diff --git a/scalar/scalar.go b/scalar/scalar.go index 57a933c6c2..4bc7778aba 100644 --- a/scalar/scalar.go +++ b/scalar/scalar.go @@ -82,6 +82,10 @@ func NewScalar(dt arrow.DataType) Scalar { } func AppendToBuilder(bldr array.Builder, s Scalar) { + if !s.IsValid() { + bldr.AppendNull() + return + } switch s.DataType().ID() { case arrow.BINARY: bldr.(*array.BinaryBuilder).Append(s.(*Binary).Value) From 4e585159d62c58a93dc09e66f2e433cf4c554213 Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 18:49:06 +0300 Subject: [PATCH 09/14] fix more stuff --- plugins/source/plugin.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/source/plugin.go b/plugins/source/plugin.go index 8da91d3aa3..5a0363af1e 100644 --- a/plugins/source/plugin.go +++ b/plugins/source/plugin.go @@ -70,7 +70,7 @@ type Plugin struct { unmanaged bool // titleTransformer allows the plugin to control how table names get turned into titles for generated documentation titleTransformer func(*schema.Table) string - syncTime *time.Time + syncTime time.Time } const ( @@ -299,7 +299,7 @@ func (p *Plugin) Sync(ctx context.Context, syncTime time.Time, res chan<- *schem return fmt.Errorf("plugin already in use") } defer p.mu.Unlock() - + p.syncTime = syncTime if p.client == nil { var err error p.client, err = p.newExecutionClient(ctx, p.logger, p.spec, Options{Backend: p.backend}) From 0a83c354f7dcbe588772944f41cc9687acc63c0a Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 18:59:22 +0300 Subject: [PATCH 10/14] fix fmt --- internal/servers/source/v2/source.go | 4 ++-- plugins/source/plugin_test.go | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/servers/source/v2/source.go b/internal/servers/source/v2/source.go index a7a8f2dbba..a010fefef3 100644 --- a/internal/servers/source/v2/source.go +++ b/internal/servers/source/v2/source.go @@ -72,7 +72,7 @@ func (s *Server) Init(ctx context.Context, req *pb.Init_Request) (*pb.Init_Respo if err := dec.Decode(&spec); err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to decode spec: %v", err) } - + if err := s.Plugin.Init(ctx, spec); err != nil { return nil, status.Errorf(codes.Internal, "failed to init plugin: %v", err) } @@ -83,7 +83,7 @@ func (s *Server) Sync(req *pb.Sync_Request, stream pb.Source_SyncServer) error { resources := make(chan *schema.Resource) var syncErr error ctx := stream.Context() - + go func() { defer close(resources) err := s.Plugin.Sync(ctx, req.SyncTime.AsTime(), resources) diff --git a/plugins/source/plugin_test.go b/plugins/source/plugin_test.go index 0f3547813a..94d2473ae6 100644 --- a/plugins/source/plugin_test.go +++ b/plugins/source/plugin_test.go @@ -160,7 +160,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, @@ -213,14 +213,14 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, @@ -241,7 +241,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, @@ -264,7 +264,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, @@ -292,14 +292,14 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, @@ -321,7 +321,7 @@ var syncTestCases = []syncTestCase{ }, data: []scalar.Vector{ { - &scalar.String{Value:"testSource" ,Valid: true}, + &scalar.String{Value: "testSource", Valid: true}, &scalar.Timestamp{}, &scalar.UUID{Value: deterministicStableUUID, Valid: true}, &scalar.UUID{}, From 4f1e3378da789f360902acf4b174ed0c7e6eabce Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 19:05:26 +0300 Subject: [PATCH 11/14] fix tests --- plugins/source/plugin_test.go | 20 ++++++++++--------- scalar/float.go | 12 ------------ scalar/float_test.go | 37 +++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/plugins/source/plugin_test.go b/plugins/source/plugin_test.go index 94d2473ae6..0c288e4b5c 100644 --- a/plugins/source/plugin_test.go +++ b/plugins/source/plugin_test.go @@ -24,6 +24,8 @@ var _ schema.ClientMeta = &testExecutionClient{} var deterministicStableUUID = uuid.MustParse("c25355aab52c5b70a4e0c9991f5a3b87") var randomStableUUID = uuid.MustParse("00000000000040008000000000000000") +var testSyncTime = time.Now() + func testResolverSuccess(_ context.Context, _ schema.ClientMeta, _ *schema.Resource, res chan<- any) error { res <- map[string]any{ "TestColumn": 3, @@ -161,7 +163,7 @@ var syncTestCases = []syncTestCase{ data: []scalar.Vector{ { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, @@ -214,14 +216,14 @@ var syncTestCases = []syncTestCase{ data: []scalar.Vector{ { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.Int64{Value: 3, Valid: true}, @@ -242,7 +244,7 @@ var syncTestCases = []syncTestCase{ data: []scalar.Vector{ { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, @@ -265,7 +267,7 @@ var syncTestCases = []syncTestCase{ data: []scalar.Vector{ { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, @@ -293,14 +295,14 @@ var syncTestCases = []syncTestCase{ data: []scalar.Vector{ { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, }, { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.UUID{Value: randomStableUUID, Valid: true}, &scalar.Int64{Value: 3, Valid: true}, @@ -322,7 +324,7 @@ var syncTestCases = []syncTestCase{ data: []scalar.Vector{ { &scalar.String{Value: "testSource", Valid: true}, - &scalar.Timestamp{}, + &scalar.Timestamp{Value: testSyncTime, Valid: true}, &scalar.UUID{Value: deterministicStableUUID, Valid: true}, &scalar.UUID{}, &scalar.Int64{Value: 3, Valid: true}, @@ -386,7 +388,7 @@ func testSyncTable(t *testing.T, tc syncTestCase, scheduler specs.Scheduler, det g.Go(func() error { defer close(resources) return plugin.Sync(ctx, - time.Now(), + testSyncTime, resources) }) diff --git a/scalar/float.go b/scalar/float.go index b413e374ec..820c157245 100644 --- a/scalar/float.go +++ b/scalar/float.go @@ -48,14 +48,8 @@ func (s *Float32) Set(val any) error { case int8: s.Value = float32(value) case int16: - if value > math.MaxInt8 { - return &ValidationError{Type: arrow.PrimitiveTypes.Float32, Msg: "int16 bigger than MaxInt8", Value: value} - } s.Value = float32(value) case int32: - if value > math.MaxInt8 { - return &ValidationError{Type: arrow.PrimitiveTypes.Float32, Msg: "int32 bigger than MaxInt8", Value: value} - } s.Value = float32(value) case int64: if value > math.MaxInt32 { @@ -63,14 +57,8 @@ func (s *Float32) Set(val any) error { } s.Value = float32(value) case uint8: - if value > math.MaxInt8 { - return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "uint8 bigger than MaxInt8", Value: value} - } s.Value = float32(value) case uint16: - if value > math.MaxInt8 { - return &ValidationError{Type: arrow.PrimitiveTypes.Int8, Msg: "uint16 bigger than MaxInt8", Value: value} - } s.Value = float32(value) case uint32: if value > math.MaxInt32 { diff --git a/scalar/float_test.go b/scalar/float_test.go index 84c55c27f5..89299013e5 100644 --- a/scalar/float_test.go +++ b/scalar/float_test.go @@ -2,6 +2,43 @@ package scalar import "testing" +func TestFloat32Set(t *testing.T) { + successfulTests := []struct { + source any + result Float32 + }{ + {source: float32(1), result: Float32{Value: 1, Valid: true}}, + {source: float64(1), result: Float32{Value: 1, Valid: true}}, + {source: int8(1), result: Float32{Value: 1, Valid: true}}, + {source: int16(1), result: Float32{Value: 1, Valid: true}}, + {source: int32(1), result: Float32{Value: 1, Valid: true}}, + {source: int64(1), result: Float32{Value: 1, Valid: true}}, + {source: int8(-1), result: Float32{Value: -1, Valid: true}}, + {source: int16(-1), result: Float32{Value: -1, Valid: true}}, + {source: int32(-1), result: Float32{Value: -1, Valid: true}}, + {source: int64(-1), result: Float32{Value: -1, Valid: true}}, + {source: uint8(1), result: Float32{Value: 1, Valid: true}}, + {source: uint16(1), result: Float32{Value: 1, Valid: true}}, + {source: uint32(1), result: Float32{Value: 1, Valid: true}}, + {source: uint64(1), result: Float32{Value: 1, Valid: true}}, + {source: "1", result: Float32{Value: 1, Valid: true}}, + {source: _int8(1), result: Float32{Value: 1, Valid: true}}, + } + + for i, tt := range successfulTests { + var r Float32 + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Equal(&tt.result) { + t.Errorf("%d: %v != %v", i, r, tt.result) + } + } +} + + func TestFloat64Set(t *testing.T) { successfulTests := []struct { source any From 099728d926f827a08cc7023144bed093447255dc Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 20:46:06 +0300 Subject: [PATCH 12/14] fix transformer --- transformers/struct.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformers/struct.go b/transformers/struct.go index 45de643fa0..f8ae936948 100644 --- a/transformers/struct.go +++ b/transformers/struct.go @@ -338,6 +338,10 @@ func defaultGoTypeToSchemaType(v reflect.Type) (arrow.DataType, error) { if err != nil { return nil, err } + // if it's already JSON then we don't want to create list of JSON + if arrow.TypeEqual(elemValueType, types.ExtensionTypes.JSON) { + return elemValueType, nil + } return arrow.ListOf(elemValueType), nil default: From 9e1a40968f9a32c1400322147bfa745aa940f1be Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Mon, 15 May 2023 20:46:24 +0300 Subject: [PATCH 13/14] fmt fix --- scalar/float_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/scalar/float_test.go b/scalar/float_test.go index 89299013e5..ee260d51e4 100644 --- a/scalar/float_test.go +++ b/scalar/float_test.go @@ -38,7 +38,6 @@ func TestFloat32Set(t *testing.T) { } } - func TestFloat64Set(t *testing.T) { successfulTests := []struct { source any From 7a337109cb96af94baaca52cd109bb27b24ff4ab Mon Sep 17 00:00:00 2001 From: Yevgeny Pats <16490766+yevgenypats@users.noreply.github.com> Date: Tue, 16 May 2023 00:30:00 +0300 Subject: [PATCH 14/14] comment some stuff out untill another iteration --- plugins/destination/plugin_testing_overwrite.go | 2 +- schema/testdata.go | 15 ++++++++------- schema/testdata_test.go | 6 +++--- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/plugins/destination/plugin_testing_overwrite.go b/plugins/destination/plugin_testing_overwrite.go index 72260d546a..29c1bdd7b5 100644 --- a/plugins/destination/plugin_testing_overwrite.go +++ b/plugins/destination/plugin_testing_overwrite.go @@ -67,7 +67,7 @@ func (*PluginTestSuite) destinationPluginTestWriteOverwrite(ctx context.Context, secondSyncTime := syncTime.Add(time.Second).UTC() // copy first resource but update the sync time - cqIDInds := resources[0].Schema().FieldIndices(schema.PKColumnNames[0]) + cqIDInds := resources[0].Schema().FieldIndices(schema.CqIDColumn.Name) u := resources[0].Column(cqIDInds[0]).(*types.UUIDArray).Value(0) opts = schema.GenTestDataOptions{ SourceName: sourceName, diff --git a/schema/testdata.go b/schema/testdata.go index 2b1c0d178e..faafb6da85 100644 --- a/schema/testdata.go +++ b/schema/testdata.go @@ -94,7 +94,7 @@ func TestSourceColumns(testOpts ...func(o *TestSourceOptions)) []Column { // cq columns var cqColumns []Column - cqColumns = append(cqColumns, Column{Name: CqIDColumn.Name, Type: types.NewUUIDType(), NotNull: true, Unique: true}) + cqColumns = append(cqColumns, Column{Name: CqIDColumn.Name, Type: types.NewUUIDType(), NotNull: true, Unique: true, PrimaryKey: true}) cqColumns = append(cqColumns, Column{Name: CqParentIDColumn.Name, Type: types.NewUUIDType(), NotNull: true}) var basicColumns []Column @@ -155,9 +155,9 @@ func TestSourceColumns(testOpts ...func(o *TestSourceOptions)) []Column { compositeColumns = append(compositeColumns, listOfColumns(basicColumnsWithExclusions)...) } - if !opts.SkipMaps { - compositeColumns = append(compositeColumns, mapOfColumns(basicColumnsWithExclusions)...) - } + // if !opts.SkipMaps { + // compositeColumns = append(compositeColumns, mapOfColumns(basicColumnsWithExclusions)...) + // } // add JSON later, we don't want to include it as a list or map right now (it causes complications with JSON unmarshalling) basicColumns = append(basicColumns, Column{Name: "json", Type: types.NewJSONType()}) @@ -270,6 +270,7 @@ func listOfColumns(baseColumns []Column) []Column { } // mapOfColumns returns a list of columns that are maps of the given columns. +// nolint:unused func mapOfColumns(baseColumns []Column) []Column { columns := make([]Column, len(baseColumns)) for i := 0; i < len(baseColumns); i++ { @@ -289,13 +290,13 @@ func columnsToFields(columns ...Column) []arrow.Field { return fields } -var PKColumnNames = []string{"uuid_pk", "string_pk"} +// var PKColumnNames = []string{"uuid_pk"} // TestTable returns a table with columns of all types. Useful for destination testing purposes func TestTable(name string, opts ...func(o *TestSourceOptions)) *Table { var columns []Column - columns = append(columns, Column{Name: "uuid_pk", Type: types.NewUUIDType(), PrimaryKey: true, Unique: true}) - columns = append(columns, Column{Name: "string_pk", Type: arrow.BinaryTypes.String, PrimaryKey: true, Unique: true}) + // columns = append(columns, Column{Name: "uuid", Type: types.NewUUIDType()}) + // columns = append(columns, Column{Name: "string_pk", Type: arrow.BinaryTypes.String}) columns = append(columns, Column{Name: CqSourceNameColumn.Name, Type: arrow.BinaryTypes.String}) columns = append(columns, Column{Name: CqSyncTimeColumn.Name, Type: arrow.FixedWidthTypes.Timestamp_us}) columns = append(columns, TestSourceColumns(opts...)...) diff --git a/schema/testdata_test.go b/schema/testdata_test.go index d1a406cc04..53903a3cf9 100644 --- a/schema/testdata_test.go +++ b/schema/testdata_test.go @@ -5,11 +5,11 @@ import "testing" func TestTestSourceColumns_Default(t *testing.T) { // basic sanity check for tested columns defaults := TestSourceColumns() - if len(defaults) < 100 { - t.Fatal("expected at least 100 columns by default") + if len(defaults) < 73 { + t.Fatalf("expected at least 73 columns by default got: %d ", len(defaults)) } // test some specific columns - checkColumnsExist(t, defaults, []string{"int64", "date32", "timestamp_us", "string", "struct", "string_map", "string_list"}) + checkColumnsExist(t, defaults, []string{"int64", "date32", "timestamp_us", "string", "struct", "string_list"}) } func TestTestSourceColumns_SkipAll(t *testing.T) {