Create dummy data with:
- `device_id`: 10 different devices
- `record_id`: 10k unique records
- `feature_1`: a feature for model training
- `feature_2`: a feature for model training
- `feature_3`: a feature for model training
- `label`: the variable we're trying to predict

In [6]:
import pyspark.sql.functions as f

df = (spark.range(100*100)
  .select(f.col("id").alias("record_id"), (f.col("id") % 10).alias("device_id"))
  .withColumn("feature_1", f.rand() * 1)
  .withColumn("feature_2", f.rand() * 2)
  .withColumn("feature_3", f.rand() * 3)
  .withColumn("label", (f.col("feature_1") + f.col("feature_2") + f.col("feature_3")) + f.rand())
)

df.show()

Accordion(children=(VBox(),), layout=Layout(display='none'), selected_index=None)

+---------+---------+--------------------+--------------------+-------------------+------------------+
|record_id|device_id|           feature_1|           feature_2|          feature_3|             label|
+---------+---------+--------------------+--------------------+-------------------+------------------+
|        0|        0|  0.3798100624574545|  1.4202490062082653| 1.6781901282345544| 3.525612733862513|
|        1|        1|  0.9107536883090278|   1.442838269067278|  2.611383485089366| 5.814245271889851|
|        2|        2| 0.41590416052434676| 0.10418460693028986|0.22604878358315217|1.6008841116808505|
|        3|        3|  0.5512047715185051|  1.8502137667712624| 2.7280610877991984| 5.679620353053467|
|        4|        4|  0.6497210433084482|  1.4206298832076816| 0.7105026181927976| 3.736893714115734|
|        5|        5|  0.9773319027898405| 0.12361450125920048| 1.3702032077472008|2.7353747472986503|
|        6|        6|  0.4380566234523088|  1.3427775095169203|  1.073934

object.__init__() takes exactly one argument (the instance to initialize)
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super(Widget, self).__init__(**kwargs)


Enable Apache Arrow

In [7]:
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

Accordion(children=(VBox(),), layout=Layout(display='none'), selected_index=None)

Define the return schema

In [8]:
import pyspark.sql.types as t

trainReturnSchema = t.StructType([
  t.StructField('device_id', t.IntegerType()),    # unique device ID
  t.StructField('n_used', t.IntegerType()),       # number of records used in training
  t.StructField('model_path', t.StringType()),    # path to the model for a given device
  t.StructField('mse', t.FloatType())             # metric for model performance
])

Accordion(children=(VBox(),), layout=Layout(display='none'), selected_index=None)

Define a pandas UDF that takes all the data for a given device, train a model, saves it as a nested run, and returns a spark object with the above schema

In [11]:
import mlflow
import mlflow.sklearn
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

@f.pandas_udf(trainReturnSchema, functionType=f.PandasUDFType.GROUPED_MAP)
def train_model(df_pandas):
    """
    Trains an sklearn model on grouped instances
    """
    # Pull metadata
    device_id = df_pandas['device_id'].iloc[0]
    n_used = df_pandas.shape[0]
    run_id = df_pandas['run_id'].iloc[0]
  
    # Train the model
    input_columns = ['feature_1', 'feature_2', 'feature_3']
    X = df_pandas[input_columns]
    y = df_pandas['label']
    rf = RandomForestRegressor()
    rf.fit(X, y)

    # Evaluate the model
    predictions = rf.predict(X)
    mse = mean_squared_error(y, predictions) # NOTE NO TRAIN/TEST SPLIT

    # Log the results as a nested run
    # Note that we need 2 with blocks due to the distributed nature of this action
    with mlflow.start_run(run_id=run_id):
        with mlflow.start_run(run_name=str(device_id), nested=True) as run:
            mlflow.sklearn.log_model(rf, str(device_id))
            mlflow.log_metric("mse", mse)

            artifact_uri = run.info.artifact_uri + "/" + str(device_id)
            # Create a return pandas DataFrame that matches the schema above
            returnDF = pd.DataFrame([[device_id, n_used, artifact_uri, mse]], 
            columns=["device_id", "n_used", "model_path", "mse"])

    return returnDF 


