Skip to content

Commit

Permalink
Implement casts from categorical types, add to_csv() support for ca…
Browse files Browse the repository at this point in the history
…tegorical columns (#3372)

In this PR we 
- implement casts from `dt.cat*(...)` to all of the basic types;
- as a consequence, support for converting categorical columns to CSV has been added.

WIP for #1691
  • Loading branch information
oleksiyskononenko committed Oct 21, 2022
1 parent 4110455 commit 404d2e4
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/core/types/type_array.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//------------------------------------------------------------------------------
// Copyright 2021 H2O.ai
// Copyright 2021-2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -111,7 +111,7 @@ Type Type_Array::child_type() const {
*/
Column Type_Array::cast_column(Column&& col) const {
const auto st = stype();
switch (col.stype()) {
switch (col.data_stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), make_type());

Expand Down
8 changes: 7 additions & 1 deletion src/core/types/type_bool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ const char* Type_Bool8::struct_format() const {
// - TIME64->BOOL: --
//
Column Type_Bool8::cast_column(Column&& col) const {
switch (col.stype()) {
switch (col.data_stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), SType::BOOL);

case SType::BOOL:
if (col.type().is_categorical()) {
col.replace_type_unsafe(Type::bool8());
}
return std::move(col);

case SType::INT8:
return Column(new CastNumericToBool_ColumnImpl<int8_t>(std::move(col)));

Expand Down
7 changes: 5 additions & 2 deletions src/core/types/type_date.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//------------------------------------------------------------------------------
// Copyright 2021 H2O.ai
// Copyright 2021-2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -79,7 +79,7 @@ TypeImpl* Type_Date32::common_type(TypeImpl* other) {
//
Column Type_Date32::cast_column(Column&& col) const {
constexpr SType st = SType::DATE32;
switch (col.stype()) {
switch (col.data_stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), st);

Expand All @@ -97,6 +97,9 @@ Column Type_Date32::cast_column(Column&& col) const {
return Column(new CastNumeric_ColumnImpl<double>(st, std::move(col)));

case SType::DATE32:
if (col.type().is_categorical()) {
col.replace_type_unsafe(Type::date32());
}
return std::move(col);

case SType::TIME64:
Expand Down
4 changes: 2 additions & 2 deletions src/core/types/type_object.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//------------------------------------------------------------------------------
// Copyright 2021 H2O.ai
// Copyright 2021-2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -58,7 +58,7 @@ const char* Type_Object::struct_format() const {
//
Column Type_Object::cast_column(Column&& col) const {
constexpr auto st = SType::OBJ;
switch (col.stype()) {
switch (col.data_stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), st);

Expand Down
14 changes: 8 additions & 6 deletions src/core/types/type_string.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//------------------------------------------------------------------------------
// Copyright 2021 H2O.ai
// Copyright 2021-2022 H2O.ai
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
Expand Down Expand Up @@ -62,7 +62,7 @@ TypeImpl* Type_String::common_type(TypeImpl* other) {
//
Column Type_String::cast_column(Column&& col) const {
const auto st = stype();
switch (col.stype()) {
switch (col.data_stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), st);

Expand Down Expand Up @@ -95,7 +95,12 @@ Column Type_String::cast_column(Column&& col) const {

case SType::STR32:
case SType::STR64:
if (st == col.stype()) return std::move(col);
if (st == col.data_stype()) {
if (col.type().is_categorical()) {
col.replace_type_unsafe(Type::from_stype(st));
}
return std::move(col);
}
return Column(new CastString_ColumnImpl(st, std::move(col)));

case SType::OBJ:
Expand All @@ -109,8 +114,6 @@ Column Type_String::cast_column(Column&& col) const {





//------------------------------------------------------------------------------
// Type_String32
//-----------------------------------------------------------------------------
Expand All @@ -124,7 +127,6 @@ std::string Type_String32::to_string() const {




//------------------------------------------------------------------------------
// Type_String64
//------------------------------------------------------------------------------
Expand Down
5 changes: 4 additions & 1 deletion src/core/types/type_time.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Column Type_Time64::cast_column(Column&& col) const {
constexpr SType st = SType::TIME64;
constexpr int64_t SECONDS = 1000000000;
constexpr int64_t DAYS = 24 * 3600 * SECONDS;
switch (col.stype()) {
switch (col.data_stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), st);

Expand All @@ -109,6 +109,9 @@ Column Type_Time64::cast_column(Column&& col) const {
}

case SType::TIME64:
if (col.type().is_categorical()) {
col.replace_type_unsafe(Type::time64());
}
return std::move(col);

case SType::OBJ:
Expand Down
2 changes: 1 addition & 1 deletion src/core/types/typeimpl_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ TypeImpl* TypeImpl_Numeric::common_type(TypeImpl* other) {
//
Column TypeImpl_Numeric::cast_column(Column&& col) const {
const SType st = stype();
switch (col.stype()) {
switch (col.data_stype()) {
case SType::VOID:
return Column::new_na_column(col.nrows(), st);

Expand Down
21 changes: 21 additions & 0 deletions tests/frame/test-tocsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,27 @@ def test_save_str64():
",bvqpoeqnperoin;dj\n")


@pytest.mark.parametrize('t', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
def test_save_cat(t):
src = [[True, True, True, None, False],
[None, 3, 14, None, 15],
[-1.5, 1.5, 100.1, None, 18.0],
["mouse", "cat", "dog", None, "mouse"],
[None] * 5]
DT = dt.Frame(src,
types=[t(bool), t(int), t(float), t(str), t(None)],
names=["Booleans", "Integers", "Floats", "Strings", "Voids"])
assert DT.to_csv() == (
"Booleans,Integers,Floats,Strings,Voids\n"
"True,,-1.5,mouse,\n"
"True,3,1.5,cat,\n"
"True,14,100.1,dog,\n"
",,,,\n"
"False,15,18.0,mouse,\n")


def test_issue_1278():
f0 = dt.Frame([[True, False] * 10] * 1000)
a = f0.to_csv()
Expand Down
107 changes: 106 additions & 1 deletion tests/types/test-categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_create_multicolumn(t):


#-------------------------------------------------------------------------------
# Casting to other types
# Cast to categorical types
#-------------------------------------------------------------------------------

@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
Expand Down Expand Up @@ -443,6 +443,111 @@ def test_float_to_cat_str(cat_type):



#-------------------------------------------------------------------------------
# Cast from categorical types
#-------------------------------------------------------------------------------

@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [None, bool, int, float, str])
def test_cast_to_void(cat_type, data_type):
src = [None] * 10
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.void
DT_ref = dt.Frame(src)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [bool, int, float, str])
def test_cast_to_bool(cat_type, data_type):
src = [None, True, False, None, False, False]
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.bool8
DT_ref = dt.Frame(src)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [int, float, str])
def test_cast_to_int(cat_type, data_type):
src = [3, None, 1, 4, 1, None, 5, 9, 2, 6]
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.int32
DT_ref = dt.Frame(src)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [float, str])
def test_cast_to_float(cat_type, data_type):
src = [3.14, None, 1.15, 4.92, 1.6, None, 5.5, 9, 2.35, 6.0]
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.float64
DT_ref = dt.Frame(src)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [dt.Type.str32, dt.Type.str64])
def test_cast_to_string(cat_type, data_type):
src = ["dog", "mouse", None, "dog", "cat", None, "1", "pig"]
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.str32
DT_ref = dt.Frame(src)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [str, dt.Type.date32, dt.Type.time64])
def test_cast_to_date(cat_type, data_type):
from datetime import date as d
src = [d(1997, 9, 1), d(2002, 7, 31), d(2000, 2, 20), None]
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.date32
DT_ref = dt.Frame(src)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [str, dt.Type.time64])
def test_cast_to_time(cat_type, data_type):
from datetime import datetime as d
src = [d(2000, 10, 18, 3, 30),
d(2010, 11, 13, 15, 11, 59),
d(2020, 2, 29, 20, 20, 20, 20), None]
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.time64
DT_ref = dt.Frame(src)
assert_equals(DT, DT_ref)


@pytest.mark.parametrize('cat_type', [dt.Type.cat8,
dt.Type.cat16,
dt.Type.cat32])
@pytest.mark.parametrize('data_type', [int, float])
def test_cast_to_obj(cat_type, data_type):
src = [1, 2, None, 100, -10]
DT = dt.Frame(src, type = cat_type(data_type))
DT[0] = dt.Type.obj64
DT_ref = dt.Frame(src, type=dt.Type.obj64)
assert_equals(DT, DT_ref)



#-------------------------------------------------------------------------------
# Conversion to other formats
#-------------------------------------------------------------------------------
Expand Down

0 comments on commit 404d2e4

Please sign in to comment.