### Mapping data

In [24]:
from pyspark.sql import SparkSession,Row
from termcolor import cprint

spark = SparkSession.builder.appName('mapping').getOrCreate()

In [25]:
from pyspark.sql.functions import lit, map_from_arrays, array, col

_dict = {"High":1, "Medium":2, "Low":3}

df = spark.createDataFrame([["Medium"], ["Medium"], ["Medium"], ["High"], ["Medium"], ["Medium"], ["Low"], ["Low"], ["High"]]
                           , ["level"])
df.show()

+------+
| level|
+------+
|Medium|
|Medium|
|Medium|
|  High|
|Medium|
|Medium|
|   Low|
|   Low|
|  High|
+------+



In [26]:
print(_dict)

{'High': 1, 'Medium': 2, 'Low': 3}


In [27]:
print(list(map(lit, _dict.keys())))

[Column<''High''>, Column<''Medium''>, Column<''Low''>]


In [28]:
keys = array(list(map(lit, _dict.keys()))) # or alternatively [lit(k) for k in _dict.keys()]
values = array(list(map(lit, _dict.values())))
_map = map_from_arrays(keys, values)
print( _map.getItem(col("level")))

Column<'map_from_arrays(array('High', 'Medium', 'Low'), array(1, 2, 3))[level]'>




`map` and `map_from_arrays` to implement a key-based search mechanism for filling in the level_num field

In [29]:
keys = array(list(map(lit, _dict.keys()))) # or alternatively [lit(k) for k in _dict.keys()]
values = array(list(map(lit, _dict.values())))
#  Creates a new map from two arrays, map_from_arrays(col1, col2)  col1 keys. col2 values
_map = map_from_arrays(keys, values)

df_1 = df.withColumn("level_num", _map.getItem(col("level"))) # or element_at(_map, col("level"))

df_1.show()

+------+---------+
| level|level_num|
+------+---------+
|Medium|        2|
|Medium|        2|
|Medium|        2|
|  High|        1|
|Medium|        2|
|Medium|        2|
|   Low|        3|
|   Low|        3|
|  High|        1|
+------+---------+



In [30]:
df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])
df.show()

+------+------+
|     k|     v|
+------+------+
|[2, 5]|[a, b]|
+------+------+



In [31]:
#  Creates a new map from two arrays, map_from_arrays(col1, col2)  col1 keys. col2 values
df.select(map_from_arrays(df.k, df.v).alias("map")).show()

+----------------+
|             map|
+----------------+
|{2 -> a, 5 -> b}|
+----------------+



### Using Broadcast join

In [32]:
from pyspark.sql.functions import broadcast, col, explode
from pyspark.sql.types import IntegerType, MapType, StringType
from pyspark.sql.types import StructType, StructField

# set up data
map_df = spark.createDataFrame(
    [({1: "Spain"},),({2: "Germany"},),({3: "Czech Republic"},),({4: "Malta"},)],
    schema=StructType([StructField("map", MapType(IntegerType(), StringType()))])
)

map_df.show(truncate=False)

+---------------------+
|map                  |
+---------------------+
|{1 -> Spain}         |
|{2 -> Germany}       |
|{3 -> Czech Republic}|
|{4 -> Malta}         |
+---------------------+



In [33]:
sale_df = spark.createDataFrame([(1, 200), (2, 565),(3,467)], ["country_id","Sale"])
sale_df.show()

+----------+----+
|country_id|Sale|
+----------+----+
|         1| 200|
|         2| 565|
|         3| 467|
+----------+----+



In [34]:
# join
sale_df.join(
    broadcast(map_df.select(explode("map").alias("country_id", "country"))), 
    on="country_id",
    how="left"
).select("country", "Sale").show()

+--------------+----+
|       country|Sale|
+--------------+----+
|         Spain| 200|
|       Germany| 565|
|Czech Republic| 467|
+--------------+----+



If instead, you had your mapping as a single MapType, you could avoid the join by pushing the evaluation of the map up in execution plan.

In [35]:
from pyspark.sql.functions import array, map_from_arrays, lit

my_dict = {1: "Spain", 2: "Germany", 3: "Czech Republic", 4: "Malta"}
my_map = map_from_arrays(
    array(*map(lit, my_dict.keys())),
    array(*map(lit, my_dict.values()))
)

In [36]:
sale_df.select(my_map.getItem(col("country_id")).alias("country"), "Sale").show()

+--------------+----+
|       country|Sale|
+--------------+----+
|         Spain| 200|
|       Germany| 565|
|Czech Republic| 467|
+--------------+----+



In [37]:
sale_df.select(my_map.getItem(col("country_id")).alias("country"), "Sale").explain()

