In [1]:
import os
from typing import Tuple

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from lightautoml.dataset.roles import DatetimeRole

from lightautoml.spark.tasks.base import SparkTask
from lightautoml.spark.utils import SparkDataFrame
from lightautoml.spark.automl.presets.tabular_presets import SparkTabularAutoML

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_spark_session():
    if os.environ.get("SCRIPT_ENV", None) == "cluster":
        spark_sess = SparkSession.builder.getOrCreate()
    else:
        spark_sess = (
            SparkSession
            .builder
            .master("local[*]")
            .config("spark.jars", "../../jars/spark-lightautoml_2.12-0.1.jar")
            .config("spark.jars.packages", "com.microsoft.azure:synapseml_2.12:0.9.5")
            .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
            .config("spark.kryoserializer.buffer.max", "512m")
            .config("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
            .config("spark.cleaner.referenceTracking", "true")
            .config("spark.cleaner.periodicGC.interval", "1min")
            .config("spark.sql.shuffle.partitions", "16")
            .config("spark.driver.memory", "55g")
            .config("spark.executor.memory", "55g")
            .config("spark.sql.execution.arrow.pyspark.enabled", "true")
            .getOrCreate()
        )

    spark_sess.sparkContext.setCheckpointDir("/tmp/spark_checkpoints")

    spark_sess.sparkContext.setLogLevel("OFF")

    return spark_sess

In [3]:
spark = get_spark_session()

https://mmlspark.azureedge.net/maven added as a remote repository with the name: repo-1


:: loading settings :: url = jar:file:/home/user/projects/LightAutoML/.venv/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/user/.ivy2/cache
The jars for the packages stored in: /home/user/.ivy2/jars
com.microsoft.azure#synapseml_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-2255a58a-d00b-4a02-8998-6a0f112d8434;1.0
	confs: [default]
	found com.microsoft.azure#synapseml_2.12;0.9.5 in central
	found com.microsoft.azure#synapseml-core_2.12;0.9.5 in central
	found org.scalactic#scalactic_2.12;3.0.5 in central
	found org.scala-lang#scala-reflect;2.12.4 in central
	found io.spray#spray-json_2.12;1.3.2 in central
	found com.jcraft#jsch;0.1.54 in user-list
	found org.apache.httpcomponents#httpclient;4.5.6 in user-list
	found org.apache.httpcomponents#httpcore;4.4.10 in user-list
	found commons-logging#commons-logging;1.2 in user-list
	found commons-codec#commons-codec;1.10 in user-list
	found org.apache.httpcomponents#httpmime;4.5.6 in user-list
	found com.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 in central
	found com.

22/05/24 11:26:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


# Dataset reading

In [4]:
data = spark.read.csv("file:///opt/spark_data/expo/1990.csv", header=True, escape="\"")

In [5]:
data.printSchema()

root
 |-- Year: string (nullable = true)
 |-- Month: string (nullable = true)
 |-- DayofMonth: string (nullable = true)
 |-- DayOfWeek: string (nullable = true)
 |-- DepTime: string (nullable = true)
 |-- CRSDepTime: string (nullable = true)
 |-- ArrTime: string (nullable = true)
 |-- CRSArrTime: string (nullable = true)
 |-- UniqueCarrier: string (nullable = true)
 |-- FlightNum: string (nullable = true)
 |-- TailNum: string (nullable = true)
 |-- ActualElapsedTime: string (nullable = true)
 |-- CRSElapsedTime: string (nullable = true)
 |-- AirTime: string (nullable = true)
 |-- ArrDelay: string (nullable = true)
 |-- DepDelay: string (nullable = true)
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- Distance: string (nullable = true)
 |-- TaxiIn: string (nullable = true)
 |-- TaxiOut: string (nullable = true)
 |-- Cancelled: string (nullable = true)
 |-- CancellationCode: string (nullable = true)
 |-- Diverted: string (nullable = true)
 |-- CarrierDelay:


* Year [int]: Year of the dataset (1999 & 2000)
* Month [int]: Month of the observation (1 - Jan, 2 - Feb, etc.)
* DayofMonth [int]: Day of the month (1 - 31, if applicable)
* DayOfWeek [int]: Day of the week (1 - Mon, 2 - Tue, etc.)
* DepTime [int]: Actual departure time (local time zone %H%M format)
* CRSDepTime [int]: Scheduled departure time (local time zone %H%M format)
* ArrTime [int]: Actual arrival time (local time zone %H%M format)
* CRSArrTime [int]: Scheduled arrival time (local time zone %H%M format)
* UniqueCarrier [int]: Unique carrier code to identify the carriers in carriers.csv
* FlightNum [int]: Flight number
* TailNum [str]: Unique tail number to identify the planes in plane-data.csv
* ActualElapsedTime [int]: Difference between ArrTime and DepTime in minutes, also sum of AirTime, TaxiIn, TaxiOut
* CRSElapsedTime [int]: Difference between CRSArrTime and CRSDepTime in minutes
* AirTime [int]: Air time in minutes
* ArrDelay [int]: Difference between ArrTime and CRSArrTime in minutes
* DepDelay [int]: Difference between DepTime and CRSDepTime in minutes
* Origin [str]: Unique IATA airport code that flight was departed from, can be identified in airports.csv
* Dest [str]: Unique IATA airport code for flight destination, can be identified in airports.csv
* Distance [int]: Flight distance in miles
* TaxiIn [int]: Taxi-in time in minutes
* TaxiOut [int]: Taxi-out time in minutes
* Cancelled [int]: Flight cancellation (1 - Cancelled, 0 - Not Cancelled)
* CancellationCode [str]: Flight cancellation reason (A - Carrier, B - Weather, C - National Aviation System, D - Security)
* Diverted [int]: Fight diverted (1 - Diverted, 0 - Not diverted)
* CarrierDelay [int]: Delay caused by carrier in minutes
* WeatherDelay [int]: Delay caused by weather in minutes
* NASDelay [int]: Delay caused by National Aviation System in minutes
* LateAircraftDelay [int]: Delay caused by previous late flight arrivals in minutes


# Target calculation

In [6]:
total_delay = F.when(F.col("ActualElapsedTime")-F.col("CRSElapsedTime") > 0, 1) \
                .otherwise(0).alias("total_delay")
    
data = data.select('*', total_delay)

# Features and target

In [7]:
# data = data.select([c for c in data.columns if c != "ActualElapsedTime"])

data = data.select(['Year', 'Month', 'DayOfWeek', 
                    'CRSDepTime', 'CRSArrTime', 'CRSElapsedTime', 
                    'Distance', 'total_delay'])


In [8]:
data = data.cache()
data.write.mode('overwrite').format('noop').save()

                                                                                

In [17]:
data.count()

5270893

In [9]:
data.columns

['Year',
 'Month',
 'DayOfWeek',
 'CRSDepTime',
 'CRSArrTime',
 'CRSElapsedTime',
 'Distance',
 'total_delay']

# Divide into train and test parts

In [10]:
seed = 42
train_data, test_data = data.randomSplit([0.8, 0.2], seed)
train_data.write.mode('overwrite').format('noop').save()
test_data.write.mode('overwrite').format('noop').save()

data.unpersist()

                                                                                

DataFrame[Year: string, Month: string, DayOfWeek: string, CRSDepTime: string, CRSArrTime: string, CRSElapsedTime: string, Distance: string, total_delay: int]

# AutoML params

In [11]:
roles = {
    "target": "total_delay"
}
task = SparkTask("binary")
use_algos = [["lgb"]]
cv = 2

# Fitting and prediction

In [12]:
automl = SparkTabularAutoML(
    spark=spark,
    task=task,
    general_params={"use_algos": use_algos},
    lgb_params={'use_single_dataset_mode': True },
    reader_params={"cv": cv, "advanced_roles": False}
)

oof_predictions = automl.fit_predict(
    train_data,
    roles=roles
)



[LightGBM] [Info] Number of positive: 868841, number of negative: 1238372
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 1705
[LightGBM] [Info] Number of data points in the train set: 2107213, number of used features: 6
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0,412318 -> initscore=-0,354393
[LightGBM] [Info] Start training from score -0,354393




[LightGBM] [Info] Number of positive: 868841, number of negative: 1238372
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 1039
[LightGBM] [Info] Number of data points in the train set: 2107213, number of used features: 6
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0,412318 -> initscore=-0,354393
[LightGBM] [Info] Start training from score -0,354393


                                                                                

# Score calculation

In [13]:
score = task.get_dataset_metric()
metric_value = score(oof_predictions)

te_pred = automl.predict(test_data, add_reader_attrs=True)
score = task.get_dataset_metric()
test_metric_value = score(te_pred)

print(f"OOF score: {metric_value}")
print(f"TEST score: {test_metric_value}")

                                                                                

OOF score: 0.648197628359205
TEST score: 0.6606222253443365
