## Import libraries

In [1]:
from pyspark.sql import SparkSession


from pyspark.sql.functions import flatten, collect_list

In [2]:
spark = SparkSession.builder.appName("read-json")\
.getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/11 18:48:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Creating a dataframe

In [3]:
df = spark.createDataFrame(
    [(1, [[1,2], [3,4],  [5,6]]),
     (2, [[7,8], [9,10], [11,12]])], ["id", "data"])

In [4]:
df.show()

+---+--------------------+
| id|                data|
+---+--------------------+
|  1|[[1, 2], [3, 4], ...|
|  2|[[7, 8], [9, 10],...|
+---+--------------------+



In [5]:
df.show(truncate=False)

+---+---------------------------+
|id |data                       |
+---+---------------------------+
|1  |[[1, 2], [3, 4], [5, 6]]   |
|2  |[[7, 8], [9, 10], [11, 12]]|
+---+---------------------------+



## Using collect_list() function to group by specified columns

In [6]:
collect_df = df.select(collect_list("data").alias("data"))

collect_df.show(truncate=False)

+-------------------------------------------------------+
|data                                                   |
+-------------------------------------------------------+
|[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]|
+-------------------------------------------------------+



## Merging the arrays elements using flatten() function

In [7]:
flatten_df = collect_df.select(flatten("data").alias("merged_data"))

In [8]:
flatten_df.show(truncate=False)

+---------------------------------------------------+
|merged_data                                        |
+---------------------------------------------------+
|[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]|
+---------------------------------------------------+



## Using Flatten() function for Nested Array 

## Create a dataframe

In [11]:
df_2 = spark.createDataFrame(
    [(1, [[[1,2],[3,4]], [[5,6],[7,8]]]),
     (2, [[[9,10], [11,12]], [[13,14], [15,16]]])], ["id", "data"])

df_2.show(truncate=False)

+---+-------------------------------------------+
|id |data                                       |
+---+-------------------------------------------+
|1  |[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]       |
|2  |[[[9, 10], [11, 12]], [[13, 14], [15, 16]]]|
+---+-------------------------------------------+



## Explode the outermost array to flatten the structure

In [20]:
exploded_df = df_2.select(col("id"), explode("data").alias("inner_data"))


exploded_df.show()

                    

+---+--------------------+
| id|          inner_data|
+---+--------------------+
|  1|    [[1, 2], [3, 4]]|
|  1|    [[5, 6], [7, 8]]|
|  2| [[9, 10], [11, 12]]|
|  2|[[13, 14], [15, 16]]|
+---+--------------------+



In [15]:
from pyspark.sql.functions import col, explode, collect_list, flatten

In [22]:
exploded_df = df_2.select(col("id"), explode("data").alias("inner_data"))


exploded_df.show(truncate= False)

                    

+---+--------------------+
|id |inner_data          |
+---+--------------------+
|1  |[[1, 2], [3, 4]]    |
|1  |[[5, 6], [7, 8]]    |
|2  |[[9, 10], [11, 12]] |
|2  |[[13, 14], [15, 16]]|
+---+--------------------+



## Use collect_list() to group all the inner arrays together

In [23]:
grouped_df = exploded_df.groupby("id").agg(collect_list("inner_data").alias("merged_data"))


grouped_df.show(truncate=False)

                                           

+---+-------------------------------------------+
|id |merged_data                                |
+---+-------------------------------------------+
|1  |[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]       |
|2  |[[[9, 10], [11, 12]], [[13, 14], [15, 16]]]|
+---+-------------------------------------------+



## Use flatten() to merge all elements of the inner arrays

In [24]:
flattened_df_2 = grouped_df.select(flatten("merged_data").alias("final_data"))

flattened_df_2.show(truncate=False)

+---------------------------------------+
|final_data                             |
+---------------------------------------+
|[[1, 2], [3, 4], [5, 6], [7, 8]]       |
|[[9, 10], [11, 12], [13, 14], [15, 16]]|
+---------------------------------------+

