In [1]:
import pandas as pd

!pip install pyspark

from pyspark.sql import SparkSession
# Create a SparkSession (without a specified name)
spark = SparkSession.builder.getOrCreate()
spark.conf.set('spark.sql.repl.eagerEval.enabled', True) #for simple calls and better display



#  Understanding `expr()` vs. Strings in PySpark

PySpark provides flexibility in how we pass expressions, but not all methods handle strings and expressions in the same way. This guide explains when you can use raw strings and when you must use `expr()`.

---

##  **1. `filter()` (or `where()`) Accepts Both Strings and `expr()`**
- `filter()` automatically converts a SQL expression string into a `Column`, so using `expr()` explicitly is optional.

 **Both of these work the same way:**
```python
dfa = dfa.filter(expr("player_date_rank + device_id = 1"))
dfa = dfa.filter("player_date_rank + device_id = 1")  # No need for expr()
```
Since filter() interprets string arguments as SQL expressions internally, both versions work.

## **2. withColumn() Requires a Column Object**
Unlike filter(), withColumn() does not automatically convert strings into SQL expressions.
If you pass a raw string, Spark throws an error:
"Argument col should be a Column, got str."
```python
dfa = dfa.withColumn("abc", "games_played + device_id")  # ❌ Error

from pyspark.sql.functions import expr
dfa = dfa.withColumn("abc", expr("games_played + device_id"))  # ✅ Works! Use expr() to convert the string into a Column:

dfa = dfa.withColumn("abc", dfa["games_played"] + dfa["device_id"])  # ✅ Works!  Or use DataFrame column operations:
```
| Method          | Can Use String Directly? | Requires `expr()` or Column Object? | Notes                                                                 |
|-----------------|--------------------------|-------------------------------------|-----------------------------------------------------------------------|
| `filter()` / `where()` | ✅ Yes                 | ❌ No                              | Accepts SQL expressions as strings. `expr()` is optional.             |
| `select()`      | ❌ No                    | ✅ Yes                             | Must use `expr()` or `col()`, unless selecting column names.          |
| `withColumn()`  | ❌ No                    | ✅ Yes                             | Always requires a Column object, use `expr()` if needed.              |
| `groupBy().agg()` | ❌ No                    | ✅ Yes                             | Aggregation functions require Column objects.                         |
| `orderBy()` / `sort()` | ✅ Yes                 | ❌ No                              | Strings (column names) are allowed, but complex expressions need `expr()` |


#### Summary: withColumn(param_1: `<String>`new_column_name, param_2: `<Col>` defining_col_object), while filter will accept both Column class or String with SQL style expression and auto convert it to a Column class

# Problem 20

Write a solution to report the fraction of players that logged in again on the day after the day they first logged in, rounded to 2 decimal places. In other words, you need to count the number of players that logged in for at least two consecutive days starting from their first login date, then divide that number by the total number of players.

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
from datetime import datetime

# Initialize Spark session
spark = SparkSession.builder \
    .appName("Create DataFrames") \
    .getOrCreate()

# Define schema for Activity table
activity_schema = StructType([
    StructField("player_id", IntegerType(), True),
    StructField("device_id", IntegerType(), True),
    StructField("event_date", DateType(), True),
    StructField("games_played", IntegerType(), True)
])

# Sample data for Activity table
activity_data = [
    (1, 2, datetime.strptime("2016-03-01", "%Y-%m-%d"), 5),
    (1, 2, datetime.strptime("2016-03-02", "%Y-%m-%d"), 6),
    (2, 3, datetime.strptime("2017-06-25", "%Y-%m-%d"), 1),
    (3, 1, datetime.strptime("2016-03-02", "%Y-%m-%d"), 0),
    (3, 4, datetime.strptime("2018-07-03", "%Y-%m-%d"), 5)
]

# Create Activity DataFrame
activity_df = spark.createDataFrame(activity_data, schema=activity_schema)

# Show the Activity DataFrame
activity_df.show()

+---------+---------+----------+------------+
|player_id|device_id|event_date|games_played|
+---------+---------+----------+------------+
|        1|        2|2016-03-01|           5|
|        1|        2|2016-03-02|           6|
|        2|        3|2017-06-25|           1|
|        3|        1|2016-03-02|           0|
|        3|        4|2018-07-03|           5|
+---------+---------+----------+------------+



