### Initialize Spark Session

In [1]:
%%configure -f
{ "conf": {"spark.jars.packages": "com.databricks:spark-xml_2.10:0.4.1" }}

## Handling xsi:nil attribute as null

Our XML input looks like below.

Contains a nested nested list of structs 'nls' field.
One of the fields in the structs complex_fields.keys.head has xsi:nil="true" and the other has nf11 populated.

```
<root>
<record>
    <id>1</id>
    <nls>
       <ns2>
          <nf11 xsi:nil="true"/>
          <nf12>nf12</nf12>
       </ns2>
       <ns2>
          <nf11>nf11</nf11>
          <nf13>nf13</nf13>
       </ns2>
    </nls>
</record>
<record>
    <id>2</id>
</record>
</root>

```

### Read Input File

In [2]:
var df = spark.read.format("xml").option("rowTag","record").load("s3://bucket/xml/sample_nil.xml")
df.printSchema()
print (df.count())

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
15,application_1597770352735_0016,spark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

df: org.apache.spark.sql.DataFrame = [id: bigint, nls: struct<ns2: array<struct<nf11:struct<_VALUE:string,_nil:boolean>,nf12:string,nf13:string>>>]
root
 |-- id: long (nullable = true)
 |-- nls: struct (nullable = true)
 |    |-- ns2: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- nf11: struct (nullable = true)
 |    |    |    |    |-- _VALUE: string (nullable = true)
 |    |    |    |    |-- _nil: boolean (nullable = true)
 |    |    |    |-- nf12: string (nullable = true)
 |    |    |    |-- nf13: string (nullable = true)

2

In [3]:
df.schema.fields

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

res3: Array[org.apache.spark.sql.types.StructField] = Array(StructField(id,LongType,true), StructField(nls,StructType(StructField(ns2,ArrayType(StructType(StructField(nf11,StructType(StructField(_VALUE,StringType,true), StructField(_nil,BooleanType,true)),true), StructField(nf12,StringType,true), StructField(nf13,StringType,true)),true),true)),true))


### Flatten the Dataframe

In [4]:
import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.types._