== Physical Plan ==
*(1) Project [map(keys: [1,2,3,4], values: [Spain,Germany,Czech Republic,Malta])[cast(country_id#176L as int)] AS country#202, Sale#177L]
+- *(1) Scan ExistingRDD[country_id#176L,Sale#177L]




### Test mapping with data

In [38]:
from pyspark.sql import Row

# Dataframe use for mapping
data = [Row(company_name='AEW', country_of_source='Global'),
        Row(company_name='Apollo', country_of_source='Global'),
        Row(company_name='Ares', country_of_source='Global'),
        Row(company_name='Carlyle', country_of_source='Global'),
        Row(company_name='CBREI', country_of_source='Global'),
]
global_df =  spark.createDataFrame(data)
global_df.show()

+------------+-----------------+
|company_name|country_of_source|
+------------+-----------------+
|         AEW|           Global|
|      Apollo|           Global|
|        Ares|           Global|
|     Carlyle|           Global|
|       CBREI|           Global|
+------------+-----------------+



In [39]:
# input test data
in_data = [ Row(company_name='Apollo', country_of_source='Canada', index=3.80),
            Row(company_name='JPMorgan', country_of_source='United States', index=4.80),
            Row(company_name='Miysubishi', country_of_source='Japan', index=4.56),
            Row(company_name='Ares', country_of_source=None, index=4.37),
            Row(company_name='Carlyle', country_of_source='Canada', index=None),
            Row(company_name='Costco', country_of_source=None, index=3.98)]
in_df = spark.createDataFrame(in_data)
in_df.show()

+------------+-----------------+-----+
|company_name|country_of_source|index|
+------------+-----------------+-----+
|      Apollo|           Canada|  3.8|
|    JPMorgan|    United States|  4.8|
|  Miysubishi|            Japan| 4.56|
|        Ares|             NULL| 4.37|
|     Carlyle|           Canada| NULL|
|      Costco|             NULL| 3.98|
+------------+-----------------+-----+



From Dataframe

In [40]:
dict_df = global_df.toPandas().to_dict(orient='list')

my_map = map_from_arrays(
    array(*map(lit, dict_df['company_name'])),
    array(*map(lit, dict_df['country_of_source']))
)

From Dictionary

In [41]:
from pyspark.sql.functions import array, map_from_arrays, lit

dat_dict = {'AEW':'Global', 'Apollo':'Global', 'Ares':'Global', 'Carlyle':'Global', 'CBREI':'Global'}

my_map = map_from_arrays(
    array(*map(lit, dat_dict.keys())),
    array(*map(lit, dat_dict.values()))
)

In [42]:
in_df.select(my_map.getItem(col("company_name")).alias("country_of_source_x"), "company_name", 'index', 'country_of_source').show()



+-------------------+------------+-----+-----------------+
|country_of_source_x|company_name|index|country_of_source|
+-------------------+------------+-----+-----------------+
|             Global|      Apollo|  3.8|           Canada|
|               NULL|    JPMorgan|  4.8|    United States|
|               NULL|  Miysubishi| 4.56|            Japan|
|             Global|        Ares| 4.37|             NULL|
|             Global|     Carlyle| NULL|           Canada|
|               NULL|      Costco| 3.98|             NULL|
+-------------------+------------+-----+-----------------+



In [43]:
df_z = in_df.select(my_map.getItem(col("company_name")).alias("country_of_source_x"), "company_name", 'index', 'country_of_source')
df_z.show()

+-------------------+------------+-----+-----------------+
|country_of_source_x|company_name|index|country_of_source|
+-------------------+------------+-----+-----------------+
|             Global|      Apollo|  3.8|           Canada|
|               NULL|    JPMorgan|  4.8|    United States|
|               NULL|  Miysubishi| 4.56|            Japan|
|             Global|        Ares| 4.37|             NULL|
|             Global|     Carlyle| NULL|           Canada|
|               NULL|      Costco| 3.98|             NULL|
+-------------------+------------+-----+-----------------+



In [44]:
from pyspark.sql.functions import *
df_w = df_z.withColumn('country', coalesce(df_z['country_of_source_x'], df_z['country_of_source']))
df_v = df_w.drop("country_of_source_x", "country_of_source").withColumnRenamed('country', 'country_of_source')
df_v.show()

+------------+-----+-----------------+
|company_name|index|country_of_source|
+------------+-----+-----------------+
|      Apollo|  3.8|           Global|
|    JPMorgan|  4.8|    United States|
|  Miysubishi| 4.56|            Japan|
|        Ares| 4.37|           Global|
|     Carlyle| NULL|           Global|
|      Costco| 3.98|             NULL|
+------------+-----+-----------------+



Current implementation

In [45]:
df_t2 = in_df.withColumn("country_of_source_updtd", coalesce(my_map.getItem(col("company_name")), in_df['country_of_source']))\
        .drop('country_of_source')\
        .withColumnRenamed('country_of_source_updtd', 'country_of_source')
df_t2.show()


+------------+-----+-----------------+
|company_name|index|country_of_source|
+------------+-----+-----------------+
|      Apollo|  3.8|           Global|
|    JPMorgan|  4.8|    United States|
|  Miysubishi| 4.56|            Japan|
|        Ares| 4.37|           Global|
|     Carlyle| NULL|           Global|
|      Costco| 3.98|             NULL|
+------------+-----+-----------------+



In [46]:
df_t3 = in_df.withColumn("country_of_source", coalesce(my_map.getItem(col("company_name")), in_df['country_of_source']))
df_t3.show()

+------------+-----------------+-----+
|company_name|country_of_source|index|
+------------+-----------------+-----+
|      Apollo|           Global|  3.8|
|    JPMorgan|    United States|  4.8|
|  Miysubishi|            Japan| 4.56|
|        Ares|           Global| 4.37|
|     Carlyle|           Global| NULL|
|      Costco|             NULL| 3.98|
+------------+-----------------+-----+

