Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion internal/memdb/memdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"

Expand All @@ -18,6 +19,7 @@ type client struct {
schema.DefaultTransformer
spec specs.Destination
memoryDB map[string][][]any
memoryDBLock sync.RWMutex
errOnWrite bool
blockingWrite bool
}
Expand All @@ -38,7 +40,8 @@ func WithBlockingWrite() Option {

func GetNewClient(options ...Option) destination.NewClientFunc {
c := &client{
memoryDB: make(map[string][][]any),
memoryDB: make(map[string][][]any),
memoryDBLock: sync.RWMutex{},
}
for _, opt := range options {
opt(c)
Expand Down Expand Up @@ -111,11 +114,13 @@ func (c *client) Read(_ context.Context, table *schema.Table, source string, res
}
sourceColIndex := table.Columns.Index(schema.CqSourceNameColumn.Name)
var sortedRes [][]any
c.memoryDBLock.RLock()
for _, row := range c.memoryDB[table.Name] {
if row[sourceColIndex].(*schema.Text).Str == source {
sortedRes = append(sortedRes, row)
}
}
c.memoryDBLock.RUnlock()

for _, row := range sortedRes {
res <- row
Expand All @@ -134,12 +139,15 @@ func (c *client) Write(ctx context.Context, tables schema.Tables, resources <-ch
}
return nil
}

for resource := range resources {
c.memoryDBLock.Lock()
if c.spec.WriteMode == specs.WriteModeAppend {
c.memoryDB[resource.TableName] = append(c.memoryDB[resource.TableName], resource.Data)
} else {
c.overwrite(tables.Get(resource.TableName), resource.Data)
}
c.memoryDBLock.Unlock()
}
return nil
}
Expand All @@ -156,11 +164,13 @@ func (c *client) WriteTableBatch(ctx context.Context, table *schema.Table, resou
return nil
}
for _, resource := range resources {
c.memoryDBLock.Lock()
if c.spec.WriteMode == specs.WriteModeAppend {
c.memoryDB[table.Name] = append(c.memoryDB[table.Name], resource)
} else {
c.overwrite(table, resource)
}
c.memoryDBLock.Unlock()
}
return nil
}
Expand Down
7 changes: 0 additions & 7 deletions plugins/destination/managed_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"time"

"github.com/cloudquery/plugin-sdk/schema"
"github.com/cloudquery/plugin-sdk/specs"
)

type worker struct {
Expand Down Expand Up @@ -135,11 +134,5 @@ func (p *Plugin) writeManagedTableBatch(ctx context.Context, tables schema.Table
}
}
p.workersLock.Unlock()

