The filter() function in PySpark is used to filter rows in a DataFrame, based on a condition. It returns only the rows that satisfy the condition.

You can use filter() or where() - they work the same way.

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("PySparkFilterFunction").getOrCreate()


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/09/08 07:38:06 WARN Utils: Your hostname, KLZPC0015, resolves to a loopback address: 127.0.1.1; using 172.25.17.96 instead (on interface eth0)
25/09/08 07:38:06 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/08 07:38:17 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/09/08 07:38:19 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [9]:
# create sample dataframes
data = [
    (1, "Manta", 75000, "IT", 24),
    (2, "Dipankar", 30000, "Post Master", 27),
    (3, "Souvik", 60000, "Army Officer", 27),
    (4, "Soukarjya", 45000, "BDO", None),
    (5, "Arvind", 35000, "Business Data Analyst", 28),
    (6, "Prodipta", 25000, "Data Analyst", 28),
    (7, "Padma", 20000, "Data Analyst", 27),
    (8, "Panta", 125000, "Business Analyst", 27),
    (9, "Sougato", 25000, "Para Medical Stuff", 29)
]

columns = ["id", "name", "salary", "department", "age"]

df = spark.createDataFrame(data, schema=columns)
df.show()




+---+---------+------+--------------------+----+
| id|     name|salary|          department| age|
+---+---------+------+--------------------+----+
|  1|    Manta| 75000|                  IT|  24|
|  2| Dipankar| 30000|         Post Master|  27|
|  3|   Souvik| 60000|        Army Officer|  27|
|  4|Soukarjya| 45000|                 BDO|NULL|
|  5|   Arvind| 35000|Business Data Ana...|  28|
|  6| Prodipta| 25000|        Data Analyst|  28|
|  7|    Padma| 20000|        Data Analyst|  27|
|  8|    Panta|125000|    Business Analyst|  27|
|  9|  Sougato| 25000|  Para Medical Stuff|  29|
+---+---------+------+--------------------+----+



                                                                                

In [10]:
# filter() - Filter rows where Department is 'Data Analyst'
print("Employees in Data Analyst department")

df.filter(df.department == "Data Analyst").show()


Employees in Data Analyst department


[Stage 13:>                                                         (0 + 3) / 3]

+---+--------+------+------------+---+
| id|    name|salary|  department|age|
+---+--------+------+------------+---+
|  6|Prodipta| 25000|Data Analyst| 28|
|  7|   Padma| 20000|Data Analyst| 27|
+---+--------+------+------------+---+



                                                                                

In [11]:
# Filter rows where salary is greater than 50000

print("Employees with salary greater than 50000: ")
df.filter(df.salary > 50000).show()


Employees with salary greater than 50000: 


[Stage 15:>                                                         (0 + 3) / 3]

+---+------+------+----------------+---+
| id|  name|salary|      department|age|
+---+------+------+----------------+---+
|  1| Manta| 75000|              IT| 24|
|  3|Souvik| 60000|    Army Officer| 27|
|  8| Panta|125000|Business Analyst| 27|
+---+------+------+----------------+---+



                                                                                

In [12]:
# Filter rows with multiple conditions (AND)
print("Employees in Data Analyst department with salary > 20000: ")
df.filter((df.department == "Data Analyst") & (df.salary > 20000)).show()


Employees in Data Analyst department with salary > 20000: 


[Stage 17:>                                                         (0 + 3) / 3]

+---+--------+------+------------+---+
| id|    name|salary|  department|age|
+---+--------+------+------------+---+
|  6|Prodipta| 25000|Data Analyst| 28|
+---+--------+------+------------+---+



                                                                                

In [15]:
# filter rows with mutiple conditions (OR)
print("Employees in age 27 OR salary > 50000: ")
df.filter((df.age == 27) | (df.salary > 50000)).show()

# another method - Cast the column safely with try_cast
from pyspark.sql.functions import expr

df = df.withColumn("age", expr("try_cast(age as int)"))
df.filter((df.age==27)| (df.salary > 50000)).show()

# another method - Filter out bad values before comparing
df.filter(
    ((df.age == 27) | (df.salary >50000)) & (df.age.isNotNull())
).show()


Employees in age 27 OR salary > 50000: 


                                                                                

+---+--------+------+----------------+---+
| id|    name|salary|      department|age|
+---+--------+------+----------------+---+
|  1|   Manta| 75000|              IT| 24|
|  2|Dipankar| 30000|     Post Master| 27|
|  3|  Souvik| 60000|    Army Officer| 27|
|  7|   Padma| 20000|    Data Analyst| 27|
|  8|   Panta|125000|Business Analyst| 27|
+---+--------+------+----------------+---+



                                                                                

+---+--------+------+----------------+---+
| id|    name|salary|      department|age|
+---+--------+------+----------------+---+
|  1|   Manta| 75000|              IT| 24|
|  2|Dipankar| 30000|     Post Master| 27|
|  3|  Souvik| 60000|    Army Officer| 27|
|  7|   Padma| 20000|    Data Analyst| 27|
|  8|   Panta|125000|Business Analyst| 27|
+---+--------+------+----------------+---+



                                                                                

+---+--------+------+----------------+---+
| id|    name|salary|      department|age|
+---+--------+------+----------------+---+
|  1|   Manta| 75000|              IT| 24|
|  2|Dipankar| 30000|     Post Master| 27|
|  3|  Souvik| 60000|    Army Officer| 27|
|  7|   Padma| 20000|    Data Analyst| 27|
|  8|   Panta|125000|Business Analyst| 27|
+---+--------+------+----------------+---+



In [17]:
# filter() with isin() to filter multiple names
print("Employees whose names are either Padma or Panta: ")
df.filter(df.name.isin("Padma", "Panta")).show()


Employees whose names are either Padma or Panta: 




+---+-----+------+----------------+---+
| id| name|salary|      department|age|
+---+-----+------+----------------+---+
|  7|Padma| 20000|    Data Analyst| 27|
|  8|Panta|125000|Business Analyst| 27|
+---+-----+------+----------------+---+



                                                                                