In [24]:
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, FloatType
from pyspark.sql.window import Window
from pyspark import keyword_only
from pyspark.ml import Transformer, Estimator, Model
from pyspark.ml.evaluation import Evaluator
from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, Params, Param, TypeConverters, HasLabelCol, HasPredictionCol, HasFeaturesCol, HasThreshold
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable, JavaMLReadable, JavaMLWritable
from pyspark.sql.types import *
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor, _GBTRegressorParams, GBTRegressionModel
from pyspark.ml.tuning import ParamGridBuilder

In [16]:
spark = SparkSession \
    .builder \
    .appName("geotab") \
    .getOrCreate()

In [17]:
file_location = "../../data/raw/train.csv"
file_type = "csv"

# CSV options
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","

# The applied options are for CSV files. For other file types, these will be ignored.
spark_df = spark.read.format(file_type) \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(file_location)


In [18]:
spark_df.columns

['RowId',
 'IntersectionId',
 'Latitude',
 'Longitude',
 'EntryStreetName',
 'ExitStreetName',
 'EntryHeading',
 'ExitHeading',
 'Hour',
 'Weekend',
 'Month',
 'Path',
 'TotalTimeStopped_p20',
 'TotalTimeStopped_p40',
 'TotalTimeStopped_p50',
 'TotalTimeStopped_p60',
 'TotalTimeStopped_p80',
 'TimeFromFirstStop_p20',
 'TimeFromFirstStop_p40',
 'TimeFromFirstStop_p50',
 'TimeFromFirstStop_p60',
 'TimeFromFirstStop_p80',
 'DistanceToFirstStop_p20',
 'DistanceToFirstStop_p40',
 'DistanceToFirstStop_p50',
 'DistanceToFirstStop_p60',
 'DistanceToFirstStop_p80',
 'City']

In [19]:
print((spark_df.count(), len(spark_df.columns)))

(856387, 28)


In [20]:
spark_df.select('IntersectionId').distinct().show()

+--------------+
|IntersectionId|
+--------------+
|           471|
|           496|
|           148|
|           463|
|          1238|
|           833|
|          1088|
|          1342|
|          1580|
|          1591|
|          1645|
|          1829|
|          1959|
|          2122|
|          2142|
|          2366|
|          2659|
|          2866|
|           392|
|           243|
+--------------+
only showing top 20 rows



In [21]:
spark_df.select([F.count(F.when(F.isnan(c), c)).alias(c) for c in spark_df.columns]).show()

+-----+--------------+--------+---------+---------------+--------------+------------+-----------+----+-------+-----+----+--------------------+--------------------+--------------------+--------------------+--------------------+---------------------+---------------------+---------------------+---------------------+---------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+----+
|RowId|IntersectionId|Latitude|Longitude|EntryStreetName|ExitStreetName|EntryHeading|ExitHeading|Hour|Weekend|Month|Path|TotalTimeStopped_p20|TotalTimeStopped_p40|TotalTimeStopped_p50|TotalTimeStopped_p60|TotalTimeStopped_p80|TimeFromFirstStop_p20|TimeFromFirstStop_p40|TimeFromFirstStop_p50|TimeFromFirstStop_p60|TimeFromFirstStop_p80|DistanceToFirstStop_p20|DistanceToFirstStop_p40|DistanceToFirstStop_p50|DistanceToFirstStop_p60|DistanceToFirstStop_p80|City|
+-----+--------------+--------+---------+---------------+--------------+----

In [22]:
spark_df.dtypes

[('RowId', 'int'),
 ('IntersectionId', 'int'),
 ('Latitude', 'double'),
 ('Longitude', 'double'),
 ('EntryStreetName', 'string'),
 ('ExitStreetName', 'string'),
 ('EntryHeading', 'string'),
 ('ExitHeading', 'string'),
 ('Hour', 'int'),
 ('Weekend', 'int'),
 ('Month', 'int'),
 ('Path', 'string'),
 ('TotalTimeStopped_p20', 'double'),
 ('TotalTimeStopped_p40', 'double'),
 ('TotalTimeStopped_p50', 'double'),
 ('TotalTimeStopped_p60', 'double'),
 ('TotalTimeStopped_p80', 'double'),
 ('TimeFromFirstStop_p20', 'double'),
 ('TimeFromFirstStop_p40', 'double'),
 ('TimeFromFirstStop_p50', 'double'),
 ('TimeFromFirstStop_p60', 'double'),
 ('TimeFromFirstStop_p80', 'double'),
 ('DistanceToFirstStop_p20', 'double'),
 ('DistanceToFirstStop_p40', 'double'),
 ('DistanceToFirstStop_p50', 'double'),
 ('DistanceToFirstStop_p60', 'double'),
 ('DistanceToFirstStop_p80', 'double'),
 ('City', 'string')]

In [23]:
spark_df.select([(F.count(F.when(F.col(c).isNull(), c))/spark_df.count()).alias(c) for c in spark_df.columns]).collect()

[Row(RowId=0.0, IntersectionId=0.0, Latitude=0.0, Longitude=0.0, EntryStreetName=0.00951439010634211, ExitStreetName=0.007341307142681989, EntryHeading=0.0, ExitHeading=0.0, Hour=0.0, Weekend=0.0, Month=0.0, Path=0.0, TotalTimeStopped_p20=0.0, TotalTimeStopped_p40=0.0, TotalTimeStopped_p50=0.0, TotalTimeStopped_p60=0.0, TotalTimeStopped_p80=0.0, TimeFromFirstStop_p20=0.0, TimeFromFirstStop_p40=0.0, TimeFromFirstStop_p50=0.0, TimeFromFirstStop_p60=0.0, TimeFromFirstStop_p80=0.0, DistanceToFirstStop_p20=0.0, DistanceToFirstStop_p40=0.0, DistanceToFirstStop_p50=0.0, DistanceToFirstStop_p60=0.0, DistanceToFirstStop_p80=0.0, City=0.0)]

In [None]:
class NullThresholdRemover(Transformer, HasThreshold, HasInputCols, DefaultParamsReadable, DefaultParamsWritable):

    @keyword_only
    def __init__(self, inputCols=None, threshold=0.3) -> None:
        super().__init__()
        self._setDefault(inputCols=inputCols, Colsthreshold=threshold)
        kwargs = self._input_kwargs
        self.setParams(**kwargs)
    
    @keyword_only
    def setParams(self, inputCols=None, threshold=0.3):
        kwargs = self._input_kwargs
        self._set(**kwargs)
    
    def _transform(self, dataset):
        threshold = self.getThreshold()
        cols = dataset.columns
        datasetRowCount = dataset.count()
        inputCols = list(set(self.g))
        
        colsNullCount = dataset.select([(F.count(F.when(F.col(c).isNull(), c))/datasetRowCount).alias(c) for c in dataset.columns]).collect()

        colsNullCount = [row.asDict() for row in aggregated_row]
        