Skip to content

Commit

Permalink
Changed csr to add edge_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
Dtenwolde committed Nov 15, 2022
1 parent f5e4efd commit cd51d51
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 19 deletions.
6 changes: 3 additions & 3 deletions extension/sqlpgq/sqlpgq_common.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "sqlpgq_common.hpp"

#include "duckdb/main/client_data.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/main/client_data.hpp"

#include <utility>

Expand Down Expand Up @@ -41,8 +41,8 @@ unique_ptr<FunctionData> CSRFunctionData::CSREdgeBind(ClientContext &context, Sc
throw InvalidInputException("Id must be constant.");
}
Value id = ExpressionExecutor::EvaluateScalar(context, *arguments[0]);
if (arguments.size() == 6) {
return make_unique<CSRFunctionData>(context, id.GetValue<int32_t>(), arguments[5]->return_type);
if (arguments.size() == 7) {
return make_unique<CSRFunctionData>(context, id.GetValue<int32_t>(), arguments[6]->return_type);
} else {
auto logical_type = LogicalType::SQLNULL;
return make_unique<CSRFunctionData>(context, id.GetValue<int32_t>(), logical_type);
Expand Down
40 changes: 25 additions & 15 deletions extension/sqlpgq/sqlpgq_functions/sqlpgq_csr_creation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,34 +127,39 @@ static void CreateCsrEdgeFunction(DataChunk &args, ExpressionState &state, Vecto
CsrInitializeEdge(info.context, info.id, vertex_size, edge_size);
}
if (info.weight_type == LogicalType::SQLNULL) {
BinaryExecutor::Execute<int64_t, int64_t, int32_t>(args.data[3], args.data[4], result, args.size(),
[&](int64_t src, int64_t dst) {
auto pos = ++csr_entry->second->v[src + 1];
csr_entry->second->e[(int64_t)pos - 1] = dst;
return 1;
});
TernaryExecutor::Execute<int64_t, int64_t, int64_t, int32_t>(
args.data[3], args.data[4], args.data[5], result, args.size(),
[&](int64_t src, int64_t dst, int64_t edge_id) {
auto pos = ++csr_entry->second->v[src + 1];
csr_entry->second->e[(int64_t)pos - 1] = dst;
csr_entry->second->edge_ids[(int64_t)pos - 1] = edge_id;
return 1;
});
return;
}
auto weight_type = args.data[5].GetType().InternalType();
if (!csr_entry->second->initialized_w) {
CsrInitializeWeight(info.context, info.id, edge_size, args.data[5].GetType().InternalType());
}
if (weight_type == PhysicalType::INT64) {
TernaryExecutor::Execute<int64_t, int64_t, int64_t, int32_t>(
args.data[3], args.data[4], args.data[5], result, args.size(),
[&](int64_t src, int64_t dst, int64_t weight) {
QuaternaryExecutor::Execute<int64_t, int64_t, int64_t, int64_t, int32_t>(
args.data[3], args.data[4], args.data[5], args.data[6], result, args.size(),
[&](int64_t src, int64_t dst, int64_t edge_id, int64_t weight) {
auto pos = ++csr_entry->second->v[src + 1];
csr_entry->second->e[(int64_t)pos - 1] = dst;
csr_entry->second->edge_ids[(int64_t)pos - 1] = edge_id;
csr_entry->second->w[(int64_t)pos - 1] = weight;
return weight;
});
return;
}

TernaryExecutor::Execute<int64_t, int64_t, double_t, int32_t>(
args.data[3], args.data[4], args.data[5], result, args.size(), [&](int64_t src, int64_t dst, double_t weight) {
QuaternaryExecutor::Execute<int64_t, int64_t, int64_t, double_t, int32_t>(
args.data[3], args.data[4], args.data[5], args.data[6], result, args.size(),
[&](int64_t src, int64_t dst, int64_t edge_id, double_t weight) {
auto pos = ++csr_entry->second->v[src + 1];
csr_entry->second->e[(int64_t)pos - 1] = dst;
csr_entry->second->edge_ids[(int64_t)pos - 1] = edge_id;
csr_entry->second->w_double[(int64_t)pos - 1] = weight;
return weight;
});
Expand All @@ -174,14 +179,19 @@ CreateScalarFunctionInfo SQLPGQFunctions::GetCsrVertexFunction() {

CreateScalarFunctionInfo SQLPGQFunctions::GetCsrEdgeFunction() {
ScalarFunctionSet set("create_csr_edge");
set.AddFunction(ScalarFunction(
{LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT},
LogicalType::INTEGER, CreateCsrEdgeFunction, CSRFunctionData::CSREdgeBind));
//! No edge weight
set.AddFunction(ScalarFunction({LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT,
LogicalType::BIGINT, LogicalType::BIGINT},
LogicalType::INTEGER, CreateCsrEdgeFunction, CSRFunctionData::CSREdgeBind));

//! Integer for edge weight
set.AddFunction(ScalarFunction({LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT,
LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT},
LogicalType::INTEGER, CreateCsrEdgeFunction, CSRFunctionData::CSREdgeBind));

//! Double for edge weight
set.AddFunction(ScalarFunction({LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT,
LogicalType::BIGINT, LogicalType::DOUBLE},
LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE},
LogicalType::INTEGER, CreateCsrEdgeFunction, CSRFunctionData::CSREdgeBind));

return CreateScalarFunctionInfo(set);
Expand Down
2 changes: 1 addition & 1 deletion test/sql/function/sql-pgq/test_create_csr.test
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ FROM (
LEFT JOIN Transfers t ON t.from_id = c.cid
GROUP BY c.rowid
) sub) AS BIGINT),
src.rowid, dst.rowid, t.amount))
src.rowid, dst.rowid, t.rowid, t.amount))
FROM
Transfers t
JOIN Customer src ON t.from_id = src.cid
Expand Down

0 comments on commit cd51d51

Please sign in to comment.