In [None]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://archive.apache.org/dist/spark/spark-3.0.0/spark-3.0.0-bin-hadoop3.2.tgz

In [None]:
!tar xf spark-3.0.0-bin-hadoop3.2.tgz
!pip install -q findspark==1.4.2 catboost==1.0.3 lightgbm

In [None]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.0.0-bin-hadoop3.2"
import findspark
findspark.init()

In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import DoubleType
import lightgbm as lgb
import pandas as pd
from scipy import special

In [None]:
spark = SparkSession.builder\
    .master('local[*]')\
    .appName('ExampleOfSparkSession')\
    .config("spark.executor.cores", "2")\
    .config("spark.task.cpus", "2")\
    .config("spark.driver.memory", "2g")\
    .config("spark.executor.memory", "2g")\
    .getOrCreate()

In [None]:
spark

In [None]:
clf = lgb.Booster(model_file="model.txt")

In [None]:
@F.pandas_udf(returnType='_c0 int, probability float', 
              functionType=F.PandasUDFType.GROUPED_MAP)
def predict_udf(df):
  #df['probability'] = (clf.predict(df.loc[:, features]))
  df['probability'] = special.expit(clf.predict(df.iloc[:, 1:]))
  return df.loc[:, ['_c0', 'probability']]

In [None]:
df = (
    spark.read.format("csv")
    .option("header", True)
    .option("inferSchema", True)
    .load(
        "train_processed.csv"
    )
)

In [None]:
result = df.groupBy("_c0").apply(predict_udf)

In [None]:
result.show()

+----+-----------+
| _c0|probability|
+----+-----------+
| 148| 0.21629286|
| 463| 0.44872618|
| 471| 0.82074386|
| 496|  0.8099793|
| 833|  0.9274396|
|1088|  0.7112847|
|1238|  0.9221655|
|1342|  0.8441214|
|1580|  0.6109994|
|1591| 0.33774775|
|1645| 0.32753477|
|1829| 0.41342494|
|1959|  0.4297743|
|2122|  0.7888445|
|2142|  0.5039513|
|2366| 0.41920605|
|2659| 0.78282666|
|2866|  0.5604404|
|3175| 0.24821128|
|3749|  0.6054854|
+----+-----------+
only showing top 20 rows



In [None]:
print(df.count(), len(df.columns))

300000 70
