Skip to content

Commit

Permalink
use ExpressionExecutor, enables overriding behavior through the Catal…
Browse files Browse the repository at this point in the history
…og (such as ICU)
  • Loading branch information
Tishj committed Apr 16, 2024
1 parent 50580b2 commit 44a0e68
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
82 changes: 74 additions & 8 deletions src/function/table/copy_csv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,65 @@ static unique_ptr<FunctionData> WriteCSVBind(ClientContext &context, CopyFunctio
}
bind_data->Finalize();

auto &options = csv_data->options;
auto &formats = options.write_date_format;

bool has_dateformat = !formats[LogicalTypeId::DATE].Empty();
bool has_timestampformat = !formats[LogicalTypeId::TIMESTAMP].Empty();

// Create a binder
auto binder = Binder::CreateBinder(context);

// Create a Binding, used by the ExpressionBinder to turn our columns into BoundReferenceExpressions
auto &bind_context = binder->bind_context;
auto table_index = binder->GenerateTableIndex();
bind_context.AddGenericBinding(table_index, "copy_csv", names, sql_types);

// Create the ParsedExpressions (cast, strftime, etc..)
vector<unique_ptr<ParsedExpression>> unbound_expressions;
for (idx_t i = 0; i < sql_types.size(); i++) {
auto &type = sql_types[i];
auto &name = names[i];

bool is_timestamp = type.id() == LogicalTypeId::TIMESTAMP || type.id() == LogicalTypeId::TIMESTAMP_TZ;
if (has_dateformat && type.id() == LogicalTypeId::DATE) {
// strftime(<name>, 'format')
vector<unique_ptr<ParsedExpression>> children;
children.push_back(make_uniq<ColumnRefExpression>(name));
// TODO: set from user-provided format
children.push_back(make_uniq<ConstantExpression>("%m/%d/%Y, %-I:%-M %p"));
auto func = make_uniq<FunctionExpression>("strftime", std::move(children));
unbound_expressions.push_back(std::move(expr));
} else if (has_timestampformat && is_timestamp) {
// strftime(<name>, 'format')
vector<unique_ptr<ParsedExpression>> children;
children.push_back(make_uniq<ColumnRefExpression>(name));
// TODO: set from user-provided format
children.push_back(make_uniq<ConstantExpression>("%Y-%m-%dT%H:%M:%S.%fZ"));
auto func = make_uniq<FunctionExpression>("strftime", std::move(children));
unbound_expressions.push_back(std::move(expr));
} else {
// CAST <name> AS VARCHAR
auto column = make_uniq<ColumnRefExpression>(name);
auto expr = make_uniq<CastExpression>(LogicalType::VARCHAR, std::move(column));
unbound_expressions.push_back(std::move(expr));
}
}

// Create an ExpressionBinder, bind the Expressions
vector<unique_ptr<Expression>> expressions;
ExpressionBinder expression_binder(*binder, context);
expression_binder.target_type = LogicalType::VARCHAR;
for (auto &expr : unbound_expressions) {
expressions.push_back(expression_binder.Bind(expr));
}

bind_data->cast_expressions = std::move(expressions);

// Move these into the WriteCSVData
// In 'WriteCSVInitializeLocal' we'll create an ExpressionExecutor, fed our expressions
// In 'WriteCSVChunkInternal' we use this expression executor to convert our input columns to VARCHAR

bind_data->requires_quotes = make_unsafe_uniq_array<bool>(256);
memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256);
bind_data->requires_quotes['\n'] = true;
Expand Down Expand Up @@ -262,6 +321,14 @@ static void WriteQuotedString(WriteStream &writer, WriteCSVData &csv_data, const
// Sink
//===--------------------------------------------------------------------===//
struct LocalWriteCSVData : public LocalFunctionData {
public:
LocalWriteCSVData(ClientContext &context, vector<unique_ptr<Expression>> &expressions)
: executor(context, expressions) {
}

public:
//! Used to execute the expressions that transform input -> string
ExpressionExecutor executor;
//! The thread-local buffer to write data into
MemoryStream stream;
//! A chunk with VARCHAR columns to cast intermediates into
Expand Down Expand Up @@ -314,7 +381,7 @@ struct GlobalWriteCSVData : public GlobalFunctionData {

static unique_ptr<LocalFunctionData> WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) {
auto &csv_data = bind_data.Cast<WriteCSVData>();
auto local_data = make_uniq<LocalWriteCSVData>();
auto local_data = make_uniq<LocalWriteCSVData>(context.client, csv_data.cast_expressions);

// create the chunk with VARCHAR types
vector<LogicalType> types;
Expand Down Expand Up @@ -362,6 +429,7 @@ static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_dat
MemoryStream &writer, DataChunk &input, bool &written_anything) {
auto &csv_data = bind_data.Cast<WriteCSVData>();
auto &options = csv_data.options;
auto &formats = options.write_date_format;

// first cast the columns of the chunk to varchar
cast_chunk.Reset();
Expand All @@ -370,17 +438,15 @@ static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_dat
if (csv_data.sql_types[col_idx].id() == LogicalTypeId::VARCHAR) {
// VARCHAR, just reinterpret (cannot reference, because LogicalTypeId::VARCHAR is used by the JSON type too)
cast_chunk.data[col_idx].Reinterpret(input.data[col_idx]);
} else if (!csv_data.options.write_date_format[LogicalTypeId::DATE].Empty() &&
csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) {
} else if (!formats[LogicalTypeId::DATE].Empty() && csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) {
// use the date format to cast the chunk
csv_data.options.write_date_format[LogicalTypeId::DATE].ConvertDateVector(
input.data[col_idx], cast_chunk.data[col_idx], input.size());
} else if (!csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].Empty() &&
formats[LogicalTypeId::DATE].ConvertDateVector(input.data[col_idx], cast_chunk.data[col_idx], input.size());
} else if (!formats[LogicalTypeId::TIMESTAMP].Empty() &&
(csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP ||
csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP_TZ)) {
// use the timestamp format to cast the chunk
csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].ConvertTimestampVector(
input.data[col_idx], cast_chunk.data[col_idx], input.size());
formats[LogicalTypeId::TIMESTAMP].ConvertTimestampVector(input.data[col_idx], cast_chunk.data[col_idx],
input.size());
} else {
// non varchar column, perform the cast
VectorOperations::Cast(context, input.data[col_idx], cast_chunk.data[col_idx], input.size());
Expand Down
2 changes: 2 additions & 0 deletions src/include/duckdb/function/table/read_csv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ struct WriteCSVData : public BaseCSVData {
idx_t flush_size = 4096ULL * 8ULL;
//! For each byte whether or not the CSV file requires quotes when containing the byte
unsafe_unique_array<bool> requires_quotes;
//! Expressions used to convert the input into strings
vector<unique_ptr<Expression>> cast_expressions;
};

struct ColumnInfo {
Expand Down

0 comments on commit 44a0e68

Please sign in to comment.