In [1]:
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_json, struct, array
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType, Union, Dict

In [2]:
%env PYSPARK_PYTHON=C:\Users\Milosz\AppData\Local\pypoetry\Cache\virtualenvs\recsys-streaming-ml-Mj1TWbkU-py3.10\Scripts\python.exe

env: PYSPARK_PYTHON=C:\Users\Milosz\AppData\Local\pypoetry\Cache\virtualenvs\recsys-streaming-ml-Mj1TWbkU-py3.10\Scripts\python.exe


In [3]:
%env PYSPARK_DRIVER_PYTHON=C:\Users\Milosz\AppData\Local\pypoetry\Cache\virtualenvs\recsys-streaming-ml-Mj1TWbkU-py3.10\Scripts\python.exe

env: PYSPARK_DRIVER_PYTHON=C:\Users\Milosz\AppData\Local\pypoetry\Cache\virtualenvs\recsys-streaming-ml-Mj1TWbkU-py3.10\Scripts\python.exe


In [4]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.1,org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.1 pyspark-shell'

In [5]:
KAFKA_BROKER_URL = "kafka0:9093"
RECOMMENDATIONS_TOPIC = "recommendations"
USER_ACTIONS_TOPIC = "users.actions"

In [6]:
def predict_batch_fn():
    # load model from checkpoint
    import torch    
    device = torch.device("cuda")
    model = Net().to(device)
    checkpoint = load_checkpoint(checkpoint_dir)
    model.load_state_dict(checkpoint['model'])

    # define predict function in terms of numpy arrays
    def predict(inputs: np.ndarray) -> np.ndarray:
        torch_inputs = torch.from_numpy(inputs).to(device)
        outputs = model(torch_inputs)
        return outputs.cpu().detach().numpy()
    
    return predict

In [7]:
spark = SparkSession.builder \
        .appName("KafkaRead") \
        .master("local[*]") \
        .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.1") \
        .getOrCreate()

In [8]:
schema = StructType([
        StructField("user_id", StringType(), True)
    ])

In [9]:
df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", KAFKA_BROKER_URL) \
    .option("subscribe", RECOMMENDATIONS_TOPIC) \
    .option("startingOffsets", "latest") \
    .load()

In [11]:
df.printSchema()