In [3]:
from pyspark.sql.functions import date_sub,date_add,lead,lag,rank,expr,avg,round
from pyspark.sql.window import Window

dfa = activity_df

window_spec = Window.partitionBy("player_id").orderBy("event_date")

dfa = dfa.withColumn("player_date_rank",rank().over(window_spec))\
    .withColumn("next_day_in_calendar",date_add("event_date",1))\
    .withColumn("next_day_player",lead("event_date").over(window_spec))

dfa = dfa.filter(expr("player_date_rank = 1"))\
    .withColumn("is_next_day_logged",expr("case when next_day_in_calendar = next_day_player then 1 else 0 end "))

dfa.agg(round(avg("is_next_day_logged"),2).alias("fraction")).show()

+--------+
|fraction|
+--------+
|    0.33|
+--------+



# Problem 21

Write a solution to report the customer ids from the Customer table that bought all the products in the Product table.

Return the result table in any order.

In [4]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

# Initialize Spark session
spark = SparkSession.builder \
    .appName("Create DataFrames") \
    .getOrCreate()

# Define schema for Customer table
customer_schema = StructType([
    StructField("customer_id", IntegerType(), True),
    StructField("product_key", IntegerType(), True)
])

# Sample data for Customer table
customer_data = [
    (1, 5),
    (1, 5),
    (2, 6),
    (3, 5),
    (3, 6),
    (1, 6)
]

# Create Customer DataFrame
customer_df = spark.createDataFrame(customer_data, schema=customer_schema)

# Define schema for Product table
product_schema = StructType([
    StructField("product_key", IntegerType(), True)
])

# Sample data for Product table
product_data = [
    (5,),
    (6,)
]

# Create Product DataFrame
product_df = spark.createDataFrame(product_data, schema=product_schema)

# Show the Customer DataFrame
customer_df.show()

# Show the Product DataFrame
product_df.show()

+-----------+-----------+
|customer_id|product_key|
+-----------+-----------+
|          1|          5|
|          1|          5|
|          2|          6|
|          3|          5|
|          3|          6|
|          1|          6|
+-----------+-----------+

+-----------+
|product_key|
+-----------+
|          5|
|          6|
+-----------+



In [5]:
from pyspark.sql.functions import sum,count,col
dfc = customer_df
dfp = product_df

dfc = dfc.distinct()
dfc = dfc.groupBy("customer_id").agg(count("*").alias("products"))
dfc.show()

dfp = dfp.agg(count('*').alias("total_prod_count"))
dfp.show()

# Both the below give the same results, we can see how an expr() function is making it easier to construct JOIN conditions
# res = dfp.alias("dfp").join(dfc.alias("dfc"),col("dfc.products") == col("dfp.total_prod_count"),how="inner")
res = dfp.alias("dfp").join(dfc.alias("dfc"),expr("dfc.products = dfp.total_prod_count"),how="inner")

res.select("dfc.customer_id").show()

+-----------+--------+
|customer_id|products|
+-----------+--------+
|          1|       2|
|          3|       2|
|          2|       1|
+-----------+--------+

+----------------+
|total_prod_count|
+----------------+
|               2|
+----------------+

+-----------+
|customer_id|
+-----------+
|          1|
|          3|
+-----------+



# Problem 22

Find all numbers that appear at least three times consecutively.

In [6]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType

# Initialize Spark session
spark = SparkSession.builder.appName("PracticePySpark").getOrCreate()

# Define schema for the Logs table
logs_schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("num", IntegerType(), True)
])

# Sample data for the Logs table
logs_data = [
    (1, 1),
    (2, 1),
    (3, 1),
    (4, 2),
    (5, 1),
    (6, 2),
    (7, 2)
]

# Create DataFrame for Logs table
logs_df = spark.createDataFrame(data=logs_data, schema=logs_schema)

# Show the Logs DataFrame to verify
logs_df.show()

+---+---+
| id|num|
+---+---+
|  1|  1|
|  2|  1|
|  3|  1|
|  4|  2|
|  5|  1|
|  6|  2|
|  7|  2|
+---+---+



In [7]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lag,lead

dfl = logs_df
window_spec = Window.orderBy("id")

res = dfl.withColumn("is_next_two_cols_same",\
            (lead("num",1).over(window_spec) == lead("num",2).over(window_spec) ) & (lead("num",1).over(window_spec) == col("num"))\
                    )
