# Data preaparation and xgboost regressor training

In [1]:
!pip install --upgrade xgboost



In [2]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3070 (UUID: GPU-5f5a84a5-7cf6-72b1-17eb-40f450c3509a)


# Spark set up

In [3]:
import os
import platform
local_os = platform.system()

if local_os == 'Linux':
    os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
    os.environ["SPARK_HOME"] = "/content/spark-3.3.1-bin-hadoop3"
elif local_os == 'Windows':
    os.environ["JAVA_HOME"] = "C:/Program Files/Java/jdk-19/"
else:
    os.environ["JAVA_HOME"] = "/opt/homebrew/opt/openjdk/"

In [4]:
if local_os == 'Linux':
    !apt-get install openjdk-8-jdk-headless -qq > /dev/null
    !wget https://dlcdn.apache.org/spark/spark-3.3.1/spark-3.3.1-bin-hadoop3.tgz
    !tar zxvf /content/spark-3.3.1-bin-hadoop3.tgz
    !pip install -q findspark
    import findspark
    findspark.init()

In [5]:
# Mount google drive
if local_os == 'Linux':
    from google.colab import drive
    drive.mount('/content/drive')

In [6]:
from pyspark.sql import SparkSession

In [7]:
spark = SparkSession.builder\
        .master("local")\
        .appName("flights")\
        .config("spark.executor.cores", 4) \
        .config("spark.executor.memory", "40g") \
        .config("spark.driver.memory", "40g") \
        .config("spark.driver.maxResultSize", "20g") \
        .getOrCreate()

In [8]:
spark

# Libs imports

In [9]:
from pyspark.sql.functions import desc, isnan, when, count, col, isnull
from collections import Counter

# Data loading

In [None]:
from pathlib import Path

source = "../data/"
print(source)
source_path = Path(source).glob('*.parquet')
file_names = sorted(list(source_path))
file_names

In [11]:
def merge_data(file_names):
  first_file = file_names.pop(0)
  data = spark.read.parquet(first_file.as_posix())
  for file_name in file_names:
    temp_data = spark.read.parquet(file_name.as_posix())
    data = data.union(temp_data)
    print(file_name.as_posix())
  return data
data = merge_data(file_names)

D:/data_sets/airlines/Combined_Flights_2019.parquet
D:/data_sets/airlines/Combined_Flights_2020.parquet
D:/data_sets/airlines/Combined_Flights_2021.parquet
D:/data_sets/airlines/Combined_Flights_2022.parquet


In [12]:
def count_dtypes(dataframe):
    col_dict = dict(dataframe.dtypes).values()
    c = Counter(col_dict)
    return c.most_common()
count_dtypes(data)

[('bigint', 23),
 ('double', 19),
 ('string', 17),
 ('boolean', 2),
 ('timestamp', 1)]

In [13]:
data.printSchema()

root
 |-- FlightDate: timestamp (nullable = true)
 |-- Airline: string (nullable = true)
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- Cancelled: boolean (nullable = true)
 |-- Diverted: boolean (nullable = true)
 |-- CRSDepTime: long (nullable = true)
 |-- DepTime: double (nullable = true)
 |-- DepDelayMinutes: double (nullable = true)
 |-- DepDelay: double (nullable = true)
 |-- ArrTime: double (nullable = true)
 |-- ArrDelayMinutes: double (nullable = true)
 |-- AirTime: double (nullable = true)
 |-- CRSElapsedTime: double (nullable = true)
 |-- ActualElapsedTime: double (nullable = true)
 |-- Distance: double (nullable = true)
 |-- Year: long (nullable = true)
 |-- Quarter: long (nullable = true)
 |-- Month: long (nullable = true)
 |-- DayofMonth: long (nullable = true)
 |-- DayOfWeek: long (nullable = true)
 |-- Marketing_Airline_Network: string (nullable = true)
 |-- Operated_or_Branded_Code_Share_Partners: string (nullable = true)
 |-- DOT_ID_Mar

# EDA

In [14]:
data.count(), len(data.columns)

(29193782, 62)

In [15]:
data.show(5)

