# TEST s3

Below is an expanded pytest test suite that demonstrates both unit tests (mocked) and a real integration test against an actual S3 location. It also provides examples of negative tests to handle edge cases, like a non-existent bucket, prefix, or missing file.

Important:

Integration Tests require valid AWS/S3 credentials and network access to an S3 endpoint.
For local testing with MinIO or a custom S3 endpoint, you must configure your environment variables accordingly (e.g., AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_DEFAULT_REGION, S3_ENDPOINT_URL, etc.).
Make sure PySpark is configured to access your S3 endpoint (via spark._jsc.hadoopConfiguration().set(...) or the spark-submit CLI configs).
Sample test_s3_read.py with Pytest

In [None]:
import os
import pytest
from unittest.mock import MagicMock
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import StructType, StructField, StringType
from s3 import (
    read_from_s3_spark,
    _create_df_daterange,
    _parse_raw_json,
    _apply_filters,
    init_s3_spark
)

# ------------------------------------------------------------------------------
#                               Pytest Fixtures
# ------------------------------------------------------------------------------

@pytest.fixture(scope="session")
def spark_session():
    """
    A pytest fixture that provides a SparkSession for tests.
    Runs at 'session' scope so it’s created once per test run.
    """
    # Example local Spark, for demonstration:
    spark = SparkSession.builder \
        .master("local[1]") \
        .appName("PytestSpark") \
        .getOrCreate()

    # Optionally configure S3 credentials in Spark here if needed:
    # spark._jsc.hadoopConfiguration().set("fs.s3a.access.key", os.environ.get("AWS_ACCESS_KEY_ID"))
    # spark._jsc.hadoopConfiguration().set("fs.s3a.secret.key", os.environ.get("AWS_SECRET_ACCESS_KEY"))
    # spark._jsc.hadoopConfiguration().set("fs.s3a.endpoint", os.environ.get("S3_ENDPOINT_URL", "s3.amazonaws.com"))
    # spark._jsc.hadoopConfiguration().set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")

    yield spark
    spark.stop()

@pytest.fixture
def mock_s3_spark(spark_session):
    """
    A fixture that returns a mock or a real SparkSession configured for S3 usage.
    If you want to test local mocking, you can replace this with a MagicMock.
    Or if you want to test real S3 connectivity, use init_s3_spark with real creds.
    """
    # Example: If you want a real Spark with init_s3_spark:
    #
    # s3_spark = init_s3_spark(
    #     endpoint=os.environ["S3_ENDPOINT_URL"],
    #     s3_access_key=os.environ["AWS_ACCESS_KEY_ID"],
    #     s3_secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    #     spark_home="/path/to/spark/home",
    #     spark_remote=None,   # or "spark://..." depending on your setup
    #     ca_certs=None,
    #     session_name="PytestSession"
    # )
    #
    # return s3_spark
    #
    # For now, return the existing local session:
    return spark_session

# ------------------------------------------------------------------------------
#                           Unit Tests (Mocked)
# ------------------------------------------------------------------------------

def test_read_from_s3_spark_no_select_no_filters_no_parse(spark_session, mocker):
    """
    Test read_from_s3_spark with minimal arguments: 
    no 'select', no 'filters', and no 'parse_column'.
    """
    # Create a stub DataFrame
    schema = StructType([StructField("dummy_col", StringType(), True)])
    df_stub = spark_session.createDataFrame([("test_val",)], schema=schema)

    # Mock _create_df_daterange to return our df_stub
    mock_create_df = mocker.patch("s3._create_df_daterange", return_value=df_stub)

    # Call the function under test
    result_df = read_from_s3_spark(
        s3_spark=spark_session,
        bucket="fake-bucket",
        prefix="fake-prefix",
        start_date=datetime(2021, 1, 1),
        end_date=datetime(2021, 1, 2),
        select=None,
        filters=None,
        parse_column=None
    )

    # Assertions
    assert mock_create_df.called, "Expected _create_df_daterange to be called."
    assert result_df.schema == df_stub.schema, "Result schema should match the stub schema."
    assert result_df.count() == 1, "Result dataframe should have one row."
    row = result_df.collect()[0]
    assert row['dummy_col'] == "test_val", "Row value should match the stubbed value."

def test_read_from_s3_spark_with_select(spark_session, mocker):
    """
    Test read_from_s3_spark with a 'select' list of columns.
    """
    schema = StructType([
        StructField("dummy_col", StringType(), True),
        StructField("another_col", StringType(), True)
    ])
    df_stub = spark_session.createDataFrame(
        [("test_val", "another_val")],
        schema=schema
    )

    mock_create_df = mocker.patch("s3._create_df_daterange", return_value=df_stub)

    # We want only 'dummy_col'
    select_cols = ["dummy_col"]
    result_df = read_from_s3_spark(
        s3_spark=spark_session,
        bucket="fake-bucket",
        prefix="fake-prefix",
        start_date=datetime(2021, 1, 1),
        end_date=datetime(2021, 1, 2),
        select=select_cols,
        filters=None,
        parse_column=None
    )

    # Verify
    assert mock_create_df.called
    assert len(result_df.columns) == 1, "Result should have exactly 1 column from 'select'."
    assert "dummy_col" in result_df.columns, "Result must include 'dummy_col'."
    assert "another_col" not in result_df.columns, "'another_col' should be excluded by the select."

