In [1]:
import os
from typing import Tuple

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from lightautoml.dataset.roles import DatetimeRole

from lightautoml.spark.tasks.base import SparkTask
# from lightautoml.spark.utils import SparkDataFrame
from lightautoml.spark.automl.presets.tabular_presets import SparkTabularAutoML

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_spark_session():
    if os.environ.get("SCRIPT_ENV", None) == "cluster":
        spark_sess = SparkSession.builder.getOrCreate()
    else:
        spark_sess = (
            SparkSession
            .builder
            .master("local[*]")
            .config("spark.jars", "../../jars/spark-lightautoml_2.12-0.1.jar")
            .config("spark.jars.packages", "com.microsoft.azure:synapseml_2.12:0.9.5")
            .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
            .config("spark.kryoserializer.buffer.max", "512m")
            .config("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
            .config("spark.cleaner.referenceTracking", "true")
            .config("spark.cleaner.periodicGC.interval", "1min")
            .config("spark.sql.shuffle.partitions", "16")
            .config("spark.driver.memory", "55g")
            .config("spark.executor.memory", "55g")
            .config("spark.sql.execution.arrow.pyspark.enabled", "true")
            .getOrCreate()
        )

    spark_sess.sparkContext.setCheckpointDir("/tmp/spark_checkpoints")

    spark_sess.sparkContext.setLogLevel("OFF")

    return spark_sess

In [3]:
spark = get_spark_session()

https://mmlspark.azureedge.net/maven added as a remote repository with the name: repo-1
Ivy Default Cache set to: /home/user/.ivy2/cache
The jars for the packages stored in: /home/user/.ivy2/jars
com.microsoft.azure#synapseml_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-cf2489c4-6f52-4a6b-b5ca-ff04387988cc;1.0
	confs: [default]


:: loading settings :: url = jar:file:/home/user/projects/LightAutoML/.venv/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


	found com.microsoft.azure#synapseml_2.12;0.9.5 in central
	found com.microsoft.azure#synapseml-core_2.12;0.9.5 in central
	found org.scalactic#scalactic_2.12;3.0.5 in central
	found org.scala-lang#scala-reflect;2.12.4 in central
	found io.spray#spray-json_2.12;1.3.2 in central
	found com.jcraft#jsch;0.1.54 in user-list
	found org.apache.httpcomponents#httpclient;4.5.6 in user-list
	found org.apache.httpcomponents#httpcore;4.4.10 in user-list
	found commons-logging#commons-logging;1.2 in user-list
	found commons-codec#commons-codec;1.10 in user-list
	found org.apache.httpcomponents#httpmime;4.5.6 in user-list
	found com.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 in central
	found com.chuusai#shapeless_2.12;2.3.2 in user-list
	found org.typelevel#macro-compat_2.12;1.1.1 in user-list
	found org.apache.spark#spark-avro_2.12;3.2.0 in central
	found org.tukaani#xz;1.8 in local-m2-cache
	found org.spark-project.spark#unused;1.0.0 in user-list
	found org.testng#testng;6.8.8 i

22/05/24 12:43:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/05/24 12:43:03 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/05/24 12:43:03 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


# Dataset reading

In [4]:
data = spark.read.parquet("/opt/spark_data/parquet_dataset/train_low_mem.parquet")

In [7]:
data = data.cache()
data.write.mode('overwrite').format('noop').save()

                                                                                

In [6]:
data.count()

3141410

In [5]:
data.printSchema()

root
 |-- row_id: string (nullable = true)
 |-- time_id: integer (nullable = true)
 |-- investment_id: integer (nullable = true)
 |-- target: float (nullable = true)
 |-- f_0: float (nullable = true)
 |-- f_1: float (nullable = true)
 |-- f_2: float (nullable = true)
 |-- f_3: float (nullable = true)
 |-- f_4: float (nullable = true)
 |-- f_5: float (nullable = true)
 |-- f_6: float (nullable = true)
 |-- f_7: float (nullable = true)
 |-- f_8: float (nullable = true)
 |-- f_9: float (nullable = true)
 |-- f_10: float (nullable = true)
 |-- f_11: float (nullable = true)
 |-- f_12: float (nullable = true)
 |-- f_13: float (nullable = true)
 |-- f_14: float (nullable = true)
 |-- f_15: float (nullable = true)
 |-- f_16: float (nullable = true)
 |-- f_17: float (nullable = true)
 |-- f_18: float (nullable = true)
 |-- f_19: float (nullable = true)
 |-- f_20: float (nullable = true)
 |-- f_21: float (nullable = true)
 |-- f_22: float (nullable = true)
 |-- f_23: float (nullable = true)
 |--

# Divide into train and test parts

In [8]:
seed = 42
train_data, test_data = data.randomSplit([0.8, 0.2], seed)
train_data.write.mode('overwrite').format('noop').save()
test_data.write.mode('overwrite').format('noop').save()

data.unpersist()

                                                                                

DataFrame[row_id: string, time_id: int, investment_id: int, target: float, f_0: float, f_1: float, f_2: float, f_3: float, f_4: float, f_5: float, f_6: float, f_7: float, f_8: float, f_9: float, f_10: float, f_11: float, f_12: float, f_13: float, f_14: float, f_15: float, f_16: float, f_17: float, f_18: float, f_19: float, f_20: float, f_21: float, f_22: float, f_23: float, f_24: float, f_25: float, f_26: float, f_27: float, f_28: float, f_29: float, f_30: float, f_31: float, f_32: float, f_33: float, f_34: float, f_35: float, f_36: float, f_37: float, f_38: float, f_39: float, f_40: float, f_41: float, f_42: float, f_43: float, f_44: float, f_45: float, f_46: float, f_47: float, f_48: float, f_49: float, f_50: float, f_51: float, f_52: float, f_53: float, f_54: float, f_55: float, f_56: float, f_57: float, f_58: float, f_59: float, f_60: float, f_61: float, f_62: float, f_63: float, f_64: float, f_65: float, f_66: float, f_67: float, f_68: float, f_69: float, f_70: float, f_71: float,

In [11]:
train_data.select("target").head()

                                                                                

Row(target=-0.35412395000457764)

# AutoML params

In [12]:
roles = {
    "target": "target"
}
task = SparkTask("reg")
use_algos = [["lgb"]]
cv = 2

# Fitting and prediction

In [None]:
automl = SparkTabularAutoML(
    spark=spark,
    task=task,
    general_params={"use_algos": use_algos},
    lgb_params={'use_single_dataset_mode': True },
    reader_params={"cv": cv, "advanced_roles": False}
)

oof_predictions = automl.fit_predict(
    train_data,
    roles=roles
)



# Score calculation

In [None]:
score = task.get_dataset_metric()
metric_value = score(oof_predictions)

te_pred = automl.predict(test_data, add_reader_attrs=True)
score = task.get_dataset_metric()
test_metric_value = score(te_pred)

print(f"OOF score: {metric_value}")
print(f"TEST score: {test_metric_value}")