def flatten(df: DataFrame): DataFrame = {
   var input_df=df 
   var complex_fields = input_df.schema.fields.collect{
      case f if (f.dataType.isInstanceOf[ArrayType] || f.dataType.isInstanceOf[StructType]) => (f.name, f.dataType)
   }.toMap
   print (complex_fields)
   
   while (complex_fields.size!=0) {
    
      var col_name=complex_fields.keys.head
      print ("\n Processing : "+col_name+", Type : "+complex_fields(col_name).getClass)

      if (complex_fields(col_name).isInstanceOf[StructType]){
         var expanded=complex_fields(col_name).asInstanceOf[StructType].fields
          .map(_.name)
          // ignore the _nil attribute and read the _VALUE attribute only
          .filterNot(x=>x.endsWith("_nil"))
          .map(c=>col(col_name+'.'+c).alias(col_name+'_'+c))
         print ("\n "+expanded) 
         input_df=input_df.select(input_df.columns.map(input_df(_)) ++ expanded:_*).drop(col_name)
      }
      else if (complex_fields(col_name).isInstanceOf[ArrayType]){
         input_df=input_df.withColumn(col_name,explode_outer(col(col_name)))
      }

      //recompute remaining Complex Fields in Schema 
      complex_fields = input_df.schema.fields.collect{
         case f if (f.dataType.isInstanceOf[ArrayType] || f.dataType.isInstanceOf[StructType]) => (f.name, f.dataType)
      }.toMap
   }
   // return final flattened dataframe 
   // Truncate __VALUE from column names if it exists. 
   var newNames=input_df.schema.fields.map(_.name).map(c=>if (c.endsWith("__VALUE")) c.dropRight(7) else c)
   return input_df.toDF(newNames: _*)
}


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
flatten: (df: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame


In [5]:
var flattened_df=flatten(df)
flattened_df.printSchema()
print (flattened_df.count())
flattened_df.show(20)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Map(nls -> StructType(StructField(ns2,ArrayType(StructType(StructField(nf11,StructType(StructField(_VALUE,StringType,true), StructField(_nil,BooleanType,true)),true), StructField(nf12,StringType,true), StructField(nf13,StringType,true)),true),true)))
 Processing : nls, Type : class org.apache.spark.sql.types.StructType
 [Lorg.apache.spark.sql.Column;@631f3727
 Processing : nls_ns2, Type : class org.apache.spark.sql.types.ArrayType
 Processing : nls_ns2, Type : class org.apache.spark.sql.types.StructType
 [Lorg.apache.spark.sql.Column;@e2fb44b
 Processing : nls_ns2_nf11, Type : class org.apache.spark.sql.types.StructType
 [Lorg.apache.spark.sql.Column;@5413c226flattened_df: org.apache.spark.sql.DataFrame = [id: bigint, nls_ns2_nf12: string ... 2 more fields]
root
 |-- id: long (nullable = true)
 |-- nls_ns2_nf12: string (nullable = true)
 |-- nls_ns2_nf13: string (nullable = true)
 |-- nls_ns2_nf11: string (nullable = true)

3+---+------------+------------+------------+
| id|nls_ns2_nf1

Observe the our dataframe is now flattened. 

nls_ns2_nf11 has null in the 1st record, value 'nf11' in the 2nd record and is null in the 3rd record.

## Handling multiple nested lists

```
<root>
<record>
    <id>1</id>
    <nls>
       <ns2>
          <nf11 xsi:nil="true"/>
          <nf12>nf12</nf12>
       </ns2>
       <ns2>
          <nf11>nf11</nf11>
          <nf13>nf13</nf13>
       </ns2>
    </nls>
    <nls1>
       <ns2>
          <nf11 xsi:nil="true"/>
          <nf12>nf12</nf12>
       </ns2>
       <ns2>
          <nf11>nf11</nf11>
          <nf13>nf13</nf13>
       </ns2>
    </nls1>
</record>
</root>
```

In [6]:
var df = spark.read.format("xml").option("rowTag","record").load("s3://bucket/xml/sample_multiple_nested_lists.xml")
df.printSchema()
print (df.count())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

df: org.apache.spark.sql.DataFrame = [id: bigint, nls: struct<ns2: array<struct<nf11:struct<_VALUE:string,_nil:boolean>,nf12:string,nf13:string>>> ... 1 more field]
root
 |-- id: long (nullable = true)
 |-- nls: struct (nullable = true)
 |    |-- ns2: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- nf11: struct (nullable = true)
 |    |    |    |    |-- _VALUE: string (nullable = true)
 |    |    |    |    |-- _nil: boolean (nullable = true)
 |    |    |    |-- nf12: string (nullable = true)
 |    |    |    |-- nf13: string (nullable = true)
 |-- nls1: struct (nullable = true)
 |    |-- ns2: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- nf11: struct (nullable = true)
 |    |    |    |    |-- _VALUE: string (nullable = true)
 |    |    |    |    |-- _nil: boolean (nullable = true)
 |    |    |    |-- nf12: string (nullable = true)
 |    |    |    |-- nf13: string (nullable = true)

1

In [7]:
var flattened_df=flatten(df)
flattened_df.printSchema()
print (flattened_df.count())
flattened_df.show(20)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Map(nls -> StructType(StructField(ns2,ArrayType(StructType(StructField(nf11,StructType(StructField(_VALUE,StringType,true), StructField(_nil,BooleanType,true)),true), StructField(nf12,StringType,true), StructField(nf13,StringType,true)),true),true)), nls1 -> StructType(StructField(ns2,ArrayType(StructType(StructField(nf11,StructType(StructField(_VALUE,StringType,true), StructField(_nil,BooleanType,true)),true), StructField(nf12,StringType,true), StructField(nf13,StringType,true)),true),true)))
 Processing : nls, Type : class org.apache.spark.sql.types.StructType
 [Lorg.apache.spark.sql.Column;@72d2c375
 Processing : nls1, Type : class org.apache.spark.sql.types.StructType
 [Lorg.apache.spark.sql.Column;@c408a73
 Processing : nls_ns2, Type : class org.apache.spark.sql.types.ArrayType
 Processing : nls_ns2, Type : class org.apache.spark.sql.types.StructType
 [Lorg.apache.spark.sql.Column;@59542c4f
 Processing : nls1_ns2, Type : class org.apache.spark.sql.types.ArrayType
 Processing : nls

As expected, there are 4 records.