In [0]:
from pyspark.sql import SparkSession
from pyspark.context import SparkContext

In [0]:
spark = SparkSession.builder.appName('cas').getOrCreate()

In [0]:
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

In [0]:
df_train = spark.read.csv("/FileStore/tables/train.csv", header=True, inferSchema=True)
df_train.show()

+---+------------+------------+-------+------+---+---------------+---------+-------------------------+-------------------+--------------------+-----------------+-------------+------------------+---------+------+------------------------+-----+----------+-----+--------+---------+---------------+----+-----+----------------+-----+----------+-------------+---------------------+-----------+----------------+--------+-------------+-----------+----------+--------------------+-------------+--------+-----------+---------+---------+------------------+-----------+----------+-----------+-----------+---+----+--------------------+---------------+-------------+-----------+-----+----+-----+------------------+-----+------------+-------+----------+--------+-----------+
|_c0|           X|           Y|bicycle|bridge|bus|carStationWagon|cliffBank|crashDirectionDescription| crashFinancialYear|      crashLocation1|   crashLocation2|crashSeverity|crashSHDescription|crashYear|debris|directionRoleDescription|ditc

In [0]:
df_train = df_train.select("_c0","crashFinancialYear","speedLimit","crashSeverity","crashSHDescription","roadLane","claimAmount", "region")

In [0]:
df_test = spark.read.csv("/FileStore/tables/test.csv", header=True, inferSchema=True)
df_test.count()

Out[26]: 44839

In [0]:
df_test = df_test.select("_c0","crashFinancialYear","speedLimit","crashSeverity","crashSHDescription","roadLane","claimAmount", "region")

In [0]:
df_train.count()

Out[28]: 179808

In [0]:
bootstrapped_data = df_train.sample(withReplacement = True, fraction = 100.0)

In [0]:
bootstrapped_data.count()

Out[30]: 179821453

In [0]:
df_train.describe().show()

In [0]:
bootstrapped_data.describe().show()

In [0]:
df = df_train.toPandas()

In [0]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.metrics import mean_absolute_error

In [0]:
df1 = df_test.toPandas()

In [0]:
model = sm.MixedLM.from_formula("claimAmount ~ crashFinancialYear+speedLimit+crashSeverity+crashSHDescription+roadLane", groups=df["region"], data=df)
result = model.fit()

# The BLUPs
re = result.random_effects

# Multiply each BLUP by the random effects design matrix for one group
rex = [np.dot(model.exog_re_li[j], re[k]) for (j, k) in enumerate(model.group_labels)]

# Add the fixed and random terms to get the overall prediction
rex = np.concatenate(rex)
yp = result.fittedvalues + rex

pred = result.predict(exog=df1)
print("RMSE: ", np.sqrt(np.mean(np.square(pred - df1['claimAmount']))))
print("MAE: ", mean_absolute_error(pred, df1['claimAmount']))

RMSE:  3585454.3055692245
MAE:  2065862.8859656514


In [0]:
result.summary()

0,1,2,3
Model:,MixedLM,Dependent Variable:,claimAmount
No. Observations:,179808,Method:,REML
No. Groups:,17,Scale:,12581791591559.6621
Min. group size:,657,Log-Likelihood:,-2966894.5945
Max. group size:,52280,Converged:,Yes
Mean group size:,10576.9,,

0,1,2,3,4,5,6
,Coef.,Std.Err.,z,P>|z|,[0.025,0.975]
Intercept,679902.915,55870.438,12.169,0.000,570398.870,789406.961
crashFinancialYear,-29910.961,1368.129,-21.863,0.000,-32592.446,-27229.477
speedLimit,14561.306,423.598,34.375,0.000,13731.070,15391.543
crashSeverity,2546871.619,10249.961,248.476,0.000,2526782.064,2566961.173
crashSHDescription,167777.662,10704.222,15.674,0.000,146797.773,188757.550
roadLane,235524.565,24887.331,9.464,0.000,186746.292,284302.838
Group Var,18994584282.457,2183.417,,,,


In [0]:
bootstrapped_data.createOrReplaceTempView("data")

In [0]:
bootstrapped_data.count()

Out[39]: 179821453

In [0]:
spark_df = spark.sql("""
select *, _c0%10 as partition_id 
from (
  select *, row_number() over (order by rand()) as user_id
  from data
) 
""")

# preview the results
#spark_df.show()

In [0]:
spark_df.count()

Out[41]: 179821453

In [0]:
import time

In [0]:
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, LongType, DoubleType

# define a schema for the result set, the user ID and model prediction
schema = StructType([StructField('user_id', LongType(), True),
                     StructField('prediction', DoubleType(), True)])  

# define the Pandas UDF 
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def apply_model(sample_df):

    # run the model on the partitioned data set 
    ids = sample_df['user_id']
    x_train = sample_df.drop(['user_id', 'partition_id', 'claimAmount'], axis=1)
    pred = result.predict(exog=sample_df)

    return pd.DataFrame({'user_id': ids, 'prediction': pred[:]})
    #return pd.DataFrame({'prediction': pred[:]})
  
# partition the data and run the UDF  
results = spark_df.groupby('partition_id').apply(apply_model)

start = time.time()
results.show()
print('time: ',start-time.time())

+---------+--------------------+
|  user_id|          prediction|
+---------+--------------------+
|        1|   5220352.261952994|
|154956903|   4981064.571745215|
|        5|    4867068.43527849|
|154956914|  3977265.8482996793|
|        9|   4156731.615955514|
|154956916|   4156731.615955514|
|       12|1.0074807809205348E7|
|154956936|    4366108.34488732|
|       13|   4208670.054752502|
|154956945|    4687245.43516806|
|       18|   5100708.416849105|
|154956946|   2882857.372154735|
|       42|   5399818.029608828|
|154956956|   9041098.124483839|
|       56|   4372643.094337186|
|154956977|   4396019.306163292|
|       64|  1759414.8036053095|
|154956996|   9041098.124483839|
|       90|   9071009.085759811|
|154957008|  3917443.9257477345|
+---------+--------------------+
only showing top 20 rows

time:  -820.8498141765594


In [0]:
spark_df.count()

Out[44]: 179821453

In [0]:
sdf = spark_df.toPandas()

  An error occurred while calling o1849.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:428)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:107)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:103)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
	at py4j.Gateway.invoke(Gateway.java:295)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(Gatewa

[0;31m---------------------------------------------------------------------------[0m
[0;31mPy4JJavaError[0m                             Traceback (most recent call last)
[0;32m<command-1791933884705404>[0m in [0;36m<module>[0;34m[0m
[0;32m----> 1[0;31m [0msdf[0m [0;34m=[0m [0mspark_df[0m[0;34m.[0m[0mtoPandas[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
[0;32m/databricks/spark/python/pyspark/databricks/utils/instrumentation.py[0m in [0;36mwrapper[0;34m(self, *args, **kwargs)[0m
[1;32m     40[0m         [0;32mtry[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m     41[0m             [0mstart_time[0m [0;34m=[0m [0mtime[0m[0;34m.[0m[0mtime[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;32m---> 42[0;31m             [0mreturn_val[0m [0;34m=[0m [0mfunc[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m     43[0m         

In [0]:
sdf.shape

Out[66]: (17974190, 9)

In [0]:
result.predict(sdf)

Out[67]: 0           4.137119e+06
1           9.250475e+06
2           4.425930e+06
3           5.280174e+06
4           4.425930e+06
                ...     
17974185    4.092968e+06
17974186    5.194204e+06
17974187    9.041098e+06
17974188    4.854886e+06
17974189    4.705331e+06
Length: 17974190, dtype: float64