Skip to content

Commit

Permalink
Implement casting of bool/int/float/str to categoricals (#3365)
Browse files Browse the repository at this point in the history
Implement casting of boolean, integer, float, date, time and string columns to categoricals.

WIP for #1691
  • Loading branch information
oleksiyskononenko committed Oct 4, 2022
1 parent 8219cea commit b4866f5
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 22 deletions.
29 changes: 23 additions & 6 deletions src/core/types/type_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,35 @@ size_t Type_Cat::hash() const noexcept {
*
* The following type casts are currently supported:
* * void -> cat*<T>
* * bool -> cat*<T>
* * int* -> cat*<T>
* * date32 -> cat*<T>
* * time64 -> cat*<T>
* * float* -> cat*<T>
* * str* -> cat*<T>
* * obj64 -> cat*<T>
*/
Column Type_Cat::cast_column(Column&& col) const {
switch (col.stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), make_type());

case SType::BOOL:
case SType::INT8:
case SType::INT16:
case SType::DATE32:
case SType::INT32:
case SType::TIME64:
case SType::INT64:
case SType::FLOAT32:
case SType::FLOAT64:
case SType::STR32:
case SType::STR64:
case SType::OBJ: {
switch (stype()) {
case SType::CAT8: cast_obj_column_<uint8_t>(col); break;
case SType::CAT16: cast_obj_column_<uint16_t>(col); break;
case SType::CAT32: cast_obj_column_<uint32_t>(col); break;
case SType::CAT8: cast_column_impl<uint8_t>(col); break;
case SType::CAT16: cast_column_impl<uint16_t>(col); break;
case SType::CAT32: cast_column_impl<uint32_t>(col); break;
default: throw RuntimeError() << "Unknown categorical type";
}
return std::move(col);
Expand All @@ -152,7 +169,7 @@ Column Type_Cat::cast_column(Column&& col) const {
* finally assigning it to `col`.
*/
template <typename T>
void Type_Cat::cast_obj_column_(Column& col) const {
void Type_Cat::cast_column_impl(Column& col) const {
size_t nrows = col.nrows(); // save nrows as `col` will be modified in-place

// First, cast `col` to the requested `elementType_` and obtain
Expand All @@ -172,8 +189,8 @@ void Type_Cat::cast_obj_column_(Column& col) const {

if (gb.size() > MAX_CATS) {
throw ValueError() << "Number of categories in the column is `" << gb.size()
<< "`, that is larger than " << to_string() << " type "
<< "can accomodate, i.e. `" << MAX_CATS << "`";
<< "`, that is larger than " << to_string() << " type supports, "
<< "i.e. `" << MAX_CATS << "`";
}

// Fill out two buffers:
Expand Down
2 changes: 1 addition & 1 deletion src/core/types/type_categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace dt {
class Type_Cat : public TypeImpl {
private:
template <typename T>
void cast_obj_column_(Column& col) const;
void cast_column_impl(Column& col) const;

protected:
Type elementType_;
Expand Down
178 changes: 163 additions & 15 deletions tests/types/test-categorical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#-------------------------------------------------------------------------------
# Copyright 2021 H2O.ai
# Copyright 2022 H2O.ai
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -116,8 +116,8 @@ def test_create_too_many_cats_error(t):
src = list(range(1000))

if t is dt.Type.cat8:
msg = r"Number of categories in the column is 1000, that is larger than cat8\(int32\) " \
"type can accomodate, i.e. 256"
msg = "Number of categories in the column is 1000, that is larger " \
r"than cat8\(int32\) type supports, i.e. 256"
with pytest.raises(ValueError, match=msg):
DT = dt.Frame(src, types = [t(dt.Type.int32)])
else:
Expand Down Expand Up @@ -261,25 +261,173 @@ def test_create_multicolumn(t):
# Casting to other types
#-------------------------------------------------------------------------------

@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_void_to_cat(t):
@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_void_to_cat(cat_type):
src = [None] * 11
DT = dt.Frame(src)
DT[0] = t(dt.Type.str32)
DT_ref = dt.Frame(src, type=t(dt.Type.str32))
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_obj_to_cat(t):
@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_bool_to_cat(cat_type):
src = [True, False, False, False, True, True]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.bool8)
DT_ref = dt.Frame(src, type=cat_type(dt.Type.bool8))
assert_equals(DT, DT_ref)



@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.int8,
dt.Type.int16,
dt.Type.int32,
dt.Type.int64])
def test_int_to_cat(cat_type, type):
src = [100, 500, 100500, 500, 100, 1]
DT = dt.Frame(src, type=type)
DT[0] = cat_type(type)
DT_ref = dt.Frame(src, type=cat_type(type))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.float32,
dt.Type.float64])
def test_float_to_cat(cat_type, type):
src = [100.1, -3.14, 500.1, 100500.5, None, 100500900.9, None]
DT = dt.Frame(src, type=type)
DT[0] = cat_type(type)
DT_ref = dt.Frame(src, type=cat_type(type))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.str32,
dt.Type.str64])
def test_str_to_cat(cat_type, type):
src = ["cat", "dog", "mouse", "dog", "mouse", "a"]
DT = dt.Frame(src, type=type)
DT[0] = cat_type(type)
DT_ref = dt.Frame(src, type=cat_type(type))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.date32,
dt.Type.time64])
def test_str_to_cat(cat_type, type):
from datetime import datetime as d
src = [d(1991, 1, 9, 10, 2, 3, 500),
d(2014, 7, 2, 1, 3, 5, 100),
d(2021, 2, 2, 1, 3, 5, 1000),
d(2014, 7, 2, 1, 3, 5, 100),
d(2020, 2, 22, 1, 3, 5, 129),
d(2020, 2, 22, 1, 3, 5, 129),
d(1950, 8, 2, 4, 3, 5, 10)]
DT = dt.Frame(src, type=type)
print(DT)
DT[0] = cat_type(type)
print(DT)
DT_ref = dt.Frame(src, type=cat_type(type))
print(DT_ref)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_obj_to_cat(cat_type):
src = [None, "cat", "cat", "dog", "mouse", None, "panda", "dog"]
DT = dt.Frame(A=src, type=object)
DT['A'] = t(dt.Type.str32)
DT_ref = dt.Frame(A=src, type=t(dt.Type.str32))
DT['A'] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(A=src, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_str_to_cat_bool(cat_type):
src = ["True", "weird", "False", "False", "None", "True", "True"]
src_ref = [True, None, False, False, None, True, True]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.bool8)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.bool8))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_bool_to_cat_str(cat_type):
src = [True, None, False, False, None, True, True]
src_ref = ["True", None, "False", "False", None, "True", "True"]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_str_to_cat_int(cat_type):
src = ["weird", "500", "100500", None, "100500900"]
src_ref = [None, 500, 100500, None, 100500900]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.int32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.int32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_int_to_cat_str(cat_type):
src = [None, 500, 100500, None, 100500900]
src_ref = [None, "500", "100500", None, "100500900"]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_str_to_cat_float(cat_type):
src = ["100.1", "500.1", "100500.5", None, "100500900.9", "yeah"]
src_ref = [100.1, 500.1, 100500.5, None, 100500900.9, None]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.float32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.float32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_float_to_cat_str(cat_type):
src = [100.1, 500.1, 100500.5, None, 100500900.9, None]
src_ref = ["100.1", "500.1", "100500.5", None, "100500900.9", None]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


Expand Down

0 comments on commit b4866f5

Please sign in to comment.