Skip to content

Commit

Permalink
update fake adls resource to work with leases (#7587)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed May 3, 2022
1 parent 671cea6 commit c55e3f4
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fake_adls2_resource import FakeADLS2Resource, FakeADLS2ServiceClient
from .fake_adls2_resource import FakeADLS2Resource, FakeADLS2ServiceClient, fake_adls2_resource
from .file_cache import ADLS2FileCache, adls2_file_cache
from .file_manager import ADLS2FileHandle, ADLS2FileManager
from .io_manager import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import io
import random
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict
from unittest import mock

from dagster_azure.blob import FakeBlobServiceClient

from dagster import resource

from .resources import ADLS2Resource
from .utils import ResourceNotFoundError


@resource({"account_name": str})
def fake_adls2_resource(context):
return FakeADLS2Resource(account_name=context.resource_config["account_name"])


class FakeADLS2Resource(ADLS2Resource):
"""Stateful mock of an ADLS2Resource for testing.
Expand All @@ -21,6 +27,31 @@ def __init__(
): # pylint: disable=unused-argument,super-init-not-called
self._adls2_client = FakeADLS2ServiceClient(account_name)
self._blob_client = FakeBlobServiceClient(account_name)
self._lease_client_constructor = FakeLeaseClient


class FakeLeaseClient:
def __init__(self, client):
self.client = client
self.id = None

# client needs a ref to self to check if a given lease is valid
self.client._lease = self

def acquire(self, lease_duration=-1): # pylint: disable=unused-argument
if self.id is None:
self.id = random.randint(0, 2**9)
else:
raise Exception("Lease already held")

def release(self):
self.id = None

def is_valid(self, lease):
if self.id is None:
# no lease is held so any operation is valid
return True
return lease == self.id


class FakeADLS2ServiceClient:
Expand Down Expand Up @@ -61,7 +92,7 @@ class FakeADLS2FilesystemClient:
"""Stateful mock of an ADLS2 filesystem client for testing."""

def __init__(self, account_name, file_system_name):
self._file_system = defaultdict(FakeADLS2FileClient)
self._file_system: Dict[str, FakeADLS2FileClient] = {}
self._account_name = account_name
self._file_system_name = file_system_name

Expand All @@ -83,9 +114,14 @@ def has_file(self, path):
return bool(self._file_system.get(path))

def get_file_client(self, file_path):
# pass fileclient a ref to self and its name so the file can delete itself
self._file_system.setdefault(file_path, FakeADLS2FileClient(self, file_path))
return self._file_system[file_path]

def create_file(self, file):
# pass fileclient a ref to self and the file's name so the file can delete itself by
# accessing the self._file_system dict
self._file_system.setdefault(file, FakeADLS2FileClient(fs_client=self, name=file))
return self._file_system[file]

def delete_file(self, file):
Expand All @@ -97,18 +133,25 @@ def delete_file(self, file):
class FakeADLS2FileClient:
"""Stateful mock of an ADLS2 file client for testing."""

def __init__(self):
def __init__(self, name, fs_client):
self.name = name
self.contents = None
self.lease = None
self._lease = None
self.fs_client = fs_client

@property
def lease(self):
return self._lease if self._lease is None else self._lease.id

def get_file_properties(self):
if self.contents is None:
raise ResourceNotFoundError("File does not exist!")
return {"lease": self.lease}
lease_id = None if self._lease is None else self._lease.id
return {"lease": lease_id}

def upload_data(self, contents, overwrite=False, lease=None):
if self.lease is not None:
if lease != self.lease:
if self._lease is not None:
if not self._lease.is_valid(lease):
raise Exception("Invalid lease!")
if self.contents is not None or overwrite is True:
if isinstance(contents, str):
Expand All @@ -122,22 +165,17 @@ def upload_data(self, contents, overwrite=False, lease=None):
else:
self.contents = contents

@contextmanager
def acquire_lease(self, lease_duration=-1): # pylint: disable=unused-argument
if self.lease is None:
self.lease = random.randint(0, 2**9)
try:
yield self.lease
finally:
self.lease = None
else:
raise Exception("Lease already held")

def download_file(self):
if self.contents is None:
raise ResourceNotFoundError("File does not exist!")
return FakeADLS2FileDownloader(contents=self.contents)

def delete_file(self, lease=None):
if self._lease is not None:
if not self._lease.is_valid(lease):
raise Exception("Invalid lease!")
self.fs_client.delete_file(self.name)


class FakeADLS2FileDownloader:
"""Mock of an ADLS2 file downloader for testing."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pickle
from contextlib import contextmanager

from azure.storage.filedatalake import DataLakeLeaseClient
from dagster_azure.adls2.utils import ResourceNotFoundError

from dagster import Field, IOManager, StringSource, check, io_manager
Expand All @@ -11,14 +10,17 @@


class PickledObjectADLS2IOManager(IOManager):
def __init__(self, file_system, adls2_client, blob_client, prefix="dagster"):
def __init__(
self, file_system, adls2_client, blob_client, lease_client_constructor, prefix="dagster"
):
self.adls2_client = adls2_client
self.file_system_client = self.adls2_client.get_file_system_client(file_system)
# We also need a blob client to handle copying as ADLS doesn't have a copy API yet
self.blob_client = blob_client
self.blob_container_client = self.blob_client.get_container_client(file_system)
self.prefix = check.str_param(prefix, "prefix")

self.lease_client_constructor = lease_client_constructor
self.lease_duration = _LEASE_DURATION
self.file_system_client.get_file_system_properties()

Expand Down Expand Up @@ -68,7 +70,7 @@ def _uri_for_key(self, key, protocol=None):

@contextmanager
def _acquire_lease(self, client, is_rm=False):
lease_client = DataLakeLeaseClient(client=client)
lease_client = self.lease_client_constructor(client=client)
try:
lease_client.acquire(lease_duration=self.lease_duration)
yield lease_client.id
Expand Down Expand Up @@ -140,10 +142,12 @@ def my_job():
adls_resource = init_context.resources.adls2
adls2_client = adls_resource.adls2_client
blob_client = adls_resource.blob_client
lease_client = adls_resource.lease_client_constructor
pickled_io_manager = PickledObjectADLS2IOManager(
init_context.resource_config["adls2_file_system"],
adls2_client,
blob_client,
lease_client,
init_context.resource_config.get("adls2_prefix"),
)
return pickled_io_manager
Expand Down Expand Up @@ -194,10 +198,12 @@ def adls2_pickle_asset_io_manager(init_context):
adls_resource = init_context.resources.adls2
adls2_client = adls_resource.adls2_client
blob_client = adls_resource.blob_client
lease_client = adls_resource.lease_client_constructor
pickled_io_manager = PickledObjectADLS2AssetIOManager(
init_context.resource_config["adls2_file_system"],
adls2_client,
blob_client,
lease_client,
init_context.resource_config.get("adls2_prefix"),
)
return pickled_io_manager
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from azure.storage.filedatalake import DataLakeLeaseClient
from dagster_azure.blob.utils import create_blob_client

from dagster import Field, Selector, StringSource, resource
Expand Down Expand Up @@ -101,6 +102,7 @@ class ADLS2Resource:
def __init__(self, storage_account, credential):
self._adls2_client = create_adls2_client(storage_account, credential)
self._blob_client = create_blob_client(storage_account, credential)
self._lease_client_constructor = DataLakeLeaseClient

@property
def adls2_client(self):
Expand All @@ -110,6 +112,10 @@ def adls2_client(self):
def blob_client(self):
return self._blob_client

@property
def lease_client_constructor(self):
return self._lease_client_constructor


def _adls2_resource_from_config(config):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from uuid import uuid4

import pytest
from azure.storage.filedatalake import DataLakeLeaseClient
from dagster_azure.adls2 import create_adls2_client
from dagster_azure.adls2.fake_adls2_resource import fake_adls2_resource
from dagster_azure.adls2.io_manager import (
PickledObjectADLS2IOManager,
adls2_pickle_asset_io_manager,
Expand Down Expand Up @@ -55,7 +57,7 @@ def get_step_output(step_events, step_key, output_name="result"):
return None


def define_inty_job():
def define_inty_job(adls_io_resource=adls2_resource):
@op(output_defs=[OutputDefinition(Int)])
def return_one():
return 1
Expand All @@ -73,7 +75,7 @@ def basic_external_plan_execution():
add_one(return_one())

return basic_external_plan_execution.to_job(
resource_defs={"io_manager": adls2_pickle_io_manager, "adls2": adls2_resource}
resource_defs={"io_manager": adls2_pickle_io_manager, "adls2": adls_io_resource}
)


Expand Down Expand Up @@ -124,6 +126,7 @@ def test_adls2_pickle_io_manager_execution(storage_account, file_system, credent
file_system=file_system,
adls2_client=create_adls2_client(storage_account, credential),
blob_client=create_blob_client(storage_account, credential),
lease_client_constructor=DataLakeLeaseClient,
)
assert io_manager.load_input(context) == 1

Expand Down Expand Up @@ -181,3 +184,17 @@ def downstream(upstream):

result = asset_job.execute_in_process(run_config=run_config)
assert result.success


def test_with_fake_adls2_resource():
job = define_inty_job(adls_io_resource=fake_adls2_resource)

run_config = {
"resources": {
"io_manager": {"config": {"adls2_file_system": "fake_file_system"}},
"adls2": {"config": {"account_name": "my_account"}},
}
}

result = job.execute_in_process(run_config=run_config)
assert result.success

0 comments on commit c55e3f4

Please sign in to comment.