- The agg() function in PySpark is used to perform multiple aggregate operations on a DataFrame.
- It can be applied after groupBy() to aggregate data using functions like sum(), avg(), min(), max(), count(), etc.

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, avg, min, max, count

spark = SparkSession.builder.appName("AggFunctionExample").getOrCreate()


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/09/09 12:52:08 WARN Utils: Your hostname, KLZPC0015, resolves to a loopback address: 127.0.1.1; using 172.25.17.96 instead (on interface eth0)
25/09/09 12:52:08 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/09 12:52:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
# create sample dataframes
data = [
    (1, "Manta", 75000, "IT", 24),
    (2, "Dipankar", 30000, "Post Master", 27),
    (3, "Souvik", 60000, "Army Officer", 27),
    (4, "Soukarjya", 45000, "BDO", None),
    (5, "Arvind", 35000, "Business Data Analyst", 28),
    (6, "Prodipta", 25000, "Data Analyst", 28),
    (7, "Padma", 20000, "Data Analyst", 27),
    (8, "Panta", 125000, "Business Analyst", 27),
    (9, "Sougato", 25000, "Para Medical Stuff", 29)
]

columns = ["id", "name", "salary", "department", "age"]

df = spark.createDataFrame(data, schema=columns)
df.show()


                                                                                

+---+---------+------+--------------------+----+
| id|     name|salary|          department| age|
+---+---------+------+--------------------+----+
|  1|    Manta| 75000|                  IT|  24|
|  2| Dipankar| 30000|         Post Master|  27|
|  3|   Souvik| 60000|        Army Officer|  27|
|  4|Soukarjya| 45000|                 BDO|NULL|
|  5|   Arvind| 35000|Business Data Ana...|  28|
|  6| Prodipta| 25000|        Data Analyst|  28|
|  7|    Padma| 20000|        Data Analyst|  27|
|  8|    Panta|125000|    Business Analyst|  27|
|  9|  Sougato| 25000|  Para Medical Stuff|  29|
+---+---------+------+--------------------+----+



In [None]:
# agg() with groupBy() for multiple aggregate functions
agg_df = df.groupBy("department").agg(
    sum("salary").alias("total_salary"),
    avg("salary").alias("average_salary"),
    min("salary").alias("min_salary"),
    max("salary").alias("max_salary"),
    count("name").alias("employee_count")
)
print("Aggregated Data using agg(): ")
agg_df.show()


Aggregated Data using agg(): 




+--------------------+------------+--------------+----------+----------+--------------+
|          department|total_salary|average_salary|min_salary|max_salary|employee_count|
+--------------------+------------+--------------+----------+----------+--------------+
|         Post Master|       30000|       30000.0|     30000|     30000|             1|
|                  IT|       75000|       75000.0|     75000|     75000|             1|
|                 BDO|       45000|       45000.0|     45000|     45000|             1|
|        Army Officer|       60000|       60000.0|     60000|     60000|             1|
|        Data Analyst|       45000|       22500.0|     20000|     25000|             2|
|Business Data Ana...|       35000|       35000.0|     35000|     35000|             1|
|    Business Analyst|      125000|      125000.0|    125000|    125000|             1|
|  Para Medical Stuff|       25000|       25000.0|     25000|     25000|             1|
+--------------------+----------

                                                                                

In [None]:
# agg() without groupBy() for overall aggregation
overall_agg_df = df.agg(
    sum("salary").alias("total_salary"),
    avg("salary").alias("average_salary")
)

print("Overall aggregation without groupBy(): ")
overall_agg_df.show()


Overall aggregation without groupBy(): 


[Stage 5:>                                                          (0 + 4) / 4]

+------------+-----------------+
|total_salary|   average_salary|
+------------+-----------------+
|      440000|48888.88888888889|
+------------+-----------------+



                                                                                