Skip to content

Commit

Permalink
Implement statistics for categorical columns (#3373)
Browse files Browse the repository at this point in the history
Implement most of the statistical functions for categorical columns. Once we implement sorting of categorical columns, we could add the missing `dt.mode()`.

WIP for #1691
  • Loading branch information
oleksiyskononenko authored and samukweku committed Jan 3, 2023
1 parent ad728aa commit a1dd57c
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 70 deletions.
9 changes: 2 additions & 7 deletions src/core/column.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//------------------------------------------------------------------------------
// Copyright 2018-2021 H2O.ai
// Copyright 2018-2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -173,12 +173,7 @@ dt::SType Column::stype() const noexcept {
}

dt::SType Column::data_stype() const noexcept {
if (impl_->type_.is_categorical()) {
if (n_children()) return child(0).stype();
else return dt::SType::VOID;
} else {
return stype();
}
return impl_->data_stype();
}

dt::LType Column::ltype() const noexcept {
Expand Down
8 changes: 8 additions & 0 deletions src/core/column/column_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ size_t ColumnImpl::null_count() const {
}


SType ColumnImpl::data_stype() const {
if (type_.is_categorical()) {
return n_children()? child(0).stype()
: SType::VOID;
}
return stype();
}



} // namespace dt
5 changes: 3 additions & 2 deletions src/core/column/column_impl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//------------------------------------------------------------------------------
// Copyright 2018-2021 H2O.ai
// Copyright 2018-2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -93,7 +93,8 @@ class ColumnImpl
//------------------------------------
public:
size_t nrows() const noexcept { return nrows_; }
SType stype() const { return type_.stype(); }
SType stype() const { return type_.stype(); }
SType data_stype() const;
const Type& type() const { return type_; }
virtual bool is_virtual() const noexcept = 0;
virtual bool computationally_expensive() const { return false; }
Expand Down
16 changes: 8 additions & 8 deletions src/core/stats.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//------------------------------------------------------------------------------
// Copyright 2018-2021 H2O.ai
// Copyright 2018-2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -1165,7 +1165,7 @@ size_t VoidStats::nmodal(bool* isvalid) {

static std::unique_ptr<Stats> _make_stats(const dt::ColumnImpl* col) {
using StatsPtr = std::unique_ptr<Stats>;
switch (col->stype()) {
switch (col->data_stype()) {
case dt::SType::VOID: return StatsPtr(new VoidStats(col));
case dt::SType::BOOL: return StatsPtr(new BooleanStats(col));
case dt::SType::INT8: return StatsPtr(new IntegerStats<int8_t>(col));
Expand Down Expand Up @@ -1383,7 +1383,7 @@ py::oobj Stats::get_stat_as_pyobject(Stat stat) {
return pywrap_stat<double>(stat);
}
case Stat::Sum: {
switch (dt::stype_to_ltype(column->stype())) {
switch (dt::stype_to_ltype(column->data_stype())) {
case dt::LType::MU:
case dt::LType::BOOL:
case dt::LType::INT: return pywrap_stat<int64_t>(stat);
Expand All @@ -1394,7 +1394,7 @@ py::oobj Stats::get_stat_as_pyobject(Stat stat) {
case Stat::Min:
case Stat::Max:
case Stat::Mode: {
switch (dt::stype_to_ltype(column->stype())) {
switch (dt::stype_to_ltype(column->data_stype())) {
case dt::LType::BOOL:
case dt::LType::INT: return pywrap_stat<int64_t>(stat);
case dt::LType::REAL: return pywrap_stat<double>(stat);
Expand Down Expand Up @@ -1469,7 +1469,7 @@ Column Stats::strcolwrap_stat(Stat stat) {
dt::CString value;
bool isvalid = get_stat(stat, &value);
return isvalid? _make_column_str(value)
: _make_nacol(column->stype());
: _make_nacol(column->data_stype());
}


Expand All @@ -1491,7 +1491,7 @@ Column Stats::get_stat_as_column(Stat stat) {
return colwrap_stat<double, double>(stat, dt::SType::FLOAT64);
}
case Stat::Sum: {
switch (column->stype()) {
switch (column->data_stype()) {
case dt::SType::VOID: return colwrap_stat<int64_t, int64_t>(stat, dt::SType::INT64);
case dt::SType::BOOL: return colwrap_stat<int64_t, int64_t>(stat, dt::SType::INT64);
case dt::SType::INT8: return colwrap_stat<int64_t, int64_t>(stat, dt::SType::INT64);
Expand All @@ -1506,7 +1506,7 @@ Column Stats::get_stat_as_column(Stat stat) {
case Stat::Min:
case Stat::Max:
case Stat::Mode: {
switch (column->stype()) {
switch (column->data_stype()) {
case dt::SType::BOOL: return colwrap_stat<int64_t, int8_t>(stat, dt::SType::BOOL);
case dt::SType::INT8: return colwrap_stat<int64_t, int8_t>(stat, dt::SType::INT8);
case dt::SType::INT16: return colwrap_stat<int64_t, int16_t>(stat, dt::SType::INT16);
Expand All @@ -1518,7 +1518,7 @@ Column Stats::get_stat_as_column(Stat stat) {
case dt::SType::STR64: return strcolwrap_stat(stat);
case dt::SType::DATE32: return colwrap_stat<int64_t, int32_t>(stat, dt::SType::DATE32);
case dt::SType::TIME64: return colwrap_stat<int64_t, int64_t>(stat, dt::SType::TIME64);
default: return _make_nacol(column->stype());
default: return _make_nacol(column->data_stype());
}
}
default:
Expand Down
Loading

0 comments on commit a1dd57c

Please sign in to comment.