res.filter("is_next_two_cols_same = TRUE").show()

+---+---+---------------------+
| id|num|is_next_two_cols_same|
+---+---+---------------------+
|  1|  1|                 true|
+---+---+---------------------+



# Problem 23

Write a solution to find the prices of all products on 2019-08-16. Assume the price of all products before any change is 10.

In [8]:
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, IntegerType, DateType

# Sample data for the SQL tables
data = [
    (1, 20, "2019-08-14"),
    (2, 50, "2019-08-14"),
    (1, 30, "2019-08-15"),
    (1, 35, "2019-08-16"),
    (2, 65, "2019-08-17"),
    (3, 20, "2019-08-18"),
    (3, 20, "2019-08-28")
]

# Define the schema for the DataFrame
schema = StructType([
    StructField("product_id", IntegerType(), True),
    StructField("new_price", IntegerType(), True),
    StructField("change_date_str", StringType(), True)
])

# Create the DataFrame
df = spark.createDataFrame(data, schema)

df = df.withColumn("change_date", F.to_date(F.col("change_date_str"),"yyyy-MM-dd")).drop("change_date_str")

# Show the DataFrame again to verify
df.show()

+----------+---------+-----------+
|product_id|new_price|change_date|
+----------+---------+-----------+
|         1|       20| 2019-08-14|
|         2|       50| 2019-08-14|
|         1|       30| 2019-08-15|
|         1|       35| 2019-08-16|
|         2|       65| 2019-08-17|
|         3|       20| 2019-08-18|
|         3|       20| 2019-08-28|
+----------+---------+-----------+



In [9]:
from pyspark.sql.window import Window
from pyspark.sql.functions import col,rank

window_spec = Window.partitionBy("product_id").orderBy(col("change_date").desc())

df_nearest_date =  df.filter("change_date <= '2019-08-16'").withColumn("rank_date",rank().over(window_spec))\
        .filter("rank_date = 1")
df_nearest_date.show()

res = df_nearest_date.alias("df_nearest_date")\
    .join(df.select("product_id").distinct().alias("df"),col("df.product_id")==col("df_nearest_date.product_id"),"outer")

res.select("df.product_id",expr("coalesce(new_price,10) as price")).show()

+----------+---------+-----------+---------+
|product_id|new_price|change_date|rank_date|
+----------+---------+-----------+---------+
|         1|       35| 2019-08-16|        1|
|         2|       50| 2019-08-14|        1|
+----------+---------+-----------+---------+

+----------+-----+
|product_id|price|
+----------+-----+
|         1|   35|
|         2|   50|
|         3|   10|
+----------+-----+



# Problem 24

There is a queue of people waiting to board a bus. However, the bus has a weight limit of 1000 kilograms, so there may be some people who cannot board.

Write a solution to find the person_name of the last person that can fit on the bus without exceeding the weight limit. The test cases are generated such that the first person does not exceed the weight limit.

In [10]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

# Initialize Spark session
spark = SparkSession.builder.appName("QueueTable").getOrCreate()

# Define schema
queue_schema = StructType([
    StructField("person_id", IntegerType(), False),
    StructField("person_name", StringType(), False),
    StructField("weight", IntegerType(), False),
    StructField("turn", IntegerType(), False)
])

# Create DataFrame
data = [
    (5, "Alice", 250, 1),
    (4, "Bob", 175, 5),
    (3, "Alex", 350, 2),
    (6, "John Cena", 400, 3),
    (1, "Winston", 500, 6),
    (2, "Marie", 200, 4)
]

queue_df = spark.createDataFrame(data, schema=queue_schema)

# Show DataFrame
queue_df.show()


+---------+-----------+------+----+
|person_id|person_name|weight|turn|
+---------+-----------+------+----+
|        5|      Alice|   250|   1|
|        4|        Bob|   175|   5|
|        3|       Alex|   350|   2|
|        6|  John Cena|   400|   3|
|        1|    Winston|   500|   6|
|        2|      Marie|   200|   4|
+---------+-----------+------+----+



In [11]:
from pyspark.sql.window import Window
from pyspark.sql.functions import sum,lit

dfq = queue_df

window_spec = Window.orderBy("turn").rowsBetween(Window.unboundedPreceding,Window.currentRow)