Accordion(children=(VBox(),), layout=Layout(display='none'), selected_index=None)

In [15]:
with mlflow.start_run(run_name="Parent Run") as run:
    run_id = run.info.run_uuid
    
    print(run_id)
    

Accordion(children=(VBox(),), layout=Layout(display='none'), selected_index=None)

4aa97073258b4726bb1381f3ce713a2a


Apply the pandas UDF to grouped data

In [17]:
import tempfile

mlflow.set_experiment(f"/Users/niall.turbitt@databricks/test_workspace_2")

# Create the parent run and add run_id to the workers so they can find the parent run
with mlflow.start_run(run_name="Parent Run") as run:

    run_id = run.info.run_uuid
  
  # This is doing most of the work
    modelDirectoriesDF = df.withColumn("run_id", f.lit(run_id)).groupby("device_id").apply(train_model)
  
  # Log modelDirectoriesDF to the parent run using a temporary file
    temp = tempfile.NamedTemporaryFile(prefix="modelDirectoriesDF-", suffix=".csv")
    temp_name = temp.name
    try:
        modelDirectoriesDF.toPandas().to_csv(temp_name, index=False)
        mlflow.log_artifact(temp_name, "modelDirectoriesDF.csv")
    finally:
        temp.close() # Delete the temp file
        
modelDirectoriesDF.show()

Accordion(children=(VBox(),), layout=Layout(display='none'), selected_index=None)

INFO: '/Users/niall.turbitt@databricks/test_workspace_2' does not exist. Creating a new experiment


object.__init__() takes exactly one argument (the instance to initialize)
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super(Widget, self).__init__(**kwargs)
  An error occurred while calling o1769.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:358)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:67)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:63)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
	a

