# Databricks Distributed Grid Search - 2,800 Hyperparameter Combinations Tuning

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator
import time, json, os, hashlib
from datetime import datetime
from typing import Dict, List, Tuple
import mlflow
from mlflow.tracking import MlflowClient

class GridSearchCheckpoint:
    def __init__(self, experiment_name: str):
        self.client = MlflowClient()
        
        # Check MLflow configuration
        tracking_uri = mlflow.get_tracking_uri()
        print(f"MLflow tracking URI: {tracking_uri}")
        
        # Safely create or get experiment
        try:
            # set_experiment creates experiment if not exists and returns experiment
            mlflow.set_experiment(experiment_name)
            self.experiment = mlflow.get_experiment_by_name(experiment_name)
            
            # Verify experiment was created/found
            if self.experiment is None:
                raise ValueError(f"Failed to create/find experiment: {experiment_name}")
                
            print(f"MLflow experiment ready: {experiment_name} (ID: {self.experiment.experiment_id})")
                
        except Exception as e:
            print(f"MLflow experiment setup failed: {str(e)}")
            print("Attempting to create experiment directly...")
            
            try:
                # Try creating experiment directly with client
                experiment_id = self.client.create_experiment(experiment_name)
                self.experiment = self.client.get_experiment(experiment_id)
            except Exception as create_error:
                print(f"Direct experiment creation failed: {str(create_error)}")
                # Try to get existing experiment by searching
                experiments = self.client.search_experiments()
                matching_exp = next((exp for exp in experiments if exp.name == experiment_name), None)
                
                if matching_exp:
                    self.experiment = matching_exp
                    mlflow.set_experiment(experiment_name)
                else:
                    raise SystemExit(f"Cannot initialize MLflow experiment '{experiment_name}'. Please check MLflow configuration.")
        
        self._cache, self._cache_time = {}, 0
    
    def get_hash(self, params: Dict) -> str:
        return hashlib.md5(json.dumps(params, sort_keys=True).encode()).hexdigest()[:8]
    
    def get_completed(self) -> Dict[str, float]:
        if time.time() - self._cache_time < 300: return self._cache
        runs = self.client.search_runs(
            experiment_ids=[self.experiment.experiment_id],
            filter_string="status = 'FINISHED' and metrics.r2_score > 0", max_results=5000)
        self._cache = {run.data.tags.get('param_hash', ''): run.data.metrics.get('r2_score', 0.0) 
                      for run in runs if 'param_hash' in run.data.tags}
        self._cache_time = time.time()
        return self._cache
    
    def log_result(self, params: Dict, r2_score: float, exec_time: float):
        with mlflow.start_run(run_name=f"gs_{self.get_hash(params)}"):
            mlflow.log_params(params)
            mlflow.log_metrics({"r2_score": r2_score, "execution_time": exec_time})
            mlflow.set_tags({"param_hash": self.get_hash(params), "job": "grid_search"})

In [0]:
# MLflow Debug Function
def check_mlflow_status():
    try:
        import mlflow
        print(f"MLflow version: {mlflow.__version__}")
        print(f"Tracking URI: {mlflow.get_tracking_uri()}")
        
        # Test basic MLflow operations
        client = mlflow.tracking.MlflowClient()
        experiments = client.search_experiments(max_results=3)
        print(f"Accessible experiments: {len(experiments)}")
        for exp in experiments[:3]:
            print(f"  - {exp.name} (ID: {exp.experiment_id})")
        
        return True
    except Exception as e:
        print(f"MLflow check failed: {str(e)}")
        return False

# Uncomment to run MLflow diagnostic
# check_mlflow_status()

In [0]:
spark = SparkSession.builder \
    .appName("Distributed_RandomForest_GridSearch") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .getOrCreate()

# Create checkpoint only if not exists
# Replace the path with your own path
if 'checkpoint' not in globals() or checkpoint is None:
    checkpoint = GridSearchCheckpoint(f"/Users/user/RF_GridSearch_{datetime.now().strftime('%Y%m%d')}")
    print("New checkpoint created")
