# Preprocess Data

---

In this notebook, we will edit the data to include only columns that we want. 

There are two types of variables that we want to keep:
- Those which are practical for the user to input when looking up a price prediction
- Those which would actually contribute to the price of a flight

---

## Data Dictionary Reminder

below is a reminder of what variables we have, 

- legId: An identifier for the flight. // Not needed
- searchDate: The date (YYYY-MM-DD) on which this entry was taken from Expedia. // Change search date and flight date into one variable - days before flight
- flightDate: The date (YYYY-MM-DD) of the flight. // Potentially separate this into just the day. Month is meaningless as the data is not even for a full year
- startingAirport: Three-character IATA airport code for the initial location. // One hot encoding
- destinationAirport: Three-character IATA airport code for the arrival location. // One hot encoding
- fareBasisCode: The fare basis code. // Not required
- travelDuration: The travel duration in hours and minutes. // Not required
- elapsedDays: The number of elapsed days (usually 0). // Not required
- isBasicEconomy: Boolean for whether the ticket is for basic economy. // Change into binary
- isRefundable: Boolean for whether the ticket is refundable. // Change into binary 
- isNonStop: Boolean for whether the flight is non-stop. // Change into binary 
- baseFare: The price of the ticket (in USD). // Not required
- totalFare: The price of the ticket (in USD) including taxes and other fees. // No change needed
- seatsRemaining: Integer for the number of seats remaining. // Not required
- totalTravelDistance: The total travel distance in miles. This data is sometimes missing. // Not required
- segmentsDepartureTimeEpochSeconds: String containing the departure time (Unix time) for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsDepartureTimeRaw: String containing the departure time (ISO 8601 format: YYYY-MM-DDThh:mm:ss.000±[hh]:00) for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsArrivalTimeEpochSeconds: String containing the arrival time (Unix time) for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsArrivalTimeRaw: String containing the arrival time (ISO 8601 format: YYYY-MM-DDThh:mm:ss.000±[hh]:00) for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsArrivalAirportCode: String containing the IATA airport code for the arrival location for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsDepartureAirportCode: String containing the IATA airport code for the departure location for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsAirlineName: String containing the name of the airline that services each leg of the trip. The entries for each of the legs are separated by '||'. // Most trips use the same airlines the whole way, drop trips that don't do this and change trips that do into one hot encoding. We can also use this to count how many layovers there are.
- segmentsAirlineCode: String containing the two-letter airline code that services each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsEquipmentDescription: String containing the type of airplane used for each leg of the trip (e.g. "Airbus A321" or "Boeing 737-800"). The entries for each of the legs are separated by '||'. // Not required
- segmentsDurationInSeconds: String containing the duration of the flight (in seconds) for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsDistance: String containing the distance traveled (in miles) for each leg of the trip. The entries for each of the legs are separated by '||'. // Not required
- segmentsCabinCode: String containing the cabin for each leg of the trip (e.g. "coach"). The entries for each of the legs are separated by '||'. // Not required

## Load Spark and Data

In [145]:
from itertools import groupby

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, datediff, dayofmonth, when, udf, split
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml import Pipeline

from pyspark.sql.types import IntegerType, StringType

In [78]:
spark = SparkSession.builder.appName("flights").getOrCreate()

# REPLACE WITH DATA FILEPATH
DATA_PATH = "../data/itineraries.csv"

df = spark.read.csv(DATA_PATH, header=True, inferSchema=True)

## Drop the Uneccessary Columns

In [79]:
cols_to_drop = ["legId", "fareBasisCode", "travelDuration", "elapsedDays", "baseFare", "seatsRemaining", "totalTravelDistance",\
                "segmentsDepartureTimeEpochSeconds","segmentsDepartureTimeRaw", "segmentsArrivalTimeEpochSeconds", "segmentsArrivalTimeRaw",\
                "segmentsArrivalAirportCode", "segmentsDepartureAirportCode","segmentsAirlineCode", "segmentsEquipmentDescription", \
                "segmentsDurationInSeconds", "segmentsDistance", "segmentsCabinCode"]