def test_read_from_s3_spark_with_parse_column(spark_session, mocker):
    """
    Test read_from_s3_spark when parse_column is provided.
    Ensures _parse_raw_json is called.
    """
    df_stub_initial = spark_session.createDataFrame(
        [("raw_json_data",)], 
        schema=StructType([StructField("raw_json", StringType(), True)])
    )
    df_stub_parsed = spark_session.createDataFrame(
        [("extracted_val",)], 
        schema=StructType([StructField("parsed_col", StringType(), True)])
    )

    mock_create_df = mocker.patch("s3._create_df_daterange", return_value=df_stub_initial)
    mock_parse_json = mocker.patch("s3._parse_raw_json", return_value=df_stub_parsed)

    result_df = read_from_s3_spark(
        s3_spark=spark_session,
        bucket="fake-bucket",
        prefix="fake-prefix",
        start_date=datetime(2021, 1, 1),
        end_date=datetime(2021, 1, 2),
        select=["parsed_col"],
        filters={},
        parse_column="raw_json"
    )

    # Check if _parse_raw_json was called with the right args
    mock_parse_json.assert_called_once()
    called_args, _ = mock_parse_json.call_args
    # called_args[0] = df, called_args[1] = select (list), called_args[2] = filters, called_args[3] = column_to_parse
    assert called_args[0].schema == df_stub_initial.schema, \
        "Should pass the original DF to _parse_raw_json."

    # The final DataFrame should come from the parsed stub
    assert result_df.schema == df_stub_parsed.schema
    assert result_df.collect()[0]['parsed_col'] == "extracted_val"

def test_read_from_s3_spark_with_filters(spark_session, mocker):
    """
    Test read_from_s3_spark with filters. Ensures _apply_filters is used.
    """
    schema = StructType([
        StructField("dummy_col", StringType(), True),
        StructField("filter_col", StringType(), True)
    ])
    df_stub = spark_session.createDataFrame(
        [
            ("keep_row", "match"),
            ("discard_row", "no_match")
        ],
        schema=schema
    )
    mock_create_df = mocker.patch("s3._create_df_daterange", return_value=df_stub)

    # We'll filter where filter_col == 'match'
    filters = {"filter_col": "match"}
    result_df = read_from_s3_spark(
        s3_spark=spark_session,
        bucket="fake-bucket",
        prefix="fake-prefix",
        start_date=datetime(2021, 1, 1),
        end_date=datetime(2021, 1, 2),
        select=None,
        filters=filters,
        parse_column=None
    )

    assert mock_create_df.called
    # Should only keep the row matching filter_col == "match"
    assert result_df.count() == 1, "Only one row should remain after filter."
    row = result_df.collect()[0]
    assert row["dummy_col"] == "keep_row", "The row with 'discard_row' should be filtered out."

# ------------------------------------------------------------------------------
#                Direct Tests for Helper Methods (Optional)
# ------------------------------------------------------------------------------

def test__create_df_daterange_single_date(spark_session, mocker):
    from s3 import _create_df_daterange
    mocker.patch("s3.get_text_before_date", return_value="date_label=")

    mock_s3_spark = MagicMock()
    df_stub = spark_session.createDataFrame([("test",)], ["col1"])
    mock_s3_spark.read.parquet.return_value = df_stub

    single_date = [datetime(2021, 1, 1)]
    result_df = _create_df_daterange(
        s3_spark=mock_s3_spark,
        daterange=single_date,
        bucket="fake-bucket",
        prefix="fake-prefix"
    )
    assert result_df is not None
    mock_s3_spark.read.parquet.assert_called_once_with(
        "s3a://fake-bucket/fake-prefixdate_label=2021-01-01/"
    )

def test__create_df_daterange_range_of_dates(spark_session, mocker):
    from s3 import _create_df_daterange
    mocker.patch("s3.get_text_before_date", return_value="date_label=")

    mock_s3_spark = MagicMock()
    df_stub = spark_session.createDataFrame([("test",)], ["col1"])
    mock_s3_spark.read.parquet.return_value = df_stub

    date_range = [datetime(2021, 1, 1), datetime(2021, 1, 3)]
    result_df = _create_df_daterange(
        s3_spark=mock_s3_spark,
        daterange=date_range,
        bucket="fake-bucket",
        prefix="fake-prefix"
    )
    assert result_df is not None

    calls = mock_s3_spark.read.parquet.call_args_list
    called_paths = [str(c[0][0]) for c in calls]
    assert "s3a://fake-bucket/fake-prefixdate_label=2021-01-01/" in called_paths
    assert "s3a://fake-bucket/fake-prefixdate_label=2021-01-02/" in called_paths
    assert "s3a://fake-bucket/fake-prefixdate_label=2021-01-03/" in called_paths

# ------------------------------------------------------------------------------
#                   Integration Tests (Actual S3 Access)
# ------------------------------------------------------------------------------

@pytest.mark.integration
def test_read_from_s3_spark_integration(mock_s3_spark):
    """
    Example integration test that tries reading actual data from S3.
    Requires valid AWS creds & real S3 data. 
    This test is marked as 'integration' so you can skip it in normal runs:
      `pytest -m "not integration"`
    """
    # Adjust these to point to a real S3 bucket/prefix that you have permission to read.
    bucket = os.environ.get("TEST_S3_BUCKET", "my-real-bucket")
    prefix = os.environ.get("TEST_S3_PREFIX", "my-data/")
    start_date = datetime(2021, 1, 1)
    end_date = datetime(2021, 1, 1)

    df = read_from_s3_spark(
        s3_spark=mock_s3_spark,  # This is either a real spark session configured for S3
        bucket=bucket,
        prefix=prefix,
        start_date=start_date,
        end_date=end_date,
        select=None,
        filters=None,
        parse_column=None
    )

    # If data exists for that date, we should be able to read some rows.
    # Adjust logic depending on your real data. 
    # For example, you might expect non-empty data, or at least no error.
    # If you expect real data, you can assert df.count() > 0
    # If the prefix often might be empty, you can just ensure no exception was raised:
    assert df is not None, "Should successfully return a DataFrame from S3"
    # Optional: check columns or row counts
    # assert df.count() > 0

