In [3]:
'''
You are given a dataset containing sales information of different sales people. Each row in the dataset contains the salesperson’s ID, their name, and the sales amount. You need to write a PySpark program that performs the following tasks:

Filter out records where sales are less than 50.

Create a new column sales_category that categorizes sales into:

Low (50–100)
Medium (101–200)
High (>200)
Group the data by sales_category and calculate the average sales for each category.

Input Schema & Example
Column Name	Data Type
id	Integer
name	String
sales	Double
Example Input Table
id	name	sales
1	Alice	45
2	Bob	120
3	Carol	75
4	David	250
5	Eve	180
Output Schema, Example & Explanation
Column Name	Data Type
sales_category	String
avg_sales	Double
Example Output Table
sales_category	avg_sales
Low	75.0
Medium	150.0
High	250.0
Explanation
Records with sales < 50 are filtered out (so Alice’s record is removed).

Carol (75) falls into Low, Bob (120) and Eve (180) fall into Medium, and David (250) falls into High.

The averages are then computed per category:

Low: (75) / 1 = 75.0
Medium: (120 + 180) / 2 = 150.0
High: (250) / 1 = 250.0
Starter Code
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

data = [
    (1, "Alice", 45),
    (2, "Bob", 120),
    (3, "Carol", 77),
    (4, "David", 250),
    (5, "Eve", 180),
    (6, "Jacob", 30),
    (7, "Mike", 90),
    (8, "Tim", 65),
    (9, "Lukas",159),
    (10, "Peter", 217),
    (11, "Henry", 100),
    (12, "Frida", 200),
    (13, "Grisha", 50)
]

columns = ["id", "name", "sales"]

df = spark.createDataFrame(data, columns)

# Write your transformations here

Use display(df_result) to show the final DataFrame.

'''
# Initialize Spark session
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.appName('Spark Playground').getOrCreate()

data = [
    (1, "Alice", 45),
    (2, "Bob", 120),
    (3, "Carol", 77),
    (4, "David", 250),
    (5, "Eve", 180),
    (6, "Jacob", 30),
    (7, "Mike", 90),
    (8, "Tim", 65),
    (9, "Lukas",159),
    (10, "Peter", 217),
    (11, "Henry", 100),
    (12, "Frida", 200),
    (13, "Grisha", 50)
]

columns = ["id", "name", "sales"]

df = spark.createDataFrame(data, columns)

# Filter sales >= 50
df_filtered = df.filter(F.col("sales") >= 50)

# Create new column sales_category
df_with_category = (
  df_filtered
  .withColumn("sales_category",
              F.when((F.col("sales") >= 50) & (F.col("sales") <= 100), "Low")
             .when((F.col("sales") >= 101) & (F.col("sales") <= 200), "Medium")
             .otherwise("High"))
)

# Aggregate
df_result = (
  df_with_category.groupBy(F.col("sales_category"))
  .agg(F.avg(F.col("sales")).alias("avg_sales"))
)

# Display result
df_result.show()

+--------------+---------+
|sales_category|avg_sales|
+--------------+---------+
|           Low|     76.4|
|        Medium|   164.75|
|          High|    233.5|
+--------------+---------+

