### Initialize Spark Session

In [1]:
%%configure -f
{ "conf": {"spark.jars.packages": "com.databricks:spark-xml_2.10:0.4.1" }}

Our XML input looks like below.

Contains a nested struct field 'ns' as well as a nested list of structs 'nls' field.

```
<?xml version="1.0" encoding="UTF-8"?>
<root>
   <record attribute1="AAAA">
      <field1>1</field1>
      <field2>three</field2>
      <ns>
         <nf1>nf1</nf1>
         <nf2>nf2</nf2>
         <nf3>nf3</nf3>
      </ns>
      <nls>
         <ns2>
            <nf11>nf11</nf11>
            <nf12>nf12</nf12>
         </ns2>
         <ns2>
            <nf13>nf13</nf13>
         </ns2>
      </nls>
   </record>
   <record attribute1="AAAA">
      <field1>2</field1>
      <field2>three</field2>
   </record>
   <record attribute1="AAAA">
      <field1>3</field1>
      <field2>three</field2>
   </record>
   <record attribute1="AAAA">
      <field1>4</field1>
      <field2>three</field2>
   </record>
</root>
```

### Read Input File

In [2]:
var df = spark.read.format("xml").option("rowTag","record").load("s3://neilawstmp2/xml/sample.xml")
df.printSchema()
print (df.count())

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
5,application_1597770352735_0006,spark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

df: org.apache.spark.sql.DataFrame = [_attribute1: string, field1: bigint ... 3 more fields]
root
 |-- _attribute1: string (nullable = true)
 |-- field1: long (nullable = true)
 |-- field2: string (nullable = true)
 |-- nls: struct (nullable = true)
 |    |-- ns2: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- nf11: string (nullable = true)
 |    |    |    |-- nf12: string (nullable = true)
 |    |    |    |-- nf13: string (nullable = true)
 |-- ns: struct (nullable = true)
 |    |-- nf1: string (nullable = true)
 |    |-- nf2: string (nullable = true)
 |    |-- nf3: string (nullable = true)

4

### Flatten the Dataframe

In [3]:
import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.types._

def flatten(df: DataFrame): DataFrame = {
   var input_df=df 
   var complex_fields = input_df.schema.fields.collect{
      case f if (f.dataType.isInstanceOf[ArrayType] || f.dataType.isInstanceOf[StructType]) => (f.name, f.dataType)
   }.toMap
   print (complex_fields)
   
   while (complex_fields.size!=0) {
    
      var col_name=complex_fields.keys.head
      print ("Processing : "+col_name+", Type : "+complex_fields(col_name).getClass)

      if (complex_fields(col_name).isInstanceOf[StructType]){
         var expanded=complex_fields(col_name).asInstanceOf[StructType].fields.map(_.name).map(c=>col(col_name+'.'+c).alias(col_name+'_'+c))
         input_df=input_df.select(input_df.columns.map(input_df(_)) ++ expanded:_*).drop(col_name)
      }
      else if (complex_fields(col_name).isInstanceOf[ArrayType]){
         input_df=input_df.withColumn(col_name,explode_outer(col(col_name)))
      }

      //recompute remaining Complex Fields in Schema 
      complex_fields = input_df.schema.fields.collect{
         case f if (f.dataType.isInstanceOf[ArrayType] || f.dataType.isInstanceOf[StructType]) => (f.name, f.dataType)
      }.toMap
   }
   // return final flattened dataframe 
   return input_df 
}


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
flatten: (df: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame


In [4]:
df=flatten(df)
df.printSchema()
print (df.count())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Map(nls -> StructType(StructField(ns2,ArrayType(StructType(StructField(nf11,StringType,true), StructField(nf12,StringType,true), StructField(nf13,StringType,true)),true),true)), ns -> StructType(StructField(nf1,StringType,true), StructField(nf2,StringType,true), StructField(nf3,StringType,true)))Processing : nls, Type : class org.apache.spark.sql.types.StructTypeProcessing : ns, Type : class org.apache.spark.sql.types.StructTypeProcessing : nls_ns2, Type : class org.apache.spark.sql.types.ArrayTypeProcessing : nls_ns2, Type : class org.apache.spark.sql.types.StructTypedf: org.apache.spark.sql.DataFrame = [_attribute1: string, field1: bigint ... 7 more fields]
root
 |-- _attribute1: string (nullable = true)
 |-- field1: long (nullable = true)
 |-- field2: string (nullable = true)
 |-- ns_nf1: string (nullable = true)
 |-- ns_nf2: string (nullable = true)
 |-- ns_nf3: string (nullable = true)
 |-- nls_ns2_nf11: string (nullable = true)
 |-- nls_ns2_nf12: string (nullable = true)
 |-- nls

Observe the our dataframe is now flattened out. The List of Structs (nls) is exploded and the Struct field (ns) is expanded into distinct columns.

We started with 3 records and have 5 records now after exploding the List of Structs (nls).

In [5]:
df.select("field1","ns_nf1","nls_ns2_nf11","nls_ns2_nf12","nls_ns2_nf13").show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------+------+------------+------------+------------+
|field1|ns_nf1|nls_ns2_nf11|nls_ns2_nf12|nls_ns2_nf13|
+------+------+------------+------------+------------+
|     1|   nf1|        nf11|        nf12|        null|
|     1|   nf1|        null|        null|        nf13|
|     2|  null|        null|        null|        null|
|     3|  null|        null|        null|        null|
|     4|  null|        null|        null|        null|
+------+------+------------+------------+------------+



### Write out in Parquet

In [6]:
df.write.mode("OVERWRITE").parquet("s3://neilawstmp2/xml/parquet/")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Read Parquet data

In [7]:
val input_df=spark.read.parquet("s3://neilawstmp2/xml/parquet/")
input_df.printSchema()
input_df.select("field1","ns_nf1","nls_ns2_nf11","nls_ns2_nf12","nls_ns2_nf13").show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

input_df: org.apache.spark.sql.DataFrame = [_attribute1: string, field1: bigint ... 7 more fields]
root
 |-- _attribute1: string (nullable = true)
 |-- field1: long (nullable = true)
 |-- field2: string (nullable = true)
 |-- ns_nf1: string (nullable = true)
 |-- ns_nf2: string (nullable = true)
 |-- ns_nf3: string (nullable = true)
 |-- nls_ns2_nf11: string (nullable = true)
 |-- nls_ns2_nf12: string (nullable = true)
 |-- nls_ns2_nf13: string (nullable = true)

+------+------+------------+------------+------------+
|field1|ns_nf1|nls_ns2_nf11|nls_ns2_nf12|nls_ns2_nf13|
+------+------+------------+------------+------------+
|     1|   nf1|        nf11|        nf12|        null|
|     1|   nf1|        null|        null|        nf13|
|     2|  null|        null|        null|        null|
|     3|  null|        null|        null|        null|
|     4|  null|        null|        null|        null|
+------+------+------------+------------+------------+

