Skip to content

Commit

Permalink
chore: Update vendored sources to duckdb/duckdb@d0d7f7f
Browse files Browse the repository at this point in the history
Merge pull request duckdb/duckdb#11711 from Tishj/copy_csv_cast_rework
  • Loading branch information
krlmlr committed May 1, 2024
1 parent 7fbca36 commit bf6610f
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void CSVReaderOptions::SetDateFormat(LogicalTypeId type, const string &format, b
error = StrTimeFormat::ParseFormatSpecifier(format, strpformat);
dialect_options.date_format[type].Set(strpformat);
} else {
error = StrTimeFormat::ParseFormatSpecifier(format, write_date_format[type]);
write_date_format[type] = Value(format);
}
if (!error.empty()) {
throw InvalidInputException("Could not parse DATEFORMAT: %s", error.c_str());
Expand Down
110 changes: 86 additions & 24 deletions src/duckdb/src/function/table/copy_csv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
#include "duckdb/function/scalar/string_functions.hpp"
#include "duckdb/function/table/read_csv.hpp"
#include "duckdb/parser/parsed_data/copy_info.hpp"
#include "duckdb/parser/expression/cast_expression.hpp"
#include "duckdb/parser/expression/function_expression.hpp"
#include "duckdb/parser/expression/columnref_expression.hpp"
#include "duckdb/parser/expression/constant_expression.hpp"
#include "duckdb/parser/expression/bound_expression.hpp"
#include "duckdb/planner/expression/bound_reference_expression.hpp"
#include "duckdb/execution/column_binding_resolver.hpp"
#include "duckdb/planner/operator/logical_dummy_scan.hpp"
#include <limits>

namespace duckdb {
Expand Down Expand Up @@ -93,6 +101,62 @@ string TransformNewLine(string new_line) {
;
}

static vector<unique_ptr<Expression>> CreateCastExpressions(WriteCSVData &bind_data, ClientContext &context,
const vector<string> &names,
const vector<LogicalType> &sql_types) {
auto &options = bind_data.options;
auto &formats = options.write_date_format;

bool has_dateformat = !formats[LogicalTypeId::DATE].IsNull();
bool has_timestampformat = !formats[LogicalTypeId::TIMESTAMP].IsNull();

// Create a binder
auto binder = Binder::CreateBinder(context);

auto &bind_context = binder->bind_context;
auto table_index = binder->GenerateTableIndex();
bind_context.AddGenericBinding(table_index, "copy_csv", names, sql_types);

// Create the ParsedExpressions (cast, strftime, etc..)
vector<unique_ptr<ParsedExpression>> unbound_expressions;
for (idx_t i = 0; i < sql_types.size(); i++) {
auto &type = sql_types[i];
auto &name = names[i];

bool is_timestamp = type.id() == LogicalTypeId::TIMESTAMP || type.id() == LogicalTypeId::TIMESTAMP_TZ;
if (has_dateformat && type.id() == LogicalTypeId::DATE) {
// strftime(<name>, 'format')
vector<unique_ptr<ParsedExpression>> children;
children.push_back(make_uniq<BoundExpression>(make_uniq<BoundReferenceExpression>(name, type, i)));
children.push_back(make_uniq<ConstantExpression>(formats[LogicalTypeId::DATE]));
auto func = make_uniq_base<ParsedExpression, FunctionExpression>("strftime", std::move(children));
unbound_expressions.push_back(std::move(func));
} else if (has_timestampformat && is_timestamp) {
// strftime(<name>, 'format')
vector<unique_ptr<ParsedExpression>> children;
children.push_back(make_uniq<BoundExpression>(make_uniq<BoundReferenceExpression>(name, type, i)));
children.push_back(make_uniq<ConstantExpression>(formats[LogicalTypeId::TIMESTAMP]));
auto func = make_uniq_base<ParsedExpression, FunctionExpression>("strftime", std::move(children));
unbound_expressions.push_back(std::move(func));
} else {
// CAST <name> AS VARCHAR
auto column = make_uniq<BoundExpression>(make_uniq<BoundReferenceExpression>(name, type, i));
auto expr = make_uniq_base<ParsedExpression, CastExpression>(LogicalType::VARCHAR, std::move(column));
unbound_expressions.push_back(std::move(expr));
}
}

// Create an ExpressionBinder, bind the Expressions
vector<unique_ptr<Expression>> expressions;
ExpressionBinder expression_binder(*binder, context);
expression_binder.target_type = LogicalType::VARCHAR;
for (auto &expr : unbound_expressions) {
expressions.push_back(expression_binder.Bind(expr));
}

return expressions;
}

static unique_ptr<FunctionData> WriteCSVBind(ClientContext &context, CopyFunctionBindInput &input,
const vector<string> &names, const vector<LogicalType> &sql_types) {
auto bind_data = make_uniq<WriteCSVData>(input.info.file_path, sql_types, names);
Expand All @@ -110,6 +174,9 @@ static unique_ptr<FunctionData> WriteCSVBind(ClientContext &context, CopyFunctio
}
bind_data->Finalize();

auto expressions = CreateCastExpressions(*bind_data, context, names, sql_types);
bind_data->cast_expressions = std::move(expressions);

bind_data->requires_quotes = make_unsafe_uniq_array<bool>(256);
memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256);
bind_data->requires_quotes['\n'] = true;
Expand Down Expand Up @@ -264,6 +331,14 @@ static void WriteQuotedString(WriteStream &writer, WriteCSVData &csv_data, const
// Sink
//===--------------------------------------------------------------------===//
struct LocalWriteCSVData : public LocalFunctionData {
public:
LocalWriteCSVData(ClientContext &context, vector<unique_ptr<Expression>> &expressions)
: executor(context, expressions) {
}

public:
//! Used to execute the expressions that transform input -> string
ExpressionExecutor executor;
//! The thread-local buffer to write data into
MemoryStream stream;
//! A chunk with VARCHAR columns to cast intermediates into
Expand Down Expand Up @@ -316,7 +391,7 @@ struct GlobalWriteCSVData : public GlobalFunctionData {

static unique_ptr<LocalFunctionData> WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) {
auto &csv_data = bind_data.Cast<WriteCSVData>();
auto local_data = make_uniq<LocalWriteCSVData>();
auto local_data = make_uniq<LocalWriteCSVData>(context.client, csv_data.cast_expressions);

// create the chunk with VARCHAR types
vector<LogicalType> types;
Expand Down Expand Up @@ -361,33 +436,16 @@ idx_t WriteCSVFileSize(GlobalFunctionData &gstate) {
}

static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_data, DataChunk &cast_chunk,
MemoryStream &writer, DataChunk &input, bool &written_anything) {
MemoryStream &writer, DataChunk &input, bool &written_anything,
ExpressionExecutor &executor) {
auto &csv_data = bind_data.Cast<WriteCSVData>();
auto &options = csv_data.options;

// first cast the columns of the chunk to varchar
cast_chunk.Reset();
cast_chunk.SetCardinality(input);
for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) {
if (csv_data.sql_types[col_idx].id() == LogicalTypeId::VARCHAR) {
// VARCHAR, just reinterpret (cannot reference, because LogicalTypeId::VARCHAR is used by the JSON type too)
cast_chunk.data[col_idx].Reinterpret(input.data[col_idx]);
} else if (!csv_data.options.write_date_format[LogicalTypeId::DATE].Empty() &&
csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) {
// use the date format to cast the chunk
csv_data.options.write_date_format[LogicalTypeId::DATE].ConvertDateVector(
input.data[col_idx], cast_chunk.data[col_idx], input.size());
} else if (!csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].Empty() &&
(csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP ||
csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP_TZ)) {
// use the timestamp format to cast the chunk
csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].ConvertTimestampVector(
input.data[col_idx], cast_chunk.data[col_idx], input.size());
} else {
// non varchar column, perform the cast
VectorOperations::Cast(context, input.data[col_idx], cast_chunk.data[col_idx], input.size());
}
}

executor.Execute(input, cast_chunk);

cast_chunk.Flatten();
// now loop over the vectors and output the values
Expand Down Expand Up @@ -428,7 +486,7 @@ static void WriteCSVSink(ExecutionContext &context, FunctionData &bind_data, Glo

// write data into the local buffer
WriteCSVChunkInternal(context.client, bind_data, local_data.cast_chunk, local_data.stream, input,
local_data.written_anything);
local_data.written_anything, local_data.executor);

// check if we should flush what we have currently written
auto &writer = local_data.stream;
Expand Down Expand Up @@ -506,11 +564,15 @@ unique_ptr<PreparedBatchData> WriteCSVPrepareBatch(ClientContext &context, Funct
DataChunk cast_chunk;
cast_chunk.Initialize(Allocator::Get(context), types);

auto &original_types = collection->Types();
auto expressions = CreateCastExpressions(csv_data, context, csv_data.options.name_list, original_types);
ExpressionExecutor executor(context, expressions);

// write CSV chunks to the batch data
bool written_anything = false;
auto batch = make_uniq<WriteCSVBatchData>();
for (auto &chunk : collection->Chunks()) {
WriteCSVChunkInternal(context, bind_data, cast_chunk, batch->stream, chunk, written_anything);
WriteCSVChunkInternal(context, bind_data, cast_chunk, batch->stream, chunk, written_anything, executor);
}
return std::move(batch);
}
Expand Down
6 changes: 3 additions & 3 deletions src/duckdb/src/function/table/version/pragma_version.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef DUCKDB_PATCH_VERSION
#define DUCKDB_PATCH_VERSION "3-dev145"
#define DUCKDB_PATCH_VERSION "3-dev158"
#endif
#ifndef DUCKDB_MINOR_VERSION
#define DUCKDB_MINOR_VERSION 10
Expand All @@ -8,10 +8,10 @@
#define DUCKDB_MAJOR_VERSION 0
#endif
#ifndef DUCKDB_VERSION
#define DUCKDB_VERSION "v0.10.3-dev145"
#define DUCKDB_VERSION "v0.10.3-dev158"
#endif
#ifndef DUCKDB_SOURCE_ID
#define DUCKDB_SOURCE_ID "1ffcc8bf8e"
#define DUCKDB_SOURCE_ID "d0d7f7fd09"
#endif
#include "duckdb/function/table/system_functions.hpp"
#include "duckdb/main/database.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ struct CSVReaderOptions {
//! The date format to use (if any is specified)
map<LogicalTypeId, StrpTimeFormat> date_format = {{LogicalTypeId::DATE, {}}, {LogicalTypeId::TIMESTAMP, {}}};
//! The date format to use for writing (if any is specified)
map<LogicalTypeId, StrfTimeFormat> write_date_format = {{LogicalTypeId::DATE, {}}, {LogicalTypeId::TIMESTAMP, {}}};
map<LogicalTypeId, Value> write_date_format = {{LogicalTypeId::DATE, Value()}, {LogicalTypeId::TIMESTAMP, Value()}};
//! Whether or not a type format is specified
map<LogicalTypeId, bool> has_format = {{LogicalTypeId::DATE, false}, {LogicalTypeId::TIMESTAMP, false}};

Expand Down
2 changes: 2 additions & 0 deletions src/duckdb/src/include/duckdb/function/table/read_csv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ struct WriteCSVData : public BaseCSVData {
idx_t flush_size = 4096ULL * 8ULL;
//! For each byte whether or not the CSV file requires quotes when containing the byte
unsafe_unique_array<bool> requires_quotes;
//! Expressions used to convert the input into strings
vector<unique_ptr<Expression>> cast_expressions;
};

struct ColumnInfo {
Expand Down

0 comments on commit bf6610f

Please sign in to comment.