In [0]:
%pip install mlflow

In [0]:

import mlflow
from pyspark.sql.functions import struct
import dlt
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql.types import *

features = ["duration", "protocol_type", "service", "flag", "difficulty", "src_bytes", "dst_bytes", "land", "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "is_host_login", "is_guest_login", "type", "num_outbound_cmds", "num_access_files", "num_shells", "num_file_creations", "num_root",  "su_attempted", "root_shell", "num_compromised",  "srv_diff_host_rate", "diff_srv_rate", "same_srv_rate", "srv_rerror_rate", "rerror_rate", "srv_serror_rate", "serror_rate", "srv_count", "count", "dst_host_srv_rerror_rate", "dst_host_rerror_rate", "dst_host_srv_serror_rate", "dst_host_serror_rate", "dst_host_srv_diff_host_rate", "dst_host_same_src_port_rate", "dst_host_diff_srv_rate", "dst_host_same_srv_rate", "dst_host_srv_count", "dst_host_count"]

run_id= "b8b152bb75c146279d674719ef87e9c7"
model_name = "FullPipeline"
model_uri = "runs:/{run_id}/{model_name}".format(run_id=run_id, model_name=model_name)
loaded_model = mlflow.pyfunc.spark_udf(spark, model_uri=model_uri)

kdd_schema = StructType([
            StructField('duration', IntegerType(), True), 
            StructField('protocol_type', StringType(), True),
            StructField('service', StringType(), True),
            StructField('flag', StringType(), True),
            StructField('src_bytes', IntegerType(), True), 
            StructField('dst_bytes', IntegerType(), True),
            StructField('land', IntegerType(), True), 
            StructField('wrong_fragment', IntegerType(), True),
            StructField('urgent', IntegerType(), True), 
            StructField('hot', IntegerType(), True),
            StructField('num_failed_logins', IntegerType(), True), 
            StructField('logged_in', IntegerType(), True),
            StructField('num_compromised', IntegerType(), True), 
            StructField('root_shell', IntegerType(), True),
            StructField('su_attempted', IntegerType(), True), 
            StructField('num_root', IntegerType(), True),
            StructField('num_file_creations', IntegerType(), True), 
            StructField('num_shells', IntegerType(), True),
            StructField('num_access_files', IntegerType(), True), 
            StructField('num_outbound_cmds', IntegerType(), True),
            StructField('is_host_login', IntegerType(), True), 
            StructField('is_guest_login', IntegerType(), True),
            StructField('count', IntegerType(), True), 
            StructField('srv_count', IntegerType(), True),
            StructField('serror_rate', DoubleType(), True),
            StructField('srv_serror_rate', DoubleType(), True),
            StructField('rerror_rate', DoubleType(), True), 
            StructField('srv_rerror_rate', DoubleType(), True),
            StructField('same_srv_rate', DoubleType(), True), 
            StructField('diff_srv_rate', DoubleType(), True),
            StructField('srv_diff_host_rate', DoubleType(), True), 
            StructField('dst_host_count', IntegerType(), True),
            StructField('dst_host_srv_count', IntegerType(), True),
            StructField('dst_host_same_srv_rate', DoubleType(), True),
            StructField('dst_host_diff_srv_rate', DoubleType(), True),
            StructField('dst_host_same_src_port_rate', DoubleType(), True),
            StructField('dst_host_srv_diff_host_rate', DoubleType(), True),
            StructField('dst_host_serror_rate', DoubleType(), True),
            StructField('dst_host_srv_serror_rate', DoubleType(), True),
            StructField('dst_host_rerror_rate', DoubleType(), True),
            StructField('dst_host_srv_rerror_rate', DoubleType(), True), 
            StructField('type', StringType(), True),
            StructField("difficulty", IntegerType(), True)])

@dlt.table(
  comment="Raw JSON data on kdd",
  name="bronze_nsl_rt")
def kddCupBronze():
  return (
    spark.readStream.format("cloudFiles")
      .option("cloudFiles.format", "json")
      .load("/mnt/kddcap-data/bronze-nsl-rt.json")
  )

@dlt.table(
  comment="Prepard data on kdd",
  name="silver_nsl_rt")
@dlt.expect_or_drop("valid_current_page", "duration IS NOT NULL")
def kddCupSilver():
    return (
        dlt.read("bronze_nsl_rt")
        .withColumn("Body", unbase64(col("Body")).cast("string"))
        .select("Body")
        .select(from_json(col("Body"), kdd_schema).alias("json_payload")) 
        .select("json_payload.*")
    )

@dlt.table(
    comment="KDDCUP classification",
    name="kddcup_classification_rt_results"
)
def kddcup_classification_results():
  return dlt.read("silver_nsl_rt").withColumn('predictions', loaded_model(struct(features)))