Skip to content

Commit

Permalink
[structured config] Migrate resources from project-fully-featured to …
Browse files Browse the repository at this point in the history
…struct config resources (#11785)

## Summary

Migrates the experimental copy of `project-fully-featured` to use struct
config resources instead of traditional function-based resources.

## Test Plan

Existing unit tests.
  • Loading branch information
benpankow committed Feb 16, 2023
1 parent 8783d2f commit b592861
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 125 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from datetime import datetime, timezone
from typing import Any, Mapping, Tuple

from dagster import _check as check
from dagster import (
OpExecutionContext,
_check as check,
)

from project_fully_featured_v2_resources.resources.hn_resource import (
HNClient,
)


def binary_search_nearest_left(get_value, start, end, min_target):
Expand Down Expand Up @@ -50,7 +57,7 @@ def binary_search_nearest_right(get_value, start, end, max_target):
return end


def _id_range_for_time(start: int, end: int, hn_client):
def _id_range_for_time(start: int, end: int, hn_client: HNClient):
check.invariant(end >= start, "End time comes before start time")

def _get_item_timestamp(item_id):
Expand Down Expand Up @@ -83,9 +90,11 @@ def _get_item_timestamp(item_id):
return id_range, metadata


def id_range_for_time(context) -> Tuple[Tuple[int, int], Mapping[str, Any]]:
def id_range_for_time(
context: OpExecutionContext, hn_client: HNClient
) -> Tuple[Tuple[int, int], Mapping[str, Any]]:
"""
For the configured time partition, searches for the range of ids that were created in that time.
"""
start, end = context.asset_partitions_time_window_for_output()
return _id_range_for_time(start.timestamp(), end.timestamp(), context.resources.hn_client)
return _id_range_for_time(int(start.timestamp()), int(end.timestamp()), hn_client)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pyspark.sql.types import ArrayType, DoubleType, LongType, StringType, StructField, StructType

from project_fully_featured_v2_resources.partitions import hourly_partitions
from project_fully_featured_v2_resources.resources.hn_resource import (
HNClient,
)

from .id_range_for_time import id_range_for_time

Expand All @@ -30,19 +33,18 @@

@asset(
io_manager_key="parquet_io_manager",
required_resource_keys={"hn_client"},
partitions_def=hourly_partitions,
key_prefix=["s3", "core"],
)
def items(context) -> Output[DataFrame]:
def items(context, hn_client: HNClient) -> Output[DataFrame]:
"""Items from the Hacker News API: each is a story or a comment on a story."""
(start_id, end_id), item_range_metadata = id_range_for_time(context)
(start_id, end_id), item_range_metadata = id_range_for_time(context, hn_client)

context.log.info(f"Downloading range {start_id} up to {end_id}: {end_id - start_id} items.")

rows = []
for item_id in range(start_id, end_id):
rows.append(context.resources.hn_client.fetch_item_by_id(item_id))
rows.append(hn_client.fetch_item_by_id(item_id))
if len(rows) % 100 == 0:
context.log.info(f"Downloaded {len(rows)} items!")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
from dagster_dbt import dbt_cli_resource
from dagster_pyspark import pyspark_resource

from .common_bucket_s3_pickle_io_manager import common_bucket_s3_pickle_io_manager
from .duckdb_parquet_io_manager import duckdb_partitioned_parquet_io_manager
from .common_bucket_s3_pickle_io_manager import CommonBucketS3PickleIOManager
from .duckdb_parquet_io_manager import DuckDBPartitionedParquetIOManager
from .hn_resource import HNAPIClient, HNAPISubsampleClient
from .parquet_io_manager import (
local_partitioned_parquet_io_manager,
s3_partitioned_parquet_io_manager,
)
from .parquet_io_manager import LocalPartitionedParquetIOManager, S3PartitionedParquetIOManager
from .snowflake_io_manager import SnowflakeIOManager

DBT_PROJECT_DIR = file_relative_path(__file__, "../../dbt_project")
Expand Down Expand Up @@ -53,37 +50,36 @@
}

RESOURCES_PROD = {
"s3_bucket": "hackernews-elementl-prod",
"io_manager": common_bucket_s3_pickle_io_manager,
"s3": s3_resource,
"parquet_io_manager": s3_partitioned_parquet_io_manager,
"warehouse_io_manager": SnowflakeIOManager(dict(database="DEMO_DB", **SHARED_SNOWFLAKE_CONF)),
"pyspark": configured_pyspark,
"io_manager": CommonBucketS3PickleIOManager(
s3=s3_resource, s3_bucket="hackernews-elementl-prod"
),
"parquet_io_manager": S3PartitionedParquetIOManager(
pyspark=configured_pyspark, s3_bucket="hackernews-elementl-dev"
),
"warehouse_io_manager": SnowflakeIOManager(database="DEMO_DB", **SHARED_SNOWFLAKE_CONF),
"hn_client": HNAPISubsampleClient(subsample_rate=10),
"dbt": dbt_prod_resource,
}


