In [1]:
import os
os.getcwd()
os.chdir("H:\pyspark_advanced-coding_interview")
print(os.getcwd())


from pyspark.sql import SparkSession

# Create a Spark session with optimized settings
spark = (
    SparkSession.builder 
    .appName("OptimizedLocalSpark") 
    .config("spark.driver.memory", "8g")        
    .config("spark.executor.memory", "8g")    
    .config("spark.executor.cores", "4")       
    .config("spark.cores.max", "12")           
    .config("spark.sql.shuffle.partitions", "28")  
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 
    .getOrCreate()
)
sc = spark.sparkContext

H:\pyspark_advanced-coding_interview


# Reset Running Total When negative

In [3]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Initialize Spark session
spark = SparkSession.builder.appName("RunningTotalReset").getOrCreate()

# Sample data with a reset condition
data = [
    (1, 100, 0),
    (2, 50, 0),
    (3, -30, 0),
    (4, 20, 0),
    (5, 30, 1),    # Reset trigger
    (6, 10, 0),
    (7, -10, 0),
    (8, 40, 0)
]
columns = ["transaction_id", "amount", "reset"]

# Create DataFrame
df = spark.createDataFrame(data, columns)
df.createOrReplaceTempView("transaction_table")
df.show()


+--------------+------+-----+
|transaction_id|amount|reset|
+--------------+------+-----+
|             1|   100|    0|
|             2|    50|    0|
|             3|   -30|    0|
|             4|    20|    0|
|             5|    30|    1|
|             6|    10|    0|
|             7|   -10|    0|
|             8|    40|    0|
+--------------+------+-----+



In [4]:
res = spark.sql("""
    WITH RunningTotal AS (
        SELECT 
            transaction_id,
            amount,
            reset,
            SUM(amount) OVER (ORDER BY transaction_id ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total
        FROM transaction_table
    ),
    ResetFlag AS (
        SELECT 
            transaction_id,
            amount,
            reset,
            running_total,
            CASE 
                WHEN reset = 1 OR running_total < 0 THEN 1 
                ELSE 0 
            END AS reset_flag
        FROM RunningTotal
    ),
    CumulativeTotal AS (
        SELECT 
            transaction_id,
            amount,
            reset,
            SUM(CASE WHEN reset_flag = 1 THEN 1 ELSE 0 END) OVER (ORDER BY transaction_id) AS reset_group
        FROM ResetFlag
    )
    SELECT 
        transaction_id,
        amount,
        reset,
        SUM(amount) OVER (PARTITION BY reset_group ORDER BY transaction_id) AS cumulative_total
    FROM CumulativeTotal
""")

res.show()


+--------------+------+-----+----------------+
|transaction_id|amount|reset|cumulative_total|
+--------------+------+-----+----------------+
|             1|   100|    0|             100|
|             2|    50|    0|             150|
|             3|   -30|    0|             120|
|             4|    20|    0|             140|
|             5|    30|    1|              30|
|             6|    10|    0|              40|
|             7|   -10|    0|              30|
|             8|    40|    0|              70|
+--------------+------+-----+----------------+



In [5]:
# Define the window for cumulative total
window_spec = Window.orderBy("transaction_id").rowsBetween(Window.unboundedPreceding, Window.currentRow)

# Calculate running total and reset flag
df_with_running_total = df.withColumn("running_total", F.sum("amount").over(window_spec)) \
                          .withColumn("reset_flag", F.when((F.col("reset") == 1) | (F.col("running_total") < 0), 1).otherwise(0))

# Define window for reset grouping
window_reset_group = Window.orderBy("transaction_id").rowsBetween(Window.unboundedPreceding, 0)

# Create cumulative reset group
df_with_reset_group = df_with_running_total.withColumn("reset_group", F.sum("reset_flag").over(window_reset_group))

# Define window for cumulative total within each reset group
window_cumulative_total = Window.partitionBy("reset_group").orderBy("transaction_id")

# Calculate cumulative total with reset
df_result = df_with_reset_group.withColumn("cumulative_total", F.sum("amount").over(window_cumulative_total))

df_result.select("transaction_id", "amount", "reset", "cumulative_total").show()


+--------------+------+-----+----------------+
|transaction_id|amount|reset|cumulative_total|
+--------------+------+-----+----------------+
|             1|   100|    0|             100|
|             2|    50|    0|             150|
|             3|   -30|    0|             120|
|             4|    20|    0|             140|
|             5|    30|    1|              30|
|             6|    10|    0|              40|
|             7|   -10|    0|              30|
|             8|    40|    0|              70|
+--------------+------+-----+----------------+

