In [103]:
import os
from pyspark.sql import SparkSession
from opencage.geocoder import OpenCageGeocode
from pprint import pprint
from pyspark.sql.functions import udf, avg
from pyspark.sql.types import FloatType, StringType
import geohash2 as gh
import numpy as np
# Load the environment variables from the .env file
from dotenv import load_dotenv
load_dotenv()

True

# Create a SparkSession

In [111]:
# Init SparkSession
spark = SparkSession.builder\
    .appName('workshop1')\
    .master("local[*]") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "8g") \
    .getOrCreate()


In [112]:
spark

# Restaurant Data

In [113]:
# Get data from csv files
csv_dir = "restaurant_csv/restaurant_csv"
csv_files = [os.path.join(csv_dir, f) for f in os.listdir(csv_dir) if f.endswith('.csv')]
df = spark.read.format("csv").option("header", "true").load(csv_files)
df.show()

+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+
|          id|franchise_id|      franchise_name|restaurant_franchise_id|country|          city|   lat|     lng|
+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+
|197568495625|          10|    The Golden Spoon|                  24784|     US|       Decatur|34.578| -87.021|
| 17179869242|          59|         Azalea Cafe|                  10902|     FR|         Paris|48.861|   2.368|
|214748364826|          27|     The Corner Cafe|                  92040|     US|    Rapid City|44.080|-103.250|
|154618822706|          51|        The Pizzeria|                  41484|     AT|        Vienna|48.213|  16.413|
|163208757312|          65|       Chef's Corner|                  96638|     GB|        London|51.495|  -0.191|
| 68719476763|          28|    The Spicy Pickle|                  77517|     US|      Grayling|44.657| -

In [107]:
# get api key from environment variable
api_key = os.environ.get('API_KEY')

# Define UDF fuynctions for latitude and longitude
@udf(returnType=StringType())
def latitude_udf(city, country, value):
    geocoder = OpenCageGeocode(api_key)
    if value is None:
        query = f"{city}, {country}"
        result = geocoder.geocode(query)
        value = result[0]['geometry']['lat']
    return value

@udf(returnType=StringType())
def longitude_udf(city, country, value):
    geocoder = OpenCageGeocode(api_key)
    if value is None:
        query = f"{city}, {country}"
        result = geocoder.geocode(query)
        value = result[0]['geometry']['lng']
    return value

# Apply the UDF to the lat and lng columns and updated values
df = df.withColumn('lat', latitude_udf('city', 'country', 'lat'))\
       .withColumn('lng', longitude_udf('city', 'country', 'lng'))

# Display the updated data
df.show()

+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+
|          id|franchise_id|      franchise_name|restaurant_franchise_id|country|          city|   lat|     lng|
+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+
|197568495625|          10|    The Golden Spoon|                  24784|     US|       Decatur|34.578| -87.021|
| 17179869242|          59|         Azalea Cafe|                  10902|     FR|         Paris|48.861|   2.368|
|214748364826|          27|     The Corner Cafe|                  92040|     US|    Rapid City|44.080|-103.250|
|154618822706|          51|        The Pizzeria|                  41484|     AT|        Vienna|48.213|  16.413|
|163208757312|          65|       Chef's Corner|                  96638|     GB|        London|51.495|  -0.191|
| 68719476763|          28|    The Spicy Pickle|                  77517|     US|      Grayling|44.657| -

In [84]:
@udf(returnType=StringType())
def geohash_udf(lat, lng):
    lat = float(lat)
    lng = float(lng)
    return gh.encode(lat, lng, precision=4)

df = df.withColumn('geohash', geohash_udf('lat', 'lng'))

df.show()

+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+-------+
|          id|franchise_id|      franchise_name|restaurant_franchise_id|country|          city|   lat|     lng|geohash|
+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+-------+
|197568495625|          10|    The Golden Spoon|                  24784|     US|       Decatur|34.578| -87.021|   dn4h|
| 17179869242|          59|         Azalea Cafe|                  10902|     FR|         Paris|48.861|   2.368|   u09t|
|214748364826|          27|     The Corner Cafe|                  92040|     US|    Rapid City|44.080|-103.250|   9xyd|
|154618822706|          51|        The Pizzeria|                  41484|     AT|        Vienna|48.213|  16.413|   u2ed|
|163208757312|          65|       Chef's Corner|                  96638|     GB|        London|51.495|  -0.191|   gcpu|
| 68719476763|          28|    The Spicy

