**collect_set():**
- Helps you **collect unique values** from a **group** into an **array** by **removing duplicates** automatically.
- returns an **array of unique elements (removes duplicates)**.
- If you use **collect_list()**, it would **keep duplicates** instead.

- is an **aggregation function** in Apache Spark that is used to **aggregate data within groups** based on a specified grouping condition.
- **collect_set** is typically used after a **groupBy** operation to **aggregate values within each group**.
- Returns an **array of unique values** from a column within **each group**.
- **Duplicates** are automatically **removed**.
- **Order of elements** is **not guaranteed**, since it’s treated as a **set**.

**Syntax**
	
     collect_set(column)

| Parameter Name | Required	| Description |
|----------------|----------|-------------|
| column (str, Column)	| Yes	| It represents the column value to be collected together |

**Difference Between collect_set() and collect_list()**

| Function         | Keeps Duplicates | Returns Type | Common Use Case                                       |
| ---------------- | ---------------- | ------------ | ----------------------------------------------------- |
| `collect_set()`  | ❌ No           | Array        | When you only need **unique values**                      |
| `collect_list()` | ✅ Yes          | Array        | When you want to preserve all values, **even duplicates** |

**1) collect_set() function returns all values from an input column with duplicate values eliminated**
- To **collect** the “salary” column values **without duplication**.
- we have all the values, **excluding the duplicates**, and there are **no null values**.
- The collect_set() function omits the **null values**.

In [0]:
from pyspark.sql.functions import lit, col, collect_set, collect_list, col
import pyspark.sql.functions as f

In [0]:
simpleData = [("James", "Sales", 3000),
    ("Michael", "Sales", 4600),
    ("Robert", "Sales", 4100),
    ("Maria", "Finance", 3000),
    ("James", "Sales", 3000),
    ("Scott", "Finance", 3300),
    ("Jen", "Finance", 3900),
    ("Jeff", "Marketing", 3000),
    ("Kumar", "Marketing", 2000),
    ("Saif", "Sales", 4100)
  ]
schema = ["employee_name", "department", "salary"]

df = spark.createDataFrame(data=simpleData, schema = schema)
display(df)

employee_name,department,salary
James,Sales,3000
Michael,Sales,4600
Robert,Sales,4100
Maria,Finance,3000
James,Sales,3000
Scott,Finance,3300
Jen,Finance,3900
Jeff,Marketing,3000
Kumar,Marketing,2000
Saif,Sales,4100


##### Single Column

In [0]:
df_list_set = df.select(
    collect_list("salary").alias("Salary_List"),
    collect_set("salary").alias("Salary_Set")
)
display(df_list_set)

Salary_List,Salary_Set
"List(3000, 3000, 3000, 3000, 4600, 4100, 4100, 3300, 3900, 2000)","List(3000, 4600, 4100, 3300, 3900, 2000)"


##### Multiple Columns

- it can only take **one column** at a time.
- If you want to apply it on multiple columns together, you need to combine them first (for example, using **concat_ws()**), or apply **collect_set()** separately on **each column**.

**a) collect_set() for each column separately**

In [0]:
from pyspark.sql import functions as F

df_list_set = df.select(
    F.collect_set("employee_name").alias("Employee_Set"),
    F.collect_set("department").alias("Department_Set"),
    F.collect_set("salary").alias("Salary_Set")
)
display(df_list_set)


Employee_Set,Department_Set,Salary_Set
"List(James, Michael, Robert, Maria, Scott, Jen, Jeff, Kumar, Saif)","List(Sales, Finance, Marketing)","List(3000, 4600, 4100, 3300, 3900, 2000)"


**b) Collect set of combinations of multiple columns**

In [0]:
from pyspark.sql import functions as F

df_list_set = df.select(F.collect_set(
                        F.concat_ws(" : ", F.col("department"), F.col("salary").cast("string"))
                        ).alias("Dept_Salary_Set"))
display(df_list_set)

Dept_Salary_Set
"List(Sales : 3000, Sales : 4600, Sales : 4100, Finance : 3000, Finance : 3300, Finance : 3900, Marketing : 3000, Marketing : 2000)"


##### 2) groupBy: Single Column

