Skip to content

Commit

Permalink
Replace session scope fixtures with pytest_sessionstart and pytest_se…
Browse files Browse the repository at this point in the history
…ssionfinish methods
  • Loading branch information
laughingman7743 committed May 3, 2023
1 parent 3a53ac3 commit 6adaaeb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 28 deletions.
3 changes: 3 additions & 0 deletions tests/__init__.py
Expand Up @@ -21,6 +21,9 @@ def __init__(self):
self.schema = "pyathena_test_" + "".join(
[random.choice(string.ascii_lowercase + string.digits) for _ in range(10)]
)
self.s3_filesystem_test_file_key = (
f"{self.s3_staging_key}{self.schema}/S3FileSystem__test_read.dat"
)


ENV = Env()
38 changes: 21 additions & 17 deletions tests/conftest.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import contextlib
from io import BytesIO
from pathlib import Path
from urllib.parse import quote_plus

Expand All @@ -11,6 +12,21 @@
from tests.util import read_query


def pytest_sessionstart(session):
_upload_rows()
with contextlib.closing(connect()) as conn:
with conn.cursor() as cursor:
_create_database(cursor)
_create_table(cursor)


def pytest_sessionfinish(session):
with contextlib.closing(connect()) as conn:
with conn.cursor() as cursor:
_drop_database(cursor)
_delete_rows()


def connect(schema_name="default", **kwargs):
from pyathena import connect

Expand Down Expand Up @@ -51,29 +67,17 @@ def create_engine(**kwargs):
)


@pytest.fixture(scope="session", autouse=True)
def _setup_session(request):
request.addfinalizer(_teardown_session)
_upload_rows()
with contextlib.closing(connect()) as conn:
with conn.cursor() as cursor:
_create_database(cursor)
_create_table(cursor)


def _teardown_session():
with contextlib.closing(connect()) as conn:
with conn.cursor() as cursor:
_drop_database(cursor)
_delete_rows()


def _upload_rows():
client = boto3.client("s3")
rows = Path(__file__).parent.resolve() / "resources" / "rows"
for row in rows.iterdir():
key = f"{ENV.s3_staging_key}{ENV.schema}/{row.stem}/{row.name}"
client.upload_file(str(row), ENV.s3_staging_bucket, key)
client.upload_fileobj(
BytesIO(b"0123456789"),
ENV.s3_staging_bucket,
ENV.s3_filesystem_test_file_key,
)


def _delete_rows():
Expand Down
16 changes: 5 additions & 11 deletions tests/filesystem/test_s3.py
Expand Up @@ -12,8 +12,6 @@


class TestS3FileSystem:
s3_test_file_key = f"{ENV.s3_staging_key}{ENV.schema}/S3FileSystem__test_read.dat"

def test_parse_path(self):
actual = S3FileSystem.parse_path("s3://bucket")
assert actual[0] == "bucket"
Expand Down Expand Up @@ -114,12 +112,6 @@ def test_parse_path_invalid(self):

@pytest.fixture(scope="class")
def fs(self) -> Dict[str, S3FileSystem]:
client = boto3.client("s3")
client.upload_fileobj(
BytesIO(b"0123456789"),
ENV.s3_staging_bucket,
self.s3_test_file_key,
)
fs = {
"default": S3FileSystem(connect()),
"small_batches": S3FileSystem(connect(), default_block_size=3),
Expand All @@ -144,11 +136,11 @@ def fs(self) -> Dict[str, S3FileSystem]:
def test_read(self, fs, start, end, batch_mode, target_data):
# lowest level access: use _get_object
data = fs[batch_mode]._get_object(
ENV.s3_staging_bucket, self.s3_test_file_key, ranges=(start, end)
ENV.s3_staging_bucket, ENV.s3_filesystem_test_file_key, ranges=(start, end)
)
assert data == (start, target_data), data
with fs[batch_mode].open(
f"s3://{ENV.s3_staging_bucket}/{self.s3_test_file_key}", "rb"
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}", "rb"
) as file:
# mid-level access: use _fetch_range
data = file._fetch_range(start, end)
Expand All @@ -162,7 +154,9 @@ def test_compatibility_with_s3fs(self):
import pandas

df = pandas.read_csv(
f"s3://{ENV.s3_staging_bucket}/{self.s3_test_file_key}", header=None, names=["col"]
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}",
header=None,
names=["col"],
)
assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)]

Expand Down

0 comments on commit 6adaaeb

Please sign in to comment.