Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging arrayAUC #9815

Merged
merged 25 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
146 changes: 146 additions & 0 deletions dbms/src/Functions/array/arrayAUC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#include <algorithm>
#include <vector>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include "arrayScalarProduct.h"


namespace DB
{

namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}


/** The function takes two arrays: scores and labels.
* Label can be one of two values: positive and negative.
* Score can be arbitrary number.
*
* These values are considered as the output of classifier. We have some true labels for objects.
* And classifier assigns some scores to objects that predict these labels in the following way:
* - we can define arbitrary threshold on score and predict that the label is positive if the score is greater than the threshold:
*
* f(object) = score
* predicted_label = score > threshold
*
* This way classifier may predict positive or negative value correctly - true positive or true negative
* or have false positive or false negative result.
* Verying the threshold we can get different probabilities of false positive or false negatives or true positives, etc...
*
* We can also calculate the True Positive Rate and the False Positive Rate:
*
* TPR (also called "sensitivity", "recall" or "probability of detection")
* is the probability of classifier to give positive result if the object has positive label:
* TPR = P(score > threshold | label = positive)
*
* FPR is the probability of classifier to give positive result if the object has negative label:
* FPR = P(score > threshold | label = negative)
*
* We can draw a curve of values of FPR and TPR with different threshold on [0..1] x [0..1] unit square.
* This curve is named "ROC curve" (Receiver Operating Characteristic).
*
* For ROC we can calculate, literally, Area Under the Curve, that will be in the range of [0..1].
* The higher the AUC the better the classifier.
*
* AUC also is as the probability that the score for positive label is greater than the score for negative label.
*
* https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc
* https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
*
* To calculate AUC, we will draw points of (FPR, TPR) for different thresholds = score_i.
* FPR_raw = countIf(score > score_i, label = negative) = count negative labels above certain score
* TPR_raw = countIf(score > score_i, label = positive) = count positive labels above certain score
*
* Let's look at the example:
* arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
*
* 1. We have pairs: (-, 0.1), (-, 0.4), (+, 0.35), (+, 0.8)
*
* 2. Let's sort by score: (-, 0.1), (+, 0.35), (-, 0.4), (+, 0.8)
*
* 3. Let's draw the points:
*
* threshold = 0, TPR = 1, FPR = 1, TPR_raw = 2, FPR_raw = 2
* threshold = 0.1, TPR = 1, FPR = 0.5, TPR_raw = 2, FPR_raw = 1
* threshold = 0.35, TPR = 0.5, FPR = 0.5, TPR_raw = 1, FPR_raw = 1
* threshold = 0.4, TPR = 0.5, FPR = 0, TPR_raw = 1, FPR_raw = 0
* threshold = 0.8, TPR = 0, FPR = 0, TPR_raw = 0, FPR_raw = 0
*
* The "curve" will be present by a line that moves one step either towards right or top on each threshold change.
*/


struct NameArrayAUC
{
static constexpr auto name = "arrayAUC";
};


class ArrayAUCImpl
{
public:
using ResultType = Float64;

static DataTypePtr getReturnType(const DataTypePtr & /* score_type */, const DataTypePtr & label_type)
{
if (!(isNumber(label_type) || isEnum(label_type)))
throw Exception(std::string(NameArrayAUC::name) + " label must have numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

return std::make_shared<DataTypeNumber<ResultType>>();
}

template <typename T, typename U>
static ResultType apply(
const T * scores,
const U * labels,
size_t size)
{
struct ScoreLabel
{
T score;
bool label;
};

PODArrayWithStackMemory<ScoreLabel, 1024> sorted_labels(size);

for (size_t i = 0; i < size; ++i)
{
bool label = labels[i] > 0;
sorted_labels[i].score = scores[i];
sorted_labels[i].label = label;
}

std::sort(sorted_labels.begin(), sorted_labels.end(), [](const auto & lhs, const auto & rhs) { return lhs.score > rhs.score; });

/// We will first calculate non-normalized area.

size_t area = 0;
size_t count_positive = 0;
for (size_t i = 0; i < size; ++i)
{
if (sorted_labels[i].label)
++count_positive; /// The curve moves one step up. No area increase.
else
area += count_positive; /// The curve moves one step right. Area is increased by 1 * height = count_positive.
}

/// Then divide the area to the area of rectangle.

if (count_positive == 0 || count_positive == size)
return std::numeric_limits<ResultType>::quiet_NaN();

return ResultType(area) / count_positive / (size - count_positive);
}
};


/// auc(array_score, array_label) - Calculate AUC with array of score and label
using FunctionArrayAUC = FunctionArrayScalarProduct<ArrayAUCImpl, NameArrayAUC>;

void registerFunctionArrayAUC(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayAUC>();
}
}
145 changes: 145 additions & 0 deletions dbms/src/Functions/array/arrayScalarProduct.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#pragma once