+-------------------+-----------------+------+----+---------+--------+----------+-------+---------------+--------+-------+---------------+-------+--------------+-----------------+--------+----+-------+-----+----------+---------+-------------------------+---------------------------------------+------------------------+---------------------------+-------------------------------+-----------------+------------------------+---------------------------+-----------+-------------------------------+---------------+------------------+------------------+--------------+-----------+---------------+---------------+---------+-------------+----------------+----------------+------------+---------+-------------+-------------+-------+--------+--------------------+----------+-------+---------+--------+------+----------+--------+--------+------------------+----------+-------------+------------------+-----------------+
|         FlightDate|          Airline|Origin|Dest|Cancelled|Diverted|CRSDepTime|DepTime|D

## Data types exploration
Handling categorical values to train model

In [16]:
data = data.drop(data.__index_level_0__)

In [17]:
dtypes = set()
[dtypes.add(item[1]) for item in data.dtypes]
dtypes

{'bigint', 'boolean', 'double', 'string', 'timestamp'}

### Boolean exploration

In [18]:
[item[0] for item in data.dtypes if item[1] == 'boolean']

['Cancelled', 'Diverted']

In [19]:
data.groupBy('Cancelled').count().orderBy('count').collect()

[Row(Cancelled=True, count=777267), Row(Cancelled=False, count=28416515)]

In [20]:
data.groupBy('Diverted').count().orderBy('count').collect()

[Row(Diverted=True, count=68349), Row(Diverted=False, count=29125433)]

### String exploration

In [21]:
str_columns = [item[0] for item in data.dtypes if item[1] == 'string']

In [22]:
data.select(str_columns).show(5)

+-----------------+------+----+-------------------------+---------------------------------------+---------------------------+-----------------+---------------------------+-----------+--------------+-----------+---------------+------------+---------+-------------+----------+----------+
|          Airline|Origin|Dest|Marketing_Airline_Network|Operated_or_Branded_Code_Share_Partners|IATA_Code_Marketing_Airline|Operating_Airline|IATA_Code_Operating_Airline|Tail_Number|OriginCityName|OriginState|OriginStateName|DestCityName|DestState|DestStateName|DepTimeBlk|ArrTimeBlk|
+-----------------+------+----+-------------------------+---------------------------------------+---------------------------+-----------------+---------------------------+-----------+--------------+-----------+---------------+------------+---------+-------------+----------+----------+
|Endeavor Air Inc.|   ABY| ATL|                       DL|                           DL_CODESHARE|                         DL|               9E

### Timestamp exploration

In [23]:
timestamp_columns = [item[0] for item in data.dtypes if item[1] == 'timestamp']

In [24]:
data.select(timestamp_columns).show(5)

+-------------------+
|         FlightDate|
+-------------------+
|2018-01-22 18:00:00|
|2018-01-23 18:00:00|
|2018-01-24 18:00:00|
|2018-01-25 18:00:00|
|2018-01-26 18:00:00|
+-------------------+
only showing top 5 rows



## Null values handling

In [25]:
null_count = data.select([count(when(col(c).isNull(), c)).alias(c) for c in data.columns])
null_count.show()

+----------+-------+------+----+---------+--------+----------+-------+---------------+--------+-------+---------------+-------+--------------+-----------------+--------+----+-------+-----+----------+---------+-------------------------+---------------------------------------+------------------------+---------------------------+-------------------------------+-----------------+------------------------+---------------------------+-----------+-------------------------------+---------------+------------------+------------------+--------------+-----------+---------------+---------------+---------+-------------+----------------+----------------+------------+---------+-------------+-------------+-------+--------+--------------------+----------+-------+---------+--------+------+----------+--------+--------+------------------+----------+-------------+------------------+
|FlightDate|Airline|Origin|Dest|Cancelled|Diverted|CRSDepTime|DepTime|DepDelayMinutes|DepDelay|ArrTime|ArrDelayMinutes|AirTime|

In [26]:
null_values = null_count.collect()[0].asDict()

In [27]:
only_nulls = null_count.select([key for key in null_values if null_values[key] != 0])
only_nulls.show()