In [7]:
# save in parquet
df.write.parquet('restaurant.parquet', mode='overwrite')

In [85]:
# read the cleaned restaurant dataframe from parquet
restaurant_df = spark.read.parquet("restaurant.parquet")

# Weather Data

In [86]:
#Read weather dataset
path = 'weather_dataset'
dfw = spark.read.option("mergeSchema", "true").option("recursiveFileLookup", "true").parquet(path)

In [87]:
#Geohash weather dataset lat and lng
dfw = dfw.withColumn('geohash', geohash_udf('lat', 'lng'))

In [88]:
dfw.show()

+--------+-------+----------+----------+----------+-------+
|     lng|    lat|avg_tmpr_f|avg_tmpr_c| wthr_date|geohash|
+--------+-------+----------+----------+----------+-------+
| -111.09|18.6251|      80.7|      27.1|2017-08-29|   9e31|
|-111.042|18.6305|      80.7|      27.1|2017-08-29|   9e31|
|-110.995|18.6358|      80.7|      27.1|2017-08-29|   9e34|
|-110.947|18.6412|      80.9|      27.2|2017-08-29|   9e34|
|  -110.9|18.6465|      80.9|      27.2|2017-08-29|   9e34|
|-110.852|18.6518|      80.9|      27.2|2017-08-29|   9e34|
|-110.804|18.6571|      80.9|      27.2|2017-08-29|   9e34|
|-105.068|19.1765|      82.4|      28.0|2017-08-29|   9emm|
| -105.02|19.1799|      82.0|      27.8|2017-08-29|   9emm|
|-104.972|19.1832|      82.0|      27.8|2017-08-29|   9emm|
|-104.924|19.1866|      82.0|      27.8|2017-08-29|   9emm|
|-104.876|19.1899|      82.0|      27.8|2017-08-29|   9emm|
|-104.828|19.1932|      81.6|      27.6|2017-08-29|   9emm|
| -104.78|19.1964|      81.6|      27.6|

In [89]:
# rename lat and lng columns to avoid ambiguity
dfw = dfw.withColumnRenamed('lat', 'lat_weather') \
    .withColumnRenamed('lng', 'lng_weather')

In [90]:
dfw.show()

+-----------+-----------+----------+----------+----------+-------+
|lng_weather|lat_weather|avg_tmpr_f|avg_tmpr_c| wthr_date|geohash|
+-----------+-----------+----------+----------+----------+-------+
|    -111.09|    18.6251|      80.7|      27.1|2017-08-29|   9e31|
|   -111.042|    18.6305|      80.7|      27.1|2017-08-29|   9e31|
|   -110.995|    18.6358|      80.7|      27.1|2017-08-29|   9e34|
|   -110.947|    18.6412|      80.9|      27.2|2017-08-29|   9e34|
|     -110.9|    18.6465|      80.9|      27.2|2017-08-29|   9e34|
|   -110.852|    18.6518|      80.9|      27.2|2017-08-29|   9e34|
|   -110.804|    18.6571|      80.9|      27.2|2017-08-29|   9e34|
|   -105.068|    19.1765|      82.4|      28.0|2017-08-29|   9emm|
|    -105.02|    19.1799|      82.0|      27.8|2017-08-29|   9emm|
|   -104.972|    19.1832|      82.0|      27.8|2017-08-29|   9emm|
|   -104.924|    19.1866|      82.0|      27.8|2017-08-29|   9emm|
|   -104.876|    19.1899|      82.0|      27.8|2017-08-29|   9

In [91]:
#drop rows where geohash and wthr_date together duplicate
dfw = dfw.dropDuplicates(['wthr_date', 'geohash'])

