In [0]:
# Generate a sample XML file in DBFS location
sample_json_content = """{
  "bank": {
    "id": "B001",
    "name": "Global Commercial Bank",
    "branches": [
      {
        "id": "BR001",
        "location": {
          "city": "Mumbai",
          "state": "Maharashtra",
          "country": "India"
        },
        "employees": [
          {
            "id": "E1001",
            "name": {
              "first": "Rajesh",
              "last": "Mehta"
            },
            "role": "Relationship Manager",
            "clients_assigned": ["CUST001", "CUST002"]
          },
          {
            "id": "E1002",
            "name": {
              "first": "Anita",
              "last": "Sharma"
            },
            "role": "Credit Analyst",
            "clients_assigned": ["CUST003"]
          }
        ],
        "customers": [
          {
            "id": "CUST001",
            "name": {
              "first": "Amit",
              "last": "Patel"
            },
            "contact": {
              "email": "amit.patel@example.com",
              "phones": [
                {"type": "mobile", "number": "+91-9876543210"},
                {"type": "home", "number": "+91-2267890123"}
              ]
            },
            "accounts": [
              {
                "id": "ACC001",
                "type": "Checking",
                "balance": 150000.50,
                "currency": "INR",
                "transactions": [
                  {
                    "id": "TXN001",
                    "date": "2025-08-01",
                    "amount": -5000,
                    "merchant": {
                      "name": "Reliance Retail",
                      "category": "Shopping"
                    },
                    "channel": "Debit Card",
                    "tags": ["shopping", "debit"]
                  },
                  {
                    "id": "TXN002",
                    "date": "2025-08-03",
                    "amount": 20000,
                    "merchant": {
                      "name": "NEFT Transfer",
                      "category": "Transfer"
                    },
                    "channel": "Online Banking",
                    "tags": ["salary", "credit"]
                  }
                ]
              },
              {
                "account_id": "ACC002",
                "type": "Savings",
                "balance": 350000.75,
                "currency": "INR",
                "transactions": []
              }
            ],
            "loans": [
              {
                "id": "LN001",
                "type": "Home Loan",
                "principal": 5000000,
                "interest_rate": 7.5,
                "tenure_months": 240,
                "repayments": [
                  {
                    "installment_no": 1,
                    "date": "2025-09-01",
                    "amount_due": 48250.75,
                    "status": "Pending"
                  }
                ]
              }
            ],
            "credit_cards": [
              {
                "id": "CC001",
                "card_type": "VISA Platinum",
                "limit": 200000,
                "current_due": 15300,
                "transactions": [
                  {
                    "txn_id": "CCTXN001",
                    "date": "2025-07-25",
                    "merchant": {
                      "name": "Indigo Airlines",
                      "category": "Travel"
                    },
                    "amount": 12500,
                    "currency": "INR",
                    "status": "Posted"
                  }
                ]
              }
            ]
          }
        ]
      }
    ]
  }
}
"""

dbutils.fs.put("/Volumes/workspace/default/test/sample.json", sample_json_content, overwrite=True)



In [0]:
# Read the Json file using Spark with option set before .json()
df = spark.read.option(
    "inferSchema",
    "false"
).option("multiline", "true").json(
    "/Volumes/workspace/default/test/sample.json"
)

# Display the DataFrame
display(df)

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T


def flatten_df(df: DataFrame, col_to_drop: list = None) -> DataFrame:
    """
    Recursively flattens all StructType and ArrayType columns in a DataFrame.
    Drops specified columns early to avoid duplication.
    """
    if col_to_drop:
        # Drop at the very start
        df = df.drop(*col_to_drop)

    complex_fields = True
    while complex_fields:
        complex_fields = False
        flat_cols = []

        for field in df.schema.fields:
            col_name = field.name
            dtype = field.dataType

            if isinstance(dtype, T.StructType):
                complex_fields = True
                for subfield in dtype.fields:
                    flat_cols.append(
                        F.col(f"{col_name}.{subfield.name}")
                        .alias(f"{col_name}_{subfield.name}")
                    )
            elif isinstance(dtype, T.ArrayType):
                complex_fields = True
                # explode array → keep nulls with explode_outer
                df = df.withColumn(col_name, F.explode_outer(F.col(col_name)))
                flat_cols.append(F.col(col_name))
            else:
                flat_cols.append(F.col(col_name))

        df = df.select(flat_cols)

        # Drop again if needed in subsequent iterations
        if col_to_drop:
            df = df.drop(*col_to_drop)

    return df


In [0]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextvars
from pyspark.sql.functions import *

shred_points = [("bank.branches", "array")]  # This is a manual step
col_to_drop = ["branches_employees"]  # manual

def process_shred_point(i):
    """
    Processes a shred point by selecting the specified nested array column and its parent id columns,
    then flattens the resulting DataFrame using the flatten_df function.

    Args:
        i (tuple): A tuple containing the column path and its type.

    Returns:
        tuple: A tuple containing the original selected DataFrame and its flattened version.
    """
    parent_tags = i[0].split(".")
    result = []
    for j in range(1, len(parent_tags)):
        result.append(".".join(parent_tags[:j]) + ".id")
    df_local = df.select(col(i[0]), *[col(c).alias(c.replace(".", "_")) for c in result])
    flattened_df_local = flatten_df(df_local, col_to_drop)
    return df_local, flattened_df_local

results = []
with ThreadPoolExecutor(max_workers=2) as executor:
    futures = []
    for i in shred_points:
        ctx = contextvars.copy_context() # copied from Databricks AI, can be avoided
        futures.append(executor.submit(ctx.run, process_shred_point, i))
    for future in as_completed(futures):
        df_local, flattened_df_local = future.result()
        display(df_local)
        display(flattened_df_local)
        flattened_df_local.write.mode("append").option("mergeSchema","true").saveAsTable("default.flattened_df_local")
        results.append(flattened_df_local)
display(flattened_df_local)