root
 |-- key: binary (nullable = true)
 |-- value: binary (nullable = true)
 |-- topic: string (nullable = true)
 |-- partition: integer (nullable = true)
 |-- offset: long (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- timestampType: integer (nullable = true)



In [12]:
values_df = df.selectExpr("CAST(value AS STRING) as json_data") \
                .select(from_json(col("json_data"), schema).alias("data")) \
                .select("data.*")

In [18]:
values_df.writeStream.format('console').outputMode('append').start().awaitTermination()

StreamingQueryException: [STREAM_FAILED] Query [id = 84fae067-80d2-44df-b56d-852e529ca430, runId = deda2154-197d-4525-9d8d-abb6e5413f37] terminated with exception: 'boolean org.apache.hadoop.io.nativeio.NativeIO$Windows.access0(java.lang.String, int)'

In [None]:
query = values_df \
    .writeStream \
    .outputMode("append") \
    .format("console") \
    .trigger(processingTime='30 seconds') \
    .start()

In [None]:
query.awaitTermination(30)

In [None]:
df_parsed = df.selectExpr("CAST(value AS STRING) as json_value") \
    .select(from_json(col("json_value"), schema).alias("data")) \
    .select("data.*")


In [None]:
query = df_parsed.writeStream.outputMode("append").format("console").start()
query.awaitTermination()

In [19]:
def create_spark_session():
    """
    Create a SparkSession.
    """
    spark = SparkSession.builder \
        .appName("CreateDataFrameFromDict") \
        .getOrCreate()
    return spark

def create_dataframe_from_dict(spark, data):
    """
    Create a DataFrame from a list of dictionaries.
    Each dictionary represents a record with a single field `user_id`.
    """
    # Define schema
    schema = StructType([StructField("user_id", StringType(), True)])

    # Create DataFrame
    df = spark.createDataFrame(data, schema)
    return df

# Sample data
data = [
    {"user_id": "A1"},
    {"user_id": "B2"},
    {"user_id": "C3"},
    {"user_id": "D4"}
]

# Create Spark session
spark = create_spark_session()

# Create DataFrame from data
df = create_dataframe_from_dict(spark, data)

# Show DataFrame
df.show()

+-------+
|user_id|
+-------+
|     A1|
|     B2|
|     C3|
|     D4|
+-------+



In [120]:
from pyspark.sql.functions import udf

user_id_mapping = {"A1": 2, "B2": 3, "C3": 1, "D4": 0}

def process_data(df, user_id_mapping=user_id_mapping):
    """
    Process the DataFrame by mapping user_ids using the provided dictionary.
    """
    def map_user_id(user_id):
        return float(user_id_mapping.get(user_id, None))

    map_user_id_udf = udf(map_user_id, FloatType())

    processed_df = df.withColumn("mapped_user_id", map_user_id_udf(df["user_id"]))

    return processed_df

In [121]:
preprocessed_data = process_data(df)

In [128]:
preprocessed_data

DataFrame[user_id: string, mapped_user_id: float]

In [133]:
exploded_data = preprocessed_data.withColumn("data", explode_data(preprocessed_data['mapped_user_id']))

In [134]:
exploded_data.show()

+-------+--------------+---------------+
|user_id|mapped_user_id|           data|
+-------+--------------+---------------+
|     A1|           2.0|[0.0, 1.0, 2.0]|
|     B2|           3.0|[0.0, 1.0, 2.0]|
|     C3|           1.0|[0.0, 1.0, 2.0]|
|     D4|           0.0|[0.0, 1.0, 2.0]|
+-------+--------------+---------------+



In [136]:
from pyspark.sql.functions import explode
exploded_data.select(explode(exploded_data.data)).show()

+---+
|col|
+---+
|0.0|
|1.0|
|2.0|
|0.0|
|1.0|
|2.0|
|0.0|
|1.0|
|2.0|
|0.0|
|1.0|
|2.0|
+---+



In [129]:
def explode_data(df):
    data_list_udf = udf(lambda mapped_user_id: [0., 1., 2.], ArrayType(FloatType()))
    return data_list_udf(df)

In [124]:
explode_data(preprocessed_data)

Column<'<lambda>(mapped_user_id)'>

In [83]:
import torch
import io
import numpy as np
from pymongo import MongoClient
from pymongo.errors import OperationFailure

MONGODB_HOST = os.getenv("MONGODB_HOST", default="localhost")
MONGODB_PORT = 27017
MONGODB_AUTHSOURCE = "admin"
MONGODB_USERNAME = "root"
MONGODB_PASSWORD = "root"

def mongo_client(*args, **kwargs):
    try:
        client = MongoClient(
            host=MONGODB_HOST,
            port=MONGODB_PORT,
            authSource=MONGODB_AUTHSOURCE,
            username=MONGODB_USERNAME,
            password=MONGODB_PASSWORD
        )
        db = client.admin
        client.server_info()
        db.create_collection("model_versions", check_exists=False)

        print("[MONGO] Connection successful")
        return db

    except OperationFailure as e:
        print("[MONGO] Connection failed:", e)
    except Exception:
        print("[MONGO] Connection failed")

client = mongo_client()

def _get_latest_model_version_document(client):
    latest_version = client['model_versions'].find_one({}, sort=[("timestamp", -1)])
    return latest_version

def _get_model_buffer(model_version_document):
    #return model_version["binary"]
    return io.BytesIO(model_version_document["binary"])

def load_model(client):
    latest_version_document = _get_latest_model_version_document(client=client)
    model_buffer = _get_model_buffer(model_version_document=latest_version_document)

    # Assuming 'path' is the buffer containing the model's binary data
    print(f'Loading model: {latest_version_document["model"]}:v{latest_version_document["version"]}-{latest_version_document["timestamp"]}')
    model = torch.jit.load(model_buffer, map_location='cpu')
    return model

def build_input_tensor(inputs: np.ndarray) -> torch.Tensor:
    return torch.tensor(inputs, dtype=torch.long)

[MONGO] Connection successful


In [84]:
import numpy as np
import torch

# def predict_batch_fn():
#     print("dupa")
#     def predict(inputs: np.ndarray) -> np.ndarray:
#         torch_inputs = torch.from_numpy(inputs)
#         return (torch_inputs * -1).numpy()
    
#     return predict


def predict_batch_fn():
    # load model from checkpoint
    import torch    
    device = torch.device("cpu")
    client = mongo_client()
    model = load_model(client).to(device)

    # define predict function in terms of numpy arrays
    def predict(inputs: np.ndarray) -> np.ndarray:
        inputs = np.stack([inputs, inputs, inputs], axis=1)
        torch_inputs = build_input_tensor(inputs).to(device)
        outputs = model(torch_inputs)
        return outputs.cpu().detach().numpy()
    
    return predict

In [100]:
preprocessed_data.select("mapped_user_id").collect()

[Row(mapped_user_id=2),
 Row(mapped_user_id=3),
 Row(mapped_user_id=1),
 Row(mapped_user_id=0)]

In [104]:
preprocessed_data.show()

+-------+--------------+
|user_id|mapped_user_id|
+-------+--------------+
|     A1|             2|
|     B2|             3|
|     C3|             1|
|     D4|             0|
+-------+--------------+



In [101]:
make_recommendations = predict_batch_udf(predict_batch_fn,
                        input_tensor_shapes=[[1]],
                        return_type=ArrayType(FloatType()),
                        batch_size=7)

In [102]:
preds = preprocessed_data.withColumn("recommendations", make_recommendations('mapped_user_id'))

In [103]:
preds.collect()

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "c:\Users\Milosz\AppData\Local\pypoetry\Cache\virtualenvs\recsys-streaming-ml-Mj1TWbkU-py3.10\lib\site-packages\pyspark\ml\functions.py", line 802, in predict
    single_input = _validate_and_transform_single_input(
  File "C:\Spark\spark-3.5.1-bin-hadoop3\python\lib\pyspark.zip\pyspark\ml\functions.py", line 242, in _validate_and_transform_single_input
    raise ValueError("Invalid input_tensor_shape for scalar column.")
ValueError: Invalid input_tensor_shape for scalar column.


In [None]:
df = spark.read.parquet("/path/to/test/data")
preds = df.withColumn("preds", mnist('data')).collect()

query = df_parsed.writeStream \
    .outputMode("append") \
    .format("console") \
    #.trigger(processingTime='15 seconds') \
    .start()

query.awaitTermination()