In [1]:
import os, warnings, sys, logging 
import mlflow 
import pandas as pd 
import numpy as np 
from datetime import date

#use the GPU-native implementation
from spark_rapids_ml.classification import RandomForestClassifier

from pyspark.sql import SparkSession

import pprint

from py4j.java_gateway import java_import

USERNAME = os.environ["PROJECT_OWNER"] 
DBNAME = "DEMO_"+USERNAME 
CONNECTION_NAME = "go01-aw-dl" 
STORAGE = os.environ["DATA_STORAGE"] 
DATE = date.today()

RAPIDS_JAR = "/home/cdsw/rapids-4-spark_2.12-25.10.0.jar"

LOCAL_PACKAGES = "/home/cdsw/.local/lib/python3.10/site-packages"

#This is where the specific CUDA 12 NVRTC library lives
NVRTC_LIB_PATH = f"{LOCAL_PACKAGES}/nvidia/cuda_nvrtc/lib" 
WRITABLE_CACHE_DIR = "/tmp/cupy_cache"

spark = SparkSession.builder \
.appName("Spark-Rapids-32GB-Final") \
.config("spark.jars", RAPIDS_JAR) \
.config("spark.plugins", "com.nvidia.spark.SQLPlugin") \
.config("spark.task.resource.gpu.amount", 0.5) \
.config("spark.executor.resource.gpu.vendor", "nvidia.com") \
.config("spark.executor.resource.gpu.discoveryScript", "/home/cdsw/spark-rapids-ml/getGpusResources.sh") \
.config("spark.executorEnv.LD_LIBRARY_PATH", f"{NVRTC_LIB_PATH}:{os.environ.get('LD_LIBRARY_PATH', '')}") \
.config("spark.executorEnv.PYTHONPATH", LOCAL_PACKAGES) \
.config("spark.executorEnv.CUPY_CACHE_DIR", WRITABLE_CACHE_DIR) \
.config("spark.driverEnv.CUPY_CACHE_DIR", WRITABLE_CACHE_DIR) \
.config("spark.driver.memory", "12g") \
.config("spark.driver.extraJavaOptions", f"-Djava.library.path={NVRTC_LIB_PATH}") \
.config("spark.driver.maxResultSize", "4g") \
.config("spark.dynamicAllocation.enabled", "false") \
.config("spark.executor.cores", 2) \
.config("spark.executor.instances", 1) \
.config("spark.executor.heartbeatInterval", "60s") \
.config("spark.executor.memory", "10g") \
.config("spark.executor.resource.gpu.amount", 1) \
.config("spark.executor.memoryOverhead", "10g") \
.config("spark.sql.autoBroadcastJoinThreshold", -1) \
.config("spark.sql.broadcastTimeout", "1200") \
.config("spark.sql.cache.serializer", "com.nvidia.spark.ParquetCachedBatchSerializer") \
.config('spark.sql.shuffle.partitions', '200') \
.config("spark.network.timeout", "800s") \
.config("spark.rapids.sql.enabled", "true") \
.config("spark.rapids.shims-provider-override", "com.nvidia.spark.rapids.shims.spark351.SparkShimServiceProvider") \
.config("spark.rapids.memory.pinnedPool.size", "4g") \
.config("spark.kerberos.access.hadoopFileSystems", "s3a://go01-demo/user/jprosser/spark-rapids-ml/") \
.config("spark.shuffle.service.enabled", "false") \
.config('spark.shuffle.file.buffer', '64k') \
.config('spark.shuffle.spill.compress', 'true') \
.config("spark.hadoop.fs.defaultFS", "s3a://go01-demo/") \
.getOrCreate()

spark.sparkContext.setLogLevel("WARN")

#View the underlying Java Spark Context
pprint.pprint(f"Java Context Object: {spark.sparkContext._jsc}")

#View the Spark Master (in CML, this usually points to the local container or YARN)
pprint.pprint(f"Master: {spark.sparkContext.master}")

#View the User running the session
pprint.pprint(f"Spark User: {spark.sparkContext.sparkUser()}")

