Skip to content

Commit

Permalink
high_level: Accept lists for data arguments
Browse files Browse the repository at this point in the history
Signed-off-by: Hashim Chaudry <hashimchaudry23@gmail.com>
  • Loading branch information
mhash1m authored and pdxjohnny committed Jan 26, 2022
1 parent 3d09e28 commit 714d325
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
same, it still accepts a name or a path.
- Renamed `accuracy()` to `score()`.
- Renamed `Optimizer` to `Tuner`.
- High-level functions now accept list for data.
### Fixed
- Record object key properties are now always strings
- High level functions (`train()`, etc.) now work on existing open contexts
Expand Down
57 changes: 52 additions & 5 deletions dffml/high_level/ml.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import contextlib
from typing import Union, Dict, Any
from typing import Union, Dict, Any, List

from ..record import Record
from ..source.source import BaseSource
from ..feature import Feature, Features
from ..model import Model, ModelContext
from ..util.internal import records_to_sources
from ..util.internal import records_to_sources, list_records_to_dict
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext


async def train(model, *args: Union[BaseSource, Record, Dict[str, Any]]):
async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
"""
Train a machine learning model.
Expand Down Expand Up @@ -51,6 +51,23 @@ async def train(model, *args: Union[BaseSource, Record, Dict[str, Any]]):
>>>
>>> asyncio.run(main())
"""
if (
hasattr(model.config, "features")
and any(isinstance(arg, list) for arg in args)
and hasattr(model.config, "predict")
):
if isinstance(model.config.predict, Features):
predict_feature = [
feature.name for feature in model.config.predict
]
else:
predict_feature = [model.config.predict.name]
args = list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*args,
model=model,
)
async with contextlib.AsyncExitStack() as astack:
# Open sources
sctx = await astack.enter_async_context(records_to_sources(*args))
Expand All @@ -68,7 +85,7 @@ async def score(
model,
accuracy_scorer: Union[AccuracyScorer, AccuracyContext],
features: Union[Feature, Features],
*args: Union[BaseSource, Record, Dict[str, Any]],
*args: Union[BaseSource, Record, Dict[str, Any], List],
) -> float:
"""
Assess the accuracy of a machine learning model.
Expand Down Expand Up @@ -138,6 +155,21 @@ async def score(
)
if isinstance(features, Feature):
features = Features(features)
if any(isinstance(arg, list) for arg in args) and hasattr(
model.config, "predict"
):
if isinstance(model.config.predict, Features):
predict_feature = [
feature.name for feature in model.config.predict
]
else:
predict_feature = [model.config.predict.name]
args = list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*args,
model=model,
)

async with contextlib.AsyncExitStack() as astack:
# Open sources
Expand All @@ -164,7 +196,7 @@ async def score(

async def predict(
model,
*args: Union[BaseSource, Record, Dict[str, Any]],
*args: Union[BaseSource, Record, Dict[str, Any], List],
update: bool = False,
keep_record: bool = False,
):
Expand Down Expand Up @@ -228,6 +260,21 @@ async def predict(
{'Years': 6, 'Salary': 70}
{'Years': 7, 'Salary': 80}
"""
if any(isinstance(arg, list) for arg in args) and hasattr(
model.config, "predict"
):
if isinstance(model.config.predict, Features):
predict_feature = [
feature.name for feature in model.config.predict
]
else:
predict_feature = [model.config.predict.name]
args = list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*args,
model=model,
)
async with contextlib.AsyncExitStack() as astack:
# Open sources
sctx = await astack.enter_async_context(records_to_sources(*args))
Expand Down
17 changes: 17 additions & 0 deletions dffml/util/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
from ..source.memory import MemorySource, MemorySourceConfig


class CannotConvertToRecord(Exception):
"""
Raised when a list is provided to convert to records but the model doesn't
exist.
"""


@contextlib.asynccontextmanager
async def records_to_sources(*args):
"""
Expand Down Expand Up @@ -55,3 +62,13 @@ async def records_to_sources(*args):
for already_open_sctx in sctxs:
sctx.append(already_open_sctx)
yield sctx


def list_records_to_dict(features, *args, model=None):
if model:
args = list(args)
for i in range(len(args)):
if isinstance(args[i], list):
args[i] = dict(zip(features, args[i]))
return args
raise CannotConvertToRecord("Model does not exist!")
10 changes: 10 additions & 0 deletions tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ async def test_predict(self):
self.assertEqual(round(predictions[0][2]["Salary"]["value"]), 70)
self.assertEqual(round(predictions[1][2]["Salary"]["value"]), 80)

# Test input data as list
await train(model, *self.train_data)
await score(model, scorer, Feature("Salary", int, 1), *self.test_data)
predictions = [
prediction
async for prediction in predict(model, *self.predict_data)
]
self.assertEqual(round(predictions[0][2]["Salary"]["value"]), 70)
self.assertEqual(round(predictions[1][2]["Salary"]["value"]), 80)


class TestDataFlow(TestOrchestrator):
@contextlib.asynccontextmanager
Expand Down

0 comments on commit 714d325

Please sign in to comment.