In [1]:
import os
import numpy as np
import geohash2 as gh

from pyspark.sql import SparkSession
from pyspark.sql.types import FloatType, StringType
from pyspark.sql.functions import udf, avg, split, col
from opencage.geocoder import OpenCageGeocode
from pprint import pprint
from dotenv import load_dotenv

## Initialization and Definations

In [55]:
# Init SparkSession
spark = SparkSession.builder\
    .appName('workshop1')\
    .master("local[*]") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "8g") \
    .getOrCreate()

# Init OpenCageGeocode
geocoder = OpenCageGeocode(os.environ.get('API_KEY'))

# Define a UDF function to find coordinates by country and city name
@udf(returnType=StringType())
def coordinates_udf(city, country):
    query = f"{city}, {country}"
    result = geocoder.geocode(query)
    latitude = "%.3f" % result[0]['geometry']['lat']
    longitude = "%.3f" % result[0]['geometry']['lng']
    return f"{latitude},{longitude}" 

# Define a UDF function to calculate a geohash by coordinates
@udf(returnType=StringType())
def geohash_udf(lat, lng):
    lat, lng = float(lat), float(lng)
    return gh.encode(lat, lng, precision=4)



In [56]:
spark

## Task 1: 
### Check restaurant data for incorrect (null) values (latitude and longitude). For incorrect values, map latitude and longitude from the OpenCage Geocoding API in a job via the REST API.

In [28]:
# Reading Restaurant Data from csv files
csv_files = [
    os.path.join("restaurant_csv", f) 
    for f in os.listdir("restaurant_csv") 
    if f.endswith('.csv')]

df_res = spark.read\
    .format("csv")\
    .option("header", "true")\
    .load(csv_files)

# Chek Incorrect Data
df_inc = df_res.filter("lat is NULL or lng is NULL")

# Finding coordinates of Incorrect Data
df_inc = df_inc.withColumn('coordinates', coordinates_udf('city', 'country'))

# Updating the latitude and longitude values and deleting the coordinates column
df_cor = df_inc.withColumn("lat", split(col("coordinates"), ",").getItem(0))\
               .withColumn("lng", split(col("coordinates"), ",").getItem(1))\
               .drop("coordinates")

# Union Restaurant and Corrected Data
df_res = df_res.filter("lat is NOT NULL and lng is NOT NULL").union(df_cor)

In [29]:
print(f"Incorrect Data:  {df_inc.count()} rows")
df_inc.show()

print(f"Corrected Data:  {df_cor.count()} rows")
df_cor.show()

print(f"Restaurant Data: {df_res.count()} rows") 
df_res.show()

print(f"Checking: ID = 85899345920") 
df_res.filter("id = 85899345920").show()

Incorrect Data:  1 rows
+-----------+------------+--------------+-----------------------+-------+------+----+----+--------------------+
|         id|franchise_id|franchise_name|restaurant_franchise_id|country|  city| lat| lng|         coordinates|
+-----------+------------+--------------+-----------------------+-------+------+----+----+--------------------+
|85899345920|           1|       Savoria|                  18952|     US|Dillon|null|null|34.4014089,-79.38...|
+-----------+------------+--------------+-----------------------+-------+------+----+----+--------------------+

Corrected Data:  1 rows
+-----------+------------+--------------+-----------------------+-------+------+----------+-----------+
|         id|franchise_id|franchise_name|restaurant_franchise_id|country|  city|       lat|        lng|
+-----------+------------+--------------+-----------------------+-------+------+----------+-----------+
|85899345920|           1|       Savoria|                  18952|     US|Dillon

## Task 2:
### Generate a geohash by latitude and longitude using a geohash library like geohash-java. Your geohash should be four characters long and placed in an extra column.

In [30]:
# Generate a geohash by latitude and longitude for Restaurant Data
df_res = df_res.withColumn('geohash', geohash_udf('lat', 'lng'))

# Save Restaurant Data to parquet file
df_res.write.parquet('restaurant.parquet', mode='overwrite')

# Read Restaurant Data from parquet file
df_res = spark.read.parquet("restaurant.parquet")

In [32]:
print(f"Restaurant Data with GeoHash: {df_res.count()} rows") 
df_res.show()

Restaurant Data with GeoHash: 1997 rows
+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+-------+
|          id|franchise_id|      franchise_name|restaurant_franchise_id|country|          city|   lat|     lng|geohash|
+------------+------------+--------------------+-----------------------+-------+--------------+------+--------+-------+
|197568495625|          10|    The Golden Spoon|                  24784|     US|       Decatur|34.578| -87.021|   dn4h|
| 17179869242|          59|         Azalea Cafe|                  10902|     FR|         Paris|48.861|   2.368|   u09t|
|214748364826|          27|     The Corner Cafe|                  92040|     US|    Rapid City|44.080|-103.250|   9xyd|
|154618822706|          51|        The Pizzeria|                  41484|     AT|        Vienna|48.213|  16.413|   u2ed|
|163208757312|          65|       Chef's Corner|                  96638|     GB|        London|51.495|  -0.191|   gcpu|


## Task 3: 
### Left-join weather and restaurant data using the four-character geohash. Make sure to avoid data multiplication and keep your job idempotent