@pytest.mark.integration
def test_read_from_s3_spark_nonexistent_bucket(mock_s3_spark):
    """
    Negative test: reading from a bucket that does not exist 
    should raise an exception (AnalysisException or similar).
    """
    with pytest.raises((AnalysisException, Exception)) as excinfo:
        read_from_s3_spark(
            s3_spark=mock_s3_spark,
            bucket="bucket-that-likely-does-not-exist-1234",
            prefix="some-prefix/",
            start_date=datetime(2021, 1, 1),
            end_date=datetime(2021, 1, 1)
        )
    # Some text in the exception message might indicate the bucket doesn't exist
    # Adjust according to your environment (Spark might produce different messages).
    assert "does not exist" in str(excinfo.value).lower() or \
           "statuscode=404" in str(excinfo.value).lower() or \
           "not found" in str(excinfo.value).lower()

@pytest.mark.integration
def test_read_from_s3_spark_nonexistent_prefix(mock_s3_spark):
    """
    Negative test: reading from an empty/non-existent prefix 
    in an existing bucket should raise an error or return empty. 
    This may depend on your Spark/Hadoop version.
    """
    bucket = os.environ.get("TEST_S3_BUCKET", "my-real-bucket")
    prefix = "totally-made-up/prefix/"

    # If the prefix definitely doesn't exist, Spark often raises an AnalysisException.
    # Some versions might allow an empty DataFrame to return if the directory is just empty. 
    # Adjust your logic accordingly. 
    with pytest.raises((AnalysisException, Exception)) as excinfo:
        read_from_s3_spark(
            s3_spark=mock_s3_spark,
            bucket=bucket,
            prefix=prefix,
            start_date=datetime(2021, 1, 1),
            end_date=datetime(2021, 1, 1),
        )
    # Example checks
    err_str = str(excinfo.value).lower()
    assert "path does not exist" in err_str or "not found" in err_str or "statuscode=404" in err_str


Explanation and Key Points
Unit Tests (Mocked)

These tests (test_read_from_s3_spark_no_select_no_filters_no_parse, test_read_from_s3_spark_with_select, etc.) mock out the calls to S3 and only validate the internal logic of read_from_s3_spark.
They prevent real AWS calls, which keeps tests fast and consistent.
Integration Tests (Real S3)

Marked with @pytest.mark.integration. By default, you could run them with pytest -m integration or skip them if you run pytest -m "not integration".
They use the mock_s3_spark fixture, which could be a real SparkSession configured with your actual S3 endpoint and credentials.
You set environment variables like TEST_S3_BUCKET and TEST_S3_PREFIX to control which real bucket/prefix you want to test.
The test reads real data from S3. If the data is missing, it might raise an exception or produce an empty DataFrame, depending on your environment and Spark version.
Edge Cases / Negative Tests

test_read_from_s3_spark_nonexistent_bucket() tries to read from a clearly bogus bucket, expecting an exception.
test_read_from_s3_spark_nonexistent_prefix() tries a valid bucket with a non-existent folder/prefix, often resulting in an error (e.g., AnalysisException) or an empty DataFrame. This can vary by Spark version; adjust the test to your environment.
Running Tests

Install dependencies: pytest, pytest-mock, pyspark, etc.
Set environment variables for AWS credentials (e.g., AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) and the S3 endpoint if you’re not using AWS.
Run tests:
bash
Copy
Edit
pytest test_s3_read.py
or
bash
Copy
Edit
pytest -m integration  # Only integration tests
pytest -m "not integration"  # Only unit tests
This layout gives you both confidence that your code behaves as expected (unit tests) and verifies that the pipeline works with an actual S3 environment (integration tests). Adjust bucket names, prefixes, and environment variable usage to fit your setup.

# Test db

Below is an example of how you might organize unit tests and integration tests for db.py using pytest. This example assumes:

You have the DB class in db.py (as you showed) that can connect to either Elasticsearch (es) or S3 (s3) or read from a file.
You have a working read_from_s3_spark function in s3.py.
You want to test both how DB interacts with read_from_s3_spark (mocked unit tests) and run an integration test that actually hits a real S3 environment.
Feel free to mix these tests into the same file (e.g., test_db.py) or break them into multiple files. The example below uses a single file for clarity.

Example: test_db.py

In [None]:
import os
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime

# Import your DB class and the read_from_s3_spark function
from db import DB
from s3 import read_from_s3_spark

##########################
#      Pytest Fixtures
##########################

@pytest.fixture
def s3_config():
    """
    Returns a sample config dict pointing to S3.
    Adjust 'endpoint', 'bucket', etc., to fit your environment or tests.
    """
    return {
        "loc": "s3",
        "s3": {
            "endpoint": "fake-endpoint",
            "bucket": "fake-bucket",
            "prefix": "some-prefix/",
            "filters": {"key": "value"},
        },
    }

@pytest.fixture
def es_config():
    """
    Returns a sample config dict pointing to an ES instance.
    """
    return {
        "loc": "es",
        "es": {
            "endpoint": "fake-es-host",
            "index": "my-index",
            "filters": {"es_key": "es_val"},
        },
    }

@pytest.fixture
def file_config():
    """
    Returns a sample config dict pointing to a local file.
    """
    return {
        "loc": "file",
        "file": {
            "path": "/tmp/fake_file.txt",
        },
    }

@pytest.fixture
def infra_config():
    """
    Returns a sample infra_config that DB looks at for environment details.
    Typically you have keys like 'prod_s3', 'prod_es', 'prod_spark', etc.
    Adjust or remove as needed.
    """
    return {
        "prod_es": ["fake-es-host"],     # if host is in here, it’s considered "prod"
        "prod_s3": ["fake-endpoint"],    # likewise for S3
        "prod_spark": ["spark://fake-spark-cluster:7077"],
    }