Py4JJavaError: An error occurred while calling o1769.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:358)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:67)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:63)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
	at py4j.Gateway.invoke(Gateway.java:295)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:251)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 45 in stage 858.0 failed 4 times, most recent failure: Lost task 45.3 in stage 858.0 (TID 3758, 10.0.228.150, executor 3): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/worker.py", line 480, in main
    process()
  File "/databricks/spark/python/pyspark/worker.py", line 472, in process
    serializer.dump_stream(out_iter, outfile)
  File "/databricks/spark/python/pyspark/serializers.py", line 408, in dump_stream
    timely_flush_timeout_ms=self.timely_flush_timeout_ms)
  File "/databricks/spark/python/pyspark/serializers.py", line 215, in dump_stream
    for batch in iterator:
  File "/databricks/spark/python/pyspark/serializers.py", line 398, in init_stream_yield_batches
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/databricks/spark/python/pyspark/worker.py", line 136, in <lambda>
    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
  File "/databricks/spark/python/pyspark/worker.py", line 121, in wrapped
    result = f(pd.concat(value_series, axis=1))
  File "/databricks/spark/python/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-11-9662324b31f3>", line 30, in train_model
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/fluent.py", line 122, in start_run
    active_run_obj = MlflowClient().get_run(existing_run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/client.py", line 92, in get_run
    return self._tracking_client.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/_tracking_service/client.py", line 48, in get_run
    return self.store.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 92, in get_run
    response_proto = self._call_endpoint(GetRun, req_body)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 32, in _call_endpoint
    return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 137, in call_endpoint
    response = verify_rest_response(response, endpoint)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 103, in verify_rest_response
    raise RestException(json.loads(response.text))
mlflow.exceptions.RestException: RESOURCE_DOES_NOT_EXIST: Run 'd7cdceac941142b099cc3b44204c51de' not found.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:534)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:194)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:144)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:488)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.hasNext(ArrowConverters.scala:116)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at scala.collection.AbstractIterator.to(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1334)
	at org.apache.spark.sql.Dataset$$anonfun$collectAsArrowToPython$1$$anonfun$apply$20$$anonfun$apply$21.apply(Dataset.scala:3442)
	at org.apache.spark.sql.Dataset$$anonfun$collectAsArrowToPython$1$$anonfun$apply$20$$anonfun$apply$21.apply(Dataset.scala:3442)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:140)
	at org.apache.spark.scheduler.Task.run(Task.scala:113)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$13.apply(Executor.scala:533)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:539)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2362)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2350)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2349)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2349)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1102)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2581)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2529)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2517)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:897)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2280)
	at org.apache.spark.sql.Dataset$$anonfun$collectAsArrowToPython$1$$anonfun$apply$20.apply(Dataset.scala:3440)
	at org.apache.spark.sql.Dataset$$anonfun$collectAsArrowToPython$1$$anonfun$apply$20.apply(Dataset.scala:3409)
	at org.apache.spark.api.python.PythonRDD$$anonfun$6$$anonfun$apply$7.apply$mcV$sp(PythonRDD.scala:627)
	at org.apache.spark.api.python.PythonRDD$$anonfun$6$$anonfun$apply$7.apply(PythonRDD.scala:627)
	at org.apache.spark.api.python.PythonRDD$$anonfun$6$$anonfun$apply$7.apply(PythonRDD.scala:627)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541)
	at org.apache.spark.api.python.PythonRDD$$anonfun$6.apply(PythonRDD.scala:628)
	at org.apache.spark.api.python.PythonRDD$$anonfun$6.apply(PythonRDD.scala:624)
	at org.apache.spark.api.python.SocketFuncServer.handleConnection(PythonRDD.scala:1172)
	at org.apache.spark.api.python.SocketFuncServer.handleConnection(PythonRDD.scala:1166)
	at org.apache.spark.security.SocketAuthServer$$anonfun$1$$anonfun$apply$1.apply(SocketAuthServer.scala:48)
	at scala.util.Try$.apply(Try.scala:192)
	at org.apache.spark.security.SocketAuthServer$$anonfun$1.apply(SocketAuthServer.scala:48)
	at org.apache.spark.security.SocketAuthServer$$anonfun$1.apply(SocketAuthServer.scala:47)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:102)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/worker.py", line 480, in main
    process()
  File "/databricks/spark/python/pyspark/worker.py", line 472, in process
    serializer.dump_stream(out_iter, outfile)
  File "/databricks/spark/python/pyspark/serializers.py", line 408, in dump_stream
    timely_flush_timeout_ms=self.timely_flush_timeout_ms)
  File "/databricks/spark/python/pyspark/serializers.py", line 215, in dump_stream
    for batch in iterator:
  File "/databricks/spark/python/pyspark/serializers.py", line 398, in init_stream_yield_batches
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/databricks/spark/python/pyspark/worker.py", line 136, in <lambda>
    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
  File "/databricks/spark/python/pyspark/worker.py", line 121, in wrapped
    result = f(pd.concat(value_series, axis=1))
  File "/databricks/spark/python/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-11-9662324b31f3>", line 30, in train_model
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/fluent.py", line 122, in start_run
    active_run_obj = MlflowClient().get_run(existing_run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/client.py", line 92, in get_run
    return self._tracking_client.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/_tracking_service/client.py", line 48, in get_run
    return self.store.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 92, in get_run
    response_proto = self._call_endpoint(GetRun, req_body)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 32, in _call_endpoint
    return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 137, in call_endpoint
    response = verify_rest_response(response, endpoint)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 103, in verify_rest_response
    raise RestException(json.loads(response.text))
mlflow.exceptions.RestException: RESOURCE_DOES_NOT_EXIST: Run 'd7cdceac941142b099cc3b44204c51de' not found.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:534)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:194)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:144)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:488)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.hasNext(ArrowConverters.scala:116)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at scala.collection.AbstractIterator.to(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1334)
	at org.apache.spark.sql.Dataset$$anonfun$collectAsArrowToPython$1$$anonfun$apply$20$$anonfun$apply$21.apply(Dataset.scala:3442)
	at org.apache.spark.sql.Dataset$$anonfun$collectAsArrowToPython$1$$anonfun$apply$20$$anonfun$apply$21.apply(Dataset.scala:3442)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:140)
	at org.apache.spark.scheduler.Task.run(Task.scala:113)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$13.apply(Executor.scala:533)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:539)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)


