In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
import random
from datetime import datetime, timedelta

In [None]:
spark = SparkSession.builder \
    .appName("MinIO-Test") \
    .master("spark://spark-master:7077") \
    .config("spark.hadoop.fs.s3a.endpoint", "http://minio:9000") \
    .config("spark.hadoop.fs.s3a.access.key", "admin*12345") \
    .config("spark.hadoop.fs.s3a.secret.key", "psswrd*12345") \
    .config("spark.hadoop.fs.s3a.path.style.access", True) \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.cores.max", "2") \
    .config("spark.driver.maxResultSize", "4g") \
    .getOrCreate()
    # .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    # .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \


In [None]:
spark.version

sc = spark.sparkContext

# Get cluster status details
print(f"Master URL: {sc.master}")
print(f"Application ID: {sc.applicationId}")
print(f"Spark UI: {sc.uiWebUrl}")

In [None]:
def generate_test_data(num_records=100):
    data = []
    start_date = datetime(2024, 1, 1)
    
    for i in range(num_records):
        date = start_date + timedelta(days=i % 365)
        data.append({
            'id': i,
            'date': date.strftime('%Y-%m-%d'),
            'value': random.uniform(1, 1000),
            'category': random.choice(['A', 'B', 'C', 'D']),
            'quantity': random.randint(1, 100)
        })
    
    return data

test_data = generate_test_data()
print(test_data)

In [None]:
# Generate test DataFrame
test_df = spark.createDataFrame(test_data)
test_df.show(5)

In [None]:
def test_minio_bucket(bucket_name):
    try:
        # List files in the bucket
        files = spark.sparkContext._jsc.hadoopConfiguration().get("fs.s3a.impl")
        print(f"S3A Implementation: {files}")
        
        # Try to write a small test file
        test_df.limit(1).write.mode("overwrite").parquet(f"s3a://{bucket_name}/test_file")
        print(f"Successfully wrote test file to s3a://{bucket_name}/test_file")
        
        return True
    except Exception as e:
        print(f"Error testing bucket {bucket_name}: {str(e)}")
        return False

# Test bucket connection
bucket_name = "lake"
test_minio_bucket(bucket_name)

In [None]:
# Test Parquet
parquet_path = f"s3a://{bucket_name}/parquet_test"
test_df.write.mode("overwrite").parquet(parquet_path)
parquet_df = spark.read.parquet(parquet_path)
print("Parquet read/write test:")
parquet_df.show(5)

In [None]:
# Test Delta Lake
delta_path = f"s3a://{bucket_name}/delta_test"
test_df.write.format("delta").mode("overwrite").save(delta_path)
delta_df = spark.read.format("delta").load(delta_path)
print("Delta Lake read/write test:")
delta_df.show(5)

In [None]:
# Test Data Consistency
def verify_data_consistency(original_df, loaded_df, format_name):
    original_count = original_df.count()
    loaded_count = loaded_df.count()
    print(f"{format_name} Data Consistency Check:")
    print(f"Original count: {original_count}")
    print(f"Loaded count: {loaded_count}")
    print(f"Counts match: {original_count == loaded_count}")
    
    # Compare schemas
    schema_match = original_df.schema == loaded_df.schema
    print(f"Schemas match: {schema_match}")
    
    # Compare data distribution
    original_stats = original_df.select(mean('value'), stddev('value')).collect()[0]
    loaded_stats = loaded_df.select(mean('value'), stddev('value')).collect()[0]
    print(f"Original mean: {original_stats[0]:.2f}, std: {original_stats[1]:.2f}")
    print(f"Loaded mean: {loaded_stats[0]:.2f}, std: {loaded_stats[1]:.2f}")

# Verify Parquet consistency
verify_data_consistency(test_df, parquet_df, "Parquet")
print("\n")
# Verify Delta consistency
verify_data_consistency(test_df, delta_df, "Delta")

In [None]:
# Register Delta table for SQL queries
delta_df.createOrReplaceTempView("test_table")

# Perform some test queries
print("Average value by category:")
spark.sql("""
    SELECT 
        category,
        COUNT(*) as count,
        AVG(value) as avg_value,
        AVG(quantity) as avg_quantity
    FROM test_table
    GROUP BY category
    ORDER BY category
""").show()

In [None]:
# Function to clean up test data
def cleanup_test_data():
    from py4j.protocol import Py4JJavaError
    
    paths = [parquet_path, delta_path]
    for path in paths:
        try:
            # Delete test data
            spark.sparkContext._jsc.hadoopConfiguration().set("fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
            hadoop_fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(spark._jsc.hadoopConfiguration())
            hadoop_fs.delete(spark._jvm.org.apache.hadoop.fs.Path(path), True)
            print(f"Successfully deleted {path}")
        except Py4JJavaError as e:
            print(f"Error deleting {path}: {str(e)}")

# Uncomment to clean up test data
# cleanup_test_data()

In [None]:
spark.stop()