##########################
#        Unit Tests
##########################

def test_db_extract_config_s3(s3_config, infra_config):
    """
    Test DB._extract_config() logic for S3 type.
    """
    db_obj = DB(config=s3_config, infra_config=infra_config)
    assert db_obj.db_type == "s3"
    assert db_obj.db_config == s3_config["s3"]

def test_db_extract_config_es(es_config, infra_config):
    """
    Test DB._extract_config() logic for ES type.
    """
    db_obj = DB(config=es_config, infra_config=infra_config)
    assert db_obj.db_type == "es"
    assert db_obj.db_config == es_config["es"]

def test_db_extract_config_file(file_config, infra_config):
    """
    Test DB._extract_config() logic for file type.
    """
    db_obj = DB(config=file_config, infra_config=infra_config)
    assert db_obj.db_type == "file"
    assert db_obj.db_config == file_config["file"]

def test_db_connect_s3_valid(s3_config, infra_config, mocker):
    """
    Test DB.connect() for S3 type with valid config.
    Mocks init_s3_spark so we don't create a real Spark session.
    """
    # Arrange
    mock_init_s3 = mocker.patch("db.init_s3_spark", return_value="mocked_spark_session")

    # Act
    db_obj = DB(config=s3_config, infra_config=infra_config)
    s3_session, err = db_obj.connect()

    # Assert
    mock_init_s3.assert_called_once_with(
        "fake-endpoint",                # from config
        os.getenv("S3_ACCESS_KEY"),     # from environment
        os.getenv("S3_SECRET_KEY"),     # from environment
        mocker.ANY,                     # spark_home
        "spark://fake-spark-cluster:7077",  # from infra_config['prod_spark']
        os.getenv("CA_CERTS"),          # from environment
    )
    assert s3_session == "mocked_spark_session"
    assert err is None

def test_db_connect_s3_invalid_config(s3_config, infra_config):
    """
    Test DB.connect() for S3 type when config is missing an endpoint or bucket.
    The code returns (None, "INVALID_S3_CONFIG") in that case.
    """
    # Break the config to remove 'bucket'
    del s3_config["s3"]["bucket"]

    db_obj = DB(config=s3_config, infra_config=infra_config)
    s3_session, err = db_obj.connect()

    assert s3_session is None
    assert err == "INVALID_S3_CONFIG"

def test_db_connect_es_valid(es_config, infra_config, mocker):
    """
    Test DB.connect() for ES with mocking of init_es.
    """
    mock_init_es = mocker.patch("db.init_es", return_value=("mocked_es_session", None))

    db_obj = DB(config=es_config, infra_config=infra_config)
    es_session, err = db_obj.connect()

    assert es_session == "mocked_es_session"
    assert err is None
    # The code calls db_obj._get_infra_info(host) and eventually init_es
    mock_init_es.assert_called_once()

def test_db_read_data_s3_unit(s3_config, infra_config, mocker):
    """
    Unit test of DB.read_data() for S3, mocking read_from_s3_spark.
    """
    db_obj = DB(config=s3_config, infra_config=infra_config)
    db_obj.s3 = "fake_spark_session"  # pretend we already connected

    mock_read_s3 = mocker.patch("db.read_from_s3_spark", return_value="some_dataframe")

    start_date = datetime(2021, 1, 1)
    end_date = datetime(2021, 1, 2)

    result = db_obj.read_data(start_date, end_date, parse_column="some_col")

    # read_from_s3_spark should be called once with the matching arguments:
    mock_read_s3.assert_called_once_with(
        "fake_spark_session",
        "fake-bucket",
        "some-prefix/",
        start_date,
        end_date,
        select=["*"],
        filters={"key": "value"},
        parse_column="some_col",
    )
    assert result == "some_dataframe"

def test_db_disconnect_s3(s3_config, infra_config, mocker):
    """
    Test DB.disconnect() for S3. Ensure it calls s3.stop() 
    if self.db_type == 's3'.
    """
    db_obj = DB(config=s3_config, infra_config=infra_config)
    db_obj.s3 = MagicMock()
    db_obj.db_type = "s3"

    db_obj.disconnect()
    db_obj.s3.stop.assert_called_once()

##########################
#  Integration Test (S3)
##########################

@pytest.mark.integration
def test_db_read_data_s3_integration(s3_config, infra_config):
    """
    Integration test that tries to connect to a real S3 + Spark 
    and read actual data. Requires valid AWS credentials and 
    a real S3 endpoint if 'fake-endpoint' is replaced with a real one.

    Example usage:
      pytest -m integration
    
    Adjust environment variables or s3_config so that:
     - s3_config['s3']['endpoint'] is a real endpoint 
     - s3_config['s3']['bucket'] is a real bucket 
     - s3_config['s3']['prefix'] is valid or might contain data 
     - infra_config['prod_s3'] includes that real endpoint
     - infra_config['prod_spark'] is a real Spark master 
     - AWS creds are set in environment
    """
    # Overwrite s3_config with real data for integration
    s3_config["s3"]["endpoint"] = os.environ.get("REAL_S3_ENDPOINT", "your-s3-endpoint.amazonaws.com")
    s3_config["s3"]["bucket"]   = os.environ.get("REAL_S3_BUCKET", "your-bucket")
    s3_config["s3"]["prefix"]   = os.environ.get("REAL_S3_PREFIX", "some/path/")
    
    # Mark it as "prod" so the logic picks up environment-based credentials:
    infra_config["prod_s3"] = [s3_config["s3"]["endpoint"]]

    # Create the DB object
    db_obj = DB(config=s3_config, infra_config=infra_config)

    # Connect
    s3_session, err = db_obj.connect()
    assert s3_session is not None, f"Expected a valid S3 Spark session, got err={err}"

    # Attempt to read data
    start_date = datetime(2023, 1, 1)
    end_date = datetime(2023, 1, 1)

    df = db_obj.read_data(
        start_date=start_date, 
        end_date=end_date, 
        fields=["*"], 
        parse_column=None
    )

    # Check if the DataFrame is valid
    # Depending on your data, you might assert df.count() > 0, or 
    # just ensure it doesn't fail.
    assert df is not None, "Expected a Spark DataFrame from S3."
    # e.g., maybe check schema or rowcount:
    # print(df.schema)
    # print(df.show())

    # Cleanup
    db_obj.disconnect()