Combine the orignal data to the new DataFrame so we can use `model_path`

In [13]:
combinedDF = (df
  .join(modelDirectoriesDF, on="device_id", how="left")
)

combinedDF.show()

Accordion(children=(VBox(),), layout=Layout(display='none'), selected_index=None)

Py4JJavaError: An error occurred while calling o1465.showString.
: org.apache.spark.SparkException: Exception thrown in Future.get: 
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:195)
	at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:391)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeBroadcast$1.apply(SparkPlan.scala:168)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeBroadcast$1.apply(SparkPlan.scala:156)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$5.apply(SparkPlan.scala:188)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:184)
	at org.apache.spark.sql.execution.SparkPlan.executeBroadcast(SparkPlan.scala:156)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.prepareBroadcast(BroadcastHashJoinExec.scala:134)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.doProduce(BroadcastHashJoinExec.scala:111)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:94)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:89)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$5.apply(SparkPlan.scala:188)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:184)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:89)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.produce(BroadcastHashJoinExec.scala:39)
	at org.apache.spark.sql.execution.ProjectExec.doProduce(basicPhysicalOperators.scala:50)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:94)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:89)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$5.apply(SparkPlan.scala:188)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:184)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:89)
	at org.apache.spark.sql.execution.ProjectExec.produce(basicPhysicalOperators.scala:40)
	at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:548)
	at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:602)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:147)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:135)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$5.apply(SparkPlan.scala:188)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:184)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:135)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:77)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:86)
	at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:508)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollectResult(limit.scala:57)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectResult(Dataset.scala:2890)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3508)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2619)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2619)
	at org.apache.spark.sql.Dataset$$anonfun$54.apply(Dataset.scala:3492)
	at org.apache.spark.sql.Dataset$$anonfun$54.apply(Dataset.scala:3487)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withCustomExecutionEnv$1.apply(SQLExecution.scala:112)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:241)
	at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:98)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:171)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withAction(Dataset.scala:3487)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2619)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2833)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:266)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:303)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
	at py4j.Gateway.invoke(Gateway.java:295)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:251)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.util.concurrent.ExecutionException: org.apache.spark.SparkException: Job aborted due to stage failure: Task 49 in stage 840.0 failed 4 times, most recent failure: Lost task 49.3 in stage 840.0 (TID 3189, 10.0.228.150, executor 3): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/worker.py", line 480, in main
    process()
  File "/databricks/spark/python/pyspark/worker.py", line 472, in process
    serializer.dump_stream(out_iter, outfile)
  File "/databricks/spark/python/pyspark/serializers.py", line 408, in dump_stream
    timely_flush_timeout_ms=self.timely_flush_timeout_ms)
  File "/databricks/spark/python/pyspark/serializers.py", line 215, in dump_stream
    for batch in iterator:
  File "/databricks/spark/python/pyspark/serializers.py", line 398, in init_stream_yield_batches
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/databricks/spark/python/pyspark/worker.py", line 136, in <lambda>
    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
  File "/databricks/spark/python/pyspark/worker.py", line 121, in wrapped
    result = f(pd.concat(value_series, axis=1))
  File "/databricks/spark/python/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-11-9662324b31f3>", line 30, in train_model
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/fluent.py", line 122, in start_run
    active_run_obj = MlflowClient().get_run(existing_run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/client.py", line 92, in get_run
    return self._tracking_client.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/_tracking_service/client.py", line 48, in get_run
    return self.store.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 92, in get_run
    response_proto = self._call_endpoint(GetRun, req_body)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 32, in _call_endpoint
    return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 137, in call_endpoint
    response = verify_rest_response(response, endpoint)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 103, in verify_rest_response
    raise RestException(json.loads(response.text))
mlflow.exceptions.RestException: RESOURCE_DOES_NOT_EXIST: Run '6c9fd0ce3c6744339db10a64cfd5487c' not found.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:534)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:194)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:144)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:488)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:640)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:62)
	at org.apache.spark.sql.execution.collect.Collector$$anonfun$1.apply(Collector.scala:151)
	at org.apache.spark.sql.execution.collect.Collector$$anonfun$1.apply(Collector.scala:150)
	at org.apache.spark.SparkContext$$anonfun$41.apply(SparkContext.scala:2377)
	at org.apache.spark.SparkContext$$anonfun$41.apply(SparkContext.scala:2377)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:140)
	at org.apache.spark.scheduler.Task.run(Task.scala:113)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$13.apply(Executor.scala:533)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:539)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at java.util.concurrent.FutureTask.report(FutureTask.java:122)
	at java.util.concurrent.FutureTask.get(FutureTask.java:206)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:182)
	... 62 more
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 49 in stage 840.0 failed 4 times, most recent failure: Lost task 49.3 in stage 840.0 (TID 3189, 10.0.228.150, executor 3): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/worker.py", line 480, in main
    process()
  File "/databricks/spark/python/pyspark/worker.py", line 472, in process
    serializer.dump_stream(out_iter, outfile)
  File "/databricks/spark/python/pyspark/serializers.py", line 408, in dump_stream
    timely_flush_timeout_ms=self.timely_flush_timeout_ms)
  File "/databricks/spark/python/pyspark/serializers.py", line 215, in dump_stream
    for batch in iterator:
  File "/databricks/spark/python/pyspark/serializers.py", line 398, in init_stream_yield_batches
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/databricks/spark/python/pyspark/worker.py", line 136, in <lambda>
    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
  File "/databricks/spark/python/pyspark/worker.py", line 121, in wrapped
    result = f(pd.concat(value_series, axis=1))
  File "/databricks/spark/python/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-11-9662324b31f3>", line 30, in train_model
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/fluent.py", line 122, in start_run
    active_run_obj = MlflowClient().get_run(existing_run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/client.py", line 92, in get_run
    return self._tracking_client.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/_tracking_service/client.py", line 48, in get_run
    return self.store.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 92, in get_run
    response_proto = self._call_endpoint(GetRun, req_body)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 32, in _call_endpoint
    return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 137, in call_endpoint
    response = verify_rest_response(response, endpoint)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 103, in verify_rest_response
    raise RestException(json.loads(response.text))
mlflow.exceptions.RestException: RESOURCE_DOES_NOT_EXIST: Run '6c9fd0ce3c6744339db10a64cfd5487c' not found.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:534)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:194)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:144)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:488)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:640)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:62)
	at org.apache.spark.sql.execution.collect.Collector$$anonfun$1.apply(Collector.scala:151)
	at org.apache.spark.sql.execution.collect.Collector$$anonfun$1.apply(Collector.scala:150)
	at org.apache.spark.SparkContext$$anonfun$41.apply(SparkContext.scala:2377)
	at org.apache.spark.SparkContext$$anonfun$41.apply(SparkContext.scala:2377)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:140)
	at org.apache.spark.scheduler.Task.run(Task.scala:113)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$13.apply(Executor.scala:533)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:539)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2362)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2350)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2349)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2349)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:1102)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1102)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2581)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2529)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2517)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:897)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2280)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2378)
	at org.apache.spark.sql.execution.collect.Collector.runSparkJobs(Collector.scala:245)
	at org.apache.spark.sql.execution.collect.Collector.collect(Collector.scala:280)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:80)
	at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:86)
	at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:508)
	at org.apache.spark.sql.execution.ResultCacheManager.getOrComputeResult(ResultCacheManager.scala:480)
	at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:325)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:90)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withExecutionId$1.apply(SQLExecution.scala:196)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:241)
	at org.apache.spark.sql.execution.SQLExecution$.withExecutionId(SQLExecution.scala:193)
	at org.apache.spark.sql.execution.SQLExecution$.dbrWithExecutionId(SQLExecution.scala:216)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:77)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:73)
	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable$$anonfun$run$1.apply$mcV$sp(SparkThreadLocalForwardingThreadPoolExecutor.scala:100)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable$$anonfun$run$1.apply(SparkThreadLocalForwardingThreadPoolExecutor.scala:100)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable$$anonfun$run$1.apply(SparkThreadLocalForwardingThreadPoolExecutor.scala:100)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingHelper$class.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:68)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:97)
	at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.run(SparkThreadLocalForwardingThreadPoolExecutor.scala:100)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/worker.py", line 480, in main
    process()
  File "/databricks/spark/python/pyspark/worker.py", line 472, in process
    serializer.dump_stream(out_iter, outfile)
  File "/databricks/spark/python/pyspark/serializers.py", line 408, in dump_stream
    timely_flush_timeout_ms=self.timely_flush_timeout_ms)
  File "/databricks/spark/python/pyspark/serializers.py", line 215, in dump_stream
    for batch in iterator:
  File "/databricks/spark/python/pyspark/serializers.py", line 398, in init_stream_yield_batches
    for series in iterator:
  File "<string>", line 1, in <lambda>
  File "/databricks/spark/python/pyspark/worker.py", line 136, in <lambda>
    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
  File "/databricks/spark/python/pyspark/worker.py", line 121, in wrapped
    result = f(pd.concat(value_series, axis=1))
  File "/databricks/spark/python/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-11-9662324b31f3>", line 30, in train_model
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/fluent.py", line 122, in start_run
    active_run_obj = MlflowClient().get_run(existing_run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/client.py", line 92, in get_run
    return self._tracking_client.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/tracking/_tracking_service/client.py", line 48, in get_run
    return self.store.get_run(run_id)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 92, in get_run
    response_proto = self._call_endpoint(GetRun, req_body)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/store/tracking/rest_store.py", line 32, in _call_endpoint
    return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 137, in call_endpoint
    response = verify_rest_response(response, endpoint)
  File "/databricks/python/lib/python3.7/site-packages/mlflow/utils/rest_utils.py", line 103, in verify_rest_response
    raise RestException(json.loads(response.text))
mlflow.exceptions.RestException: RESOURCE_DOES_NOT_EXIST: Run '6c9fd0ce3c6744339db10a64cfd5487c' not found.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:534)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:194)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:144)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:488)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:640)
	at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:62)
	at org.apache.spark.sql.execution.collect.Collector$$anonfun$1.apply(Collector.scala:151)
	at org.apache.spark.sql.execution.collect.Collector$$anonfun$1.apply(Collector.scala:150)
	at org.apache.spark.SparkContext$$anonfun$41.apply(SparkContext.scala:2377)
	at org.apache.spark.SparkContext$$anonfun$41.apply(SparkContext.scala:2377)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.doRunTask(Task.scala:140)
	at org.apache.spark.scheduler.Task.run(Task.scala:113)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$13.apply(Executor.scala:533)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1541)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:539)
	... 3 more