In [80]:
# drop cols
df = df.drop(*cols_to_drop)

In [81]:
df.show()

+----------+----------+---------------+------------------+--------------+------------+---------+---------+--------------------+
|searchDate|flightDate|startingAirport|destinationAirport|isBasicEconomy|isRefundable|isNonStop|totalFare| segmentsAirlineName|
+----------+----------+---------------+------------------+--------------+------------+---------+---------+--------------------+
|2022-04-16|2022-04-17|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|
|2022-04-16|2022-04-17|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|
|2022-04-16|2022-04-17|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|
|2022-04-16|2022-04-17|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|
|2022-04-16|2022-04-17|            ATL|               BOS|         false|       false|     true|    248.

## Days before flight Column

In [82]:
df = df.withColumn("days_before_flight", datediff(col("flightDate"), col("searchDate")))

## DayofMonth Column

In [83]:
df = df.withColumn("day", dayofmonth(col("flightDate")))

## Drop: searchDay and FlightDay

In [84]:
df = df.drop(*["searchDate", "flightDate"])

In [85]:
df.show()

+---------------+------------------+--------------+------------+---------+---------+--------------------+------------------+---+
|startingAirport|destinationAirport|isBasicEconomy|isRefundable|isNonStop|totalFare| segmentsAirlineName|days_before_flight|day|
+---------------+------------------+--------------+------------+---------+---------+--------------------+------------------+---+
|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|                 1| 17|
|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|                 1| 17|
|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|                 1| 17|
|            ATL|               BOS|         false|       false|     true|    248.6|               Delta|                 1| 17|
|            ATL|               BOS|         false|       false|     true|    248.6|             

## One Hot Encode Starting and Destination Airport

In [86]:
# string indexers
string_indexer = StringIndexer(inputCol="startingAirport", outputCol="startingAirport_index")
string_indexer_2 = StringIndexer(inputCol="destinationAirport", outputCol="destinationAirport_index")

# encoders
encoder = OneHotEncoder(inputCol="startingAirport_index", outputCol="startingAirport_encoded")
encoder_2 = OneHotEncoder(inputCol="destinationAirport_index", outputCol="destinationAirport_encoded")

# pipeline
pipeline = Pipeline(stages=[string_indexer, string_indexer_2, encoder, encoder_2])
pipeline_model = pipeline.fit(df)
df = pipeline_model.transform(df)

## Drop: startingAirport, destinationAirport, startingAirport_index, destinationAirport_index

In [87]:
df = df.drop(*["startingAirport", "destinationAirport", "startingAirport_index", "destinationAirport_index"])

## Convert Boolean Columns into Binary

In [88]:
bool_columns = ["isBasicEconomy", "isRefundable", "isNonStop"]

In [89]:
for col_name in bool_columns:
    df = df.withColumn(col_name, when(col(col_name), 1).otherwise(0))

In [92]:
df.show()

+--------------+------------+---------+---------+--------------------+------------------+---+-----------------------+--------------------------+
|isBasicEconomy|isRefundable|isNonStop|totalFare| segmentsAirlineName|days_before_flight|day|startingAirport_encoded|destinationAirport_encoded|
+--------------+------------+---------+---------+--------------------+------------------+---+-----------------------+--------------------------+
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (15,[2],[1.0])|
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (15,[2],[1.0])|
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (15,[2],[1.0])|
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (

## Create Number of Layovers and Airline Name
## TODO: FIX THIS

In [99]:
def all_equal(iterable):
    g = groupby(iterable)
    return next(g, True) and not next(g, False)

# get airline name
def get_airline(segmentsAirlineName):
    splits = segmentsAirlineName.split("||")
    if all_equal(splits):
        return splits[0]
    else:
        return "Multi-Airline"

# get number of splits
def get_nsplits(segmentsAirlineName):
    splits = segmentsAirlineName.split("||")
    return len(splits)

In [106]:
# register as udf 
airline_udf = udf(get_airline)
nsplit_udf = udf(get_nsplits, IntegerType())

In [None]:
df.withColumn("age_plus_5", col("segmentsAirlineName") + 5)

In [141]:
test_df = df.select("segmentsAirlineName").limit(3)

In [142]:
test_df.withColumn("airline_name", airline_udf(col("segmentsAirlineName"))).show()

# this shit aint working
#df.withColumn("airline_name", airline_udf(df["segmentsAirlineName"])).show()
#df.withColumn("nStops", nsplit_udf(df["segmentsAirlineName"])).show()

24/03/19 16:24:20 ERROR Executor: Exception in task 0.0 in stage 95.0 (TID 81)1]
org.apache.spark.SparkException: 
Bad data in pyspark.daemon's standard output. Invalid port number:
  1231975525 (0x496e7465)