Explanation
Pytest Fixtures

s3_config, es_config, file_config: Provide test configs for different DB types.
infra_config: Mock infrastructure info that the DB class uses to determine if an endpoint is “prod” or not.
Unit Tests

test_db_extract_config_*: Verifies _extract_config() sets db_type and db_config properly based on config.
test_db_connect_s3_valid: Mocks init_s3_spark so we don’t spin up a real Spark session. Checks that the DB.connect() method sets self.s3 and returns (s3_session, None).
test_db_connect_s3_invalid_config: If endpoint or bucket is missing, code returns (None, "INVALID_S3_CONFIG").
test_db_connect_es_valid: Mocks init_es (for Elasticsearch) so it doesn’t connect to a real ES cluster.
test_db_read_data_s3_unit: Mocks read_from_s3_spark to ensure the DB.read_data() calls it with correct arguments.
test_db_disconnect_s3: Ensures db_obj.s3.stop() is called when disconnect() is called on an S3-based DB.
Integration Test

test_db_read_data_s3_integration:
Marked with @pytest.mark.integration so you can selectively run it with pytest -m integration.
Adjust the config to real S3 credentials and endpoints at runtime.
Calls connect() on DB, ensures no error, and tries read_data().
Optionally checks row count, schema, or whether the call completed without error.
Finally calls disconnect().
Running the Tests

Make sure you have pytest (and optionally pytest-mock) installed:
bash
Copy
Edit
pip install pytest pytest-mock
Place this file (test_db.py) in your test directory or next to db.py.
Run the tests:
All tests:
bash
Copy
Edit
pytest test_db.py
Just integration:
bash
Copy
Edit
pytest -m integration
All except integration:
bash
Copy
Edit
pytest -m "not integration"
Environment Variables

For real integration, you’ll need AWS credentials for S3 in the environment (e.g., AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) or whichever environment variables your code references (S3_ACCESS_KEY, S3_SECRET_KEY).
You can override them with your own environment variables if needed.
With this structure, you’ll have comprehensive coverage:

Unit tests verify your DB class logic and how it calls read_from_s3_spark.
Integration tests confirm it actually works with live S3.
Feel free to refine or expand these tests based on your specific environment, naming conventions, or additional requirements.

# TEST metrics


Below is a Pytest-based unit test module for MetricsSpark.py that follows the same parametrized style and structure as the code snippet you provided. It creates synthetic data, runs the vdr, tdr, and tfpr methods, and compares the results against manually computed expectations. You can place this in a file like test_metrics_spark.py alongside your MetricsSpark.py.

In [None]:
import pytest
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql import Row
from src.analysis.metrics_spark import MetricsSpark

#
# --------------- Pytest Fixture for SparkSession ---------------
#

@pytest.fixture(scope="module")
def spark_session():
    """
    Creates a local Spark session once per test session.
    """
    return SparkSession.builder.appName("Test").master("local[*]").getOrCreate()

#
# --------------- Helpers for Generating & Computing ---------------
#

def _gen_random_values(val_range=10, n=1000, offset=0, **kwargs):
    """
    Generates random score values, binary labels, and amounts for testing.
    """
    np.random.seed(42)
    values = np.random.uniform(-val_range, val_range, n) + offset
    labels = np.random.choice([0, 1], size=n)
    amt_values = np.random.uniform(1, 100, n)  # amounts between 1 and 100
    return {"values": values, "labels": labels, "amt_values": amt_values}

def _cal_vdr(df, field, threshold, amt_field):
    """
    Manually compute "Value Detection Rate":
      sum of amounts captured by the model / total fraudulent amount
    """
    tp_value = 0
    total_value = 0

    for row in df.collect():
        amount = row[amt_field]
        label = row[field + "_label"]
        score = row[field]

        if label == 1:           # ground-truth is fraud
            total_value += amount
            if score > threshold:
                tp_value += amount

    return tp_value / total_value if total_value != 0 else None

def _cal_tdr(df, field, threshold):
    """
    Manually compute "True Detection Rate":
      #fraud captured / #total fraud
    """
    tp_count = 0
    total_count = 0

    for row in df.collect():
        label = row[field + "_label"]
        score = row[field]

        if label == 1:
            total_count += 1
            if score > threshold:
                tp_count += 1

    return tp_count / total_count if total_count != 0 else None

def _cal_tfpr(df, field, threshold):
    """
    Manually compute ratio of false positives to true positives:
      fp / tp
    """
    tp_count = 0
    fp_count = 0

    for row in df.collect():
        label = row[field + "_label"]
        score = row[field]

        if score > threshold:
            if label == 1:
                tp_count += 1
            else:
                fp_count += 1

    return fp_count / tp_count if tp_count != 0 else None