#Enable CollectLimit so that large datasets are collected on the GPU.
#Not worth it for small datasets
spark.conf.set("spark.rapids.sql.exec.CollectLimitExec", "true")

#Enabled to let the GPU to handle the random sampling of rows for large datasets
spark.conf.set("spark.rapids.sql.exec.SampleExec", "true")

#Enabled to let allow more time for large broadcast joins
spark.conf.set("spark.sql.broadcastTimeout", "1200") # Increase to 20 mins from pyspark.sql import functions as F

#spark.conf.set("spark.rapids.sql.explain", "ALL") 
spark.conf.set("spark.rapids.sql.explain", "NOT_ON_GPU") # Only log when/why the GPU was not selected 
spark.conf.set("spark.rapids.sql.variable.float.allow", "true") # Allow float math

#Allow the GPU to cast instead of pushing back to CPU just for cast
spark.conf.set("spark.rapids.sql.castFloatToDouble.enabled", "true") 
spark.conf.set("spark.rapids.sql.format.parquet.enabled", "true")

#Turning off Adaptive Query Execution (AQE) makes the entire SQL plan use the GPU
spark.conf.set("spark.sql.adaptive.enabled", "false")

spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10000)

#Test if the JVM can actually talk to the CUDA driver
cuda_manager = spark._jvm.ai.rapids.cudf.Cuda 
print(f"CUDA Driver Version: {cuda_manager.getDriverVersion()}") 
print(f"Device Count: {cuda_manager.getDeviceCount()}") 
print(f"Dynamic Allocation: {spark.conf.get('spark.dynamicAllocation.enabled')}") 
print(f"Executor Instances: {spark.conf.get('spark.executor.instances')}") 
print(f"Dynamic Allocation Enabled: {spark.conf.get('spark.dynamicAllocation.enabled')}")

#Test acess to the SQLPlugin
sql_plugin = spark._jvm.com.nvidia.spark.SQLPlugin()

driver_comp = sql_plugin.driverPlugin()

log_manager = spark._jvm.org.apache.log4j.LogManager 
level_debug = spark._jvm.org.apache.log4j.Level.DEBUG

logger = driver_comp.log() 
log_manager.getLogger("com.nvidia.spark.rapids").setLevel(level_debug)

print(f"Debug enabled for RAPIDS: {driver_comp.isTraceEnabled() or True}")

Setting spark.hadoop.yarn.resourcemanager.principal to jprosser


'Java Context Object: org.apache.spark.api.java.JavaSparkContext@30440c9f'
'Master: k8s://https://172.20.0.1:443'
'Spark User: jprosser'
CUDA Driver Version: 13000
Device Count: 1
Dynamic Allocation: false
Executor Instances: 1
Dynamic Allocation Enabled: false
Debug enabled for RAPIDS: True


In [2]:
df = spark.read.table("DataLakeTable")
print(f"Columns: {len(df.columns)}")
print(f"Schema: {df.schema}")
# Look for 'Gpu' operators in the output
df.limit(5).explain(mode="formatted")

# Transform data into a single vector column 
feature_cols = ["age", "credit_card_balance", "bank_account_balance", "mortgage_balance", "sec_bank_account_balance", "savings_account_balance",
                    "sec_savings_account_balance", "total_est_nworth", "primary_loan_balance", "secondary_loan_balance", "uni_loan_balance",
                    "longitude", "latitude", "transaction_amount"]

# Avoid VectorAssembler as it creates VectorUDT data types that are not GPU Friendly
#assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
#df_assembled = assembler.transform(df)


# Split data into training and test sets
#(training_data, test_data) = df_assembled.randomSplit([0.8, 0.2], seed=1234)

26/02/13 18:54:45 WARN  client.HiveClientImpl: [Thread-6]: Detected HiveConf hive.execution.engine is 'tez' and will be reset to 'mr' to disable useless hive logic
Hive Session ID = bdfd90a7-4a16-416d-92c8-d747b15f423c