dfq = dfq.withColumn("cumulative_sum",sum("weight").over(window_spec))
dfq.show()

dfq = dfq.filter("cumulative_sum <= 1000")
dfq.show()

max_turn = dfq.selectExpr("max(turn)")
max_turn = max_turn.collect()[0][0]
print(max_turn)

dfq.filter(f"turn = {max_turn}").select("person_name").show()

+---------+-----------+------+----+--------------+
|person_id|person_name|weight|turn|cumulative_sum|
+---------+-----------+------+----+--------------+
|        5|      Alice|   250|   1|           250|
|        3|       Alex|   350|   2|           600|
|        6|  John Cena|   400|   3|          1000|
|        2|      Marie|   200|   4|          1200|
|        4|        Bob|   175|   5|          1375|
|        1|    Winston|   500|   6|          1875|
+---------+-----------+------+----+--------------+

+---------+-----------+------+----+--------------+
|person_id|person_name|weight|turn|cumulative_sum|
+---------+-----------+------+----+--------------+
|        5|      Alice|   250|   1|           250|
|        3|       Alex|   350|   2|           600|
|        6|  John Cena|   400|   3|          1000|
+---------+-----------+------+----+--------------+

3
+-----------+
|person_name|
+-----------+
|  John Cena|
+-----------+



# Problem 25

Write a solution to swap the seat id of every two consecutive students. If the number of students is odd, the id of the last student is not swapped.

In [12]:
# Define schema
seat_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("student", StringType(), False)
])

# Create DataFrame
seat_data = [
    (1, "Abbot"),
    (2, "Doris"),
    (3, "Emerson"),
    (4, "Green"),
    (5, "Jeames")
]

seat_df = spark.createDataFrame(seat_data, schema=seat_schema)

# Show DataFrame
seat_df.show()


+---+-------+
| id|student|
+---+-------+
|  1|  Abbot|
|  2|  Doris|
|  3|Emerson|
|  4|  Green|
|  5| Jeames|
+---+-------+



In [13]:
from pyspark.sql.window import Window
from pyspark.sql.functions import col,lag,lead,when,coalesce

dfs = seat_df

window_spec = Window.orderBy("id")

res = dfs.withColumn("lag",lag("student").over(window_spec)).withColumn("lead",lead("student").over(window_spec))
res = res.withColumn("swapped_names",when(col("id")%2==0,col("lag")).otherwise(coalesce(col("lead"),col("student")) ))
res.select("id","swapped_names").show()

+---+-------------+
| id|swapped_names|
+---+-------------+
|  1|        Doris|
|  2|        Abbot|
|  3|        Green|
|  4|      Emerson|
|  5|       Jeames|
+---+-------------+



# Problem 26

Find the name of the user who has rated the greatest number of movies. In case of a tie, return the lexicographically smaller user name.

Find the movie name with the highest average rating in February 2020. In case of a tie, return the lexicographically smaller movie name.

In [14]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
import datetime

# Initialize Spark session
spark = SparkSession.builder.appName("SQLToPySpark").getOrCreate()

# Data for Movies table
movies_data = [
    (1, "Avengers"),
    (2, "Frozen 2"),
    (3, "Joker")
]
movies_schema = StructType([
    StructField("movie_id", IntegerType(), False),
    StructField("title", StringType(), False)
])
movies_df = spark.createDataFrame(movies_data, schema=movies_schema)

# Data for Users table
users_data = [
    (1, "Daniel"),
    (2, "Monica"),
    (3, "Maria"),
    (4, "James")
]
users_schema = StructType([
    StructField("user_id", IntegerType(), False),
    StructField("name", StringType(), False)
])
users_df = spark.createDataFrame(users_data, schema=users_schema)