def _gen_case(case_name, metric, func, args, spark_session):
    """
    Creates a Spark DataFrame with 'score', 'score_label', and 'amount'
    columns, then computes an expected result by calling the manual
    calculators above. Returns a dictionary of all relevant info.
    """
    # 1) Generate or retrieve synthetic inputs
    inputs = func(**args)
    values = inputs["values"]
    labels = inputs["labels"]
    amt_values = inputs["amt_values"]

    # 2) Build a Spark DF
    df = spark_session.createDataFrame(
        [
            Row(score=val, score_label=label, amount=amt)
            for val, label, amt in zip(values, labels, amt_values)
        ]
    )

    field = "score"
    threshold = args.get("threshold", 0.5)
    amt_field = "amount"

    # 3) Compute the 'expected' result manually
    if metric == "vdr":
        exp_res = _cal_vdr(df, field, threshold, amt_field)
    elif metric == "tdr":
        exp_res = _cal_tdr(df, field, threshold)
    elif metric == "tfpr":
        exp_res = _cal_tfpr(df, field, threshold)
    else:
        exp_res = None

    return {
        "case": case_name,
        "inputs": {
            "df": df,
            "field": field,
            "threshold": threshold,
            "amt_field": amt_field,
        },
        "expected": {"res": exp_res, "tag": None},
    }

def _gen_data(metric, spark_session):
    """
    Creates a list of test cases (dictionaries) for a given metric.
    Each entry includes the DataFrame inputs, the expected result, etc.
    """
    data = [
        _gen_case("all positive values", metric, _gen_random_values, {}, spark_session),
        _gen_case(
            "no fraud amount",
            metric,
            _gen_random_values,
            {"val_range": -10},
            spark_session,
        ),
        _gen_case(
            "mixed values",
            metric,
            _gen_random_values,
            {"offset": 0.5},
            spark_session,
        ),
        _gen_case(
            "large values",
            metric,
            _gen_random_values,
            {"val_range": 1e7, "offset": 0.5},
            spark_session,
        ),
        _gen_case(
            "small values",
            metric,
            _gen_random_values,
            {"val_range": 1e-9, "offset": 0.5},
            spark_session,
        ),
        _gen_case(
            "empty data",
            metric,
            lambda **kw: {"values": [], "labels": [], "amt_values": []},
            {},
            spark_session,
        ),
        _gen_case(
            "single class",
            metric,
            lambda **kw: {
                "values": [1, 1, 1, 1],
                "labels": [1, 1, 1, 1],
                "amt_values": [10, 10, 10, 10],
            },
            {},
            spark_session,
        ),
        _gen_case(
            "all zero values",
            metric,
            lambda **kw: {
                "values": [0, 0, 0, 0],
                "labels": [0, 0, 0, 0],
                "amt_values": [0, 0, 0, 0],
            },
            {},
            spark_session,
        ),
        _gen_case(
            "all one values",
            metric,
            lambda **kw: {
                "values": [1, 1, 1, 1],
                "labels": [1, 1, 1, 1],
                "amt_values": [1, 1, 1, 1],
            },
            {},
            spark_session,
        ),
        _gen_case(
            "high threshold",
            metric,
            _gen_random_values,
            {"threshold": 100},
            spark_session,
        ),
        _gen_case(
            "low threshold",
            metric,
            _gen_random_values,
            {"threshold": -100},
            spark_session,
        ),
        _gen_case(
            "nan values",
            metric,
            lambda **kw: {
                "values": [1, 2, np.nan],
                "labels": [1, 0, 1],
                "amt_values": [10, 20, 30],
            },
            {},
            spark_session,
        ),
    ]
    return data

#
# --------------- Parametrized Tests for Each Metric ---------------
#

@pytest.fixture()
def gen_data(request):
    return request.param

@pytest.mark.parametrize(
    "gen_data",
    argvalues=_gen_data("vdr", spark_session()),
    indirect=True,
    ids=[t["case"] for t in _gen_data("vdr", spark_session())],
)
def test_vdr(gen_data, spark_session):
    """
    Test the vdr() method of MetricsSpark with various synthetic data sets.
    """
    df = gen_data["inputs"]["df"]
    field = gen_data["inputs"]["field"]
    threshold = gen_data["inputs"]["threshold"]
    amt_field = gen_data["inputs"]["amt_field"]

    res, tag = MetricsSpark.vdr(df, field, threshold, amt_field)

    # Compare to our pre-computed expected result
    assert res == gen_data["expected"]["res"], f"Case: {gen_data['case']}"
    if "tag" in gen_data["expected"]:
        assert tag == gen_data["expected"]["tag"]

@pytest.mark.parametrize(
    "gen_data",
    argvalues=_gen_data("tdr", spark_session()),
    indirect=True,
    ids=[t["case"] for t in _gen_data("tdr", spark_session())],
)
def test_tdr(gen_data, spark_session):
    """
    Test the tdr() method of MetricsSpark.
    """
    df = gen_data["inputs"]["df"]
    field = gen_data["inputs"]["field"]
    threshold = gen_data["inputs"]["threshold"]

    res, tag = MetricsSpark.tdr(df, field, threshold)

    assert res == gen_data["expected"]["res"], f"Case: {gen_data['case']}"
    if "tag" in gen_data["expected"]:
        assert tag == gen_data["expected"]["tag"]

@pytest.mark.parametrize(
    "gen_data",
    argvalues=_gen_data("tfpr", spark_session()),
    indirect=True,
    ids=[t["case"] for t in _gen_data("tfpr", spark_session())],
)
def test_tfpr(gen_data, spark_session):
    """
    Test the tfpr() method of MetricsSpark.
    """
    df = gen_data["inputs"]["df"]
    field = gen_data["inputs"]["field"]
    threshold = gen_data["inputs"]["threshold"]

    res, tag = MetricsSpark.tfpr(df, field, threshold)

    assert res == gen_data["expected"]["res"], f"Case: {gen_data['case']}"
    if "tag" in gen_data["expected"]:
        assert tag == gen_data["expected"]["tag"]


