In [1]:
import os
os.getcwd()
os.chdir("H:\pyspark_advanced-coding_interview")
os.getcwd()

'H:\\pyspark_advanced-coding_interview'

In [2]:
from pyspark.sql import SparkSession

# Create a Spark session with optimized settings
spark = (SparkSession.builder .appName("OptimizedLocalSpark") .getOrCreate())
sc = spark.sparkContext

#### running total going negative or the Reset_Flag column value

In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, sum as Fsum
from pyspark.sql.window import Window

# Initialize Spark session
spark = SparkSession.builder.appName("ResetRunningTotal").getOrCreate()

# Sample data
data = [
    ("2022-09-01", 1, 100, 0, 0),
    ("2022-09-02", 1, 200, 0, 0),
    ("2022-09-03", 1, -500, 0, 0),
    ("2022-09-04", 1, 150, 0, 0),
    ("2022-09-05", 1, 400, 1, 1),
    ("2022-09-06", 1, 250, 0, 1),
    ("2022-09-07", 1, -850, 0, 1)
]

columns = ["TransactionDate", "ProductID", "Qty", "Reset_Flag", "Grp"]
df = spark.createDataFrame(data, columns)

# Define the window
windowSpec = Window.partitionBy("ProductID", "Grp").orderBy("TransactionDate")

# Calculate running total
df = df.withColumn(
    "RunningTotal",
    Fsum("Qty").over(windowSpec)
)

# Apply reset condition based on negative running total or Reset_Flag
df = df.withColumn(
    "AdjustedRunningTotal",
    when((col("RunningTotal") < 0) | (col("Reset_Flag") == 1), 0).otherwise(col("RunningTotal"))
)

df.show()


+---------------+---------+----+----------+---+------------+--------------------+
|TransactionDate|ProductID| Qty|Reset_Flag|Grp|RunningTotal|AdjustedRunningTotal|
+---------------+---------+----+----------+---+------------+--------------------+
|     2022-09-01|        1| 100|         0|  0|         100|                 100|
|     2022-09-02|        1| 200|         0|  0|         300|                 300|
|     2022-09-03|        1|-500|         0|  0|        -200|                   0|
|     2022-09-04|        1| 150|         0|  0|         -50|                   0|
|     2022-09-05|        1| 400|         1|  1|         400|                   0|
|     2022-09-06|        1| 250|         0|  1|         650|                 650|
|     2022-09-07|        1|-850|         0|  1|        -200|                   0|
+---------------+---------+----+----------+---+------------+--------------------+



In [4]:
# Collect data to Python for iterative calculation
data = df.collect()

# Initialize the adjusted running total and result list
adjusted_total = 0
result = []

for row in data:
    # If Reset_Flag is set or adjusted_total is negative, reset
    if row["Reset_Flag"] == 1 or adjusted_total < 0:
        adjusted_total = 0
    
    # Calculate the new running total
    adjusted_total += row["Qty"]
    
    # Append results
    result.append((row["TransactionDate"], row["ProductID"], row["Qty"], row["Reset_Flag"], row["Grp"], adjusted_total))

# Create a DataFrame with the results
result_df = spark.createDataFrame(result, columns + ["AdjustedRunningTotal"])

result_df.show()


+---------------+---------+----+----------+---+--------------------+
|TransactionDate|ProductID| Qty|Reset_Flag|Grp|AdjustedRunningTotal|
+---------------+---------+----+----------+---+--------------------+
|     2022-09-01|        1| 100|         0|  0|                 100|
|     2022-09-02|        1| 200|         0|  0|                 300|
|     2022-09-03|        1|-500|         0|  0|                -200|
|     2022-09-04|        1| 150|         0|  0|                 150|
|     2022-09-05|        1| 400|         1|  1|                 400|
|     2022-09-06|        1| 250|         0|  1|                 650|
|     2022-09-07|        1|-850|         0|  1|                -200|
+---------------+---------+----+----------+---+--------------------+



In [5]:
# Register DataFrame as a SQL temporary view
df.createOrReplaceTempView("transactions")

# SQL query to calculate running total with reset condition
result_sql = spark.sql("""
    SELECT 
        TransactionDate, 
        ProductID, 
        Qty, 
        Reset_Flag, 
        Grp,
        SUM(Qty) OVER (PARTITION BY ProductID, Grp ORDER BY TransactionDate) AS RunningTotal,
        CASE 
            WHEN SUM(Qty) OVER (PARTITION BY ProductID, Grp ORDER BY TransactionDate) < 0 OR Reset_Flag = 1 THEN 0 
            ELSE SUM(Qty) OVER (PARTITION BY ProductID, Grp ORDER BY TransactionDate) 
        END AS AdjustedRunningTotal
    FROM transactions
    ORDER BY TransactionDate
""")

result_sql.show()


+---------------+---------+----+----------+---+------------+--------------------+
|TransactionDate|ProductID| Qty|Reset_Flag|Grp|RunningTotal|AdjustedRunningTotal|
+---------------+---------+----+----------+---+------------+--------------------+
|     2022-09-01|        1| 100|         0|  0|         100|                 100|
|     2022-09-02|        1| 200|         0|  0|         300|                 300|
|     2022-09-03|        1|-500|         0|  0|        -200|                   0|
|     2022-09-04|        1| 150|         0|  0|         -50|                   0|
|     2022-09-05|        1| 400|         1|  1|         400|                   0|
|     2022-09-06|        1| 250|         0|  1|         650|                 650|
|     2022-09-07|        1|-850|         0|  1|        -200|                   0|
+---------------+---------+----+----------+---+------------+--------------------+