# Data for MovieRating table
movie_ratings_data = [
    (1, 1, 3, datetime.datetime.strptime("2020-01-12", "%Y-%m-%d")),
    (1, 2, 4, datetime.datetime.strptime("2020-02-11", "%Y-%m-%d")),
    (1, 3, 2, datetime.datetime.strptime("2020-02-12", "%Y-%m-%d")),
    (1, 4, 1, datetime.datetime.strptime("2020-01-01", "%Y-%m-%d")),
    (2, 1, 5, datetime.datetime.strptime("2020-02-17", "%Y-%m-%d")),
    (2, 2, 2, datetime.datetime.strptime("2020-02-01", "%Y-%m-%d")),
    (2, 3, 2, datetime.datetime.strptime("2020-03-01", "%Y-%m-%d")),
    (3, 1, 3, datetime.datetime.strptime("2020-02-22", "%Y-%m-%d")),
    (3, 2, 4, datetime.datetime.strptime("2020-02-25", "%Y-%m-%d"))
]
movie_ratings_schema = StructType([
    StructField("movie_id", IntegerType(), False),
    StructField("user_id", IntegerType(), False),
    StructField("rating", IntegerType(), False),
    StructField("created_at", DateType(), False)
])
movie_ratings_df = spark.createDataFrame(movie_ratings_data, schema=movie_ratings_schema)

# Show dataframes (optional, can be removed if not needed)
movies_df.show()
users_df.show()
movie_ratings_df.show()

+--------+--------+
|movie_id|   title|
+--------+--------+
|       1|Avengers|
|       2|Frozen 2|
|       3|   Joker|
+--------+--------+

+-------+------+
|user_id|  name|
+-------+------+
|      1|Daniel|
|      2|Monica|
|      3| Maria|
|      4| James|
+-------+------+

+--------+-------+------+----------+
|movie_id|user_id|rating|created_at|
+--------+-------+------+----------+
|       1|      1|     3|2020-01-12|
|       1|      2|     4|2020-02-11|
|       1|      3|     2|2020-02-12|
|       1|      4|     1|2020-01-01|
|       2|      1|     5|2020-02-17|
|       2|      2|     2|2020-02-01|
|       2|      3|     2|2020-03-01|
|       3|      1|     3|2020-02-22|
|       3|      2|     4|2020-02-25|
+--------+-------+------+----------+



In [15]:
from pyspark.sql.window import Window
from pyspark.sql.functions import rank,month,year,lit,avg

dfm = movies_df
dfu = users_df
dfmr = movie_ratings_df

dfmr_1 = dfmr.groupBy("user_id").count()
dfmr_res = dfmr_1.alias("dfmr_1").join(dfu.alias("dfu"),col("dfmr_1.user_id")==col("dfu.user_id"),"inner")
dfmr_res.select("*").show()

window_spec = Window.orderBy(col("count").desc())
dfmr_res = dfmr_res.withColumn("rank",rank().over(window_spec))
dfmr_res = dfmr_res.filter("rank=1").orderBy("name").limit(1).select("name")
dfmr_res.show()

window_spec = Window.orderBy(col("avg_rating").desc())

## Highest rated movie for the month of Feb, 2020
dfmr_2 = dfmr.filter((month("created_at")==2) & (year("created_at")==2020) )
# dfmr_2.show()
dfmr_2 = dfmr_2.groupBy("movie_id").agg(avg("rating").alias("avg_rating"))
dfmr_2 = dfmr_2.alias("dfmr_2").join(dfm.alias("dfm"),col("dfm.movie_id")==col("dfmr_2.movie_id"),"inner")
dfmr_2 = dfmr_2.withColumn("rank",rank().over(window_spec))
dfmr_2 = dfmr_2.filter("rank=1").orderBy("title").limit(1).select("title")
dfmr_2.show()

dfmr_res.union(dfmr_2).show()

+-------+-----+-------+------+
|user_id|count|user_id|  name|
+-------+-----+-------+------+
|      1|    3|      1|Daniel|
|      2|    3|      2|Monica|
|      3|    2|      3| Maria|
|      4|    1|      4| James|
+-------+-----+-------+------+

+------+
|  name|
+------+
|Daniel|
+------+

+--------+
|   title|
+--------+
|Frozen 2|
+--------+

+--------+
|    name|
+--------+
|  Daniel|
|Frozen 2|
+--------+



# Problem 27

In [16]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, DateType
from datetime import datetime


# Define schema
schema = StructType([
    StructField("requester_id", IntegerType(), True),
    StructField("accepter_id", IntegerType(), True),
    StructField("accept_date", DateType(), True)
])

# Create data
data = [
    (1, 2, datetime.strptime("2016/06/03", "%Y/%m/%d")),
    (1, 3, datetime.strptime("2016/06/08", "%Y/%m/%d")),
    (2, 3, datetime.strptime("2016/06/08", "%Y/%m/%d")),
    (3, 4, datetime.strptime("2016/06/09", "%Y/%m/%d"))
]