In [15]:
#save in parquet
dfw.write.parquet('weather_without_duplicates.parquet', mode='overwrite')

In [92]:
# read the cleaned weather dataframe from parquet
weather_df = spark.read.parquet("weather_without_duplicates.parquet")

In [17]:
#show where geohash and wthr_date together duplicate
weather_df.groupBy('wthr_date', 'geohash').count().where('count > 1').show()

+---------+-------+-----+
|wthr_date|geohash|count|
+---------+-------+-----+
+---------+-------+-----+



In [18]:
weather_df.show()

+-----------+-----------+----------+----------+----------+-------+
|lng_weather|lat_weather|avg_tmpr_f|avg_tmpr_c| wthr_date|geohash|
+-----------+-----------+----------+----------+----------+-------+
|   -101.238|    19.7128|      61.0|      16.1|2017-08-29|   9g80|
|   -71.6781|    18.4932|      78.7|      25.9|2017-08-29|   d7m1|
|   -104.629|    20.3924|      63.3|      17.4|2017-08-29|   9ets|
|   -88.2381|    20.6006|      82.5|      28.1|2017-08-29|   d59m|
|   -73.4394|    19.5544|      85.6|      29.8|2017-08-29|   d77z|
|   -81.2557|    25.1487|      84.3|      29.1|2017-08-29|   dhqp|
|   -101.569|    27.0982|      82.1|      27.8|2017-08-29|   9szf|
|    -110.47|    27.0939|      89.0|      31.7|2017-08-29|   9sc6|
|   -112.818|    27.2501|      89.6|      32.0|2017-08-29|   9kzg|
|    -103.66|    28.3193|      71.7|      22.1|2017-08-29|   9tn3|
|   -105.114|    28.3219|      73.8|      23.2|2017-08-29|   9tj3|
|   -102.302|    29.0287|      70.9|      21.6|2017-08-29|   9

In [93]:
#drop lng_weather and lat_weather
weather_df = weather_df.drop('lng_weather', 'lat_weather')

# Join df_geocoded and weather_df

In [94]:
# Left-join weather and restaurant_df data using the four-character geohash. Make sure to avoid data multiplication and keep your job idempotent
df_joined = restaurant_df.join(weather_df, 'geohash', 'left_outer')

In [95]:
df_joined.show()

+-------+------------+------------+--------------+-----------------------+-------+-------+------+-------+----------+----------+----------+
|geohash|          id|franchise_id|franchise_name|restaurant_franchise_id|country|   city|   lat|    lng|avg_tmpr_f|avg_tmpr_c| wthr_date|
+-------+------------+------------+--------------+-----------------------+-------+-------+------+-------+----------+----------+----------+
|   9yt8|120259084300|          13| The Firehouse|                  59829|     US|Branson|36.633|-93.272|      72.9|      22.7|2017-08-28|
|   9yt8|120259084300|          13| The Firehouse|                  59829|     US|Branson|36.633|-93.272|      78.5|      25.8|2017-08-19|
|   9yt8|120259084300|          13| The Firehouse|                  59829|     US|Branson|36.633|-93.272|      57.5|      14.2|2016-10-20|
|   9yt8|120259084300|          13| The Firehouse|                  59829|     US|Branson|36.633|-93.272|      72.6|      22.6|2016-10-19|
|   9yt8|120259084300|     

In [96]:
df_joined.count()

172591

In [97]:
#drop Nan values
df_joined = df_joined.dropna()

In [98]:
#drop rows where geohash and wthr_date together duplicate
df_joined = df_joined.dropDuplicates(['wthr_date', 'geohash'])

In [25]:
df_joined.count()

62897

In [26]:
df_joined.where(df_joined['geohash'] == 'dnpe').count()

92

In [27]:
#show where geohash equals dnpe
df_joined.where(df_joined['geohash'] == 'dnpe').show()