else:
    print("Using existing checkpoint (cache preserved)")

optimal_parallelism = min(max(1, spark.sparkContext.defaultParallelism // 4), 5)
print(f"System initialization complete (parallelism: {optimal_parallelism})")

# Check table existence
table_name = "df_selected_05"
try:
    # Check if table exists in catalog
    available_tables = [table.name for table in spark.catalog.listTables()]
    if table_name not in available_tables:
        raise ValueError(f"Table '{table_name}' does not exist")
    
    # Load data
    df = spark.table(table_name).cache()
    row_count = df.count()
    col_count = len(df.columns)
    print(f"Data loaded successfully: {row_count:,} rows, {col_count} columns")
    
    # Check required columns
    if "SalePrice" not in df.columns:
        raise ValueError("Required column 'SalePrice' not found in table")
    
except Exception as e:
    print(f"Data loading failed: {str(e)}")
    print(f"Available tables: {available_tables[:5]}" if 'available_tables' in locals() else "Failed to retrieve table list")
    print("Solutions:")
    print(f"1. Verify table '{table_name}' is uploaded to Databricks")
    print("2. Check table name spelling")
    print("3. Verify table access permissions")
    raise SystemExit("Cannot proceed with grid search without actual data")

target_col = "SalePrice"
df_vectorized = VectorAssembler(inputCols=[c for c in df.columns if c not in [target_col, "id"]], 
                               outputCol="features", handleInvalid="skip").transform(df).cache()
print("Feature vectorization complete")

In [0]:
# Set Spark checkpoint directory before using .checkpoint()
spark.sparkContext.setCheckpointDir("/dbfs/tmp/spark_checkpoints")

param_grid = ParamGridBuilder() \
    .addGrid(RandomForestRegressor.numTrees, [500, 600, 646, 650, 700, 750, 800]) \
    .addGrid(RandomForestRegressor.maxDepth, [25, 35, 50, 100]) \
    .addGrid(RandomForestRegressor.minInstancesPerNode, [2, 3, 4, 5, 6]) \
    .addGrid(RandomForestRegressor.subsamplingRate, [0.65, 0.69, 0.7, 0.8, 0.9]) \
    .addGrid(RandomForestRegressor.featureSubsetStrategy, ["0.3", "0.5", "0.7", "sqrt"]) \
    .build()

completed = checkpoint.get_completed()
remaining_params = [p for p in param_grid 
                   if checkpoint.get_hash({param.name: value for param, value in p.items()}) not in completed]

print(f"Total {len(param_grid):,} combinations")
if len(completed) > 0:
    print(f"Restart: {len(param_grid) - len(remaining_params)} completed, {len(remaining_params)} remaining")

def run_grid_search(param_maps: List, data, target: str) -> Tuple[Dict, float]:
    print(f"Grid search started: {len(param_maps)} combinations")
    best_score, best_params = -1.0, {}
    evaluator = RegressionEvaluator(labelCol=target, predictionCol="prediction", metricName="r2")
    
    for i, param_map in enumerate(param_maps, 1):
        start_time = time.time()
        try:
            params_dict = {param.name: value for param, value in param_map.items()}
            rf = RandomForestRegressor(featuresCol="features", labelCol=target, seed=42)
            for param, value in param_map.items(): setattr(rf, param.name, value)
            
            cv = CrossValidator(estimator=rf, estimatorParamMaps=[param_map], evaluator=evaluator,
                              numFolds=5, seed=42, parallelism=optimal_parallelism, collectSubModels=False)
            
            r2_score = cv.fit(data).avgMetrics[0]
            checkpoint.log_result(params_dict, r2_score, time.time() - start_time)
            
            if r2_score > best_score: best_score, best_params = r2_score, params_dict.copy()
            
            if i % 100 == 0:
                elapsed = time.time() - start_time
                print(f"{i}/{len(param_maps)} ({i/len(param_maps)*100:.1f}%) | Remaining time: {(elapsed * (len(param_maps) - i)) / 3600:.1f}h")
        except Exception as e:
            print(f"Combination {i} failed: {type(e).__name__}")
    
    print(f"Best performance: {best_score:.6f}")
    return best_params, best_score

def emergency_cleanup():
    """Emergency resource cleanup on interruption"""
    try:
        print("🧹 Clearing Spark caches...")
        spark.catalog.clearCache()
        print("✅ Spark caches cleared")
    except Exception as e:
        print(f"⚠️ Cache cleanup failed: {e}")
    
    try:
        # Save current progress to MLflow
        if 'checkpoint' in globals() and hasattr(checkpoint, '_cache') and checkpoint._cache:
            completed_count = len(checkpoint._cache)
            print(f"💾 Progress saved: {completed_count} combinations completed")
    except Exception as e:
        print(f"⚠️ Progress save failed: {e}")
    
    print("🔄 Resources cleaned up - safe to restart")

EXECUTE = False

if EXECUTE:
    start_time = datetime.now()
    
    try:
        print("🚀 Starting grid search with safety mechanisms...")
        df_vectorized.checkpoint()
        best_params, best_score = run_grid_search(remaining_params, df_vectorized, target_col)
        execution_time = datetime.now() - start_time
        print(f"Complete! Execution time: {execution_time}")
        print(f"Best R2: {best_score:.6f}")
        
        results_path = f"/dbfs/mnt/results/grid_search_{start_time.strftime('%Y%m%d_%H%M%S')}.json"
        os.makedirs(os.path.dirname(results_path), exist_ok=True)
        with open(results_path, "w") as f:
            json.dump({"execution_time_seconds": execution_time.total_seconds(),
                      "best_r2_score": float(best_score), "best_params": best_params,
                      "completion_time": datetime.now().isoformat()}, f, indent=2)
        print(f"Results saved: {results_path}")
        
    except (KeyboardInterrupt, Exception) as e:
        execution_time = datetime.now() - start_time
        print(f"\n💥 Grid search terminated after {execution_time}")
        print(f"Error: {type(e).__name__}: {str(e) if str(e) else 'User interruption'}")
        
        # Save partial results
        if 'checkpoint' in globals() and hasattr(checkpoint, '_cache'):
            partial_results_path = f"/dbfs/mnt/results/partial_grid_search_{start_time.strftime('%Y%m%d_%H%M%S')}.json"
            try:
                os.makedirs(os.path.dirname(partial_results_path), exist_ok=True)
                with open(partial_results_path, "w") as f:
                    json.dump({
                        "execution_time_seconds": execution_time.total_seconds(),
                        "completed_combinations": len(checkpoint._cache),
                        "total_combinations": len(param_grid),
                        "completion_percentage": len(checkpoint._cache) / len(param_grid) * 100,
                        "termination_reason": type(e).__name__,
                        "termination_time": datetime.now().isoformat()
                    }, f, indent=2)
                print(f"📊 Partial progress saved: {partial_results_path}")
            except Exception as save_error:
                print(f"⚠️ Failed to save partial results: {save_error}")
        
        # Re-raise to maintain error behavior
        raise
    
    finally:
        # Always cleanup, regardless of success or failure
        print("🧹 Final cleanup...")
        spark.catalog.clearCache()
        print("Grid search complete - resources cleaned")
        spark.stop()
        
else:
    print("Execution disabled: Set EXECUTE = True to start")
    print("Data loaded and ready for analysis (Spark session remains active)")

The code is a distributed grid search for hyperparameter tuning of a random forest model. 
This code ran on Databricks using Spark and Hadoop engine along with MLflow for tracking and logging.
This code generated models with the metrics of r2 surrounding 0.82 which is a lot lower than previous random search.

