### Initialize Spark Session

In [1]:
%%configure -f
{ "conf": {"spark.jars.packages": "com.databricks:spark-xml_2.10:0.4.1" }}

Our XML input looks like below.

Contains a nested struct field 'ns' as well as a nested list of structs 'nls' field.

```
<?xml version="1.0" encoding="UTF-8"?>
<root>
   <record attribute1="AAAA">
      <field1>1</field1>
      <field2>three</field2>
      <ns>
         <nf1>nf1</nf1>
         <nf2>nf2</nf2>
         <nf3>nf3</nf3>
      </ns>
      <nls>
         <ns2>
            <nf11>nf11</nf11>
            <nf12>nf12</nf12>
         </ns2>
         <ns2>
            <nf13>nf13</nf13>
         </ns2>
      </nls>
   </record>
   <record attribute1="AAAA">
      <field1>2</field1>
      <field2>three</field2>
   </record>
   <record attribute1="AAAA">
      <field1>3</field1>
      <field2>three</field2>
   </record>
   <record attribute1="AAAA">
      <field1>4</field1>
      <field2>three</field2>
   </record>
</root>
```

### Read Input File

In [2]:
from pyspark.sql.types import *
from pyspark.sql.functions import *

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
2,application_1597770352735_0003,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
df = spark.read \
    .format('xml') \
    .options(rowTag='record') \
    .load('s3://neilawstmp2/xml/sample.xml')
df.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- _attribute1: string (nullable = true)
 |-- field1: long (nullable = true)
 |-- field2: string (nullable = true)
 |-- nls: struct (nullable = true)
 |    |-- ns2: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- nf11: string (nullable = true)
 |    |    |    |-- nf12: string (nullable = true)
 |    |    |    |-- nf13: string (nullable = true)
 |-- ns: struct (nullable = true)
 |    |-- nf1: string (nullable = true)
 |    |-- nf2: string (nullable = true)
 |    |-- nf3: string (nullable = true)

### Flatten the Dataframe

In [4]:
def flatten(df):
   # compute Complex Fields (Lists and Structs) in Schema   
   complex_fields = dict([(field.name, field.dataType)
                             for field in df.schema.fields
                             if type(field.dataType) == ArrayType or  type(field.dataType) == StructType])
   while len(complex_fields)!=0:
      col_name=list(complex_fields.keys())[0]
      print ("Processing :"+col_name+" Type : "+str(type(complex_fields[col_name])))
    
      # if StructType then convert all sub element to columns.
      # i.e. flatten structs
      if (type(complex_fields[col_name]) == StructType):
         expanded = [col(col_name+'.'+k).alias(col_name+'_'+k) for k in [ n.name for n in  complex_fields[col_name]]]
         df=df.select("*", *expanded).drop(col_name)
    
      # if ArrayType then add the Array Elements as Rows using the explode function
      # i.e. explode Arrays
      elif (type(complex_fields[col_name]) == ArrayType):    
         df=df.withColumn(col_name,explode_outer(col_name))
    
      # recompute remaining Complex Fields in Schema       
      complex_fields = dict([(field.name, field.dataType)
                             for field in df.schema.fields
                             if type(field.dataType) == ArrayType or  type(field.dataType) == StructType])
   return df


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
df=flatten(df)
df.printSchema()
print (df.count())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- _attribute1: string (nullable = true)
 |-- field1: long (nullable = true)
 |-- field2: string (nullable = true)
 |-- ns_nf1: string (nullable = true)
 |-- ns_nf2: string (nullable = true)
 |-- ns_nf3: string (nullable = true)
 |-- nls_ns2_nf11: string (nullable = true)
 |-- nls_ns2_nf12: string (nullable = true)
 |-- nls_ns2_nf13: string (nullable = true)

5

Observe the our dataframe is now flattened out. The List of Structs (nls) is exploded and the Struct field (ns) is expanded into distinct columns.

We started with 3 records and have 5 records now after exploding the List of Structs (nls).

In [8]:
df.select("field1","ns_nf1","nls_ns2_nf11","nls_ns2_nf12","nls_ns2_nf13").show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------+------+------------+------------+------------+
|field1|ns_nf1|nls_ns2_nf11|nls_ns2_nf12|nls_ns2_nf13|
+------+------+------------+------------+------------+
|     1|   nf1|        nf11|        nf12|        null|
|     1|   nf1|        null|        null|        nf13|
|     2|  null|        null|        null|        null|
|     3|  null|        null|        null|        null|
|     4|  null|        null|        null|        null|
+------+------+------------+------------+------------+

### Write out in Parquet

In [9]:
df.write.mode("OVERWRITE").parquet("s3://neilawstmp2/xml/parquet/")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Read Parquet data

In [10]:
input_df=spark.read.parquet("s3://neilawstmp2/xml/parquet/")
input_df.printSchema()
input_df.select("field1","ns_nf1","nls_ns2_nf11","nls_ns2_nf12","nls_ns2_nf13").show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- _attribute1: string (nullable = true)
 |-- field1: long (nullable = true)
 |-- field2: string (nullable = true)
 |-- ns_nf1: string (nullable = true)
 |-- ns_nf2: string (nullable = true)
 |-- ns_nf3: string (nullable = true)
 |-- nls_ns2_nf11: string (nullable = true)
 |-- nls_ns2_nf12: string (nullable = true)
 |-- nls_ns2_nf13: string (nullable = true)

+------+------+------------+------------+------------+
|field1|ns_nf1|nls_ns2_nf11|nls_ns2_nf12|nls_ns2_nf13|
+------+------+------------+------------+------------+
|     1|   nf1|        nf11|        nf12|        null|
|     1|   nf1|        null|        null|        nf13|
|     2|  null|        null|        null|        null|
|     3|  null|        null|        null|        null|
|     4|  null|        null|        null|        null|
+------+------+------------+------------+------------+