# Create DataFrame
df = spark.createDataFrame(data, schema=schema)

# Show DataFrame
df.show()


+------------+-----------+-----------+
|requester_id|accepter_id|accept_date|
+------------+-----------+-----------+
|           1|          2| 2016-06-03|
|           1|          3| 2016-06-08|
|           2|          3| 2016-06-08|
|           3|          4| 2016-06-09|
+------------+-----------+-----------+



In [17]:
from pyspark.sql.functions import coalesce,expr
from pyspark.sql.window import Window

df_r = df.groupBy("requester_id").count()
df_r.show()

df_a = df.groupBy("accepter_id").count()
df_a.show()

df_res = df_r.alias("df_r").join(df_a.alias("df_a"),col("requester_id")==col("accepter_id"),"outer")
df_res.show()

df_res = df_res.withColumn("user_id",coalesce("df_r.requester_id","df_a.accepter_id"))\
                .withColumn("total_friends",expr("coalesce(df_r.count,0)+coalesce(df_a.count,0)"))
df_res.show()

df_res = df_res.select("user_id","total_friends")
df_res.show()

window_spec = Window.orderBy(col("total_friends").desc())

df_res.withColumn("rank",rank().over(window_spec)).filter("rank = 1").show()

+------------+-----+
|requester_id|count|
+------------+-----+
|           1|    2|
|           2|    1|
|           3|    1|
+------------+-----+

+-----------+-----+
|accepter_id|count|
+-----------+-----+
|          2|    1|
|          3|    2|
|          4|    1|
+-----------+-----+

+------------+-----+-----------+-----+
|requester_id|count|accepter_id|count|
+------------+-----+-----------+-----+
|           1|    2|       NULL| NULL|
|           2|    1|          2|    1|
|           3|    1|          3|    2|
|        NULL| NULL|          4|    1|
+------------+-----+-----------+-----+

+------------+-----+-----------+-----+-------+-------------+
|requester_id|count|accepter_id|count|user_id|total_friends|
+------------+-----+-----------+-----+-------+-------------+
|           1|    2|       NULL| NULL|      1|            2|
|           2|    1|          2|    1|      2|            2|
|           3|    1|          3|    2|      3|            3|
|        NULL| NULL|          4|

# Problem 28

A company's executives are interested in seeing who earns the most money in each of the company's departments. A high earner in a department is an employee who has a salary in the top three unique salaries for that department.

Write a solution to find the employees who are high earners in each of the departments.

In [18]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

# Initialize Spark Session
spark = SparkSession.builder.appName("Practice").getOrCreate()

# Define schemas
employee_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("salary", IntegerType(), False),
    StructField("departmentId", IntegerType(), False)
])

department_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), False)
])

# Create DataFrames
employee_data = [
    (1, "Joe", 85000, 1),
    (2, "Henry", 80000, 2),
    (3, "Sam", 60000, 2),
    (4, "Max", 90000, 1),
    (5, "Janet", 69000, 1),
    (6, "Randy", 85000, 1),
    (7, "Will", 70000, 1)
]

department_data = [
    (1, "IT"),
    (2, "Sales")
]

employee_df = spark.createDataFrame(employee_data, schema=employee_schema)
department_df = spark.createDataFrame(department_data, schema=department_schema)

# Show DataFrames
employee_df.show()
department_df.show()


+---+-----+------+------------+
| id| name|salary|departmentId|
+---+-----+------+------------+
|  1|  Joe| 85000|           1|
|  2|Henry| 80000|           2|
|  3|  Sam| 60000|           2|
|  4|  Max| 90000|           1|
|  5|Janet| 69000|           1|
|  6|Randy| 85000|           1|
|  7| Will| 70000|           1|
+---+-----+------+------------+

+---+-----+
| id| name|
+---+-----+
|  1|   IT|
|  2|Sales|
+---+-----+



In [19]:
from pyspark.sql.functions import dense_rank
from pyspark.sql.window import Window

dfe = employee_df
dfd = department_df

window_spec = Window.partitionBy("departmentId").orderBy(col("salary").desc())

dfe_1 = dfe.withColumn("dense_rank",dense_rank().over(window_spec))
dfe_1.show()