#include <Columns/ColumnArray.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypeArray.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>


namespace DB
{

class Context;

namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS;
}


template <typename Method, typename Name>
class FunctionArrayScalarProduct : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context &) { return std::make_shared<FunctionArrayScalarProduct>(); }

private:
using ResultColumnType = ColumnVector<typename Method::ResultType>;

template <typename T>
bool executeNumber(Block & block, const ColumnNumbers & arguments, size_t result)
{
return executeNumberNumber<T, UInt8>(block, arguments, result)
|| executeNumberNumber<T, UInt16>(block, arguments, result)
|| executeNumberNumber<T, UInt32>(block, arguments, result)
|| executeNumberNumber<T, UInt64>(block, arguments, result)
|| executeNumberNumber<T, Int8>(block, arguments, result)
|| executeNumberNumber<T, Int16>(block, arguments, result)
|| executeNumberNumber<T, Int32>(block, arguments, result)
|| executeNumberNumber<T, Int64>(block, arguments, result)
|| executeNumberNumber<T, Float32>(block, arguments, result)
|| executeNumberNumber<T, Float64>(block, arguments, result);
}


template <typename T, typename U>
bool executeNumberNumber(Block & block, const ColumnNumbers & arguments, size_t result)
{
ColumnPtr col1 = block.getByPosition(arguments[0]).column->convertToFullColumnIfConst();
ColumnPtr col2 = block.getByPosition(arguments[1]).column->convertToFullColumnIfConst();
if (!col1 || !col2)
return false;

const ColumnArray * col_array1 = checkAndGetColumn<ColumnArray>(col1.get());
const ColumnArray * col_array2 = checkAndGetColumn<ColumnArray>(col2.get());
if (!col_array1 || !col_array2)
return false;

if (!col_array1->hasEqualOffsets(*col_array2))
throw Exception("Array arguments for function " + getName() + " must have equal sizes", ErrorCodes::BAD_ARGUMENTS);

const ColumnVector<T> * col_nested1 = checkAndGetColumn<ColumnVector<T>>(col_array1->getData());
const ColumnVector<U> * col_nested2 = checkAndGetColumn<ColumnVector<U>>(col_array2->getData());
if (!col_nested1 || !col_nested2)
return false;

auto col_res = ResultColumnType::create();

vector(
col_nested1->getData(),
col_nested2->getData(),
col_array1->getOffsets(),
col_res->getData());

block.getByPosition(result).column = std::move(col_res);
return true;
}

template <typename T, typename U>
static NO_INLINE void vector(
const PaddedPODArray<T> & data1,
const PaddedPODArray<U> & data2,
const ColumnArray::Offsets & offsets,
PaddedPODArray<typename Method::ResultType> & result)
{
size_t size = offsets.size();
result.resize(size);

ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < size; ++i)
{
size_t array_size = offsets[i] - current_offset;
result[i] = Method::apply(&data1[current_offset], &data2[current_offset], array_size);
current_offset = offsets[i];
}
}

public:
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }

DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
// Basic type check
std::vector<DataTypePtr> nested_types(2, nullptr);
for (size_t i = 0; i < getNumberOfArguments(); ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
if (!array_type)
throw Exception("All arguments for function " + getName() + " must be an array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

auto & nested_type = array_type->getNestedType();
if (!isNativeNumber(nested_type) && !isEnum(nested_type))
throw Exception(
getName() + " cannot process values of type " + nested_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
nested_types[i] = nested_type;
}

// Detail type check in Method, then return ReturnType
return Method::getReturnType(nested_types[0], nested_types[1]);
}

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /* input_rows_count */) override
{
if (!(executeNumber<UInt8>(block, arguments, result)
|| executeNumber<UInt16>(block, arguments, result)
|| executeNumber<UInt32>(block, arguments, result)
|| executeNumber<UInt64>(block, arguments, result)
|| executeNumber<Int8>(block, arguments, result)
|| executeNumber<Int16>(block, arguments, result)
|| executeNumber<Int32>(block, arguments, result)
|| executeNumber<Int64>(block, arguments, result)
|| executeNumber<Float32>(block, arguments, result)
|| executeNumber<Float64>(block, arguments, result)))
throw Exception{"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function "
+ getName(),
ErrorCodes::ILLEGAL_COLUMN};
}
};

}