RESOURCES_STAGING = {
"s3_bucket": "hackernews-elementl-dev",
"io_manager": common_bucket_s3_pickle_io_manager,
"s3": s3_resource,
"parquet_io_manager": s3_partitioned_parquet_io_manager,
"warehouse_io_manager": SnowflakeIOManager(
dict(database="DEMO_DB_STAGING", **SHARED_SNOWFLAKE_CONF)
"io_manager": CommonBucketS3PickleIOManager(
s3=s3_resource, s3_bucket="hackernews-elementl-dev"
),
"parquet_io_manager": S3PartitionedParquetIOManager(
pyspark=configured_pyspark, s3_bucket="hackernews-elementl-dev"
),
"pyspark": configured_pyspark,
"warehouse_io_manager": SnowflakeIOManager(database="DEMO_DB_STAGING", **SHARED_SNOWFLAKE_CONF),
"hn_client": HNAPISubsampleClient(subsample_rate=10),
"dbt": dbt_staging_resource,
}


RESOURCES_LOCAL = {
"parquet_io_manager": local_partitioned_parquet_io_manager,
"warehouse_io_manager": duckdb_partitioned_parquet_io_manager.configured(
{"duckdb_path": os.path.join(DBT_PROJECT_DIR, "hackernews.duckdb")},
"parquet_io_manager": LocalPartitionedParquetIOManager(pyspark=configured_pyspark),
"warehouse_io_manager": DuckDBPartitionedParquetIOManager(
pyspark=configured_pyspark, duckdb_path=os.path.join(DBT_PROJECT_DIR, "hackernews.duckdb")
),
"pyspark": configured_pyspark,
"hn_client": HNAPIClient(),
"dbt": dbt_local_resource,
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from dagster import build_init_resource_context, io_manager
from typing import Any

from dagster import build_init_resource_context
from dagster._config.structured_config import ConfigurableIOManagerInjector, ResourceDependency
from dagster._core.storage.io_manager import IOManager
from dagster_aws.s3 import s3_pickle_io_manager


@io_manager(required_resource_keys={"s3_bucket", "s3"})
def common_bucket_s3_pickle_io_manager(init_context):
class CommonBucketS3PickleIOManager(ConfigurableIOManagerInjector):
"""
A version of the s3_pickle_io_manager that gets its bucket from another resource.
"""
return s3_pickle_io_manager(
build_init_resource_context(
config={"s3_bucket": init_context.resources.s3_bucket},
resources={"s3": init_context.resources.s3},

s3_bucket: str
s3: ResourceDependency[Any]

def create_io_manager_to_inject(self, context) -> IOManager:
return s3_pickle_io_manager(
build_init_resource_context(
config={"s3_bucket": self.s3_bucket},
resources={"s3": self.s3},
)
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import duckdb
import pandas as pd
from dagster import (
Field,
PartitionKeyRange,
_check as check,
io_manager,
)
from dagster._seven.temp_dir import get_system_temp_directory

Expand All @@ -16,9 +14,12 @@
class DuckDBPartitionedParquetIOManager(PartitionedParquetIOManager):
"""Stores data in parquet files and creates duckdb views over those files."""

def __init__(self, base_path: str, duckdb_path: str):
super().__init__(base_path=base_path)
self._duckdb_path = check.str_param(duckdb_path, "duckdb_path")
duckdb_path: str
base_path: str = get_system_temp_directory()

@property
def _base_path(self):
return self.base_path

def handle_output(self, context, obj):
if obj is not None: # if this is a dbt output, then the value will be None
Expand Down Expand Up @@ -63,15 +64,4 @@ def _schema(self, context) -> str:
return f"{context.asset_key.path[-2]}"

def _connect_duckdb(self):
return duckdb.connect(database=self._duckdb_path, read_only=False)


@io_manager(
config_schema={"base_path": Field(str, is_required=False), "duckdb_path": str},
required_resource_keys={"pyspark"},
)
def duckdb_partitioned_parquet_io_manager(init_context):
return DuckDBPartitionedParquetIOManager(
base_path=init_context.resource_config.get("base_path", get_system_temp_directory()),
duckdb_path=init_context.resource_config["duckdb_path"],
)
return duckdb.connect(database=self.duckdb_path, read_only=False)
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from typing import Any, Dict, Optional

import requests
from dagster._config.structured_config import ConfigurableResource
from dagster._utils import file_relative_path
from dagster._utils.cached_method import cached_method

HNItemRecord = Dict[str, Any]

HN_BASE_URL = "https://hacker-news.firebaseio.com/v0"


class HNClient(ABC):
class HNClient(ConfigurableResource, ABC):
@abstractmethod
def fetch_item_by_id(self, item_id: int) -> Optional[HNItemRecord]:
pass
Expand Down Expand Up @@ -39,19 +41,20 @@ def min_item_id(self) -> int:


class HNSnapshotClient(HNClient):
def __init__(self):
@cached_method
def load_items(self) -> Dict[str, HNItemRecord]:
file_path = file_relative_path(__file__, "../utils/snapshot.gzip")
with gzip.open(file_path, "r") as f:
self._items: Dict[str, HNItemRecord] = json.loads(f.read().decode())
return json.loads(f.read().decode())

def fetch_item_by_id(self, item_id: int) -> Optional[HNItemRecord]:
return self._items.get(str(item_id))
return self.load_items().get(str(item_id))

def fetch_max_item_id(self) -> int:
return int(list(self._items.keys())[-1])
return int(list(self.load_items().keys())[-1])

def min_item_id(self) -> int:
return int(list(self._items.keys())[0])
return int(list(self.load_items().keys())[0])


class HNAPISubsampleClient(HNClient):
Expand All @@ -60,9 +63,8 @@ class HNAPISubsampleClient(HNClient):
which is useful for testing / demoing purposes.
"""

def __init__(self, subsample_rate):
self._items = {}
self.subsample_rate = subsample_rate
subsample_rate: int
_items: Dict[int, HNItemRecord] = {}

def fetch_item_by_id(self, item_id: int) -> Optional[HNItemRecord]:
# map self.subsample_rate items to the same item_id, caching it for faster performance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@
from typing import Union

import pandas
import pyspark
import pyspark.sql
from dagster import (
Field,
InputContext,
IOManager,
OutputContext,
_check as check,
io_manager,
)
from dagster._config.structured_config import ConfigurableIOManager, ResourceDependency
from dagster._seven.temp_dir import get_system_temp_directory
from dagster_pyspark.resources import PySparkResource
from pyspark.sql import DataFrame as PySparkDataFrame


class PartitionedParquetIOManager(IOManager):
class PartitionedParquetIOManager(ConfigurableIOManager):
"""
This IOManager will take in a pandas or pyspark dataframe and store it in parquet at the
specified path.
Expand All @@ -26,12 +24,13 @@ class PartitionedParquetIOManager(IOManager):
to where the data is stored.
"""

def __init__(self, base_path):
self._base_path = base_path
pyspark: ResourceDependency[PySparkResource]

def handle_output(
self, context: OutputContext, obj: Union[pandas.DataFrame, pyspark.sql.DataFrame]
):
@property
def _base_path(self):
raise NotImplementedError()

def handle_output(self, context: OutputContext, obj: Union[pandas.DataFrame, PySparkDataFrame]):
path = self._get_path(context)
if "://" not in self._base_path:
os.makedirs(os.path.dirname(path), exist_ok=True)
Expand All @@ -40,19 +39,19 @@ def handle_output(
row_count = len(obj)
context.log.info(f"Row count: {row_count}")
obj.to_parquet(path=path, index=False)
elif isinstance(obj, pyspark.sql.DataFrame):
elif isinstance(obj, PySparkDataFrame):
row_count = obj.count()
obj.write.parquet(path=path, mode="overwrite")
else:
raise Exception(f"Outputs of type {type(obj)} not supported.")

context.add_output_metadata({"row_count": row_count, "path": path})

def load_input(self, context) -> Union[pyspark.sql.DataFrame, str]:
def load_input(self, context) -> Union[PySparkDataFrame, str]:
path = self._get_path(context)
if context.dagster_type.typing_type == pyspark.sql.DataFrame:
if context.dagster_type.typing_type == PySparkDataFrame:
# return pyspark dataframe
return context.resources.pyspark.spark_session.read.parquet(path)
return self.pyspark.spark_session.read.parquet(path)

return check.failed(
f"Inputs of type {context.dagster_type} not supported. Please specify a valid type "
Expand All @@ -71,16 +70,17 @@ def _get_path(self, context: Union[InputContext, OutputContext]):
return os.path.join(self._base_path, f"{key}.pq")


@io_manager(
config_schema={"base_path": Field(str, is_required=False)},
required_resource_keys={"pyspark"},
)
def local_partitioned_parquet_io_manager(init_context):
return PartitionedParquetIOManager(
base_path=init_context.resource_config.get("base_path", get_system_temp_directory())
)
class LocalPartitionedParquetIOManager(PartitionedParquetIOManager):
base_path: str = get_system_temp_directory()

@property
def _base_path(self):
return self.base_path


class S3PartitionedParquetIOManager(PartitionedParquetIOManager):
s3_bucket: str

@io_manager(required_resource_keys={"pyspark", "s3_bucket"})
def s3_partitioned_parquet_io_manager(init_context):
return PartitionedParquetIOManager(base_path="s3://" + init_context.resources.s3_bucket)
@property
def _base_path(self):
return "s3://" + self.s3_bucket

0 comments on commit b592861

Please sign in to comment.