dfe_1  = dfe_1.filter("dense_rank <= 3")
dfe_1 = dfe_1.alias("dfe_1").join(dfd.alias("dfd"),col("dfd.id")==col("dfe_1.departmentId"),"inner")
dfe_1.show()

+---+-----+------+------------+----------+
| id| name|salary|departmentId|dense_rank|
+---+-----+------+------------+----------+
|  4|  Max| 90000|           1|         1|
|  1|  Joe| 85000|           1|         2|
|  6|Randy| 85000|           1|         2|
|  7| Will| 70000|           1|         3|
|  5|Janet| 69000|           1|         4|
|  2|Henry| 80000|           2|         1|
|  3|  Sam| 60000|           2|         2|
+---+-----+------+------------+----------+

+---+-----+------+------------+----------+---+-----+
| id| name|salary|departmentId|dense_rank| id| name|
+---+-----+------+------------+----------+---+-----+
|  7| Will| 70000|           1|         3|  1|   IT|
|  6|Randy| 85000|           1|         2|  1|   IT|
|  1|  Joe| 85000|           1|         2|  1|   IT|
|  4|  Max| 90000|           1|         1|  1|   IT|
|  3|  Sam| 60000|           2|         2|  2|Sales|
|  2|Henry| 80000|           2|         1|  2|Sales|
+---+-----+------+------------+----------+---+---

# Problem 29

Write a solution to report the sum of all total investment values in 2016 tiv_2016, for all policyholders who:

i. have the same tiv_2015 value as one or more other policyholders, and

ii. are not located in the same city as any other policyholder (i.e., the (lat, lon) attribute pairs must be unique).

In [20]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType

# Initialize Spark session
spark = SparkSession.builder.appName("InsuranceTable").getOrCreate()

# Define schema
schema = StructType([
    StructField("pid", IntegerType(), True),
    StructField("tiv_2015", DoubleType(), True),
    StructField("tiv_2016", DoubleType(), True),
    StructField("lat", DoubleType(), True),
    StructField("lon", DoubleType(), True)
])

# Create data
data = [
    (1, 10.0, 5.0, 10.0, 10.0),
    (2, 20.0, 20.0, 20.0, 20.0),
    (3, 10.0, 30.0, 20.0, 20.0),
    (4, 10.0, 40.0, 40.0, 40.0)
]

# Create DataFrame
df = spark.createDataFrame(data, schema)

# Show DataFrame
df.show()


+---+--------+--------+----+----+
|pid|tiv_2015|tiv_2016| lat| lon|
+---+--------+--------+----+----+
|  1|    10.0|     5.0|10.0|10.0|
|  2|    20.0|    20.0|20.0|20.0|
|  3|    10.0|    30.0|20.0|20.0|
|  4|    10.0|    40.0|40.0|40.0|
+---+--------+--------+----+----+



In [38]:
df_1 = df.groupBy("lat","lon").agg(count("*").alias("lat_long_unique"))
df_1.show()

df_1 = df_1.filter("lat_long_unique = 1")
df_1.show()

df_2 = df.groupBy("tiv_2015").agg(count("*").alias("tiv_2015_count"))
df_2 = df_2.filter("tiv_2015_count>1")
df_2.show()

df_3 = df.alias("df").join(df_1.alias("df_1")\
                           ,( col("df_1.lat")==col("df.lat")) & (col("df_1.lon")==col("df.lon")) \
                          ,"left")\
                    .join(df_2.alias("df_2")\
                          ,(col("df_2.tiv_2015")==col("df.tiv_2015"))\
                         ,"left")

df_3 = df_3.filter("tiv_2015_count>1 AND lat_long_unique = 1")
df_3 = df_3.agg(sum("tiv_2016").alias("tiv_2016_sum"))
df_3.show()


+----+----+---------------+
| lat| lon|lat_long_unique|
+----+----+---------------+
|10.0|10.0|              1|
|20.0|20.0|              2|
|40.0|40.0|              1|
+----+----+---------------+

+----+----+---------------+
| lat| lon|lat_long_unique|
+----+----+---------------+
|10.0|10.0|              1|
|40.0|40.0|              1|
+----+----+---------------+

+--------+--------------+
|tiv_2015|tiv_2015_count|
+--------+--------------+
|    10.0|             3|
+--------+--------------+

+------------+
|tiv_2016_sum|
+------------+
|        45.0|
+------------+

