## Importing Libraries

In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

**2153. The Number of Passengers in Each Bus II (Hard)**

**Table: Buses**

| Column Name  | Type |
|--------------|------|
| bus_id       | int  |
| arrival_time | int  |
| capacity     | int  |

bus_id contains unique values.
Each row of this table contains information about the arrival time of a bus at the LeetCode station and its capacity (the number of empty seats it has).
No two buses will arrive at the same time and all bus capacities will be positive integers.
 
**Table: Passengers**

| Column Name  | Type |
|--------------|------|
| passenger_id | int  |
| arrival_time | int  |

passenger_id contains unique values.
Each row of this table contains information about the arrival time of a passenger at the LeetCode station.
 
Buses and passengers arrive at the LeetCode station. If a bus arrives at the station at a time tbus and a passenger arrived at a time tpassenger where tpassenger <= tbus and the passenger did not catch any bus, the passenger will use that bus. In addition, each bus has a capacity. If at the moment the bus arrives at the station there are more passengers waiting than its capacity capacity, only capacity passengers will use the bus.

**Write a solution to report the number of users that used each bus.**

Return the result table ordered by bus_id in ascending order.

The result format is in the following example.

**Example 1:**

**Input:** 

**Buses table:**

| bus_id | arrival_time | capacity |
|--------|--------------|----------|
| 1      | 2            | 1        |
| 2      | 4            | 10       |
| 3      | 7            | 2        |

**Passengers table:**
| passenger_id | arrival_time |
|--------------|--------------|
| 11           | 1            |
| 12           | 1            |
| 13           | 5            |
| 14           | 6            |
| 15           | 7            |

**Output:** 
| bus_id | passengers_cnt |
|--------|----------------|
| 1      | 1              |
| 2      | 1              |
| 3      | 2              |

**Explanation:** 
- Passenger 11 arrives at time 1.
- Passenger 12 arrives at time 1.
- Bus 1 arrives at time 2 and collects passenger 11 as it has one empty seat.
- Bus 2 arrives at time 4 and collects passenger 12 as it has ten empty seats.
- Passenger 12 arrives at time 5.
- Passenger 13 arrives at time 6.
- Passenger 14 arrives at time 7.
- Bus 3 arrives at time 7 and collects passengers 12 and 13 as it has two empty seats.

In [0]:
buses_data_2153 = [
    (1, 2, 1),
    (2, 4, 10),
    (3, 7, 2)
]

buses_columns_2153 = ["bus_id", "arrival_time", "capacity"]
buses_df_2153 = spark.createDataFrame(buses_data_2153, buses_columns_2153)
buses_df_2153.show()

passengers_data_2153 = [
    (11, 1),
    (12, 1),
    (13, 5),
    (14, 6),
    (15, 7)
]

passengers_columns_2153 = ["passenger_id", "arrival_time"]
passengers_df_2153 = spark.createDataFrame(passengers_data_2153, passengers_columns_2153)
passengers_df_2153.show()

+------+------------+--------+
|bus_id|arrival_time|capacity|
+------+------------+--------+
|     1|           2|       1|
|     2|           4|      10|
|     3|           7|       2|
+------+------------+--------+

+------------+------------+
|passenger_id|arrival_time|
+------------+------------+
|          11|           1|
|          12|           1|
|          13|           5|
|          14|           6|
|          15|           7|
+------------+------------+



In [0]:
w_pass = Window.orderBy(col("arrival_time"), col("passenger_id"))

In [0]:
passengers_ranked_df_2153 = passengers_df_2153\
                                .withColumn("passenger_rank", row_number().over(w_pass)) \
                                    .withColumnRenamed("arrival_time", "passenger_arrival_time")



In [0]:
w_bus = Window.orderBy("arrival_time").rowsBetween(Window.unboundedPreceding, 0)

In [0]:
buses_cumulative_df_2153 = buses_df_2153\
                                .withColumn("cum_capacity", sum("capacity").over(w_bus))



In [0]:
joined_df_2153 = passengers_ranked_df_2153.crossJoin(buses_cumulative_df_2153) \
    .where(
        (col("passenger_rank") <= col("cum_capacity")) & 
        (col("arrival_time") >= col("passenger_arrival_time"))
    )



In [0]:
w_bus_assign = Window.partitionBy("bus_id").orderBy("passenger_rank")

In [0]:
assigned_df_2153 = joined_df_2153\
                        .withColumn("rn", row_number().over(w_bus_assign)) \
                            .where(col("rn") <= col("capacity")) \
                                .select("bus_id", "passenger_id")



In [0]:
bus_counts_df_2153 = assigned_df_2153\
                            .groupBy("bus_id")\
                                .agg(
                                    count("passenger_id").alias("passengers_cnt")
                                    )



In [0]:
buses_df_2153\
    .join(bus_counts_df_2153, "bus_id", "left").fillna(0).orderBy("bus_id")\
        .select("bus_id","passengers_cnt").show()



+------+--------------+
|bus_id|passengers_cnt|
+------+--------------+
|     1|             1|
|     2|             2|
|     3|             2|
+------+--------------+