+-------+---------------+--------+-------+---------------+-------+--------------+-----------------+-----------+--------+--------------------+-------+---------+--------+------+--------+--------+------------------+------------------+
|DepTime|DepDelayMinutes|DepDelay|ArrTime|ArrDelayMinutes|AirTime|CRSElapsedTime|ActualElapsedTime|Tail_Number|DepDel15|DepartureDelayGroups|TaxiOut|WheelsOff|WheelsOn|TaxiIn|ArrDelay|ArrDel15|ArrivalDelayGroups|DivAirportLandings|
+-------+---------------+--------+-------+---------------+-------+--------------+-----------------+-----------+--------+--------------------+-------+---------+--------+------+--------+--------+------------------+------------------+
| 761652|         763084|  763084| 786177|         846183| 852561|            22|           845637|     267613|  763084|              763084| 780561|   780551|  793133|793143|  846183|  846183|            846183|                90|
+-------+---------------+--------+-------+---------------+-------+------

## Drop all null values

In [28]:
data_no_na = data.dropna()

# Feature selection

In [29]:
#Drop timestamp data
data_no_na = data_no_na.drop('FlightDate')

In [30]:
data_no_na.count(), len(data_no_na.columns)

(28339510, 60)

In [31]:
cols_to_encode = ['Airline', 'Origin', 'Dest']
[str_columns.remove(x) for x in cols_to_encode]
None

In [32]:
data_no_na = data_no_na.drop(*str_columns)
data_no_na.count(), len(data_no_na.columns)

(28339510, 46)

In [33]:
data_no_na.printSchema()

root
 |-- Airline: string (nullable = true)
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- Cancelled: boolean (nullable = true)
 |-- Diverted: boolean (nullable = true)
 |-- CRSDepTime: long (nullable = true)
 |-- DepTime: double (nullable = true)
 |-- DepDelayMinutes: double (nullable = true)
 |-- DepDelay: double (nullable = true)
 |-- ArrTime: double (nullable = true)
 |-- ArrDelayMinutes: double (nullable = true)
 |-- AirTime: double (nullable = true)
 |-- CRSElapsedTime: double (nullable = true)
 |-- ActualElapsedTime: double (nullable = true)
 |-- Distance: double (nullable = true)
 |-- Year: long (nullable = true)
 |-- Quarter: long (nullable = true)
 |-- Month: long (nullable = true)
 |-- DayofMonth: long (nullable = true)
 |-- DayOfWeek: long (nullable = true)
 |-- DOT_ID_Marketing_Airline: long (nullable = true)
 |-- Flight_Number_Marketing_Airline: long (nullable = true)
 |-- DOT_ID_Operating_Airline: long (nullable = true)
 |-- Flight_Number_

In [34]:
dep_cols = [x for x in data_no_na.columns if "Dep" in x]
arr_cols = [x for x in data_no_na.columns if "Arr" in x]

In [35]:
time_columns = dep_cols + arr_cols
time_columns

['CRSDepTime',
 'DepTime',
 'DepDelayMinutes',
 'DepDelay',
 'DepDel15',
 'DepartureDelayGroups',
 'ArrTime',
 'ArrDelayMinutes',
 'CRSArrTime',
 'ArrDelay',
 'ArrDel15',
 'ArrivalDelayGroups']

In [36]:
time_columns.remove('DepDelay')
# time_columns.remove('ArrTime')
time_columns

['CRSDepTime',
 'DepTime',
 'DepDelayMinutes',
 'DepDel15',
 'DepartureDelayGroups',
 'ArrTime',
 'ArrDelayMinutes',
 'CRSArrTime',
 'ArrDelay',
 'ArrDel15',
 'ArrivalDelayGroups']

In [37]:
data_no_na.select('DepDelay').describe().show()

+-------+-----------------+
|summary|         DepDelay|
+-------+-----------------+
|  count|         28339510|
|   mean| 9.23847367156313|
| stddev|47.10140749050439|
|    min|          -1280.0|
|    max|           7223.0|
+-------+-----------------+



In [38]:
data_no_na.select('ArrDelay').describe().show()

