diff --git a/internal/memdb/memdb.go b/internal/memdb/memdb.go index 719eafe9c1..479596baae 100644 --- a/internal/memdb/memdb.go +++ b/internal/memdb/memdb.go @@ -211,6 +211,8 @@ func (c *client) Write(ctx context.Context, msgs <-chan message.WriteMessage) er c.migrate(ctx, msg.Table) case *message.WriteDeleteStale: c.deleteStale(ctx, msg) + case *message.WriteDeleteRecord: + c.deleteRecord(ctx, msg) case *message.WriteInsert: sc := msg.Record.Schema() tableName, ok := sc.Metadata().GetValue(schema.MetadataTableName) @@ -257,3 +259,53 @@ func (c *client) deleteStale(_ context.Context, msg *message.WriteDeleteStale) { } c.memoryDB[tableName] = filteredTable } + +func (c *client) deleteRecord(_ context.Context, msg *message.WriteDeleteRecord) { + var filteredTable []arrow.Record + tableName := msg.TableName + for i, row := range c.memoryDB[tableName] { + isMatch := true + // Groups are evaluated as AND + for _, predGroup := range msg.WhereClause { + for _, pred := range predGroup.Predicates { + predResult := evaluatePredicate(pred, row) + if predGroup.GroupingType == "AND" { + isMatch = isMatch && predResult + } else if predResult { + isMatch = true + break + } + } + // If any single predicate group is false then we can break out of the loop + if !isMatch { + break + } + } + + if !isMatch { + filteredTable = append(filteredTable, c.memoryDB[tableName][i]) + } + } + c.memoryDB[tableName] = filteredTable +} + +func evaluatePredicate(pred message.Predicate, record arrow.Record) bool { + sc := record.Schema() + indices := sc.FieldIndices(pred.Column) + if len(indices) == 0 { + return false + } + syncColIndex := indices[0] + + if record.Column(syncColIndex).DataType() != pred.Record.Column(0).DataType() { + return false + } + // dataType := record.Column(syncColIndex).DataType() + switch pred.Operator { + case "eq": + return record.Column(syncColIndex).String() == pred.Record.Column(0).String() + // return record.Column(syncColIndex).(*array.String).Value(0) == pred.Record.Column(0).(*array.String).Value(0) + default: + return false + } +} diff --git a/internal/servers/plugin/v3/plugin.go b/internal/servers/plugin/v3/plugin.go index c88205c5c2..1cf001f193 100644 --- a/internal/servers/plugin/v3/plugin.go +++ b/internal/servers/plugin/v3/plugin.go @@ -156,6 +156,41 @@ func (s *Server) Sync(req *pb.Sync_Request, stream pb.Plugin_SyncServer) error { Record: recordBytes, }, } + case *message.SyncDeleteRecord: + whereClause := make([]*pb.PredicatesGroup, len(m.WhereClause)) + for j, predicateGroup := range m.WhereClause { + whereClause[j] = &pb.PredicatesGroup{ + GroupingType: pb.PredicatesGroup_GroupingType(pb.PredicatesGroup_GroupingType_value[predicateGroup.GroupingType]), + Predicates: make([]*pb.Predicate, len(predicateGroup.Predicates)), + } + for i, predicate := range predicateGroup.Predicates { + record, err := pb.RecordToBytes(predicate.Record) + if err != nil { + return status.Errorf(codes.Internal, "failed to encode record: %v", err) + } + + whereClause[j].Predicates[i] = &pb.Predicate{ + Record: record, + Column: predicate.Column, + Operator: pb.Predicate_Operator(pb.Predicate_Operator_value[predicate.Operator]), + } + } + } + + tableRelations := make([]*pb.TableRelation, len(m.TableRelations)) + for i, tr := range m.TableRelations { + tableRelations[i] = &pb.TableRelation{ + TableName: tr.TableName, + ParentTable: tr.ParentTable, + } + } + pbMsg.Message = &pb.Sync_Response_DeleteRecord{ + DeleteRecord: &pb.Sync_MessageDeleteRecord{ + TableName: m.TableName, + TableRelations: tableRelations, + WhereClause: whereClause, + }, + } default: return status.Errorf(codes.Internal, "unknown message type: %T", msg) } @@ -230,6 +265,40 @@ func (s *Server) Write(stream pb.Plugin_WriteServer) error { SourceName: pbMsg.Delete.SourceName, SyncTime: pbMsg.Delete.SyncTime.AsTime(), } + + case *pb.Write_Request_DeleteRecord: + whereClause := make(message.PredicateGroups, len(pbMsg.DeleteRecord.WhereClause)) + + for j, predicateGroup := range pbMsg.DeleteRecord.WhereClause { + whereClause[j].Predicates = make(message.Predicates, len(predicateGroup.Predicates)) + for i, predicate := range predicateGroup.Predicates { + record, err := pb.NewRecordFromBytes(predicate.Record) + if err != nil { + pbMsgConvertErr = status.Errorf(codes.InvalidArgument, "failed to create record: %v", err) + break + } + whereClause[j].Predicates[i] = message.Predicate{ + Record: record, + Column: predicate.Column, + Operator: predicate.Operator.String(), + } + } + } + + tableRelations := make([]message.TableRelation, len(pbMsg.DeleteRecord.TableRelations)) + for i, tr := range pbMsg.DeleteRecord.TableRelations { + tableRelations[i] = message.TableRelation{ + TableName: tr.TableName, + ParentTable: tr.ParentTable, + } + } + pluginMessage = &message.WriteDeleteRecord{ + DeleteRecord: message.DeleteRecord{ + TableName: pbMsg.DeleteRecord.TableName, + TableRelations: tableRelations, + WhereClause: whereClause, + }, + } } if pbMsgConvertErr != nil { diff --git a/message/sync_message.go b/message/sync_message.go index 87cea8ea70..e18bf43ee9 100644 --- a/message/sync_message.go +++ b/message/sync_message.go @@ -111,3 +111,13 @@ func (m SyncInserts) GetRecordsForTable(table *schema.Table) []arrow.Record { } return slices.Clip(res) } + +type SyncDeleteRecord struct { + syncBaseMessage + // TODO: Instead of using this struct we should derive the DeletionKeys and parent/child relation from the schema.Table itself + DeleteRecord +} + +func (m SyncDeleteRecord) GetTable() *schema.Table { + return &schema.Table{Name: m.TableName} +} diff --git a/message/write_message.go b/message/write_message.go index 88991e8b3b..0e8fbad72a 100644 --- a/message/write_message.go +++ b/message/write_message.go @@ -128,3 +128,44 @@ func (m WriteDeleteStales) Exists(tableName string) bool { return msg.TableName == tableName }) } + +type TableRelation struct { + TableName string + ParentTable string +} + +type TableRelations []TableRelation + +type Predicate struct { + Operator string + Column string + Record arrow.Record +} + +type Predicates []Predicate + +type PredicateGroup struct { + // This will be AND or OR + GroupingType string + Predicates Predicates +} + +type PredicateGroups []PredicateGroup + +type DeleteRecord struct { + TableName string + TableRelations TableRelations + WhereClause PredicateGroups + SyncTime time.Time +} + +type WriteDeleteRecord struct { + writeBaseMessage + DeleteRecord +} + +func (m WriteDeleteRecord) GetTable() *schema.Table { + return &schema.Table{Name: m.TableName} +} + +type WriteDeleteRecords []*WriteDeleteRecord diff --git a/plugin/testing_write.go b/plugin/testing_write.go index 6d52573167..658865b750 100644 --- a/plugin/testing_write.go +++ b/plugin/testing_write.go @@ -148,6 +148,15 @@ func TestWriterSuiteRunner(t *testing.T, p *Plugin, tests WriterTestSuiteTests, }) }) + t.Run("TestDeleteRecord", func(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + suite.testDeleteRecordBasic(ctx) + }) + t.Run("DeleteAll", func(t *testing.T) { + suite.testDeleteAllRecords(ctx) + }) + }) + t.Run("TestMigrate", func(t *testing.T) { if suite.tests.SkipMigrate { t.Skip("skipping " + t.Name()) diff --git a/plugin/testing_write_delete.go b/plugin/testing_write_delete.go index baab1d666f..6709c16f72 100644 --- a/plugin/testing_write_delete.go +++ b/plugin/testing_write_delete.go @@ -142,3 +142,153 @@ func (s *WriterTestSuite) testDeleteStaleAll(ctx context.Context) { require.EqualValuesf(s.t, rowsPerRecord, TotalRows(readRecords), "unexpected amount of items after second delete stale") require.Emptyf(s.t, RecordsDiff(table.ToArrowSchema(), readRecords, []arrow.Record{nullRecord}), "record differs") } + +func (s *WriterTestSuite) testDeleteRecordBasic(ctx context.Context) { + tableName := s.tableNameForTest("delete_all_rows") + syncTime := time.Now().UTC().Truncate(s.genDatOptions.TimePrecision). + Truncate(time.Microsecond) // https://github.com/golang/go/issues/41087 + table := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + schema.Column{Name: "id", Type: arrow.PrimitiveTypes.Int64, PrimaryKey: true, NotNull: true}, + schema.CqSourceNameColumn, + schema.CqSyncTimeColumn, + }, + } + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteMigrateTable{Table: table}), "failed to create table") + const sourceName = "source-test" + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.Int64Builder).Append(0) + bldr.Field(1).(*array.StringBuilder).Append(sourceName) + bldr.Field(2).(*array.TimestampBuilder).AppendTime(syncTime) + record1 := bldr.NewRecord() + + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteInsert{Record: record1}), "failed to insert record") + record1 = s.handleNulls(record1) // we process nulls after writing + + records, err := s.plugin.readAll(ctx, table) + require.NoErrorf(s.t, err, "failed to read") + require.EqualValuesf(s.t, 1, TotalRows(records), "unexpected amount of items") + + // create value for delete statement but nothing will be deleted because ID value isn't present + bldrDeleteNoMatch := array.NewRecordBuilder(memory.DefaultAllocator, (&schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + schema.Column{Name: "id", Type: arrow.PrimitiveTypes.Int64}, + }, + }).ToArrowSchema()) + bldrDeleteNoMatch.Field(0).(*array.Int64Builder).Append(1) + deleteValue := bldrDeleteNoMatch.NewRecord() + + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteDeleteRecord{ + DeleteRecord: message.DeleteRecord{ + TableName: table.Name, + WhereClause: message.PredicateGroups{ + { + GroupingType: "AND", + Predicates: []message.Predicate{ + { + Operator: "eq", + Column: "id", + Record: deleteValue, + }, + }, + }, + }, + }, + }), "failed to delete record no match") + + records, err = s.plugin.readAll(ctx, table) + require.NoErrorf(s.t, err, "failed to read after delete with no match") + require.EqualValuesf(s.t, 1, TotalRows(records), "unexpected amount of items after delete with no match") + require.Emptyf(s.t, RecordsDiff(table.ToArrowSchema(), records, []arrow.Record{record1}), "record differs after delete with no match") + + // create value for delete statement will be delete One record + bldrDeleteMatch := array.NewRecordBuilder(memory.DefaultAllocator, (&schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + schema.Column{Name: "id", Type: arrow.PrimitiveTypes.Int64}, + }, + }).ToArrowSchema()) + bldrDeleteMatch.Field(0).(*array.Int64Builder).Append(0) + deleteValue = bldrDeleteMatch.NewRecord() + + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteDeleteRecord{ + DeleteRecord: message.DeleteRecord{ + TableName: table.Name, + WhereClause: message.PredicateGroups{ + { + GroupingType: "AND", + Predicates: []message.Predicate{ + { + Operator: "eq", + Column: "id", + Record: deleteValue, + }, + }, + }, + }, + }, + }), "failed to delete record no match") + + records, err = s.plugin.readAll(ctx, table) + require.NoErrorf(s.t, err, "failed to read after delete with match") + require.EqualValuesf(s.t, 0, TotalRows(records), "unexpected amount of items after delete with match") +} + +func (s *WriterTestSuite) testDeleteAllRecords(ctx context.Context) { + tableName := s.tableNameForTest("delete_all_records") + syncTime := time.Now().UTC().Truncate(s.genDatOptions.TimePrecision). + Truncate(time.Microsecond) // https://github.com/golang/go/issues/41087 + table := &schema.Table{ + Name: tableName, + Columns: schema.ColumnList{ + schema.Column{Name: "id", Type: arrow.PrimitiveTypes.Int64, PrimaryKey: true, NotNull: true}, + schema.CqSourceNameColumn, + schema.CqSyncTimeColumn, + }, + } + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteMigrateTable{Table: table}), "failed to create table") + const sourceName = "source-test" + + bldr := array.NewRecordBuilder(memory.DefaultAllocator, table.ToArrowSchema()) + bldr.Field(0).(*array.Int64Builder).Append(0) + bldr.Field(1).(*array.StringBuilder).Append(sourceName) + bldr.Field(2).(*array.TimestampBuilder).AppendTime(syncTime) + record1 := bldr.NewRecord() + + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteInsert{Record: record1}), "failed to insert record") + + records, err := s.plugin.readAll(ctx, table) + require.NoErrorf(s.t, err, "failed to read") + require.EqualValuesf(s.t, 1, TotalRows(records), "unexpected amount of items") + + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteDeleteRecord{ + DeleteRecord: message.DeleteRecord{ + TableName: table.Name, + }, + }), "failed to delete records") + + records, err = s.plugin.readAll(ctx, table) + require.NoErrorf(s.t, err, "failed to read after delete all records") + require.EqualValuesf(s.t, 0, TotalRows(records), "unexpected amount of items after delete stale") + + bldr.Field(0).(*array.Int64Builder).Append(1) + bldr.Field(1).(*array.StringBuilder).Append(sourceName) + bldr.Field(2).(*array.TimestampBuilder).AppendTime(syncTime.Add(time.Second)) + record2 := bldr.NewRecord() + + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteInsert{Record: record2}), "failed to insert second record") + + require.NoErrorf(s.t, s.plugin.writeOne(ctx, &message.WriteDeleteRecord{ + DeleteRecord: message.DeleteRecord{ + TableName: table.Name, + }, + }), "failed to delete records second time") + + records, err = s.plugin.readAll(ctx, table) + require.NoErrorf(s.t, err, "failed to read second time") + sortRecords(table, records, "id") + require.EqualValuesf(s.t, 0, TotalRows(records), "unexpected amount of items second time") +} diff --git a/writers/mixedbatchwriter/mixedbatchwriter.go b/writers/mixedbatchwriter/mixedbatchwriter.go index f3eb1cc2d8..10fe2aeb29 100644 --- a/writers/mixedbatchwriter/mixedbatchwriter.go +++ b/writers/mixedbatchwriter/mixedbatchwriter.go @@ -15,6 +15,7 @@ type Client interface { MigrateTableBatch(ctx context.Context, messages message.WriteMigrateTables) error InsertBatch(ctx context.Context, messages message.WriteInserts) error DeleteStaleBatch(ctx context.Context, messages message.WriteDeleteStales) error + DeleteRecordsBatch(ctx context.Context, messages message.WriteDeleteRecords) error } type MixedBatchWriter struct { @@ -97,6 +98,12 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri batch: make([]*message.WriteDeleteStale, 0, w.batchSize), writeFunc: w.client.DeleteStaleBatch, } + + deleteRecord := &batchManager[message.WriteDeleteRecords, *message.WriteDeleteRecord]{ + batch: make([]*message.WriteDeleteRecord, 0, w.batchSize), + writeFunc: w.client.DeleteRecordsBatch, + } + flush := func(msgType writers.MsgType) error { if msgType == writers.MsgTypeUnset { return nil @@ -108,6 +115,8 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri return insert.flush(ctx) case writers.MsgTypeDeleteStale: return deleteStale.flush(ctx) + case writers.MsgTypeDeleteRecord: + return deleteRecord.flush(ctx) default: panic("unknown message type") } @@ -138,6 +147,8 @@ loop: err = insert.append(ctx, v) case *message.WriteDeleteStale: err = deleteStale.append(ctx, v) + case *message.WriteDeleteRecord: + err = deleteRecord.append(ctx, v) default: panic("unknown message type") } diff --git a/writers/mixedbatchwriter/mixedbatchwriter_test.go b/writers/mixedbatchwriter/mixedbatchwriter_test.go index a321c08722..47d4f1eb07 100644 --- a/writers/mixedbatchwriter/mixedbatchwriter_test.go +++ b/writers/mixedbatchwriter/mixedbatchwriter_test.go @@ -45,6 +45,15 @@ func (c *testMixedBatchClient) DeleteStaleBatch(_ context.Context, messages mess return nil } +func (c *testMixedBatchClient) DeleteRecordsBatch(_ context.Context, messages message.WriteDeleteRecords) error { + m := make([]message.WriteMessage, len(messages)) + for i, msg := range messages { + m[i] = msg + } + c.receivedBatches = append(c.receivedBatches, m) + return nil +} + var _ Client = (*testMixedBatchClient)(nil) type testMessages struct { diff --git a/writers/mixedbatchwriter/unimplemented.go b/writers/mixedbatchwriter/unimplemented.go index 1b7b3f8b7d..674afa8ccb 100644 --- a/writers/mixedbatchwriter/unimplemented.go +++ b/writers/mixedbatchwriter/unimplemented.go @@ -19,3 +19,9 @@ type UnimplementedDeleteStaleBatch struct{} func (UnimplementedDeleteStaleBatch) DeleteStaleBatch(context.Context, message.WriteDeleteStales) error { return fmt.Errorf("DeleteStaleBatch: %w", plugin.ErrNotImplemented) } + +type UnimplementedDeleteRecordsBatch struct{} + +func (UnimplementedDeleteRecordsBatch) DeleteRecordsBatch(context.Context, message.WriteDeleteRecords) error { + return fmt.Errorf("DeleteRecordsBatch: %w", plugin.ErrNotImplemented) +} diff --git a/writers/mixedbatchwriter/unimplemented_test.go b/writers/mixedbatchwriter/unimplemented_test.go index 10814fb042..64cfd05b60 100644 --- a/writers/mixedbatchwriter/unimplemented_test.go +++ b/writers/mixedbatchwriter/unimplemented_test.go @@ -10,6 +10,7 @@ import ( type testDummyClient struct { mixedbatchwriter.IgnoreMigrateTableBatch mixedbatchwriter.UnimplementedDeleteStaleBatch + mixedbatchwriter.UnimplementedDeleteRecordsBatch } func (testDummyClient) InsertBatch(context.Context, message.WriteInserts) error { diff --git a/writers/msgtype.go b/writers/msgtype.go index ebdaa71e38..360f9e8d3c 100644 --- a/writers/msgtype.go +++ b/writers/msgtype.go @@ -13,6 +13,7 @@ const ( MsgTypeMigrateTable MsgTypeInsert MsgTypeDeleteStale + MsgTypeDeleteRecord ) func MsgID(msg message.WriteMessage) MsgType { @@ -23,6 +24,8 @@ func MsgID(msg message.WriteMessage) MsgType { return MsgTypeInsert case *message.WriteDeleteStale: return MsgTypeDeleteStale + case *message.WriteDeleteRecord: + return MsgTypeDeleteRecord } panic("unknown message type: " + reflect.TypeOf(msg).Name()) }