if p.spec.WriteMode == specs.WriteModeOverwriteDeleteStale {
if err := p.DeleteStale(ctx, tables, sourceName, syncTime); err != nil {
return err
}
}
return nil
}
8 changes: 7 additions & 1 deletion plugins/destination/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,13 @@ func (p *Plugin) Write(ctx context.Context, tables schema.Tables, sourceName str
panic("unknown client type")
}
if p.spec.WriteMode == specs.WriteModeOverwriteDeleteStale {
if err := p.DeleteStale(ctx, tables, sourceName, syncTime); err != nil {
include := func(t *schema.Table) bool { return true }
exclude := func(t *schema.Table) bool { return t.IsIncremental }
nonIncrementalTables, err := tables.FilterDfsFunc(include, exclude)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are tables already filtered by the glob patterns in the spec?
Making sure we don't return all tables here since we include everything.

Also since include is always true, maybe add only the exclude part for now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, tables are already filtered by glob patterns at this point (previously we were passing all these tables in to delete stale, now we're filtering out the incremental ones from that list first).

FilterDfsFunc is now also used by FilterDfs internally, and that needs both include and exclude.

if err != nil {
return err
}
if err := p.DeleteStale(ctx, nonIncrementalTables, sourceName, syncTime); err != nil {
return err
}
}
Expand Down
110 changes: 104 additions & 6 deletions plugins/destination/plugin_testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type PluginTestSuiteTests struct {
SkipMigrateAppend bool
}

func (s *PluginTestSuite) destinationPluginTestWriteOverwrite(ctx context.Context, p *Plugin, logger zerolog.Logger, spec specs.Destination) error {
func (*PluginTestSuite) destinationPluginTestWriteOverwrite(ctx context.Context, p *Plugin, logger zerolog.Logger, spec specs.Destination) error {
spec.WriteMode = specs.WriteModeOverwrite
if err := p.Init(ctx, logger, spec); err != nil {
return fmt.Errorf("failed to init plugin: %w", err)
Expand Down Expand Up @@ -117,18 +117,95 @@ func (s *PluginTestSuite) destinationPluginTestWriteOverwrite(ctx context.Contex
return fmt.Errorf("after overwrite expected second resource diff: %s", diff)
}

if !s.tests.SkipDeleteStale {
if err := p.DeleteStale(ctx, tables, sourceName, secondSyncTime); err != nil {
return fmt.Errorf("failed to delete stale data second time: %w", err)
}
return nil
}

func (*PluginTestSuite) destinationPluginTestWriteOverwriteDeleteStale(ctx context.Context, p *Plugin, logger zerolog.Logger, spec specs.Destination) error {
spec.WriteMode = specs.WriteModeOverwriteDeleteStale
if err := p.Init(ctx, logger, spec); err != nil {
return fmt.Errorf("failed to init plugin: %w", err)
}
tableName := "cq_test_write_overwrite_delete_stale"
table := testdata.TestTable(tableName)
incTable := testdata.TestTable(tableName + "_incremental")
incTable.IsIncremental = true
syncTime := time.Now().UTC().Round(1 * time.Second)
tables := []*schema.Table{
table,
incTable,
}
if err := p.Migrate(ctx, tables); err != nil {
return fmt.Errorf("failed to migrate tables: %w", err)
}

resourcesRead, err = p.readAll(ctx, tables[0], sourceName)
sourceName := "testOverwriteSource" + uuid.NewString()

resources := createTestResources(table, sourceName, syncTime, 2)
incResources := createTestResources(incTable, sourceName, syncTime, 2)
if err := p.writeAll(ctx, tables, sourceName, syncTime, append(resources, incResources...)); err != nil {
return fmt.Errorf("failed to write all: %w", err)
}
sortResources(table, resources)

resourcesRead, err := p.readAll(ctx, table, sourceName)
if err != nil {
return fmt.Errorf("failed to read all: %w", err)
}
sortCQTypes(table, resourcesRead)

if len(resourcesRead) != 2 {
return fmt.Errorf("expected 2 resources, got %d", len(resourcesRead))
}

if diff := resources[0].Data.Diff(resourcesRead[0]); diff != "" {
return fmt.Errorf("expected first resource diff: %s", diff)
}

if diff := resources[1].Data.Diff(resourcesRead[1]); diff != "" {
return fmt.Errorf("expected second resource diff: %s", diff)
}

// read from incremental table
resourcesRead, err = p.readAll(ctx, incTable, sourceName)
if err != nil {
return fmt.Errorf("failed to read all: %w", err)
}
if len(resourcesRead) != 2 {
return fmt.Errorf("expected 2 resources in incremental table, got %d", len(resourcesRead))
}

secondSyncTime := syncTime.Add(time.Second).UTC()

// copy first resource but update the sync time
updatedResource := schema.DestinationResource{
TableName: table.Name,
Data: make(schema.CQTypes, len(resources[0].Data)),
}
copy(updatedResource.Data, resources[0].Data)
_ = updatedResource.Data[1].Set(secondSyncTime)

// write second time
if err := p.writeOne(ctx, tables, sourceName, secondSyncTime, updatedResource); err != nil {
return fmt.Errorf("failed to write one second time: %w", err)
}

resourcesRead, err = p.readAll(ctx, table, sourceName)
if err != nil {
return fmt.Errorf("failed to read all second time: %w", err)
}
sortCQTypes(table, resourcesRead)
if len(resourcesRead) != 1 {
return fmt.Errorf("after overwrite expected 1 resource, got %d", len(resourcesRead))
}

if diff := resources[0].Data.Diff(resourcesRead[0]); diff != "" {
return fmt.Errorf("after overwrite expected first resource diff: %s", diff)
}

resourcesRead, err = p.readAll(ctx, tables[0], sourceName)
if err != nil {
return fmt.Errorf("failed to read all second time: %w", err)
}
if len(resourcesRead) != 1 {
return fmt.Errorf("expected 1 resource after delete stale, got %d", len(resourcesRead))
}
Expand All @@ -138,6 +215,16 @@ func (s *PluginTestSuite) destinationPluginTestWriteOverwrite(ctx context.Contex
return fmt.Errorf("after delete stale expected resource diff: %s", diff)
}

// we expect the incremental table to still have 2 resources, because delete-stale should
// not apply there
resourcesRead, err = p.readAll(ctx, tables[1], sourceName)
if err != nil {
return fmt.Errorf("failed to read all from incremental table: %w", err)
}
if len(resourcesRead) != 2 {
return fmt.Errorf("expected 2 resources in incremental table after delete-stale, got %d", len(resourcesRead))
}

return nil
}

Expand Down Expand Up @@ -326,6 +413,17 @@ func PluginTestSuiteRunner(t *testing.T, p *Plugin, spec any, tests PluginTestSu
}
})

t.Run("TestWriteOverwriteDeleteStale", func(t *testing.T) {
t.Helper()
if suite.tests.SkipOverwrite || suite.tests.SkipDeleteStale {
t.Skip("skipping TestWriteOverwriteDeleteStale")
return
}
if err := suite.destinationPluginTestWriteOverwriteDeleteStale(ctx, p, logger, destSpec); err != nil {
t.Fatal(err)
}
})

t.Run("TestWriteAppend", func(t *testing.T) {
t.Helper()
if suite.tests.SkipAppend {
Expand Down
54 changes: 34 additions & 20 deletions schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ var (
reValidColumnName = regexp.MustCompile(`^[a-z_][a-z\d_]*$`)
)

func (tt Tables) FilterDfsFunc(include, exclude func(*Table) bool) (Tables, error) {
filteredTables := make(Tables, 0, len(tt))
for _, t := range tt {
filteredTable := t.Copy(nil)
filteredTable = filteredTable.filterDfs(false, include, exclude)
if filteredTable != nil {
filteredTables = append(filteredTables, filteredTable)
}
}
return filteredTables, nil
}

func (tt Tables) FilterDfs(tables, skipTables []string) (Tables, error) {
flattenedTables := tt.FlattenTables()
for _, includePattern := range tables {
Expand All @@ -98,16 +110,23 @@ func (tt Tables) FilterDfs(tables, skipTables []string) (Tables, error) {
return nil, fmt.Errorf("skip_tables include a pattern %s with no matches", excludePattern)
}
}

filteredTables := make(Tables, 0, len(tt))
for _, t := range tt {
filteredTable := t.Copy(nil)
filteredTable = filteredTable.filterDfs(false, tables, skipTables)
if filteredTable != nil {
filteredTables = append(filteredTables, filteredTable)
include := func(t *Table) bool {
for _, includePattern := range tables {
if glob.Glob(includePattern, t.Name) {
return true
}
}
return false
}
return filteredTables, nil
exclude := func(t *Table) bool {
for _, skipPattern := range skipTables {
if glob.Glob(skipPattern, t.Name) {
return true
}
}
return false
}
return tt.FilterDfsFunc(include, exclude)
}

func (tt Tables) FlattenTables() Tables {
Expand Down Expand Up @@ -196,22 +215,17 @@ func (tt Tables) ValidateColumnNames() error {
}

// this will filter the tree in-place
func (t *Table) filterDfs(parentMatched bool, tables []string, skipTables []string) *Table {
matched := parentMatched
for _, includeTable := range tables {
if glob.Glob(includeTable, t.Name) {
matched = true
break
}
func (t *Table) filterDfs(parentMatched bool, include, exclude func(*Table) bool) *Table {
if exclude(t) {
return nil
}
for _, skipTable := range skipTables {
if glob.Glob(skipTable, t.Name) {
return nil
}
matched := parentMatched
if include(t) {
matched = true
}
filteredRelations := make([]*Table, 0, len(t.Relations))
for _, r := range t.Relations {
filteredChild := r.filterDfs(matched, tables, skipTables)
filteredChild := r.filterDfs(matched, include, exclude)
if filteredChild != nil {
matched = true
filteredRelations = append(filteredRelations, r)
Expand Down