Columns: 15
Schema: StructType([StructField('age', FloatType(), True), StructField('credit_card_balance', FloatType(), True), StructField('bank_account_balance', FloatType(), True), StructField('mortgage_balance', FloatType(), True), StructField('sec_bank_account_balance', FloatType(), True), StructField('savings_account_balance', FloatType(), True), StructField('sec_savings_account_balance', FloatType(), True), StructField('total_est_nworth', FloatType(), True), StructField('primary_loan_balance', FloatType(), True), StructField('secondary_loan_balance', FloatType(), True), StructField('uni_loan_balance', FloatType(), True), StructField('longitude', FloatType(), True), StructField('latitude', FloatType(), True), StructField('transaction_amount', FloatType(), True), StructField('fraud_trx', IntegerType(), True)])
== Physical Plan ==
GpuColumnarToRow (6)
+- GpuGlobalLimit (5)
   +- GpuShuffleCoalesce (4)
      +- GpuColumnarExchange (3)
         +- GpuLocalLimit (2)
            +- GpuSc

In [3]:
(training_data, test_data) = df.randomSplit([0.8, 0.2], seed=1234)

#from sklearn.model_selection import train_test_split
#X_train, X_validation, y_train, y_validation = train_test_split(
#    X, y, train_size=train_size)

# Use spark_rapids_ml.classification.RandomForestClassifier

# Import from spark_rapids_ml to use the GPU-native implementation
from spark_rapids_ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Define the RAPIDS-native classifier
# As noted above, by using 'featuresCols' (list of strings), we avoid VectorAssembler 
# 
rf_classifier = RandomForestClassifier(
    labelCol="fraud_trx", 
    featuresCols=feature_cols, 
    numTrees=20
)

In [4]:
# Train the model
# This runs the training logic in C++ on the GPU via cuML
print("Training Spark RAPIDS ML model...")
rf_model = rf_classifier.fit(training_data)
print("Model training complete.")
print(type(rf_model))

# Predict and optimize the output
# We drop 'probability' and 'rawPrediction' because they are VectorUDT types
# that Spark SQL would otherwise force back to the CPU for formatting.
predictions = rf_model.transform(test_data).drop("probability", "rawPrediction")
rf_model.setFeaturesCols(feature_cols)
# Show results (This will be fully accelerated)
predictions.select("prediction", "fraud_trx").show(5)
# Verify GPU Plan
# You should see 'GpuProject' and 'GpuFilter' nodes without the VectorUDT warning
predictions.explain(mode="formatted")

2026-02-13 18:55:02,955 - spark_rapids_ml.classification.RandomForestClassifier - INFO - Training spark-rapids-ml with 1 worker(s) ...


Training Spark RAPIDS ML model...


2026-02-13 18:55:04,084 - spark_rapids_ml.classification.RandomForestClassifier - INFO - Training tasks require the resource(cores=2, gpu=1.0)
26/02/13 18:55:04 WARN  scheduler.DAGScheduler: [dag-scheduler-event-loop]: Barrier stage in job 0 requires 1 slots, but only 0 are available. Will retry up to 40 more times
26/02/13 18:55:19 WARN  scheduler.DAGScheduler: [dag-scheduler-event-loop]: Barrier stage in job 0 requires 1 slots, but only 0 are available. Will retry up to 39 more times
26/02/13 18:55:34 WARN  scheduler.DAGScheduler: [dag-scheduler-event-loop]: Barrier stage in job 0 requires 1 slots, but only 0 are available. Will retry up to 38 more times
26/02/13 18:55:49 WARN  scheduler.DAGScheduler: [dag-scheduler-event-loop]: Barrier stage in job 0 requires 1 slots, but only 0 are available. Will retry up to 37 more times
2026-02-13 18:57:08,992 - spark_rapids_ml.classification.RandomForestClassifier - INFO - Finished training


Model training complete.
<class 'spark_rapids_ml.classification.RandomForestClassificationModel'>


26/02/13 18:57:09 WARN  util.SparkStringUtils: [Thread-6]: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