In [0]:
df_list_set_grp = df.groupBy("department").agg(collect_list("salary").alias("Salaries_List"),
                                               collect_set("salary").alias("Salaries_Set"))
df_list_set_grp.display()

department,Salaries_List,Salaries_Set
Sales,"List(3000, 3000, 4600, 4100, 4100)","List(3000, 4600, 4100)"
Finance,"List(3000, 3300, 3900)","List(3000, 3300, 3900)"
Marketing,"List(3000, 2000)","List(3000, 2000)"


In [0]:
data = [
    ("S1", "Engineering", 2024, "Math", 85),
    ("S1", "Engineering", 2024, "Physics", 78),
    ("S1", "Engineering", 2024, "Math", 85),     # duplicate subject
    ("S2", "Engineering", 2024, "Chemistry", 72),
    ("S2", "Engineering", 2024, "Math", 80),
    ("S3", "Arts", 2024, "History", 90),
    ("S3", "Arts", 2024, "History", 90),          # duplicate
    ("S3", "Arts", 2024, "Political Science", 88),
    ("S4", "Science", 2023, "Biology", 76),
    ("S4", "Science", 2023, "Physics", 69),
    ("S4", "Science", 2023, "Biology", 76),       # duplicate
    ("S5", "Science", 2023, "Chemistry", 91),
    ("S5", "Science", 2023, "Math", 82),
    ("S5", "Science", 2024, "Math", 83)
]

columns = ["student_id", "department", "year", "subject", "marks"]

df_single = spark.createDataFrame(data, columns)
display(df_single)

student_id,department,year,subject,marks
S1,Engineering,2024,Math,85
S1,Engineering,2024,Physics,78
S1,Engineering,2024,Math,85
S2,Engineering,2024,Chemistry,72
S2,Engineering,2024,Math,80
S3,Arts,2024,History,90
S3,Arts,2024,History,90
S3,Arts,2024,Political Science,88
S4,Science,2023,Biology,76
S4,Science,2023,Physics,69


In [0]:
# Example 1: collect_set() per student
print("Unique and total subjects taken by each student:")

df_single_col = df_single.groupBy("student_id").agg(
    collect_list("subject").alias("total_subjects_list"),
    collect_set("subject").alias("unique_subjects_set")
)

display(df_single_col)

Unique and total subjects taken by each student:


student_id,total_subjects_list,unique_subjects_set
S1,"List(Math, Math, Physics)","List(Math, Physics)"
S2,"List(Chemistry, Math)","List(Chemistry, Math)"
S3,"List(History, History, Political Science)","List(History, Political Science)"
S4,"List(Biology, Biology, Physics)","List(Biology, Physics)"
S5,"List(Chemistry, Math, Math)","List(Chemistry, Math)"


In [0]:
# Example 2: collect_set() per department and year
print("Unique subjects offered in each department per year:")

df_mulple_col = df_single.groupBy("department", "year").agg(
    collect_list("subject").alias("total_subjects_list"),
    collect_set("subject").alias("unique_subjects_set"))
display(df_mulple_col)

Unique subjects offered in each department per year:


department,year,total_subjects_list,unique_subjects_set
Engineering,2024,"List(Math, Math, Math, Physics, Chemistry)","List(Math, Physics, Chemistry)"
Arts,2024,"List(History, History, Political Science)","List(History, Political Science)"
Science,2023,"List(Biology, Biology, Physics, Chemistry, Math)","List(Biology, Physics, Chemistry, Math)"
Science,2024,List(Math),List(Math)


In [0]:
# Example 3: collect_set() inside SQL query
df_mulple_col.createOrReplaceTempView("student_table")

# Check the columns in the DataFrame
print(df_mulple_col.columns)

sql_query = """
SELECT department,
       year,
       collect_set(total_subjects_list) AS unique_subjects,
       collect_set(unique_subjects_set) AS subjects_set
FROM student_table
GROUP BY department, year
ORDER BY department, year
"""

display(spark.sql(sql_query))

['department', 'year', 'total_subjects_list', 'unique_subjects_set']


