From 3ac82ea631f696fd976df8f1bc51bd0fa287f4ae Mon Sep 17 00:00:00 2001 From: hungcs <85463385+hungcs@users.noreply.github.com> Date: Thu, 9 Sep 2021 15:36:13 -0700 Subject: [PATCH] Add serialization for DatasetInfo and round avg_words to int (#1294) --- ludwig/automl/base_config.py | 2 ++ ludwig/automl/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ludwig/automl/base_config.py b/ludwig/automl/base_config.py index 85d453df9dc..298376dbd5c 100644 --- a/ludwig/automl/base_config.py +++ b/ludwig/automl/base_config.py @@ -16,6 +16,7 @@ import os from dataclasses import dataclass +from dataclasses_json import LetterCase, dataclass_json from typing import List, Union import pandas as pd @@ -45,6 +46,7 @@ } +@dataclass_json(letter_case=LetterCase.CAMEL) @dataclass class DatasetInfo: fields: List[FieldInfo] diff --git a/ludwig/automl/utils.py b/ludwig/automl/utils.py index 4f881cd34f1..6f5c2e0ab80 100644 --- a/ludwig/automl/utils.py +++ b/ludwig/automl/utils.py @@ -55,7 +55,7 @@ def avg_num_tokens(field: Series) -> int: if len(field) > 5000: field = field.sample(n=5000, random_state=40) unique_entries = field.unique() - avg_words = Series(unique_entries).str.split().str.len().mean() + avg_words = round(Series(unique_entries).str.split().str.len().mean()) return avg_words