# Simple Linear Regression with MLFlow

In [1]:
# Prereqquisites
from pyspark.sql import SparkSession

In [2]:
# Spark Session and Context
spark = SparkSession.builder.master("local") \
        .appName("Simple MLFlow Linear Regression") \
        .config("spark.jars.packages", "org.mlflow:mlflow-spark:2.5.0") \
        .getOrCreate()
print("Spark Version: ", spark.version)

Spark Version:  3.4.1


In [4]:
import mlflow
import mlflow.spark

# Enable MLflow tracking
mlflow.set_tracking_uri("http://localhost:5000")  # Set your MLflow tracking server URI
mlflow.set_experiment("PySpark_MLflow_Experiment")

# Example PySpark workflow with MLflow
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import VectorAssembler

# Sample data
data = [(0, 1.0, 2.0, 3.0), (1, 2.0, 3.5, 4.0), (0, 3.0, 4.0, 5.0)]
columns = ["label", "feature1", "feature2", "feature3"]
df = spark.createDataFrame(data, columns)

# Assemble features
assembler = VectorAssembler(inputCols=["feature1", "feature2", "feature3"], outputCol="features")
df = assembler.transform(df)

# Train a logistic regression model
lr = LogisticRegression(featuresCol="features", labelCol="label")

# Start MLflow run
with mlflow.start_run():
    model = lr.fit(df)
    mlflow.spark.log_model(model, "logistic_regression_model")  # Log the PySpark MLlib model
    mlflow.log_param("maxIter", lr.getMaxIter())  # Log a parameter
    mlflow.log_metric("training_accuracy", 0.9)  # Log a metric

    print("Model logged successfully!")



Model logged successfully!
🏃 View run zealous-auk-814 at: http://localhost:5000/#/experiments/805412585363482067/runs/c864456f969c4992b4aa730227460f4b
🧪 View experiment at: http://localhost:5000/#/experiments/805412585363482067


----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 57642)
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/conda/lib/python3.11/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/opt/conda/lib/python3.11/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/conda/lib/python3.11/socketserver.py", line 755, in __init__
    self.handle()
  File "/usr/local/spark/python/pyspark/accumulators.py", line 281, in handle
    poll(accum_updates)
  File "/usr/local/spark/python/pyspark/accumulators.py", line 253, in poll
    if func():
       ^^^^^^
  File "/usr/local/spark/python/pyspark/accumulators.py", line 257, in accum_updates
    num_updates = read_int(self.rfile)
                  