Skip to content

Commit

Permalink
Merge b9ff5ab into 54764d2
Browse files Browse the repository at this point in the history
  • Loading branch information
simonwoerpel committed Oct 7, 2023
2 parents 54764d2 + b9ff5ab commit 00077d3
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 43 deletions.
4 changes: 1 addition & 3 deletions investigraph/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def cli_run(
fragments_uri: Annotated[Optional[str], typer.Option(...)] = None,
entities_uri: Annotated[Optional[str], typer.Option(...)] = None,
aggregate: Annotated[Optional[bool], typer.Option(...)] = True,
chunk_size: Annotated[Optional[int], typer.Option(...)] = None,
chunk_size: Annotated[Optional[int], typer.Option(...)] = 1_000,
):
"""
Execute a dataset pipeline
Expand Down Expand Up @@ -130,8 +130,6 @@ def cli_catalog(
investigraph build-catalog catalog.yml -u s3://mybucket/catalog.json
"""
catalog = Catalog.from_path(path)
if uri != "-":
catalog.uri = uri
if flatten:
datasets = [d.dict() for d in catalog.get_datasets()]
data = {
Expand Down
23 changes: 13 additions & 10 deletions investigraph/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,27 @@
from nomenklatura.entity import CE
from rich import print

from investigraph.logic.extract import extract_pandas
from investigraph.model import Resolver, Source
from investigraph.model import Resolver
from investigraph.model.config import Config, get_config
from investigraph.model.context import init_context
from investigraph.model.context import Context, init_context
from investigraph.util import PathLike


def print_error(msg: str):
print(f"[bold red]ERROR[/bold red] {msg}")


def get_records(source: Source) -> list[dict[str, Any]]:
def get_records(ctx: Context) -> list[dict[str, Any]]:
records: list[dict[str, Any]] = []
print("Fetching `%s` ..." % source.uri)
res = Resolver(source=source)
for ix, rec in enumerate(extract_pandas(res)):
print("Extracting `%s` ..." % ctx.source.uri)
res = Resolver(source=ctx.source)
if res.source.is_http and ctx.config.extract.fetch:
res._resolve_http()
for rec in ctx.config.extract.handle(ctx, res):
records.append(rec)
if ix == 5:
if len(records) == 5:
return records
return records


def inspect_config(p: PathLike) -> Config:
Expand Down Expand Up @@ -57,7 +59,8 @@ def inspect_extract(config: Config) -> Generator[tuple[str, pd.DataFrame], None,
Preview fetched & extracted records in tabular format
"""
for source in config.extract.sources:
df = pd.DataFrame(get_records(source))
ctx = init_context(config, source)
df = pd.DataFrame(get_records(ctx))
yield source.name, df


Expand All @@ -68,7 +71,7 @@ def inspect_transform(config: Config) -> Generator[tuple[str, CE], None, None]:
for source in config.extract.sources:
ctx = init_context(config, source)
proxies: list[CE] = []
for ix, rec in enumerate(get_records(source)):
for ix, rec in enumerate(get_records(ctx)):
for proxy in ctx.config.transform.handle(ctx, rec, ix):
proxies.append(proxy)
yield source.name, proxies
5 changes: 1 addition & 4 deletions investigraph/logic/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
"""


from urllib.parse import urlencode

import requests
from prefect import flow, get_run_logger, task
from prefect.tasks import task_input_hash
Expand Down Expand Up @@ -37,7 +34,7 @@ def get_request_cache_key(*args, **kwargs) -> str:
def _get(url: str, *args, **kwargs):
log = get_run_logger()
kwargs.pop("ckey", None)
log.info(f"GET {url}?{urlencode(kwargs)}")
log.info(f"GET {url}")
res = requests.get(url, *args, **kwargs)
assert res.ok
return res
Expand Down
8 changes: 5 additions & 3 deletions investigraph/logic/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from investigraph.types import CEGenerator, SDict


def map_record(record: SDict, mapping: QueryMapping) -> CEGenerator:
def map_record(
record: SDict, mapping: QueryMapping, dataset: str | None = "default"
) -> CEGenerator:
mapping = mapping.get_mapping()
if mapping.source.check_filters(record):
entities = mapping.map(record)
for proxy in entities.values():
yield make_proxy(proxy.to_dict())
yield make_proxy(proxy.to_dict(), dataset=dataset)


def map_ftm(ctx: "Context", data: SDict, ix: int) -> CEGenerator:
for mapping in ctx.config.transform.queries:
yield from map_record(data, mapping)
yield from map_record(data, mapping, ctx.config.dataset.name)
9 changes: 7 additions & 2 deletions investigraph/model/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ftmq.util import join_slug
from nomenklatura.entity import CE
from prefect import get_run_logger
from prefect.logging.loggers import PrefectLogAdapter
from prefect.logging.loggers import MissingContextError, PrefectLogAdapter
from pydantic import BaseModel

from investigraph.cache import Cache, get_cache
Expand Down Expand Up @@ -35,7 +35,12 @@ def cache(self) -> Cache:

@property
def log(self) -> PrefectLogAdapter:
return get_run_logger()
try:
return get_run_logger()
except MissingContextError:
import logging

return logging.getLogger(__name__)

def load_fragments(self, *args, **kwargs) -> str:
kwargs["uri"] = kwargs.pop("uri", self.config.load.fragments_uri)
Expand Down
2 changes: 1 addition & 1 deletion investigraph/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def aggregate(ctx: Context, results: list[str], ckey: str) -> Coverage:
retry_delay_seconds=settings.TASK_RETRY_DELAY,
cache_key_fn=get_task_cache_key,
cache_expiration=settings.TASK_CACHE_EXPIRATION,
refresh_cache=not settings.TASK_CACHE,
refresh_cache=not settings.TASK_CACHE or not settings.LOAD_CACHE,
)
def load(ctx: Context, ckey: str) -> str:
proxies = ctx.cache.get(ckey)
Expand Down
3 changes: 2 additions & 1 deletion investigraph/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def get_env(env: str, default: Any | None = None) -> Any | None:
timedelta(TASK_CACHE_EXPIRATION) if TASK_CACHE_EXPIRATION is not None else None
)
FETCH_CACHE = as_bool(get_env("FETCH_CACHE"), TASK_CACHE)
TRANSFORM_CACHE = as_bool(get_env("TRANSFORM_CACHE"), TASK_CACHE)
EXTRACT_CACHE = as_bool(get_env("EXTRACT_CACHE"), TASK_CACHE)
TRANSFORM_CACHE = as_bool(get_env("TRANSFORM_CACHE"), TASK_CACHE)
LOAD_CACHE = as_bool(get_env("LOAD_CACHE"), TASK_CACHE)

TASK_RUNNER = get_env("PREFECT_TASK_RUNNER", "").lower()

Expand Down
Loading

0 comments on commit 00077d3

Please sign in to comment.