How This Works
spark_session Fixture

Creates a local Spark session once per module.
Synthetic Data Functions

_gen_random_values(): produces random scores, binary labels, and amounts.
_cal_*() functions: manually compute the expected metric (vdr, tdr, tfpr) by iterating over the DataFrame.
_gen_case()

Uses the above functions to build a small DataFrame of (score, score_label, amount).
Determines what the correct "expected" result should be by calling the manual _cal_*() counterpart.
_gen_data()

Constructs a list of distinct test cases (e.g., “large values,” “no fraud amount,” “nan values,” etc.).
Each entry is a dictionary containing the test DataFrame, threshold, etc., and the expected metric result.
Parametrized Tests

For each metric (vdr, tdr, tfpr), we parametrize over the entire list of test cases returned by _gen_data().
test_vdr, test_tdr, test_tfpr each call the corresponding MetricsSpark method, then assert that the returned (res, tag) matches the pre-computed expected["res"] (and expected["tag"] if present).
Running the Tests

Make sure pytest is installed:
bash
Copy
Edit
pip install pytest
From your project directory (where test_metrics_spark.py is located), run:
bash
Copy
Edit
pytest test_metrics_spark.py
or simply
bash
Copy
Edit
pytest
if the file follows standard test discovery rules.
This setup thoroughly checks each metric function (vdr, tdr, tfpr) against a variety of data scenarios. It uses a consistent approach to fixture-based testing, as shown in your example.

In [None]:
# TEST Gound Truth

Below is an example pytest test suite for the new SPM flow that covers both non–ground-truth metrics and ground-truth metrics. The tests demonstrate how you might:

Mock out calls to DB, Analysis, etc.
Provide dummy data for both ground-truth and non–ground-truth use cases.
Follow a similar pattern to the metrics_spark.py tests (where we generate or mock data) and the db.py/s3_db tests (where we often mock out the underlying database calls).
Note: This is a sample test showing the structure and approach. You will likely need to adapt it to your actual codebase – for instance, how your SPM and Analysis classes are imported, or where your test configs live. Also, since SPM depends on reading config files, some portions are mocked in the examples.

In [None]:
import os
import pytest
import time
from unittest.mock import MagicMock, patch
from datetime import datetime

# Example imports from your codebase:
# from my_project.spm.spm_flow import SPM  # Adjust to your actual import path
# from my_project.database.db import DB
# from my_project.analysis.analysis_new import Analysis

#
# --------------- Fixtures & Helpers ---------------
#

@pytest.fixture
def spm_test_dir(tmp_path):
    """
    A pytest fixture that creates a temporary directory for test task files.
    We can place mock config files there if needed.
    """
    return tmp_path

@pytest.fixture
def mock_s3_db(mocker):
    """
    Example fixture that returns a mock DB or s3 DB object.
    We patch DB so that calls to connect / read_data are replaced by mocks.
    """
    mock_connect = mocker.patch("my_project.database.db.DB.connect", return_value=(None, None))
    mock_read_data = mocker.patch("my_project.database.db.DB.read_data", return_value=[])
    mock_disconnect = mocker.patch("my_project.database.db.DB.disconnect")
    return {
        "connect_patch": mock_connect,
        "read_data_patch": mock_read_data,
        "disconnect_patch": mock_disconnect
    }

@pytest.fixture
def spm_config_ground_truth(spm_test_dir):
    """
    Creates a minimal config file for a ground-truth metric scenario.
    Puts the file in spm_test_dir/spm_configs/<task>.yml 
    so SPM can read it if needed.
    """
    # Example config structure referencing ground-truth fields
    config_content = """
model_name: "test_model"
model_version: "v1"
model_type: "binary"
fields:
  ground_truth_field: 
    type: "gt"
    # ...
  score_field:
    type: "non_gt"
data:
  loc: "s3"
  endpoint: "fake-endpoint"
  bucket: "fake-bucket"
gt_fields:
  ground_truth:
    backfill:
      interval: "1d"
      n: 2
    id: ["unique_id"]
"""
    config_file = spm_test_dir / "spm_configs" / "task_gt.yml"
    config_file.parent.mkdir(exist_ok=True)
    config_file.write_text(config_content)
    return config_file

@pytest.fixture
def spm_config_non_ground_truth(spm_test_dir):
    """
    Creates a minimal config for a non–ground-truth metric scenario.
    """
    config_content = """
model_name: "test_model_nogt"
model_version: "v2"
model_type: "regression"
fields:
  score_field:
    type: "non_gt"
data:
  loc: "s3"
  endpoint: "fake-endpoint"
  bucket: "fake-bucket"
"""
    config_file = spm_test_dir / "spm_configs" / "task_nogt.yml"
    config_file.parent.mkdir(exist_ok=True)
    config_file.write_text(config_content)
    return config_file

@pytest.fixture
def spm_tasks_file(spm_test_dir):
    """
    Creates a minimal tasks.json file referencing tasks used by SPM.
    """
    import json
    tasks = ["task_gt", "task_nogt"]
    tasks_file = spm_test_dir / "tasks.json"
    with open(tasks_file, "w") as f:
        json.dump(tasks, f)
    return tasks_file

@pytest.fixture
def spm_hist_tasks_file(spm_test_dir):
    """
    Creates a hist_tasks.json if needed. Could be empty or reference tasks.
    """
    import json
    hist_tasks_file = spm_test_dir / "hist_tasks.json"
    with open(hist_tasks_file, "w") as f:
        # Example: we can pretend there's no historical tasks
        json.dump({}, f)
    return hist_tasks_file

#
# --------------- Example Tests ---------------
#

