In [1]:
import findspark

findspark.init()

from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
import pyspark.sql.functions as F

conf = SparkConf().setAppName("1789").setMaster("local[4]")
spark = SparkSession.builder.config(conf = conf).getOrCreate()
spark

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/07/30 00:59:17 WARN Utils: Your hostname, de24, resolves to a loopback address: 127.0.1.1; using 192.168.0.102 instead (on interface enp0s3)
25/07/30 00:59:17 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/30 00:59:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [None]:
'''
Table: Employee

+---------------+---------+
| Column Name   |  Type   |
+---------------+---------+
| employee_id   | int     |
| department_id | int     |
| primary_flag  | varchar |
+---------------+---------+
(employee_id, department_id) is the primary key 
(combination of columns with unique values) for this table.
employee_id is the id of the employee.
department_id is the id of the department to which the employee belongs.
primary_flag is an ENUM (category) of type ('Y', 'N'). If the flag is 'Y', 
the department is the primary department for the employee. If the flag is 'N', 
the department is not the primary.
 

Employees can belong to multiple departments. When the employee joins other departments, 
they need to decide which department is their primary department. 
Note that when an employee belongs to only one department, their primary column is 'N'.

Write a solution to report all the employees with their primary department. 
For employees who belong to one department, report their only department.

Return the result table in any order.

The result format is in the following example.

 

Example 1:

Input: 
Employee table:
+-------------+---------------+--------------+
| employee_id | department_id | primary_flag |
+-------------+---------------+--------------+
| 1           | 1             | N            |
| 2           | 1             | Y            |
| 2           | 2             | N            |
| 3           | 3             | N            |
| 4           | 2             | N            |
| 4           | 3             | Y            |
| 4           | 4             | N            |
+-------------+---------------+--------------+
Output: 
+-------------+---------------+
| employee_id | department_id |
+-------------+---------------+
| 1           | 1             |
| 2           | 1             |
| 3           | 3             |
| 4           | 3             |
+-------------+---------------+
Explanation: 
- The Primary department for employee 1 is 1.
- The Primary department for employee 2 is 1.
- The Primary department for employee 3 is 3.
- The Primary department for employee 4 is 3.
'''

In [2]:
data = [
(1,1,'N'),
(2,1,'Y'),
(2,2,'N'),
(3,3,'N'),
(4,2,'N'),
(4,3,'Y'),
(4,4,'N')    
]
schema = ['employee_id','department_id','primary_flag']

In [3]:
df = spark.createDataFrame(data=data, schema=schema)
df.show()

                                                                                

+-----------+-------------+------------+
|employee_id|department_id|primary_flag|
+-----------+-------------+------------+
|          1|            1|           N|
|          2|            1|           Y|
|          2|            2|           N|
|          3|            3|           N|
|          4|            2|           N|
|          4|            3|           Y|
|          4|            4|           N|
+-----------+-------------+------------+



In [12]:
df.select(F.col("employee_id"),F.col("department_id"))\
  .where(F.col("primary_flag") == 'Y')\
  .union(
      df.select(F.col("employee_id"),F.col("department_id"))\
        .groupBy(F.col("employee_id"))
        .agg(F.count(F.col("department_id")).alias("count"))\
        .where(F.col("count") == 1)
  )\
  .show()

+-----------+-------------+
|employee_id|department_id|
+-----------+-------------+
|          2|            1|
|          4|            3|
|          1|            1|
|          3|            1|
+-----------+-------------+



## Wrong output 

in the Sql the solution is there however in pyspark Dataframe <br>
Solution is really close,  that instinct to embed both logic paths in one chained statement is sharp. <br>🚀 However, there’s one subtle catch: in the second half, after the groupBy, you're aggregating and filtering but still trying to select("employee_id", "department_id") from that result. </b>The department_id column is lost during aggregation, so the union throws a mismatch.


## Correct Solution

In [24]:
df_flag = df.groupBy(F.col("employee_id"))\
            .agg(F.count(F.col("*")).alias("count"))\
            .where(F.col("count") == 1)

df_single = df_flag.alias("e").join(df.alias("e1"), 
                               F.col("e.employee_id") == F.col("e1.employee_id"),
                               'left')\
                         .select(F.col("e.employee_id"), F.col("e1.department_id"))

df.select(F.col("employee_id"),F.col("department_id"))\
  .where(F.col("primary_flag") == 'Y')\
  .union(
      df_single.select(F.col("employee_id"),F.col("department_id"))
  ).show()

+-----------+-------------+
|employee_id|department_id|
+-----------+-------------+
|          2|            1|
|          4|            3|
|          1|            1|
|          3|            3|
+-----------+-------------+



## SQL Solution

<pre>

SELECT employee_id, department_id 
FROM Employee 
WHERE primary_flag = 'Y'
UNION 
SELECT employee_id, department_id 
FROM Employee 
GROUP BY employee_id
HAVING count(department_id) = 1
</pre>