department,year,unique_subjects,subjects_set
Arts,2024,"List(List(History, History, Political Science))","List(List(History, Political Science))"
Engineering,2024,"List(List(Math, Math, Math, Physics, Chemistry))","List(List(Math, Physics, Chemistry))"
Science,2023,"List(List(Biology, Biology, Physics, Chemistry, Math))","List(List(Biology, Physics, Chemistry, Math))"
Science,2024,List(List(Math)),List(List(Math))


##### 3) groupBy: Multiple Columns

In [0]:
data = [
    ("A", "Math", "Science", 2023, 85, "Physics"),
    ("A", "Math", "Science", 2023, 92, "Chemistry"),
    ("A", "Math", "Science", 2023, 85, "Physics"),     # duplicate record
    ("B", "Arts", "Humanities", 2024, 76, "History"),
    ("B", "Arts", "Humanities", 2024, 80, "Civics"),
    ("B", "Arts", "Humanities", 2024, 76, "History"),  # duplicate record
    ("C", "Commerce", "Business", 2023, 91, "Economics"),
    ("C", "Commerce", "Business", 2023, 95, "Accounts"),
    ("C", "Commerce", "Business", 2023, 91, "Economics"),
    ("D", "Science", "Biology", 2024, 88, "Botany"),
    ("D", "Science", "Biology", 2024, 90, "Zoology"),
    ("E", "Science", "Physics", 2023, 94, "Quantum Mechanics"),
]

columns = ["student_id", "department", "stream", "year", "marks", "subject"]

df_mulple_set = spark.createDataFrame(data, columns)
display(df_mulple_set)

student_id,department,stream,year,marks,subject
A,Math,Science,2023,85,Physics
A,Math,Science,2023,92,Chemistry
A,Math,Science,2023,85,Physics
B,Arts,Humanities,2024,76,History
B,Arts,Humanities,2024,80,Civics
B,Arts,Humanities,2024,76,History
C,Commerce,Business,2023,91,Economics
C,Commerce,Business,2023,95,Accounts
C,Commerce,Business,2023,91,Economics
D,Science,Biology,2024,88,Botany


In [0]:
result_df = df_mulple_set.groupBy("department", "year") \
              .agg(
                  collect_set("student_id").alias("unique_students"),
                  collect_set("subject").alias("unique_subjects"),
                  collect_set("marks").alias("unique_marks")
              )

print("Aggregated Results using collect_set():")
result_df.display()

Aggregated Results using collect_set():


department,year,unique_students,unique_subjects,unique_marks
Math,2023,List(A),"List(Physics, Chemistry)","List(85, 92)"
Arts,2024,List(B),"List(History, Civics)","List(76, 80)"
Commerce,2023,List(C),"List(Economics, Accounts)","List(91, 95)"
Science,2024,List(D),"List(Botany, Zoology)","List(88, 90)"
Science,2023,List(E),List(Quantum Mechanics),List(94)


In [0]:
result_df.createOrReplaceTempView("student_table_multi")

# Check the columns in the DataFrame
print(result_df.columns)

sql_result_multi = """
SELECT department,
       year,
       unique_students,
       unique_subjects,
       unique_marks
FROM student_table_multi
ORDER BY department
"""

display(spark.sql(sql_result_multi))

['department', 'year', 'unique_students', 'unique_subjects', 'unique_marks']


department,year,unique_students,unique_subjects,unique_marks
Arts,2024,List(B),"List(History, Civics)","List(76, 80)"
Commerce,2023,List(C),"List(Economics, Accounts)","List(91, 95)"
Math,2023,List(A),"List(Physics, Chemistry)","List(85, 92)"
Science,2024,List(D),"List(Botany, Zoology)","List(88, 90)"
Science,2023,List(E),List(Quantum Mechanics),List(94)


In [0]:
nested_group = df_mulple_set.groupBy("department", "stream") \
                 .agg(
                     collect_set("student_id").alias("students"),
                     collect_set("subject").alias("subjects"),
                     collect_set("year").alias("years")
                 )

nested_group.display()

department,stream,students,subjects,years
Math,Science,List(A),"List(Physics, Chemistry)",List(2023)
Arts,Humanities,List(B),"List(History, Civics)",List(2024)
Commerce,Business,List(C),"List(Economics, Accounts)",List(2023)
Science,Biology,List(D),"List(Botany, Zoology)",List(2024)
Science,Physics,List(E),List(Quantum Mechanics),List(2023)
