Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions writers/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (w *BatchWriter) Close(context.Context) error {

func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) {
sizeBytes := int64(0)
resources := make([]*message.WriteInsert, 0)
resources := make([]*message.WriteInsert, 0, w.batchSize)
for {
select {
case r, ok := <-ch:
Expand All @@ -131,25 +131,23 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m
}
return
}
resources = append(resources, r)
sizeBytes += util.TotalRecordSize(r.Record)

if len(resources) >= w.batchSize || sizeBytes >= int64(w.batchSizeBytes) {
if (w.batchSize > 0 && len(resources) >= w.batchSize) || (w.batchSizeBytes > 0 && sizeBytes+util.TotalRecordSize(r.Record) >= int64(w.batchSizeBytes)) {
w.flushTable(ctx, tableName, resources)
resources = make([]*message.WriteInsert, 0)
sizeBytes = 0
resources, sizeBytes = resources[:0], 0
}

resources = append(resources, r)
sizeBytes += util.TotalRecordSize(r.Record)
case <-time.After(w.batchTimeout):
if len(resources) > 0 {
w.flushTable(ctx, tableName, resources)
resources = make([]*message.WriteInsert, 0)
sizeBytes = 0
resources, sizeBytes = resources[:0], 0
}
case done := <-flush:
if len(resources) > 0 {
w.flushTable(ctx, tableName, resources)
resources = make([]*message.WriteInsert, 0)
sizeBytes = 0
resources, sizeBytes = resources[:0], 0
}
done <- true
case <-ctx.Done():
Expand Down Expand Up @@ -258,7 +256,7 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag
w.deleteStaleMessages = append(w.deleteStaleMessages, m)
l := len(w.deleteStaleMessages)
w.deleteStaleLock.Unlock()
if l > w.batchSize {
if w.batchSize > 0 && l > w.batchSize {
if err := w.flushDeleteStaleTables(ctx); err != nil {
return err
}
Expand All @@ -282,7 +280,7 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag
w.migrateTableMessages = append(w.migrateTableMessages, m)
l := len(w.migrateTableMessages)
w.migrateTableLock.Unlock()
if l > w.batchSize {
if w.batchSize > 0 && l > w.batchSize {
if err := w.flushMigrateTables(ctx); err != nil {
return err
}
Expand Down
26 changes: 11 additions & 15 deletions writers/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ var batchTestTables = schema.Tables{
},
},
},
{
Name: "table2",
Columns: []schema.Column{
{
Name: "id",
Type: arrow.PrimitiveTypes.Int64,
},
},
},
}

// TestBatchFlushDifferentMessages tests that if writer receives a message of a new type all other pending
Expand Down Expand Up @@ -106,7 +97,7 @@ func TestBatchFlushDifferentMessages(t *testing.T) {
}

if testClient.MigrateTablesLen() != 1 {
t.Fatalf("expected 1 migrate table messages, got %d", testClient.MigrateTablesLen())
t.Fatalf("expected 1 migrate table message, got %d", testClient.MigrateTablesLen())
}

if testClient.InsertsLen() != 0 {
Expand All @@ -118,7 +109,7 @@ func TestBatchFlushDifferentMessages(t *testing.T) {
}

if testClient.InsertsLen() != 1 {
t.Fatalf("expected 1 insert messages, got %d", testClient.InsertsLen())
t.Fatalf("expected 1 insert message, got %d", testClient.InsertsLen())
}
}

Expand All @@ -142,9 +133,14 @@ func TestBatchSize(t *testing.T) {
t.Fatalf("expected 0 insert messages, got %d", testClient.InsertsLen())
}

if err := wr.writeAll(ctx, []message.WriteMessage{&message.WriteInsert{
Record: record,
}}); err != nil {
if err := wr.writeAll(ctx, []message.WriteMessage{
&message.WriteInsert{
Record: record,
},
&message.WriteInsert{ // third message to exceed the batch size
Record: record,
},
}); err != nil {
t.Fatal(err)
}
// we need to wait for the batch to be flushed
Expand Down Expand Up @@ -186,7 +182,7 @@ func TestBatchTimeout(t *testing.T) {
time.Sleep(time.Second * 1)

if testClient.InsertsLen() != 1 {
t.Fatalf("expected 1 insert messages, got %d", testClient.InsertsLen())
t.Fatalf("expected 1 insert message, got %d", testClient.InsertsLen())
}
}

Expand Down