diff --git a/writers/batch.go b/writers/batch.go index f621022eef..55ed6e0b0c 100644 --- a/writers/batch.go +++ b/writers/batch.go @@ -121,7 +121,7 @@ func (w *BatchWriter) Close(context.Context) error { func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) { sizeBytes := int64(0) - resources := make([]*message.WriteInsert, 0) + resources := make([]*message.WriteInsert, 0, w.batchSize) for { select { case r, ok := <-ch: @@ -131,25 +131,23 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m } return } - resources = append(resources, r) - sizeBytes += util.TotalRecordSize(r.Record) - if len(resources) >= w.batchSize || sizeBytes >= int64(w.batchSizeBytes) { + if (w.batchSize > 0 && len(resources) >= w.batchSize) || (w.batchSizeBytes > 0 && sizeBytes+util.TotalRecordSize(r.Record) >= int64(w.batchSizeBytes)) { w.flushTable(ctx, tableName, resources) - resources = make([]*message.WriteInsert, 0) - sizeBytes = 0 + resources, sizeBytes = resources[:0], 0 } + + resources = append(resources, r) + sizeBytes += util.TotalRecordSize(r.Record) case <-time.After(w.batchTimeout): if len(resources) > 0 { w.flushTable(ctx, tableName, resources) - resources = make([]*message.WriteInsert, 0) - sizeBytes = 0 + resources, sizeBytes = resources[:0], 0 } case done := <-flush: if len(resources) > 0 { w.flushTable(ctx, tableName, resources) - resources = make([]*message.WriteInsert, 0) - sizeBytes = 0 + resources, sizeBytes = resources[:0], 0 } done <- true case <-ctx.Done(): @@ -258,7 +256,7 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag w.deleteStaleMessages = append(w.deleteStaleMessages, m) l := len(w.deleteStaleMessages) w.deleteStaleLock.Unlock() - if l > w.batchSize { + if w.batchSize > 0 && l > w.batchSize { if err := w.flushDeleteStaleTables(ctx); err != nil { return err } @@ -282,7 +280,7 @@ func (w *BatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessag w.migrateTableMessages = append(w.migrateTableMessages, m) l := len(w.migrateTableMessages) w.migrateTableLock.Unlock() - if l > w.batchSize { + if w.batchSize > 0 && l > w.batchSize { if err := w.flushMigrateTables(ctx); err != nil { return err } diff --git a/writers/batch_test.go b/writers/batch_test.go index b5abf38ca2..ac135c69db 100644 --- a/writers/batch_test.go +++ b/writers/batch_test.go @@ -68,15 +68,6 @@ var batchTestTables = schema.Tables{ }, }, }, - { - Name: "table2", - Columns: []schema.Column{ - { - Name: "id", - Type: arrow.PrimitiveTypes.Int64, - }, - }, - }, } // TestBatchFlushDifferentMessages tests that if writer receives a message of a new type all other pending @@ -106,7 +97,7 @@ func TestBatchFlushDifferentMessages(t *testing.T) { } if testClient.MigrateTablesLen() != 1 { - t.Fatalf("expected 1 migrate table messages, got %d", testClient.MigrateTablesLen()) + t.Fatalf("expected 1 migrate table message, got %d", testClient.MigrateTablesLen()) } if testClient.InsertsLen() != 0 { @@ -118,7 +109,7 @@ func TestBatchFlushDifferentMessages(t *testing.T) { } if testClient.InsertsLen() != 1 { - t.Fatalf("expected 1 insert messages, got %d", testClient.InsertsLen()) + t.Fatalf("expected 1 insert message, got %d", testClient.InsertsLen()) } } @@ -142,9 +133,14 @@ func TestBatchSize(t *testing.T) { t.Fatalf("expected 0 insert messages, got %d", testClient.InsertsLen()) } - if err := wr.writeAll(ctx, []message.WriteMessage{&message.WriteInsert{ - Record: record, - }}); err != nil { + if err := wr.writeAll(ctx, []message.WriteMessage{ + &message.WriteInsert{ + Record: record, + }, + &message.WriteInsert{ // third message to exceed the batch size + Record: record, + }, + }); err != nil { t.Fatal(err) } // we need to wait for the batch to be flushed @@ -186,7 +182,7 @@ func TestBatchTimeout(t *testing.T) { time.Sleep(time.Second * 1) if testClient.InsertsLen() != 1 { - t.Fatalf("expected 1 insert messages, got %d", testClient.InsertsLen()) + t.Fatalf("expected 1 insert message, got %d", testClient.InsertsLen()) } }