In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder\
    .appName("array-column-join")\
    .master("local[*]").getOrCreate()

In [6]:
import pyspark.sql.functions as F

df1 = spark.createDataFrame([(2, [3, 4]), (3, [4]),(4, [3,5]),(5, [4,5]),(6, [5,4])], ["a", "b"])\
    .select("*",F.expr("posexplode(b) as (pos, key)"))
df2 = spark.createDataFrame([(3, "Three"), (4, "Four"),(5, "Five")], ["b", "c"]).withColumnRenamed("b","key")

df3 = df1.join(df2, on="key")\
    .orderBy("a", "pos")\
    .groupBy("a").agg(
    F.first("b").alias("b"),
    F.collect_list("c").alias("c")
)

df3.show()

+---+------+-------------+
|  a|     b|            c|
+---+------+-------------+
|  6|[5, 4]| [Five, Four]|
|  5|[4, 5]| [Four, Five]|
|  3|   [4]|       [Four]|
|  2|[3, 4]|[Three, Four]|
|  4|[3, 5]|[Three, Five]|
+---+------+-------------+



In [8]:
tripDf = spark.createDataFrame([
        ("PMI",        "OPO",             [2, 1]),
        ("ATH", "BCN", [3]),
        ("JFK",        "MAD",          [5, 4, 6]),
        ("HND",        "LAX",       [8, 9, 7, 0])
    ], 
    ["origin","destination","internal_flight_id"]
)

In [9]:
flightDF = spark.createDataFrame(
    [
    (0,'FR5763'),
    (1,'UT9586'),
    (2,'B4325'),
    (3,'RW35675'),
    (4,'LP656'),
    (5,'NB4321'),
    (6,'CX4599'),
    (7,'AZ8844'),
    (8,'KH8851'),
    (9,'OP8777')
    ],
    ["internal_flight_id","public_flight_number"])

In [11]:
flightTripDF = tripDf.alias("tripDf").join(
    flightDF.alias("flightDF"),
    F.expr("array_contains(tripDf.internal_flight_id, flightDF.internal_flight_id)"),
    "left"
).groupBy("tripDf.internal_flight_id").agg(
    F.first("tripDf.internal_flight_id").alias("internal_flight_ids"),
    F.collect_list("flightDF.public_flight_number").alias("public_flight_number")
)

flightTripDF.show(truncate=False)


+------------------+-------------------+--------------------------------+
|internal_flight_id|internal_flight_ids|public_flight_number            |
+------------------+-------------------+--------------------------------+
|[2, 1]            |[2, 1]             |[UT9586, B4325]                 |
|[3]               |[3]                |[RW35675]                       |
|[5, 4, 6]         |[5, 4, 6]          |[LP656, NB4321, CX4599]         |
|[8, 9, 7, 0]      |[8, 9, 7, 0]       |[FR5763, AZ8844, KH8851, OP8777]|
+------------------+-------------------+--------------------------------+



# Example 3

In [4]:
from pyspark.sql.functions import posexplode

# Sample input data
df1 = spark.createDataFrame([(1, ["a", "b"]), (2, ["c", "d", "e"]), (3, ["f"])], ["id", "arr"])
df2 = spark.createDataFrame([("a", 10), ("b", 20), ("c", 30), ("d", 40), ("e", 50), ("f", 60)], ["key", "value"])

# Explode the array column along with its index
df1_exploded = df1.selectExpr("id", "posexplode(arr) as (pos, key)")
# Join with the mapped column using the index
joined_df = df1_exploded.join(df2, on="key").orderBy("id", "pos")
# Group by id and aggregate the values in an array
final_df = joined_df.groupBy("id").agg({"value": "collect_list"}).withColumnRenamed("collect_list(value)", "values")

final_df.show()

+---+------------+
| id|      values|
+---+------------+
|  1|    [10, 20]|
|  3|        [60]|
|  2|[30, 40, 50]|
+---+------------+

