In [0]:
"""
https://www.youtube.com/watch?v=7W7B0y5WsaQ
Write a sql query to print Highest and Lowest Salary Employees in Each Department
"""

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

data = [
      ('Siva',1,30000),
      ('Ravi',2,40000),
      ('Prasad',1,50000),
      ('Sai',2,20000) 
    ]

schema = StructType(
    [
        StructField("emp_name", StringType()),
        StructField("dep_id", IntegerType()),
        StructField("salary", IntegerType())
    ]
)

df = spark.createDataFrame(data=data, schema=schema)
df.show()
df.printSchema()

+--------+------+------+
|emp_name|dep_id|salary|
+--------+------+------+
|    Siva|     1| 30000|
|    Ravi|     2| 40000|
|  Prasad|     1| 50000|
|     Sai|     2| 20000|
+--------+------+------+

root
 |-- emp_name: string (nullable = true)
 |-- dep_id: integer (nullable = true)
 |-- salary: integer (nullable = true)



In [0]:
df.withColumn("max_salary", max("salary").over(Window.partitionBy("dep_id").orderBy(desc("salary")))) \
    .withColumn("min_salary", min("salary").over(Window.partitionBy("dep_id").orderBy("salary"))) \
    .withColumn("max_sal_emp", when(col("max_salary") == col("salary"), col("emp_name")).otherwise(None)) \
    .withColumn("min_sal_emp", when(col("min_salary") == col("salary"), col("emp_name")).otherwise(None)) \
    .groupBy("dep_id").agg(max("max_sal_emp").alias("max_sal_emp"), max("min_sal_emp").alias("min_sal_emp")) \
    .show()

+------+-----------+-----------+
|dep_id|max_sal_emp|min_sal_emp|
+------+-----------+-----------+
|     1|     Prasad|       Siva|
|     2|       Ravi|        Sai|
+------+-----------+-----------+