def test_spm_ground_truth_metrics(
    mocker,
    spm_test_dir,
    spm_tasks_file,
    spm_hist_tasks_file,
    spm_config_ground_truth,
    mock_s3_db
):
    """
    Tests an SPM scenario with ground-truth metrics.
    - We create an SPM object pointing to the tasks file (which includes "task_gt").
    - We patch or mock certain calls (DB, Analysis, etc.).
    - We ensure SPM.process_tasks() runs with no error, calls ground-truth calculations, etc.
    """

    # Mock environment & ES indexing
    mocker.patch.dict(os.environ, {"USERNAME": "user", "PASSWORD": "pass"})
    mock_index_document = mocker.patch("my_project.database.es.index_document", return_value=("mock_es_res", None))

    # Mock the Analysis class so we can confirm ground-truth metrics are called
    mock_analysis = mocker.patch("my_project.analysis.analysis_new.Analysis", autospec=True)
    instance_analysis = mock_analysis.return_value
    instance_analysis.cal_metrics.return_value = None
    instance_analysis.get_result.return_value = {
        "some_metric": 123,
        "some_other_metric": 0.99
    }

    # Now create the SPM object
    from my_project.spm.spm_flow import SPM  # or whichever path your SPM code is in
    spm_obj = SPM(
        task_filename=spm_tasks_file.name, 
        task_dir=str(spm_test_dir),
        curr_time=datetime(2025, 1, 1),
        window_size=7
    )

    # We run process_tasks. This should parse "task_gt" from tasks.json,
    # load spm_configs/task_gt.yml, then run ground-truth logic.
    spm_obj.process_tasks()

    # Assertions/Verifications
    # 1. Did we call DB.read_data for ground-truth data?
    assert mock_s3_db["read_data_patch"].call_count > 0, "Expected to load data from DB"

    # 2. Analysis class used for metric calculations?
    mock_analysis.assert_called_once()

    # 3. Did we index results into ES?
    mock_index_document.assert_called()

def test_spm_non_ground_truth_metrics(
    mocker,
    spm_test_dir,
    spm_tasks_file,
    spm_hist_tasks_file,
    spm_config_non_ground_truth,
    mock_s3_db
):
    """
    Tests an SPM scenario with only non–ground-truth metrics.
    This references "task_nogt" from tasks.json.
    """

    # Mock environment & ES indexing
    mocker.patch.dict(os.environ, {"USERNAME": "user", "PASSWORD": "pass"})
    mock_index_document = mocker.patch("my_project.database.es.index_document", return_value=("mock_es_res", None))

    # Mock the Analysis class
    mock_analysis = mocker.patch("my_project.analysis.analysis_new.Analysis", autospec=True)
    instance_analysis = mock_analysis.return_value
    instance_analysis.cal_metrics.return_value = None
    instance_analysis.get_result.return_value = {
        "non_gt_metric": 42
    }

    # Create the SPM object
    from my_project.spm.spm_flow import SPM
    spm_obj = SPM(
        task_filename=spm_tasks_file.name, 
        task_dir=str(spm_test_dir),
        curr_time=datetime(2025, 1, 10),
        window_size=14
    )

    # Run the tasks -> "task_nogt"
    spm_obj.process_tasks()

    # Assertions
    assert mock_s3_db["read_data_patch"].call_count > 0, "Should read data for non-gt calculations"
    mock_analysis.assert_called_once()
    mock_index_document.assert_called()



Explanation
Fixtures

spm_test_dir: Creates a temporary directory for storing config files (task_nogt.yml, task_gt.yml, etc.).
mock_s3_db: Patches DB.connect, DB.read_data, DB.disconnect so calls to your database are mocked.
spm_config_ground_truth and spm_config_non_ground_truth: Writes out minimal .yml config files for each scenario (ground-truth vs. non–ground-truth).
spm_tasks_file: Writes a tasks.json containing tasks referencing those config files.
Ground-Truth vs. Non–Ground-Truth

We have two separate tests:
test_spm_ground_truth_metrics references task_gt, which loads spm_configs/task_gt.yml containing gt_fields. We expect the code to run ground-truth metric logic.
test_spm_non_ground_truth_metrics references task_nogt, which only has non–ground-truth fields.
Mocking

We mock index_document from my_project.database.es to avoid real ES indexing.
We mock Analysis from my_project.analysis.analysis_new to confirm it gets instantiated and check the metrics results flow.
mock_s3_db ensures SPM’s calls to DB.connect(), read_data(), disconnect() do not hit a real database.
Assertions

We verify that read_data() was called (so we know data was actually fetched).
We confirm Analysis is called, meaning SPM attempts to do metric computations.
We confirm the final indexing into ES (index_document) is invoked with some result data.
Test Approach

This approach mimics both the style of metrics_spark.py tests (where we verify metric calculations) and the style of db.py/s3_db tests (mocking out DB interactions).
For a true integration test, you could remove the mocking of DB and index_document so the code actually reads from S3 (or a real data source) and writes to Elasticsearch. That would require real credentials and network connectivity, so typically those tests are labeled as integration or end-to-end tests.
Running the Tests

Make sure pytest is installed (pip install pytest).
Save this file (e.g., test_spm.py) in your tests directory and run:
bash
Copy
Edit
pytest test_spm.py
If you prefer to separate out integration vs. unit tests, you could mark one with @pytest.mark.integration and run them selectively.
This structure demonstrates how to test:

(a) Ground-truth metrics: ensuring your code runs the path that loads ground-truth data and calls analysis.db_gt_metrics_spark or similar.
(b) Non–ground-truth metrics: ensuring your code runs the path for baseline comparison, etc.
Feel free to add or remove details according to your actual code paths, class names, and how your pipeline is set up.







