Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement casting of the most column types to categoricals #3365

Merged
merged 1 commit into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/core/types/type_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,35 @@ size_t Type_Cat::hash() const noexcept {
*
* The following type casts are currently supported:
* * void -> cat*<T>
* * bool -> cat*<T>
* * int* -> cat*<T>
* * date32 -> cat*<T>
* * time64 -> cat*<T>
* * float* -> cat*<T>
* * str* -> cat*<T>
* * obj64 -> cat*<T>
*/
Column Type_Cat::cast_column(Column&& col) const {
switch (col.stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), make_type());

case SType::BOOL:
case SType::INT8:
case SType::INT16:
case SType::DATE32:
case SType::INT32:
case SType::TIME64:
case SType::INT64:
case SType::FLOAT32:
case SType::FLOAT64:
case SType::STR32:
case SType::STR64:
case SType::OBJ: {
switch (stype()) {
case SType::CAT8: cast_obj_column_<uint8_t>(col); break;
case SType::CAT16: cast_obj_column_<uint16_t>(col); break;
case SType::CAT32: cast_obj_column_<uint32_t>(col); break;
case SType::CAT8: cast_column_impl<uint8_t>(col); break;
case SType::CAT16: cast_column_impl<uint16_t>(col); break;
case SType::CAT32: cast_column_impl<uint32_t>(col); break;
default: throw RuntimeError() << "Unknown categorical type";
}
return std::move(col);
Expand All @@ -152,7 +169,7 @@ Column Type_Cat::cast_column(Column&& col) const {
* finally assigning it to `col`.
*/
template <typename T>
void Type_Cat::cast_obj_column_(Column& col) const {
void Type_Cat::cast_column_impl(Column& col) const {
size_t nrows = col.nrows(); // save nrows as `col` will be modified in-place

// First, cast `col` to the requested `elementType_` and obtain
Expand All @@ -172,8 +189,8 @@ void Type_Cat::cast_obj_column_(Column& col) const {

if (gb.size() > MAX_CATS) {
throw ValueError() << "Number of categories in the column is `" << gb.size()
<< "`, that is larger than " << to_string() << " type "
<< "can accomodate, i.e. `" << MAX_CATS << "`";
<< "`, that is larger than " << to_string() << " type supports, "
<< "i.e. `" << MAX_CATS << "`";
}

// Fill out two buffers:
Expand Down
2 changes: 1 addition & 1 deletion src/core/types/type_categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace dt {
class Type_Cat : public TypeImpl {
private:
template <typename T>
void cast_obj_column_(Column& col) const;
void cast_column_impl(Column& col) const;

protected:
Type elementType_;
Expand Down
178 changes: 163 additions & 15 deletions tests/types/test-categorical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#-------------------------------------------------------------------------------
# Copyright 2021 H2O.ai
# Copyright 2022 H2O.ai
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -116,8 +116,8 @@ def test_create_too_many_cats_error(t):
src = list(range(1000))

if t is dt.Type.cat8:
msg = r"Number of categories in the column is 1000, that is larger than cat8\(int32\) " \
"type can accomodate, i.e. 256"
msg = "Number of categories in the column is 1000, that is larger " \
r"than cat8\(int32\) type supports, i.e. 256"
with pytest.raises(ValueError, match=msg):
DT = dt.Frame(src, types = [t(dt.Type.int32)])
else:
Expand Down Expand Up @@ -261,25 +261,173 @@ def test_create_multicolumn(t):
# Casting to other types
#-------------------------------------------------------------------------------

@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_void_to_cat(t):
@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_void_to_cat(cat_type):
src = [None] * 11
DT = dt.Frame(src)
DT[0] = t(dt.Type.str32)
DT_ref = dt.Frame(src, type=t(dt.Type.str32))
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_obj_to_cat(t):
@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_bool_to_cat(cat_type):
src = [True, False, False, False, True, True]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.bool8)
DT_ref = dt.Frame(src, type=cat_type(dt.Type.bool8))
assert_equals(DT, DT_ref)



@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.int8,
dt.Type.int16,
dt.Type.int32,
dt.Type.int64])
def test_int_to_cat(cat_type, type):
src = [100, 500, 100500, 500, 100, 1]
DT = dt.Frame(src, type=type)
DT[0] = cat_type(type)
DT_ref = dt.Frame(src, type=cat_type(type))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.float32,
dt.Type.float64])
def test_float_to_cat(cat_type, type):
src = [100.1, -3.14, 500.1, 100500.5, None, 100500900.9, None]
DT = dt.Frame(src, type=type)
DT[0] = cat_type(type)
DT_ref = dt.Frame(src, type=cat_type(type))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.str32,
dt.Type.str64])
def test_str_to_cat(cat_type, type):
src = ["cat", "dog", "mouse", "dog", "mouse", "a"]
DT = dt.Frame(src, type=type)
DT[0] = cat_type(type)
DT_ref = dt.Frame(src, type=cat_type(type))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('type', [dt.Type.date32,
dt.Type.time64])
def test_str_to_cat(cat_type, type):
from datetime import datetime as d
src = [d(1991, 1, 9, 10, 2, 3, 500),
d(2014, 7, 2, 1, 3, 5, 100),
d(2021, 2, 2, 1, 3, 5, 1000),
d(2014, 7, 2, 1, 3, 5, 100),
d(2020, 2, 22, 1, 3, 5, 129),
d(2020, 2, 22, 1, 3, 5, 129),
d(1950, 8, 2, 4, 3, 5, 10)]
DT = dt.Frame(src, type=type)
print(DT)
DT[0] = cat_type(type)
print(DT)
DT_ref = dt.Frame(src, type=cat_type(type))
print(DT_ref)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_obj_to_cat(cat_type):
src = [None, "cat", "cat", "dog", "mouse", None, "panda", "dog"]
DT = dt.Frame(A=src, type=object)
DT['A'] = t(dt.Type.str32)
DT_ref = dt.Frame(A=src, type=t(dt.Type.str32))
DT['A'] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(A=src, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_str_to_cat_bool(cat_type):
src = ["True", "weird", "False", "False", "None", "True", "True"]
src_ref = [True, None, False, False, None, True, True]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.bool8)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.bool8))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_bool_to_cat_str(cat_type):
src = [True, None, False, False, None, True, True]
src_ref = ["True", None, "False", "False", None, "True", "True"]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_str_to_cat_int(cat_type):
src = ["weird", "500", "100500", None, "100500900"]
src_ref = [None, 500, 100500, None, 100500900]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.int32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.int32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_int_to_cat_str(cat_type):
src = [None, 500, 100500, None, 100500900]
src_ref = [None, "500", "100500", None, "100500900"]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_str_to_cat_float(cat_type):
src = ["100.1", "500.1", "100500.5", None, "100500900.9", "yeah"]
src_ref = [100.1, 500.1, 100500.5, None, 100500900.9, None]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.float32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.float32))
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_float_to_cat_str(cat_type):
src = [100.1, 500.1, 100500.5, None, 100500900.9, None]
src_ref = ["100.1", "500.1", "100500.5", None, "100500900.9", None]
DT = dt.Frame(src)
DT[0] = cat_type(dt.Type.str32)
DT_ref = dt.Frame(src_ref, type=cat_type(dt.Type.str32))
assert_equals(DT, DT_ref)


Expand Down