+-------+-----------+------------+---------------+-----------------------+-------+------+------+-------+----------+----------+----------+
|geohash|         id|franchise_id| franchise_name|restaurant_franchise_id|country|  city|   lat|    lng|avg_tmpr_f|avg_tmpr_c| wthr_date|
+-------+-----------+------------+---------------+-----------------------+-------+------+------+-------+----------+----------+----------+
|   dnpe|60129542152|           9|The Grill House|                  71555|     US|Dillon|34.436|-79.370|      71.7|      22.1|2016-10-01|
|   dnpe|60129542152|           9|The Grill House|                  71555|     US|Dillon|34.436|-79.370|      70.0|      21.1|2016-10-02|
|   dnpe|60129542152|           9|The Grill House|                  71555|     US|Dillon|34.436|-79.370|      70.2|      21.2|2016-10-03|
|   dnpe|60129542152|           9|The Grill House|                  71555|     US|Dillon|34.436|-79.370|      68.9|      20.5|2016-10-04|
|   dnpe|60129542152|           9|

In [28]:
#preserving data partitioning in the parquet format.
df_joined.write.partitionBy('city').parquet('final_partitioned.parquet', mode='overwrite')

# Testing


In [118]:
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

class TestUDFs(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("test_udfs").getOrCreate()

    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

    def test_latitude_udf(self):
        df = self.spark.createDataFrame(
            [
                ("New York", "USA", None),
                ("Paris", "France", None),
                ("Sydney", "Australia", 34.0522),
            ],
            ["city", "country", "latitude"],
        )

        # Call the UDF with a null value
        df = df.withColumn(
            "latitude", latitude_udf(col("city"), col("country"), col("latitude"))
        )

        # Check that the null value has been filled
        self.assertEqual(df.filter(df.city == "New York").first().latitude, "40.7127281")

        # Call the UDF with an existing value
        df = df.withColumn(
            "latitude", latitude_udf(col("city"), col("country"), col("latitude"))
        )

        # Check that the existing value is not changed
        self.assertEqual(df.filter(df.city == "Sydney").first().latitude, "34.0522")

    def test_longitude_udf(self):
        df = self.spark.createDataFrame(
            [
                ("New York", "USA", None),
                ("Paris", "France", None),
                ("Sydney", "Australia", 151.2093),
            ],
            ["city", "country", "longitude"],
        )

        # Call the UDF with a null value
        df = df.withColumn(
            "longitude", longitude_udf(col("city"), col("country"), col("longitude"))
        )

        # Check that the null value has been filled
        self.assertEqual(df.filter(df.city == "Paris").first().longitude, "2.320041")

        # Call the UDF with an existing value
        df = df.withColumn(
            "longitude", longitude_udf(col("city"), col("country"), col("longitude"))
        )

        # Check that the existing value is not changed
        self.assertEqual(df.filter(df.city == "Sydney").first().longitude, "151.2093")

    def test_geohash_udf(self):
        df = self.spark.createDataFrame(
            [
                (40.7128, -74.0060),
                (48.8566, 2.3522),
                (-33.8650, 151.2093),
            ],
            ["latitude", "longitude"],
        )

        # Call the UDF to generate geohashes
        df = df.withColumn(
            "geohash", geohash_udf(col("latitude"), col("longitude"))
        )

        # Check that the geohashes were generated correctly
        self.assertEqual(df.filter(df.latitude == 40.7128).first().geohash, gh.encode(40.7128, -74.0060, precision=4))
        self.assertEqual(df.filter(df.latitude == 48.8566).first().geohash, gh.encode(48.8566, 2.3522, precision=4))
        self.assertEqual(df.filter(df.latitude == -33.8650).first().geohash, gh.encode(-33.8650, 151.2093, precision=4))


In [119]:
# Create an instance of the test class
test_suite = unittest.TestLoader().loadTestsFromTestCase(TestUDFs)

# Run the tests
unittest.TextTestRunner(verbosity=2).run(test_suite)


test_geohash_udf (__main__.TestUDFs) ... ok
test_latitude_udf (__main__.TestUDFs) ... ok
test_longitude_udf (__main__.TestUDFs) ... ok

----------------------------------------------------------------------
Ran 3 tests in 183.135s

OK


<unittest.runner.TextTestResult run=3 errors=0 failures=0>