Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions clients/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ import (
"time"

"github.com/cloudquery/plugin-sdk/internal/pb"
"github.com/cloudquery/plugin-sdk/internal/versions"
"github.com/cloudquery/plugin-sdk/schema"
"github.com/cloudquery/plugin-sdk/specs"
"github.com/rs/zerolog"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)

Expand Down Expand Up @@ -78,7 +81,9 @@ func NewDestinationClient(ctx context.Context, registry specs.Registry, path str
}
return c, nil
case specs.RegistryLocal:
return c.newManagedClient(ctx, path)
if err := c.newManagedClient(ctx, path); err != nil {
return nil, err
}
case specs.RegistryGithub:
pathSplit := strings.Split(path, "/")
if len(pathSplit) != 2 {
Expand All @@ -90,26 +95,40 @@ func NewDestinationClient(ctx context.Context, registry specs.Registry, path str
if err := DownloadPluginFromGithub(ctx, localPath, org, name, version, PluginTypeDestination); err != nil {
return nil, err
}
return c.newManagedClient(ctx, localPath)
if err := c.newManagedClient(ctx, localPath); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unsupported registry %s", registry)
}
protocolVersion, err := c.GetProtocolVersion(ctx)
if err != nil {
return nil, err
}

if protocolVersion < versions.DestinationProtocolVersion {
return nil, fmt.Errorf("destination plugin protocol version %d is lower than client version %d. Try updating client", protocolVersion, versions.DestinationProtocolVersion)
} else if protocolVersion > versions.DestinationProtocolVersion {
return nil, fmt.Errorf("destination plugin protocol version %d is higher than client version %d. Try updating destination plugin", protocolVersion, versions.DestinationProtocolVersion)
}

return c, nil
}

// newManagedClient starts a new destination plugin process from local file, connects to it via gRPC server
// and returns a new DestinationClient
func (c *DestinationClient) newManagedClient(ctx context.Context, path string) (*DestinationClient, error) {
func (c *DestinationClient) newManagedClient(ctx context.Context, path string) error {
c.grpcSocketName = generateRandomUnixSocketName()
// spawn the plugin first and then connect
cmd := exec.CommandContext(ctx, path, "serve", "--network", "unix", "--address", c.grpcSocketName,
"--log-level", c.logger.GetLevel().String(), "--log-format", "json")
reader, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("failed to get stdout pipe: %w", err)
return fmt.Errorf("failed to get stdout pipe: %w", err)
}
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start plugin %s: %w", path, err)
return fmt.Errorf("failed to start plugin %s: %w", path, err)
}

c.wg.Add(1)
Expand Down Expand Up @@ -157,10 +176,26 @@ func (c *DestinationClient) newManagedClient(ctx context.Context, path string) (
if err := cmd.Process.Kill(); err != nil {
c.logger.Error().Err(err).Msg("failed to kill plugin process")
}
return c, err
return err
}
c.pbClient = pb.NewDestinationClient(c.conn)
return c, nil
return nil
}

func (c *DestinationClient) GetProtocolVersion(ctx context.Context) (uint64, error) {
res, err := c.pbClient.GetProtocolVersion(ctx, &pb.GetProtocolVersion_Request{})
if err != nil {
s, ok := status.FromError(err)
if !ok {
return 0, fmt.Errorf("failed to cal GetProtocolVersion: %w", err)
}
if s.Code() != codes.Unimplemented {
return 0, err
}
c.logger.Warn().Err(err).Msg("plugin does not support protocol version. assuming protocol version 1")
return 1, nil
}
return res.Version, nil
}

