Skip to content

Commit

Permalink
Convert nan to 0 in avg_num_tokens() (#2046)
Browse files Browse the repository at this point in the history
  • Loading branch information
hungcs committed May 21, 2022
1 parent 55b7672 commit f54bf05
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ludwig/automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List

from dataclasses_json import dataclass_json, LetterCase
from numpy import nan_to_num
from pandas import Series

from ludwig.constants import COMBINER, CONFIG, HYPEROPT, NAME, NUMBER, PARAMETERS, SAMPLER, TRAINER, TYPE
Expand Down Expand Up @@ -58,7 +59,7 @@ def avg_num_tokens(field: Series) -> int:
if len(field) > 5000:
field = field.sample(n=5000, random_state=40)
unique_entries = field.unique()
avg_words = round(Series(unique_entries).str.split().str.len().mean())
avg_words = round(nan_to_num(Series(unique_entries).str.split().str.len().mean()))
return avg_words


Expand Down
9 changes: 9 additions & 0 deletions tests/ludwig/automl/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pandas as pd
import pytest

from ludwig.automl.utils import avg_num_tokens


@pytest.mark.parametrize("field,expected", [(pd.Series([None]), 0), (pd.Series(["string1", "string2", "string3"]), 1)])
def test_avg_num_tokens(field, expected):
assert avg_num_tokens(field) == expected

0 comments on commit f54bf05

Please sign in to comment.