Define a pandas UDF to apply the model.  **Note this only needs 1 read from DBFS per device**

In [14]:
applyReturnSchema = t.StructType([
  t.StructField('record_id', t.IntegerType()),
  t.StructField('device_id', t.LongType()),
  t.StructField('prediction', t.FloatType())
])

@f.pandas_udf(applyReturnSchema, functionType=f.PandasUDFType.GROUPED_MAP)
def apply_model(df_pandas):
  '''
  Applies model
  '''
  device_id = df_pandas['device_id'].iloc[0]
  model_path = df_pandas['model_path'].iloc[0]
  
  input_columns = ['feature_1', 'feature_2', 'feature_3']
  X = df_pandas[input_columns]
  
  model = mlflow.sklearn.load_model(model_path)
  prediction = model.predict(X)
  
  returnDF = pd.DataFrame({
    "record_id": df_pandas['record_id'],
    "prediction": prediction
  })
  returnDF["device_id"] = device_id

  return returnDF

predictionDF = combinedDF.groupby("device_id").apply(apply_model)
display(predictionDF)

record_id,device_id,prediction
2230,0,2.7645466
2240,0,3.4732797
2250,0,1.1655287
2260,0,1.8476894
2270,0,4.6743846
2280,0,2.6057968
2290,0,3.1509695
2300,0,3.2745278
2310,0,3.583009
2320,0,2.7822325


You made it to the end!  Here's a picture of a unicorn:

![unicorn](https://www.jing.fm/clipimg/detail/37-375094_galaxy-unicorn-png.png)