65 changes: 32 additions & 33 deletions dbms/src/Functions/array/registerFunctionsArray.cpp
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
namespace DB
{

class FunctionFactory;

void registerFunctionArray(FunctionFactory & factory);
void registerFunctionArrayElement(FunctionFactory & factory);
void registerFunctionArrayResize(FunctionFactory & factory);
void registerFunctionHas(FunctionFactory & factory);
void registerFunctionHasAll(FunctionFactory & factory);
void registerFunctionHasAny(FunctionFactory & factory);
void registerFunctionIndexOf(FunctionFactory & factory);
void registerFunctionCountEqual(FunctionFactory & factory);
void registerFunctionArrayIntersect(FunctionFactory & factory);
void registerFunctionArrayPushFront(FunctionFactory & factory);
void registerFunctionArrayPushBack(FunctionFactory & factory);
void registerFunctionArrayPopFront(FunctionFactory & factory);
void registerFunctionArrayPopBack(FunctionFactory & factory);
void registerFunctionArrayConcat(FunctionFactory & factory);
void registerFunctionArraySlice(FunctionFactory & factory);
void registerFunctionArrayReverse(FunctionFactory & factory);
void registerFunctionArrayReduce(FunctionFactory & factory);
void registerFunctionRange(FunctionFactory & factory);
void registerFunctionsEmptyArray(FunctionFactory & factory);
void registerFunctionEmptyArrayToSingle(FunctionFactory & factory);
void registerFunctionArrayEnumerate(FunctionFactory & factory);
void registerFunctionArrayEnumerateUniq(FunctionFactory & factory);
void registerFunctionArrayEnumerateDense(FunctionFactory & factory);
void registerFunctionArrayEnumerateUniqRanked(FunctionFactory & factory);
void registerFunctionArrayEnumerateDenseRanked(FunctionFactory & factory);
void registerFunctionArrayUniq(FunctionFactory & factory);
void registerFunctionArrayDistinct(FunctionFactory & factory);
void registerFunctionArrayFlatten(FunctionFactory & factory);
void registerFunctionArrayWithConstant(FunctionFactory & factory);
void registerFunctionArrayZip(FunctionFactory & factory);

void registerFunctionArray(FunctionFactory &);
void registerFunctionArrayElement(FunctionFactory &);
void registerFunctionArrayResize(FunctionFactory &);
void registerFunctionHas(FunctionFactory &);
void registerFunctionHasAll(FunctionFactory &);
void registerFunctionHasAny(FunctionFactory &);
void registerFunctionIndexOf(FunctionFactory &);
void registerFunctionCountEqual(FunctionFactory &);
void registerFunctionArrayIntersect(FunctionFactory &);
void registerFunctionArrayPushFront(FunctionFactory &);
void registerFunctionArrayPushBack(FunctionFactory &);
void registerFunctionArrayPopFront(FunctionFactory &);
void registerFunctionArrayPopBack(FunctionFactory &);
void registerFunctionArrayConcat(FunctionFactory &);
void registerFunctionArraySlice(FunctionFactory &);
void registerFunctionArrayReverse(FunctionFactory &);
void registerFunctionArrayReduce(FunctionFactory &);
void registerFunctionRange(FunctionFactory &);
void registerFunctionsEmptyArray(FunctionFactory &);
void registerFunctionEmptyArrayToSingle(FunctionFactory &);
void registerFunctionArrayEnumerate(FunctionFactory &);
void registerFunctionArrayEnumerateUniq(FunctionFactory &);
void registerFunctionArrayEnumerateDense(FunctionFactory &);
void registerFunctionArrayEnumerateUniqRanked(FunctionFactory &);
void registerFunctionArrayEnumerateDenseRanked(FunctionFactory &);
void registerFunctionArrayUniq(FunctionFactory &);
void registerFunctionArrayDistinct(FunctionFactory &);
void registerFunctionArrayFlatten(FunctionFactory &);
void registerFunctionArrayWithConstant(FunctionFactory &);
void registerFunctionArrayZip(FunctionFactory &);
void registerFunctionArrayAUC(FunctionFactory &);

void registerFunctionsArray(FunctionFactory & factory)
{
Expand Down Expand Up @@ -67,7 +66,7 @@ void registerFunctionsArray(FunctionFactory & factory)
registerFunctionArrayFlatten(factory);
registerFunctionArrayWithConstant(factory);
registerFunctionArrayZip(factory);
registerFunctionArrayAUC(factory);
}

}

9 changes: 9 additions & 0 deletions dbms/tests/performance/array_auc.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<test>
<stop_conditions>
<all_of>
<total_time_ms>10000</total_time_ms>
</all_of>
</stop_conditions>

<query>SELECT avg(ifNotFinite(arrayAUC(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000)</query>
</test>
16 changes: 16 additions & 0 deletions dbms/tests/queries/0_stateless/01064_array_auc.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
0.75
0.75
0.75
0.75
0.75
0.75
0.75
0.75
0.75
0.25
0.25
0.25
0.25
0.25
0.125
0.25