+----------+---------+
|prediction|fraud_trx|
+----------+---------+
|       0.0|        0|
|       0.0|        0|
|       0.0|        0|
|       0.0|        0|
|       0.0|        1|
+----------+---------+
only showing top 5 rows

== Physical Plan ==
GpuColumnarToRow (8)
+- GpuProject (7)
   +- GpuCoalesceBatches (6)
      +- GpuArrowEvalPython (5)
         +- GpuCoalesceBatches (4)
            +- GpuSample (3)
               +- GpuSort (2)
                  +- GpuScan parquet spark_catalog.default.datalaketable (1)


(1) GpuScan parquet spark_catalog.default.datalaketable
Output [15]: [age#0, credit_card_balance#1, bank_account_balance#2, mortgage_balance#3, sec_bank_account_balance#4, savings_account_balance#5, sec_savings_account_balance#6, total_est_nworth#7, primary_loan_balance#8, secondary_loan_balance#9, uni_loan_balance#10, longitude#11, latitude#12, transaction_amount#13, fraud_trx#14]
Batched: true
Eager_IO_Prefetch: false
Location: InMemoryFileIndex [s3a://go01-demo/wareho

In [5]:
import spark_rapids_ml.metrics.MulticlassMetrics as mm
#print(f"Available in metrics: {help(mm)}")


accuracy = predictions.filter("prediction = fraud_trx").count() / predictions.count()

print(f"GPU-Accelerated Accuracy: {accuracy:.4f}")


#from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from spark_rapids_ml.metrics.MulticlassMetrics import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(
    labelCol="fraud_trx", 
    predictionCol="prediction", 
    metricName="accuracy"
)
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy: {accuracy:.4f}")

GPU-Accelerated Accuracy: 0.8999


26/02/13 18:57:21 WARN  rapids.GpuOverrides: [Thread-6]: 
! <DeserializeToObjectExec> cannot run on GPU because not all expressions can be replaced; GPU does not currently support the operator class org.apache.spark.sql.execution.DeserializeToObjectExec
  ! <CreateExternalRow> createexternalrow(staticinvoke(class java.lang.Double, ObjectType(class java.lang.Double), valueOf, prediction#173, true, false, true), staticinvoke(class java.lang.Double, ObjectType(class java.lang.Double), valueOf, fraud_trx#272, true, false, true), staticinvoke(class java.lang.Double, ObjectType(class java.lang.Double), valueOf, 1.0#273, true, false, true), StructField(prediction,DoubleType,true), StructField(fraud_trx,DoubleType,true), StructField(1.0,DoubleType,false)) cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow
    !Expression <StaticInvoke> staticinvoke(class java.lang.Double, ObjectType(class java.lang.Dou

Test Accuracy: 0.8999


In [6]:
treelite_model_checkpoint = rf_model._treelite_model

import treelite
import cuml
import spark_rapids_ml
print(f"Treelite: {treelite.__version__}")
print(f"cuML:     {cuml.__version__}")
print(f"spark_rapids_ml:     {spark_rapids_ml.__version__}")

import base64
import pickle
import treelite


# 2. Decode from Base64 to Pickle-bytes
pickle_bytes = base64.b64decode(treelite_model_checkpoint)

# 3. Unpickle to get the actual Treelite binary bytes
# Note: This returns the raw bytes that Treelite actually understands
raw_treelite_bytes = pickle.loads(pickle_bytes)

# 4. Now deserialize using Treelite
treelite_model = treelite.Model.deserialize_bytes(raw_treelite_bytes)


Treelite: 4.4.1
cuML:     25.10.00
spark_rapids_ml:     25.10.0


In [7]:
from cuml.fil import ForestInference

fm = ForestInference.load_from_treelite_model(treelite_model)

print(f"Success! Model loaded with {treelite_model.num_tree} trees.")

Success! Model loaded with 20 trees.


In [8]:
test_data_vals=np.array(test_data.collect())
print(f"Input matrix dimensions (rows, cols): {test_data_vals.shape}")

%time FILpredictions = fm.predict(test_data_vals)


Input matrix dimensions: (1979, 15)
CPU times: user 0 ns, sys: 8.54 ms, total: 8.54 ms
Wall time: 8.53 ms


In [9]:
taller_arr = np.tile(test_data_vals, (10, 1))
print(f"input matrix dimensions (rows, cols): {taller_arr.shape}")
%time FILpredictions = fm.predict(taller_arr)

taller_arr = np.tile(taller_arr, (10, 1))
print(f"input matrix dimensions (rows, cols): {taller_arr.shape}")
%time FILpredictions = fm.predict(taller_arr)

taller_arr = np.tile(taller_arr, (10, 1))
print(f"input matrix dimensions (rows, cols): {taller_arr.shape}")
%time FILpredictions = fm.predict(taller_arr)

taller_arr = np.tile(taller_arr, (2, 1))
print(f"input matrix dimensions (rows, cols): {taller_arr.shape}")
%time FILpredictions = fm.predict(taller_arr)

taller_arr = np.tile(taller_arr, (2, 1))
print(f"input matrix dimensions (rows, cols): {taller_arr.shape}")
%time FILpredictions = fm.predict(taller_arr)

taller_arr = np.tile(taller_arr, (2, 1))
print(f"input matrix dimensions (rows, cols): {taller_arr.shape}")
%time FILpredictions = fm.predict(taller_arr)

input matrix dimensions: (19790, 15)
CPU times: user 3.11 ms, sys: 1.27 ms, total: 4.38 ms
Wall time: 3.74 ms
(197900, 15)
CPU times: user 0 ns, sys: 19.9 ms, total: 19.9 ms
Wall time: 19.9 ms
(1979000, 15)
CPU times: user 25 ms, sys: 109 ms, total: 134 ms
Wall time: 134 ms
(3958000, 15)
CPU times: user 79.6 ms, sys: 181 ms, total: 261 ms
Wall time: 259 ms
(7916000, 15)
CPU times: user 221 ms, sys: 292 ms, total: 514 ms
Wall time: 513 ms
(15832000, 15)
CPU times: user 364 ms, sys: 661 ms, total: 1.02 s
Wall time: 1.02 s


In [2]:
import mlflow.pyfunc

class FILWrapper(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        """
        This runs ONCE when the model is loaded in the registry/production.
        It pulls the raw bytes and initializes the GPU model.
        """
        # Load the bytes we saved as an artifact
        with open(context.artifacts["model_payload"], "rb") as f:
            model_bytes = f.read()
        
        # Initialize the FIL model on the GPU
        self.model = ForestInference()
        self.model.load_from_treelite_model(model_bytes)

    def predict(self, context, model_input):
        """
        Standard MLflow predict interface.
        """
        # Ensure input is a GPU-friendly format if possible, 
        # but ForestInference handles various types.
        return self.model.predict(model_input)

In [4]:
print(type(treelite_model))


<class 'treelite.model.Model'>


In [5]:
model_bytes = treelite_model.serialize("checkpoint.tl")

In [6]:

# 2. Define the artifacts for MLflow
artifacts = {"model_payload": "checkpoint.tl"}

# 3. Log and Register
with mlflow.start_run():
    mlflow.pyfunc.log_model(
        artifact_path="cuml_fil_model",
        python_model=FILWrapper(),
        artifacts=artifacts,
        registered_model_name="GPU_RandomForest_Production",
        pip_requirements=["cuml", "treelite", "cupy"] # Vital for the environment!
    )

Creating run for experiment_id: 0, user_id: cdsw, run_name: None
No experiment set using default experiment.Please set experiment using mlflow.set_experiment('<your experiment name>') to avoid using default experiment.
  from .autonotebook import tqdm as notebook_tqdm
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 1832.37it/s] 
Successfully registered model 'GPU_RandomForest_Production'.
experiment id naci-tqey-gag9-b47c 
2026/02/13 15:50:07 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: GPU_RandomForest_Production, version 1
Created version '1' of model 'GPU_RandomForest_Production'.
