groupBy() in PySpark is used to group rows based on one or more columns, After grouping, we can apply aggregate functions like count(), sum(), avg(), min(), max(), etc. It works like GROUP BY in SQL.

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("PySparkGroupByFunction").getOrCreate()


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/09/08 08:52:12 WARN Utils: Your hostname, KLZPC0015, resolves to a loopback address: 127.0.1.1; using 172.25.17.96 instead (on interface eth0)
25/09/08 08:52:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/08 08:52:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/09/08 08:52:24 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/09/08 08:52:24 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [5]:
# create sample dataframes
data = [
    (1, "Manta", 75000, "IT", 24),
    (2, "Dipankar", 30000, "Post Master", 27),
    (3, "Souvik", 60000, "Army Officer", 27),
    (4, "Soukarjya", 45000, "BDO", None),
    (5, "Arvind", 35000, "Business Analyst", 28),
    (6, "Prodipta", 25000, "Data Analyst", 28),
    (7, "Padma", 20000, "Data Analyst", 27),
    (8, "Panta", 125000, "Business Analyst", 27),
    (9, "Sougato", 25000, "IT", 29)
]

columns = ["id", "name", "salary", "department", "age"]

df = spark.createDataFrame(data, schema=columns)
df.show()




+---+---------+------+----------------+----+
| id|     name|salary|      department| age|
+---+---------+------+----------------+----+
|  1|    Manta| 75000|              IT|  24|
|  2| Dipankar| 30000|     Post Master|  27|
|  3|   Souvik| 60000|    Army Officer|  27|
|  4|Soukarjya| 45000|             BDO|NULL|
|  5|   Arvind| 35000|Business Analyst|  28|
|  6| Prodipta| 25000|    Data Analyst|  28|
|  7|    Padma| 20000|    Data Analyst|  27|
|  8|    Panta|125000|Business Analyst|  27|
|  9|  Sougato| 25000|              IT|  29|
+---+---------+------+----------------+----+



                                                                                

In [6]:
# groupBy() with count() - count of employees in each department
print("Count of employees in each Department: ")
df.groupBy("department").count().show()


Count of employees in each Department: 




+----------------+-----+
|      department|count|
+----------------+-----+
|     Post Master|    1|
|              IT|    2|
|             BDO|    1|
|    Army Officer|    1|
|Business Analyst|    2|
|    Data Analyst|    2|
+----------------+-----+



                                                                                

In [7]:
# groupBy() with sum() - Total salary by department
from pyspark.sql.functions import sum

print("Total salary by department: ")
df.groupBy("department") \
    .agg(sum("salary").alias("total_salary")) \
        .show()


Total salary by department: 


[Stage 10:>                                                         (0 + 4) / 4]

+----------------+------------+
|      department|total_salary|
+----------------+------------+
|     Post Master|       30000|
|              IT|      100000|
|             BDO|       45000|
|    Army Officer|       60000|
|Business Analyst|      160000|
|    Data Analyst|       45000|
+----------------+------------+



                                                                                

In [8]:
# groupBy() with avg(), min(), max() - department salary stats
from pyspark.sql.functions import avg, min, max

print("Salary stats by department (average, min, max): ")
df.groupBy("department") \
    .agg(
        avg("salary").alias("avg_salary"),
        min("salary").alias("min_salary"),
        max("salary").alias("max_salary")
    ).show()


Salary stats by department (average, min, max): 




+----------------+----------+----------+----------+
|      department|avg_salary|min_salary|max_salary|
+----------------+----------+----------+----------+
|     Post Master|   30000.0|     30000|     30000|
|              IT|   50000.0|     25000|     75000|
|             BDO|   45000.0|     45000|     45000|
|    Army Officer|   60000.0|     60000|     60000|
|Business Analyst|   80000.0|     35000|    125000|
|    Data Analyst|   22500.0|     20000|     25000|
+----------------+----------+----------+----------+



                                                                                

In [9]:
# groupBy() Name and Department - sum of salaries(for same names)
print("Sum of salary by name and department: ")

df.groupBy("name", "department") \
    .agg(
        sum("salary").alias("total_salary")
    ).show()


Sum of salary by name and department: 




+---------+----------------+------------+
|     name|      department|total_salary|
+---------+----------------+------------+
| Dipankar|     Post Master|       30000|
|    Manta|              IT|       75000|
|Soukarjya|             BDO|       45000|
|   Souvik|    Army Officer|       60000|
|   Arvind|Business Analyst|       35000|
| Prodipta|    Data Analyst|       25000|
|    Padma|    Data Analyst|       20000|
|  Sougato|              IT|       25000|
|    Panta|Business Analyst|      125000|
+---------+----------------+------------+



                                                                                