In [None]:
import os
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from geopy.distance import geodesic
import pyspark.sql.functions as F
from pyspark.sql.types import FloatType

spark = SparkSession.builder.master('local').getOrCreate()

In [None]:
def read_data(data_sample_path:str, poi_path:str) -> DataFrame:
    """
    Function to read data and poi given their path
    :param data_sample_path: The path to the data sample.
    :param poi_path: The path to the POI sample.
    :return df: data sample table.
    :return poi: POI table.
    """
    
    df = spark.read.options(
        header='True',
        inferSchema='True',
        delimiter=',',
    ).csv(os.path.expanduser(data_sample_path))
    
    poi = spark.read.options(
        header='True',
        inferSchema='True',
        delimiter=',',
    ).csv(os.path.expanduser(poi_path))
    
    
    #safely rename column names
    df = df.withColumnRenamed(" TimeSt", "TimeSt")\
            .withColumnRenamed("Latitude", "DataLatitude")\
            .withColumnRenamed("Longitude", "DataLongitude")
    
    poi = poi.withColumnRenamed(" Latitude", "POILatitude").withColumnRenamed("Longitude", "POILongitude")
    
    return df, poi

In [None]:
def cleanup(df: DataFrame) -> DataFrame:
    """
    Function to clean records with identical geoinfo and timest
    :param data_path: The path to the data sample.
    :return unsuspicious_data: cleaned data to be unsuspicious.
    """
    unsuspicious_data = df.selectExpr('*', 
        'count(*) over (partition by TimeSt, DataLatitude, DataLongitude) as cnt').filter(
        F.col('cnt') == 1).drop('cnt')
    
    return unsuspicious_data

In [None]:
def check_duplicates(cleaned_df: DataFrame, labeled_df: DataFrame) -> None:
    """
    Function to check if any id is being duplicated after labeling.
    :param cleaned_df: cleaned dataframe.
    :param labeled_df: labeled dataframe.
    """
    if cleaned_df.count() !=  labeled_df.count():
        print("WARNING: Duplicate found!")

In [None]:
@F.udf(returnType=FloatType())
def geodesic_udf(point_a, point_b):
    """
    A user defined function that calculates distance between two points
    :param point_a: (longitude, latitude) format representing a point.
    :param point_b: (longitude, latitude) format representing a point.
    :return distance between a and b
    """
    return geodesic(point_a, point_b).m

In [None]:
def label(df: DataFrame, poi: DataFrame) -> DataFrame:
    """
    Function to fine closest POI to each request
    :param df: data (request) records
    :param poi: POI data records
    :return min_labeled_data: data points assinged to the closest POI
    """
    aggregated_points = df.crossJoin(poi)

    labeled_data = aggregated_points.withColumn('distance', geodesic_udf(
        F.array("POILatitude", "POILongitude"), F.array("DataLatitude", "DataLongitude")))
    
    min_distance_per_id = labeled_data.groupby('_ID').min('distance').withColumnRenamed(
        "min(distance)", "min_distance").withColumnRenamed("_ID", "min_id")
    
    cond = [min_distance_per_id.min_id == labeled_data._ID, 
            min_distance_per_id.min_distance == labeled_data.distance]
    min_labeled_data = labeled_data.join(min_distance_per_id, cond, 'inner').select(labeled_data.columns)
    
    check_duplicates(df, min_labeled_data)
    
    return min_labeled_data

In [None]:
data_sample_path = '~/data/DataSample.csv'
poi_path = '~/data/POIList.csv'

In [None]:
df, poi = read_data(data_sample_path, poi_path)

In [None]:
cleaned_df = cleanup(df)
cleaned_df.show()

In [None]:
labeled_df = label(cleaned_df, poi)
labeled_df.show()

## WARNING: Duplicate found!

Since POI1 and POI2 have same lat and long, therefor if the closest request to each point is either one of them, duplicate happens for those IDs. 

The question did not mention what to do with those duplicates, so I keep the dataframe as is.

In total, there are `8748` duplicated IDs.

In [None]:
pandas_df = labeled_df.toPandas()
pandas_df[pandas_df['_ID'].duplicated()]