Python command to execute the daemon was:
  python3 -m pyspark.daemon
Check that you don't have any unexpected modules or libraries in
your PYTHONPATH:
  /Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip:/Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/python/lib/py4j-0.10.9.7-src.zip:/Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/jars/spark-core_2.12-3.5.1.jar
Also, check if you have a sitecustomize.py module in your python path,
or in your python installation, that is printing to standard output
	at org.apache.spark.api.python.PythonWorkerFactory.startDaemon(PythonWorkerFactory.scala:267)
	at org.apache.spark.api.python.PythonWorkerFactory.createThroughDaemon(PythonWorkerFa

Py4JJavaError: An error occurred while calling o1608.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 95.0 failed 1 times, most recent failure: Lost task 0.0 in stage 95.0 (TID 81) (192.168.110.41 executor driver): org.apache.spark.SparkException: 
Bad data in pyspark.daemon's standard output. Invalid port number:
  1231975525 (0x496e7465)
Python command to execute the daemon was:
  python3 -m pyspark.daemon
Check that you don't have any unexpected modules or libraries in
your PYTHONPATH:
  /Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip:/Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/python/lib/py4j-0.10.9.7-src.zip:/Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/jars/spark-core_2.12-3.5.1.jar
Also, check if you have a sitecustomize.py module in your python path,
or in your python installation, that is printing to standard output
	at org.apache.spark.api.python.PythonWorkerFactory.startDaemon(PythonWorkerFactory.scala:267)
	at org.apache.spark.api.python.PythonWorkerFactory.createThroughDaemon(PythonWorkerFactory.scala:139)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:107)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:54)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:530)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4332)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3314)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4322)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4320)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4320)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:3314)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:3537)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
	at sun.reflect.GeneratedMethodAccessor100.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:750)
Caused by: org.apache.spark.SparkException: 
Bad data in pyspark.daemon's standard output. Invalid port number:
  1231975525 (0x496e7465)
Python command to execute the daemon was:
  python3 -m pyspark.daemon
Check that you don't have any unexpected modules or libraries in
your PYTHONPATH:
  /Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip:/Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/python/lib/py4j-0.10.9.7-src.zip:/Users/eddie/opt/anaconda3/envs/bigdata/lib/python3.12/site-packages/pyspark/jars/spark-core_2.12-3.5.1.jar
Also, check if you have a sitecustomize.py module in your python path,
or in your python installation, that is printing to standard output
	at org.apache.spark.api.python.PythonWorkerFactory.startDaemon(PythonWorkerFactory.scala:267)
	at org.apache.spark.api.python.PythonWorkerFactory.createThroughDaemon(PythonWorkerFactory.scala:139)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:107)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:54)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


## Save the final Data

In [None]:
# Where the file should be saved
TARGET_PATH = "../data/itineraries_processed.csv"

In [152]:
df.write.csv(TARGET_PATH, header=True)

+--------------+------------+---------+---------+--------------------+------------------+---+-----------------------+--------------------------+
|isBasicEconomy|isRefundable|isNonStop|totalFare| segmentsAirlineName|days_before_flight|day|startingAirport_encoded|destinationAirport_encoded|
+--------------+------------+---------+---------+--------------------+------------------+---+-----------------------+--------------------------+
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (15,[2],[1.0])|
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (15,[2],[1.0])|
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (15,[2],[1.0])|
|             0|           0|        1|    248.6|               Delta|                 1| 17|         (15,[1],[1.0])|            (