In [7]:
# Reading Weather Data from parquet file
df_wea = spark.read\
    .option("mergeSchema", "true")\
    .option("recursiveFileLookup", "true")\
    .parquet('weather')

# Generate a geohash by latitude and longitude for Weather Data
df_wea = df_wea.withColumn('geohash', geohash_udf('lat', 'lng'))\
               
# Drop unnecessary columns
df_wea = df_wea.drop('lat', 'lng')

# Drop all duplicates by wthr_date and geohash in Weather Data
df_wea = df_wea.dropDuplicates(['wthr_date', 'geohash'])

# Save Weather Data to parquet file
df_wea.write.parquet('weather.parquet', mode='overwrite')

# Read Weather Data from parquet file
df_wea = spark.read.parquet("weather.parquet")

# Join Restaurant and Weather Data
res_data = df_res.join(df_wea, 'geohash', 'left_outer')

In [37]:
print(f"Weather Data:  {df_wea.count()} rows")
df_wea.show()

print(f"Result Data: {res_data.count()} rows") 
res_data.show()

Weather Data:  31882677 rows
+----------+----------+----------+-------+
|avg_tmpr_f|avg_tmpr_c| wthr_date|geohash|
+----------+----------+----------+-------+
|      83.6|      28.7|2017-08-29|   d75x|
|      75.8|      24.3|2017-08-29|   9et5|
|      81.6|      27.6|2017-08-29|   9gf0|
|      83.3|      28.5|2017-08-29|   d7t9|
|      84.4|      29.1|2017-08-29|   dk6b|
|      81.6|      27.6|2017-08-29|   9skp|
|      84.4|      29.1|2017-08-29|   9uf2|
|      83.6|      28.7|2017-08-29|   9ubc|
|      79.2|      26.2|2017-08-29|   9sy9|
|      82.3|      27.9|2017-08-29|   9ufs|
|      81.8|      27.7|2017-08-29|   9ufj|
|      82.8|      28.2|2017-08-29|   9v0c|
|      69.1|      20.6|2017-08-29|   9mjc|
|      92.1|      33.4|2017-08-29|   9mpq|
|      74.0|      23.3|2017-08-29|   9vrk|
|      81.7|      27.6|2017-08-29|   9mnq|
|      76.7|      24.8|2017-08-29|   9veq|
|      73.8|      23.2|2017-08-29|   9vtp|
|      70.0|      21.1|2017-08-29|   djb2|
|      73.9|      23.3|20

## Task 4:
### Store the enriched data (i.e., the joined data with all the fields from both datasets) in the local file system, preserving data partitioning in the parquet format.

In [38]:
# Writing Joined data to parquet file
res_data.write.partitionBy('country', 'city').parquet('result_data.parquet', mode='overwrite')

## Task 5:
### Implement tests

In [101]:
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

class TestUDFs(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession\
            .builder\
            .appName("test_udfs")\
            .config("spark.executor.memory", "4g") \
            .config("spark.driver.memory", "4g") \
            .getOrCreate()

    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

    def test_coordinates_udf(self):

        # Set up test data
        columns = ["country", "city", "coordinates"]
        data = [["USA", "New York",  ""],["France", "Paris",  ""],["Australia", "Sydney",  ""]]

        # Set up expected data
        expected_data = [
            ["USA", "New York",  "40.713,-74.006"],
            ["France", "Paris",  "48.853,2.348"],
            ["Australia", "Sydney",  "-33.870,151.208"]]

        # Create dataframes
        df = self.spark.createDataFrame(data, columns)
        expected_df = self.spark.createDataFrame(expected_data, columns)
        
        # Apply UDF
        df = df.withColumn("coordinates", coordinates_udf("country", "city"))

        # Gather result rows
        rows = df.collect()
        expected_rows = expected_df.collect()

        # Compare dataframes row by row
        for row_num, row in enumerate(rows):
            self.assertEqual(row, expected_rows[row_num])
     
    
    def test_geohash_udf(self):

        # Set up test data
        columns = ["latitude", "longitude", "geohash", ]
        data = [[40.713, -74.006,  ""],[48.853, 2.348,  ""],[-33.870, 151.208,  ""]]

        # Set up expected data
        expected_data = [
            [40.713, -74.006, "dr5r"],
            [48.853, 2.348, "u09t"],
            [-33.870, 151.208, "r3gx"]]

        # Create dataframes
        df = self.spark.createDataFrame(data, columns)
        expected_df = self.spark.createDataFrame(expected_data, columns)
        
        # Apply UDF
        df = df.withColumn("geohash", geohash_udf("latitude", "longitude"))

        # Gather result rows
        rows = df.collect()
        expected_rows = expected_df.collect()

        # Compare dataframes row by row
        for row_num, row in enumerate(rows):
            self.assertEqual(row, expected_rows[row_num])

In [102]:
# Create an instance of the test class
test_suite = unittest.TestLoader().loadTestsFromTestCase(TestUDFs)

# Run the tests
unittest.TextTestRunner(verbosity=2).run(test_suite)


test_coordinates_udf (__main__.TestUDFs) ... ok
test_geohash_udf (__main__.TestUDFs) ... ok

----------------------------------------------------------------------
Ran 2 tests in 121.888s

OK


<unittest.runner.TextTestResult run=2 errors=0 failures=0>