+-------+------------------+
|summary|          ArrDelay|
+-------+------------------+
|  count|          28339510|
|   mean|3.6081859213514984|
| stddev| 49.28063347282263|
|    min|           -1290.0|
|    max|            7232.0|
+-------+------------------+



In [39]:
data_no_na = data_no_na.drop(*time_columns)

In [40]:
data_no_na.printSchema()

root
 |-- Airline: string (nullable = true)
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- Cancelled: boolean (nullable = true)
 |-- Diverted: boolean (nullable = true)
 |-- DepDelay: double (nullable = true)
 |-- AirTime: double (nullable = true)
 |-- CRSElapsedTime: double (nullable = true)
 |-- ActualElapsedTime: double (nullable = true)
 |-- Distance: double (nullable = true)
 |-- Year: long (nullable = true)
 |-- Quarter: long (nullable = true)
 |-- Month: long (nullable = true)
 |-- DayofMonth: long (nullable = true)
 |-- DayOfWeek: long (nullable = true)
 |-- DOT_ID_Marketing_Airline: long (nullable = true)
 |-- Flight_Number_Marketing_Airline: long (nullable = true)
 |-- DOT_ID_Operating_Airline: long (nullable = true)
 |-- Flight_Number_Operating_Airline: long (nullable = true)
 |-- OriginAirportID: long (nullable = true)
 |-- OriginAirportSeqID: long (nullable = true)
 |-- OriginCityMarketID: long (nullable = true)
 |-- OriginStateFips: long (n

## One hot enconding

In [41]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder, StandardScaler, VectorAssembler
from pyspark.ml import Pipeline

### Vectorized implementation of one hot encoding

In [42]:
numerical_cols = [item[0] for item in data_no_na.dtypes if item[1] not in ['string', 'boolean']]

bigint represented as long by spark

In [43]:
count_dtypes(data_no_na)

[('bigint', 20), ('double', 10), ('string', 3), ('boolean', 2)]

In [44]:
indexed_out_names = [x + '_num' for x in cols_to_encode]
vector_out_names = [x + '_vec' for x in cols_to_encode]

In [45]:
idxer = StringIndexer(inputCols=cols_to_encode, outputCols=indexed_out_names)
enconder = OneHotEncoder(inputCols=idxer.getOutputCols(), outputCols=vector_out_names)

In [46]:
# TODO: Design Transformer to convert vectorized columns to one hot encoded columns

In [47]:
pipeline = Pipeline(stages=[idxer, enconder])

In [48]:
preprocess = pipeline.fit(data_no_na)

In [49]:
df_transformed = preprocess.transform(data_no_na)
df_transformed.first()

Row(Airline='Endeavor Air Inc.', Origin='ABY', Dest='ATL', Cancelled=False, Diverted=False, DepDelay=-5.0, AirTime=38.0, CRSElapsedTime=62.0, ActualElapsedTime=59.0, Distance=145.0, Year=2018, Quarter=1, Month=1, DayofMonth=23, DayOfWeek=2, DOT_ID_Marketing_Airline=19790, Flight_Number_Marketing_Airline=3298, DOT_ID_Operating_Airline=20363, Flight_Number_Operating_Airline=3298, OriginAirportID=10146, OriginAirportSeqID=1014602, OriginCityMarketID=30146, OriginStateFips=13, OriginWac=34, DestAirportID=10397, DestAirportSeqID=1039707, DestCityMarketID=30397, DestStateFips=13, DestWac=34, TaxiOut=14.0, WheelsOff=1211.0, WheelsOn=1249.0, TaxiIn=7.0, DistanceGroup=1, DivAirportLandings=0.0, Airline_num=8.0, Origin_num=271.0, Dest_num=0.0, Airline_vec=SparseVector(27, {8: 1.0}), Origin_vec=SparseVector(387, {271: 1.0}), Dest_vec=SparseVector(387, {0: 1.0}))

In [50]:
count_dtypes(df_transformed)

[('bigint', 20), ('double', 13), ('string', 3), ('vector', 3), ('boolean', 2)]

In [51]:
from pyspark.ml.functions import vector_to_array

In [52]:
df_transformed.first()

Row(Airline='Endeavor Air Inc.', Origin='ABY', Dest='ATL', Cancelled=False, Diverted=False, DepDelay=-5.0, AirTime=38.0, CRSElapsedTime=62.0, ActualElapsedTime=59.0, Distance=145.0, Year=2018, Quarter=1, Month=1, DayofMonth=23, DayOfWeek=2, DOT_ID_Marketing_Airline=19790, Flight_Number_Marketing_Airline=3298, DOT_ID_Operating_Airline=20363, Flight_Number_Operating_Airline=3298, OriginAirportID=10146, OriginAirportSeqID=1014602, OriginCityMarketID=30146, OriginStateFips=13, OriginWac=34, DestAirportID=10397, DestAirportSeqID=1039707, DestCityMarketID=30397, DestStateFips=13, DestWac=34, TaxiOut=14.0, WheelsOff=1211.0, WheelsOn=1249.0, TaxiIn=7.0, DistanceGroup=1, DivAirportLandings=0.0, Airline_num=8.0, Origin_num=271.0, Dest_num=0.0, Airline_vec=SparseVector(27, {8: 1.0}), Origin_vec=SparseVector(387, {271: 1.0}), Dest_vec=SparseVector(387, {0: 1.0}))

In [53]:
df_transformed.select('Airline_num').filter(col('Airline_num') == 0).show(1)

+-----------+
|Airline_num|
+-----------+
|        0.0|
+-----------+
only showing top 1 row



In [54]:
df_col_onehot = df_transformed.select('*', vector_to_array('Airline_vec').alias('airline_one_hot'), vector_to_array('Origin_vec').alias('origin_one_hot'), vector_to_array('Dest_vec').alias('dest_one_hot'))
df_col_onehot.first()

Row(Airline='Endeavor Air Inc.', Origin='ABY', Dest='ATL', Cancelled=False, Diverted=False, DepDelay=-5.0, AirTime=38.0, CRSElapsedTime=62.0, ActualElapsedTime=59.0, Distance=145.0, Year=2018, Quarter=1, Month=1, DayofMonth=23, DayOfWeek=2, DOT_ID_Marketing_Airline=19790, Flight_Number_Marketing_Airline=3298, DOT_ID_Operating_Airline=20363, Flight_Number_Operating_Airline=3298, OriginAirportID=10146, OriginAirportSeqID=1014602, OriginCityMarketID=30146, OriginStateFips=13, OriginWac=34, DestAirportID=10397, DestAirportSeqID=1039707, DestCityMarketID=30397, DestStateFips=13, DestWac=34, TaxiOut=14.0, WheelsOff=1211.0, WheelsOn=1249.0, TaxiIn=7.0, DistanceGroup=1, DivAirportLandings=0.0, Airline_num=8.0, Origin_num=271.0, Dest_num=0.0, Airline_vec=SparseVector(27, {8: 1.0}), Origin_vec=SparseVector(387, {271: 1.0}), Dest_vec=SparseVector(387, {0: 1.0}), airline_one_hot=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

In [55]:
# vecs_to_encode_lens = [len(df_col_onehot.first()["airline_one_hot"]), len(df_col_onehot.first()["origin_one_hot"]), len(df_col_onehot.first()["dest_one_hot"])]
vecs_to_encode= ["airline_one_hot", "origin_one_hot", "dest_one_hot"]
cols_expanded = [(col(col_name)[i].alias(f'{col_name.split("_")[0] + preprocess.stages[0].labelsArray[idx][i].replace(" ", "").replace(".", "")}')) for idx, col_name in enumerate(vecs_to_encode) for i in range(len(df_col_onehot.first()[col_name]))]
df_transformed =  df_col_onehot.select('*', *cols_expanded)
df_transformed = df_transformed.drop(*vecs_to_encode + vector_out_names + cols_to_encode + indexed_out_names)

In [56]:
df_transformed.printSchema()

root
 |-- Cancelled: boolean (nullable = true)
 |-- Diverted: boolean (nullable = true)
 |-- DepDelay: double (nullable = true)
 |-- AirTime: double (nullable = true)
 |-- CRSElapsedTime: double (nullable = true)
 |-- ActualElapsedTime: double (nullable = true)
 |-- Distance: double (nullable = true)
 |-- Year: long (nullable = true)
 |-- Quarter: long (nullable = true)
 |-- Month: long (nullable = true)
 |-- DayofMonth: long (nullable = true)
 |-- DayOfWeek: long (nullable = true)
 |-- DOT_ID_Marketing_Airline: long (nullable = true)
 |-- Flight_Number_Marketing_Airline: long (nullable = true)
 |-- DOT_ID_Operating_Airline: long (nullable = true)
 |-- Flight_Number_Operating_Airline: long (nullable = true)
 |-- OriginAirportID: long (nullable = true)
 |-- OriginAirportSeqID: long (nullable = true)
 |-- OriginCityMarketID: long (nullable = true)
 |-- OriginStateFips: long (nullable = true)
 |-- OriginWac: long (nullable = true)
 |-- DestAirportID: long (nullable = true)
 |-- DestAirpor

In [57]:
count_dtypes(df_transformed)

[('double', 811), ('bigint', 20), ('boolean', 2)]

## Generate feature vector from columns

In [58]:
all_columns = list(set(df_transformed.columns).difference(idxer.getOutputCols()))
all_columns

['originIMT',
 'destISN',
 'originOGS',
 'destGCC',
 'DayOfWeek',
 'originLNY',
 'originPLN',
 'originCIU',
 'originGTF',
 'destTUS',
 'destAZO',
 'destPIB',
 'destBOS',
 'destPQI',
 'OriginStateFips',
 'destMOT',
 'destBZN',
 'originBHM',
 'destFAY',
 'AirTime',
 'destOAK',
 'DestStateFips',
 'originGSO',
 'destLGB',
 'destSTC',
 'originBJI',
 'originPHL',
 'destLCH',
 'destILG',
 'originDTW',
 'destPNS',
 'originFSM',
 'destACY',
 'originSGF',
 'destPUB',
 'originIND',
 'originBIL',
 'originART',
 'destSBA',
 'destGRB',
 'originJFK',
 'OriginAirportSeqID',
 'destORF',
 'originPBI',
 'destCDB',
 'originRIC',
 'originGNV',
 'destKTN',
 'originMSY',
 'originGEG',
 'originBMI',
 'destGFK',
 'originSUX',
 'originOAJ',
 'destJAC',
 'airlineEnvoyAir',
 'originCOS',
 'destYKM',
 'destLRD',
 'destIND',
 'originPSE',
 'originAMA',
 'originSCE',
 'destART',
 'originGRB',
 'originEWR',
 'originIAD',
 'destBJI',
 'originJST',
 'originALW',
 'destGRI',
 'destGJT',
 'originFLL',
 'originANC',
 'ori

In [59]:
all_columns = list(set(df_transformed.columns).difference(idxer.getOutputCols()))
vectorizer = VectorAssembler(inputCols=df_transformed.columns, outputCol='features').transform(df_transformed)
vectorizer.printSchema()

root
 |-- Cancelled: boolean (nullable = true)
 |-- Diverted: boolean (nullable = true)
 |-- DepDelay: double (nullable = true)
 |-- AirTime: double (nullable = true)
 |-- CRSElapsedTime: double (nullable = true)
 |-- ActualElapsedTime: double (nullable = true)
 |-- Distance: double (nullable = true)
 |-- Year: long (nullable = true)
 |-- Quarter: long (nullable = true)
 |-- Month: long (nullable = true)
 |-- DayofMonth: long (nullable = true)
 |-- DayOfWeek: long (nullable = true)
 |-- DOT_ID_Marketing_Airline: long (nullable = true)
 |-- Flight_Number_Marketing_Airline: long (nullable = true)
 |-- DOT_ID_Operating_Airline: long (nullable = true)
 |-- Flight_Number_Operating_Airline: long (nullable = true)
 |-- OriginAirportID: long (nullable = true)
 |-- OriginAirportSeqID: long (nullable = true)
 |-- OriginCityMarketID: long (nullable = true)
 |-- OriginStateFips: long (nullable = true)
 |-- OriginWac: long (nullable = true)
 |-- DestAirportID: long (nullable = true)
 |-- DestAirpor

In [60]:
# all_columns.append('features')
vectorizer.select('features', 'DepDelay').show(1)

+--------------------+--------+
|            features|DepDelay|
+--------------------+--------+
|(833,[2,3,4,5,6,7...|    -5.0|
+--------------------+--------+
only showing top 1 row



## Data preparation

In [61]:
train = int(data.count() * .10)
train

2919378

In [62]:
data = vectorizer.select('features', 'DepDelay')
data.first()

Row(features=SparseVector(833, {2: -5.0, 3: 38.0, 4: 62.0, 5: 59.0, 6: 145.0, 7: 2018.0, 8: 1.0, 9: 1.0, 10: 23.0, 11: 2.0, 12: 19790.0, 13: 3298.0, 14: 20363.0, 15: 3298.0, 16: 10146.0, 17: 1014602.0, 18: 30146.0, 19: 13.0, 20: 34.0, 21: 10397.0, 22: 1039707.0, 23: 30397.0, 24: 13.0, 25: 34.0, 26: 14.0, 27: 1211.0, 28: 1249.0, 29: 7.0, 30: 1.0, 40: 1.0, 330: 1.0, 446: 1.0}), DepDelay=-5.0)

## 

# Train Boosted GBTRegressor

In [63]:
from pyspark.ml.regression import GBTRegressor

In [64]:
(trainingData, testData) = data.randomSplit([0.7, 0.3])

In [65]:
trainingData.show(10)

+--------------------+--------+
|            features|DepDelay|
+--------------------+--------+
|(833,[2,3,4,5,6,7...|   -10.0|
|(833,[2,3,4,5,6,7...|    -8.0|
|(833,[2,3,4,5,6,7...|    -7.0|
|(833,[2,3,4,5,6,7...|    -7.0|
|(833,[2,3,4,5,6,7...|    -7.0|
|(833,[2,3,4,5,6,7...|    -7.0|
|(833,[2,3,4,5,6,7...|    -7.0|
|(833,[2,3,4,5,6,7...|    -7.0|
|(833,[2,3,4,5,6,7...|    -6.0|
|(833,[2,3,4,5,6,7...|    -6.0|
+--------------------+--------+
only showing top 10 rows



In [81]:
gbt = GBTRegressor(featuresCol="features",
                   labelCol="DepDelay",
                   maxIter=10)

In [82]:
model = gbt.fit(trainingData)

In [84]:
testData.first()

Row(features=SparseVector(833, {2: -8.0, 3: 204.0, 4: 210.0, 5: 230.0, 6: 1199.0, 7: 2018.0, 8: 4.0, 9: 12.0, 10: 4.0, 11: 2.0, 12: 19393.0, 13: 703.0, 14: 19393.0, 15: 703.0, 16: 10397.0, 17: 1039707.0, 18: 30397.0, 19: 13.0, 20: 34.0, 21: 11292.0, 22: 1129202.0, 23: 30325.0, 24: 8.0, 25: 82.0, 26: 20.0, 27: 1822.0, 28: 1946.0, 29: 6.0, 30: 5.0, 32: 1.0, 59: 1.0, 448: 1.0}), DepDelay=-8.0)

In [85]:
predictions = model.transform(testData)

In [88]:
predictions.select("prediction", "DepDelay", "features").show(5)

+-------------------+--------+--------------------+
|         prediction|DepDelay|            features|
+-------------------+--------+--------------------+
| -7.398869084747202|    -8.0|(833,[2,3,4,5,6,7...|
|-6.7632761422173635|    -7.0|(833,[2,3,4,5,6,7...|
|-6.7632761422173635|    -7.0|(833,[2,3,4,5,6,7...|
|-6.7632761422173635|    -7.0|(833,[2,3,4,5,6,7...|
|-6.7632761422173635|    -7.0|(833,[2,3,4,5,6,7...|
+-------------------+--------+--------------------+
only showing top 5 rows



In [89]:
from pyspark.ml.evaluation import RegressionEvaluator

In [90]:
evaluator = RegressionEvaluator(
    labelCol="DepDelay", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

Root Mean Squared Error (RMSE) on test data = 26.0804
