Skip to content

Commit

Permalink
high level: ml: Updated to ensure contexts can be kept open
Browse files Browse the repository at this point in the history
Fixes: #1112
  • Loading branch information
programmer290399 committed Oct 7, 2021
1 parent 475dabe commit 0db7777
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Renamed `Optimizer` to `Tuner`.
### Fixed
- Record object key properties are now always strings
- High level functions (`train()`, etc.) now work on existing open contexts

## [0.4.0] - 2021-02-18
### Added
Expand Down
14 changes: 11 additions & 3 deletions dffml/high_level/ml.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import contextlib
from typing import Union, Dict, Any

from ..model import Model
from ..record import Record
from ..feature import Feature, Features
from ..source.source import BaseSource
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
from ..feature import Feature, Features
from ..model import Model, ModelContext
from ..util.internal import records_to_sources
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext


async def train(model, *args: Union[BaseSource, Record, Dict[str, Any]]):
Expand Down Expand Up @@ -58,6 +58,8 @@ async def train(model, *args: Union[BaseSource, Record, Dict[str, Any]]):
if isinstance(model, Model):
model = await astack.enter_async_context(model)
mctx = await astack.enter_async_context(model())
elif isinstance(model, ModelContext):
mctx = model
# Run training
return await mctx.train(sctx)

Expand Down Expand Up @@ -144,10 +146,14 @@ async def score(
if isinstance(model, Model):
model = await astack.enter_async_context(model)
mctx = await astack.enter_async_context(model())
elif isinstance(model, ModelContext):
mctx = model
# Allow for keep models open
if isinstance(accuracy_scorer, AccuracyScorer):
accuracy_scorer = await astack.enter_async_context(accuracy_scorer)
actx = await astack.enter_async_context(accuracy_scorer())
elif isinstance(accuracy_scorer, AccuracyContext):
actx = accuracy_scorer
else:
# TODO Replace this with static type checking and maybe dynamic
# through something like pydantic. See issue #36
Expand Down Expand Up @@ -229,6 +235,8 @@ async def predict(
if isinstance(model, Model):
model = await astack.enter_async_context(model)
mctx = await astack.enter_async_context(model())
elif isinstance(model, ModelContext):
mctx = model
# Run predictions
async for record in mctx.predict(sctx):
yield record if keep_record else (
Expand Down

0 comments on commit 0db7777

Please sign in to comment.