In [5]:
import pytest
from pyspark.sql import SparkSession, functions as Func
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType
from pyspark.sql.window import Window


@pytest.fixture(scope="module")
def spark():
    spark = SparkSession.builder.master("local[2]").appName("unitTest").getOrCreate()
    yield spark
    spark.stop()

def test_filtering_criteria(spark):
    # Create a sample DataFrame with minimal columns needed
    data = [
        # patient, enc_START, enc_STOP, overdose_flag, birthdate
        ("patient_1", "2000-01-01 10:00:00", "2000-01-01 12:00:00", "overdose", "1980-01-01 00:00:00"),
        ("patient_2", "1999-07-01 10:00:00", "1999-07-01 12:00:00", "overdose", "1985-01-01 00:00:00"),  # before cutoff
        ("patient_3", "2001-01-01 10:00:00", "2001-01-01 12:00:00", "non-overdose", "1990-01-01 00:00:00"),  # wrong type
        ("patient_4", "2000-05-01 10:00:00", "2000-05-01 12:00:00", "overdose", "1960-01-01 00:00:00")   # older than 35 at encounter 
    ]
    schema = StructType([
        StructField("PATIENT", StringType(), True),
        StructField("enc_START", StringType(), True),
        StructField("enc_STOP", StringType(), True),
        StructField("encounter_type", StringType(), True),
        StructField("BIRTHDATE", StringType(), True)
    ])
    df = spark.createDataFrame(data, schema)
    
    # Assume your filtering logic is in a function called build_cohort(df) that returns the filtered DataFrame
    # For demonstration, here is a simplified version:
    filtered = df.filter(
        (Func.col("encounter_type") == "overdose") &
        (Func.col("enc_START") > Func.lit("1999-07-15 00:00:00"))
    )
    
    res = filtered.collect()
    assert len(res) == 1
    assert res[0]["PATIENT"] == "patient_1"


In [6]:
def test_readmission_indicators(spark):
    data = [
        # PATIENT, enc_START, enc_STOP
        ("patient_1", "2022-01-01 08:00:00", "2022-01-01 10:00:00"),
        ("patient_1", "2022-01-15 09:00:00", "2022-01-15 11:00:00"),
        ("patient_1", "2022-04-01 09:00:00", "2022-04-01 11:00:00")  # > 90 days difference from previous row
    ]
    schema = StructType([
        StructField("PATIENT", StringType(), True),
        StructField("enc_START", StringType(), True),
        StructField("enc_STOP", StringType(), True)
    ])
    df = spark.createDataFrame(data, schema)

    # Convert date strings to timestamp, if your code does it within the transformation, adjust accordingly.
    df = df.withColumn("enc_START", Func.col("enc_START").cast(TimestampType())) \
           .withColumn("enc_STOP", Func.col("enc_STOP").cast(TimestampType()))
    
    # Create a window partitioned by PATIENT and ordered by enc_START
    patient_window = Window.partitionBy("PATIENT").orderBy(Func.col("enc_START"))
    
    # Compute next_enc_START and diff_days as in your notebook
    df = df.withColumn("next_enc_START", Func.lead("enc_START").over(patient_window))
    df = df.withColumn("diff_days", Func.datediff(Func.col("next_enc_START"), Func.col("enc_START")))
    
    # Flag readmissions: 90-day and 30-day indicators
    df = df.withColumn(
        "READMISSION_90_DAY_IND",
        Func.when((Func.col("diff_days").isNotNull()) & (Func.col("diff_days") != 0) &
                  (Func.col("diff_days") <= 90), Func.lit(1)).otherwise(Func.lit(0))
    )
    df = df.withColumn(
        "READMISSION_30_DAY_IND",
        Func.when((Func.col("diff_days").isNotNull()) & (Func.col("diff_days") != 0) &
                  (Func.col("diff_days") <= 30), Func.lit(1)).otherwise(Func.lit(0))
    )
    
    # Update next_enc_START to "N/A" if diff_days is 0 (simulate that transformation)
    df = df.withColumn(
        "next_enc_START",
        Func.when((Func.col("diff_days") <= 90) & (Func.col("diff_days") != 0), Func.col("next_enc_START")).otherwise(Func.lit("N/A"))
    )
    df = df.withColumnRenamed("next_enc_START", "FIRST_READMISSION_DATE")
    
    results = df.collect()
    
    # Assertions:
    # Row 1 should have diff_days = 14, and 90-day indicator = 1, 30-day = 0
    assert results[0]["diff_days"] == 14
    assert results[0]["READMISSION_90_DAY_IND"] == 1
    assert results[0]["READMISSION_30_DAY_IND"] == 0
    
    # Row 2 should have diff_days = 76 (from Jan 15 to Apr 1) > 30, so 30-day = 0, 90-day = 1
    assert results[1]["diff_days"] == 76
    assert results[1]["READMISSION_90_DAY_IND"] == 1
    assert results[1]["READMISSION_30_DAY_IND"] == 0
    
    # Row 3 (last encounter) has no next encounter so diff_days is null, readmission indicators = 0
    assert results[2]["diff_days"] is None
    assert results[2]["READMISSION_90_DAY_IND"] == 0
    assert results[2]["READMISSION_30_DAY_IND"] == 0


In [7]:
def test_opioid_indicator(spark):
    # Create a sample medications DataFrame
    data = [
        ("patient_1", "2022-01-01 08:00:00", "MED001"),  # assume MED001 is restricted
        ("patient_1", "2022-01-01 08:05:00", "MED002"),  # not restricted
        ("patient_2", "2022-01-01 09:00:00", "MED003")   # assume MED003 is restricted
    ]
    schema = StructType([
        StructField("PATIENT", StringType(), True),
        StructField("enc_START", StringType(), True),
        StructField("CODE", StringType(), True)
    ])
    med_df = spark.createDataFrame(data, schema)
    
    # Define your list of restricted opioid codes
    restricted_codes = ["MED001", "MED003"]
    
    # Add CURRENT_OPIOID_IND column
    med_df = med_df.withColumn(
        "CURRENT_OPIOID_IND",
        Func.when(Func.col("CODE").isin(*restricted_codes), Func.lit(1)).otherwise(Func.lit(0))
    )
    
    results = med_df.collect()
    
    # Assert that patients with MED001 and MED003 get flagged appropriately
    for row in results:
        if row["CODE"] in restricted_codes:
            assert row["CURRENT_OPIOID_IND"] == 1
        else:
            assert row["CURRENT_OPIOID_IND"] == 0