func (c *DestinationClient) Name(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -207,7 +242,7 @@ func (c *DestinationClient) Migrate(ctx context.Context, tables []*schema.Table)

// Write writes rows as they are received from the channel to the destination plugin.
// resources is marshaled schema.Resource. We are not marshalling this inside the function
// because usually it is alreadun marshalled from the source plugin.
// because usually it is alreadun marshalled from the destination plugin.
func (c *DestinationClient) Write(ctx context.Context, source string, syncTime time.Time, resources <-chan []byte) (uint64, error) {
saveClient, err := c.pbClient.Write(ctx)
if err != nil {
Expand Down
51 changes: 43 additions & 8 deletions clients/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ import (
"sync"

"github.com/cloudquery/plugin-sdk/internal/pb"
"github.com/cloudquery/plugin-sdk/internal/versions"
"github.com/cloudquery/plugin-sdk/schema"
"github.com/cloudquery/plugin-sdk/specs"
"github.com/rs/zerolog"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)

// SourceClient
Expand Down Expand Up @@ -82,7 +85,9 @@ func NewSourceClient(ctx context.Context, registry specs.Registry, path string,
}
return c, nil
case specs.RegistryLocal:
return c.newManagedClient(ctx, path)
if err := c.newManagedClient(ctx, path); err != nil {
return nil, err
}
case specs.RegistryGithub:
pathSplit := strings.Split(path, "/")
if len(pathSplit) != 2 {
Expand All @@ -94,26 +99,40 @@ func NewSourceClient(ctx context.Context, registry specs.Registry, path string,
if err := DownloadPluginFromGithub(ctx, localPath, org, name, version, PluginTypeSource); err != nil {
return nil, err
}
return c.newManagedClient(ctx, localPath)
if err := c.newManagedClient(ctx, localPath); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unsupported registry %s", registry)
}

protocolVersion, err := c.GetProtocolVersion(ctx)
if err != nil {
return nil, err
}
if protocolVersion < versions.SourceProtocolVersion {
return nil, fmt.Errorf("source plugin protocol version %d is lower than client version %d. Try updating client", protocolVersion, versions.SourceProtocolVersion)
} else if protocolVersion > versions.SourceProtocolVersion {
return nil, fmt.Errorf("source plugin protocol version %d is higher than client version %d. Try updating destination plugin", protocolVersion, versions.SourceProtocolVersion)
}

return c, nil
}

// newManagedClient starts a new source plugin process from local path, connects to it via gRPC server
// and returns a new SourceClient
func (c *SourceClient) newManagedClient(ctx context.Context, path string) (*SourceClient, error) {
func (c *SourceClient) newManagedClient(ctx context.Context, path string) error {
c.grpcSocketName = generateRandomUnixSocketName()
// spawn the plugin first and then connect
cmd := exec.CommandContext(ctx, path, "serve", "--network", "unix", "--address", c.grpcSocketName,
"--log-level", c.logger.GetLevel().String(), "--log-format", "json")
reader, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("failed to get stdout pipe: %w", err)
return fmt.Errorf("failed to get stdout pipe: %w", err)
}
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start plugin %s: %w", path, err)
return fmt.Errorf("failed to start plugin %s: %w", path, err)
}

c.wg.Add(1)
Expand Down Expand Up @@ -161,10 +180,26 @@ func (c *SourceClient) newManagedClient(ctx context.Context, path string) (*Sour
if err := cmd.Process.Kill(); err != nil {
c.logger.Error().Err(err).Msg("failed to kill plugin process")
}
return c, err
return err
}
c.pbClient = pb.NewSourceClient(c.conn)
return c, nil
return nil
}

func (c *SourceClient) GetProtocolVersion(ctx context.Context) (uint64, error) {
res, err := c.pbClient.GetProtocolVersion(ctx, &pb.GetProtocolVersion_Request{})
if err != nil {
s, ok := status.FromError(err)
if !ok {
return 0, fmt.Errorf("failed to cal GetProtocolVersion: %w", err)
}
if s.Code() != codes.Unimplemented {
return 0, err
}
c.logger.Warn().Err(err).Msg("plugin does not support protocol version. assuming protocol version 1")
return 1, nil
}
return res.Version, nil
}

func (c *SourceClient) Name(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -197,7 +232,7 @@ func (c *SourceClient) GetTables(ctx context.Context) ([]*schema.Table, error) {

// Sync start syncing for the source client per the given spec and returning the results
// in the given channel. res is marshaled schema.Resource. We are not unmarshalling this for performance reasons
// as usually this is sent over-the-wire anyway to a destination plugin
// as usually this is sent over-the-wire anyway to a source plugin
func (c *SourceClient) Sync(ctx context.Context, spec specs.Source, res chan<- []byte) error {
b, err := json.Marshal(spec)
if err != nil {
Expand Down
Loading