<a href="https://colab.research.google.com/github/jemelike/spark_snippets/blob/main/apply_function_to_nested_column.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyspark numpy pandas

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyspark
  Downloading pyspark-3.3.2.tar.gz (281.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m281.4/281.4 MB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting py4j==0.10.9.5
  Downloading py4j-0.10.9.5-py2.py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.7/199.7 KB[0m [31m725.2 kB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.3.2-py2.py3-none-any.whl size=281824028 sha256=d84f343d55e3a3d361cd662c083eb4ee39cf9b169f7825b7263a3e5966aec2e8
  Stored in directory: /root/.cache/pip/wheels/6c/e3/9b/0525ce8a69478916513509d43693511463c6468db0de237c86
Successfully built pyspark
Installing collected packages: py4j, pyspark
  A

In [None]:
import numpy as np
from datetime import datetime,timezone
import pandas as pd
import pyspark
import pyspark.sql.functions as F 
import pyspark.sql.types as T
from pyspark.sql import SparkSession

import re

In [None]:
spark = SparkSession.builder.master("local[1]") \
                    .appName('SparkByExamples.com') \
                    .getOrCreate()

In [None]:
data = [
    (1,'value_1',
     [ 'value_2','value_3','value_4','value_5'], 
     [ ('value_6',('value_6','value_7','value_8','value_9'),'value_8','value_9')], 
     ('value_6','value_7','value_8','value_9'),
     ('value_6',('value_6','value_7','value_8','value_9')))
 ]
t     = datetime.now(timezone.utc)
now   = t.strftime('%Y%m%d')
now2  = t.strftime('%Y/%m/%d')
now3  = t.strftime('%Y.%m.%d')
now4  = t.strftime('%Y-%m-%dT%H:%M:%S')
data2 = [(now,now2,now3,now4)]

nested_example =( data, ['ID','EntirelyFlatTest','ArrayOfNonIterable','ArrayOfIterable','SingleLeveLStruct','HeavilyNestedExample1'])

datetime_normalization = (data2,["date_1","date_2","date_3","date_4"])


df = spark.createDataFrame(*nested_example)


In [None]:
def flatten_df_struct(df:pyspark.sql.DataFrame,column_to_flatten:str=None, levels:int=1):
    flat_columns   = [c for c,t in df.dtypes if not re.search(r'^(struct)',t) ]
    struct_columns = [f'{c}.*' for c,t in df.dtypes if re.search(r'^(struct)',t) ]
    df.select(flat_columns+struct_columns)
    return df.select(flat_columns+struct_columns)

In [None]:
def apply_function_to_column(df, column, column_type, func, func_target_type='string', id_column='ID', mod_column_suffix='MOD') :
  copy_df = df.alias('copy')
  if re.search(rf'^({func_target_type})',column_type):
    columns = [ func(column).alias(f'{column}_{mod_column_suffix}') for column in df.columns if column != id_column]
    copy_df = copy_df.select(id_column,*columns)    
  elif re.search(r'^(array).+',column_type):    
    _df = copy_df.select(id_column, F.explode_outer(column).alias(column)).alias('array')
    c,t = [(col,typ) for col, typ in _df.dtypes if col == f'{column}' ].pop()
    _df = apply_function_to_column(_df,c,t,func,func_target_type,id_column=id_column)  
    _df = _df.groupby(id_column).agg(F.collect_list(f'{column}_{mod_column_suffix}').alias(f'{column}_{mod_column_suffix}')).select('*').alias('array')
    copy_df = copy_df.join(_df, _df[id_column]==copy_df[id_column]).select('copy.*',f'{column}_{mod_column_suffix}')
  elif re.search(r'^(struct).+',column_type):
    _df                    = copy_df.select(id_column,f'{column}.*')
    flat_apply_columns     = [
        func(c).alias(f'{c}_{mod_column_suffix}') for c,t in _df.dtypes if re.search(rf'^({func_target_type})',t) and c != id_column
         ]
    
    nested_apply_columns   = [
       (c ,apply_function_to_column(_df.select(id_column,c),c,t,func,func_target_type,id_column=id_column)) for c,t in _df.dtypes if re.search(r'^(struct|array|map)',t) and c != id_column
    ]
    
  
    final                  = flat_apply_columns 
    for c, _df_ in nested_apply_columns:
      flatted_df_   = flatten_df_struct(_df_)  
      idless        = flatted_df_.drop(id_column)
      restructed_df = flatted_df_.select(id_column,F.struct(idless.columns).alias(f'{c}_{mod_column_suffix}')).alias('nested')
      _df    = _df.alias('main')\
                  .join( restructed_df,_df[id_column]==restructed_df[id_column] ) \
                  .select('main.*',f'nested.{c}_{mod_column_suffix}').alias(f'{c}_{mod_column_suffix}')
      final += [f'{c}_{mod_column_suffix}']
    temp_df = _df.select(id_column,*final)
    temp_df = temp_df.select(id_column,F.struct(*temp_df.drop(id_column).columns).alias(f'{column}_{mod_column_suffix}'))
    copy_df = copy_df.alias('main').join(temp_df.alias('temp'),temp_df[id_column]==copy_df[id_column]).select("main.*",f'{column}_{mod_column_suffix}')
  return  copy_df

In [None]:
c,t = [ (c,t) for c,t in df.dtypes if c=='ArrayOfIterable'].pop()
d = apply_function_to_column(df, c,t,F.upper)
d.show()

+---+----------------+--------------------+--------------------+--------------------+---------------------+--------------------+
| ID|EntirelyFlatTest|  ArrayOfNonIterable|     ArrayOfIterable|   SingleLeveLStruct|HeavilyNestedExample1| ArrayOfIterable_MOD|
+---+----------------+--------------------+--------------------+--------------------+---------------------+--------------------+
|  1|         value_1|[value_2, value_3...|[{value_6, {value...|{value_6, value_7...| {value_6, {value_...|[{VALUE_6, VALUE_...|
+---+----------------+--------------------+--------------------+--------------------+---------------------+--------------------+

