In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
import subprocess
import weaviate
import json
import pandas as pd
import random
from typing import List, Dict, Any
import os
from sentence_transformers import SentenceTransformer

from scripts import weaviate_utils


input_schema = StructType([
    StructField("name", StringType(), True),
    StructField("ip", StringType(), True),
    StructField("port", StringType(), True),
    StructField("query", StringType(), True),
])

output_schema = StructType([
    StructField("name", StringType(), True),
    StructField("ip", StringType(), True),
    StructField("port", StringType(), True),
    StructField("rag_text", StringType(), True),
    StructField("confidence", FloatType(), True),
])

vector_schema = ArrayType(FloatType())




spark = SparkSession.builder.appName("weaviate_deneme").getOrCreate()
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
services = weaviate_utils.get_external_ips()

  from .autonotebook import tqdm as notebook_tqdm
25/04/04 16:53:23 WARN Utils: Your hostname, asgrich-laptop resolves to a loopback address: 127.0.1.1; using 192.168.1.105 instead (on interface wlp0s20f3)
25/04/04 16:53:23 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/04 16:53:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
def generate_embedding(query: str) -> list:
    try:
        q_embedding = embedding_model.encode(query).tolist()
        return q_embedding
    except Exception as e:
        print(f"err: {e}")
        return []

generate_embedding_udf = F.udf(generate_embedding, vector_schema)

In [4]:
def search_weaviate(cluster_name: str, cluster_ip: str, cluster_port: str, query_vector: list) -> dict:
    try:
        ## cluster baglantisi
        client = weaviate.connect_to_custom(    
            url=f"http://{cluster_ip}:{cluster_port}",
            port=cluster_port,
            scheme='http',
            grpc_port=50051,  ## burasi defaultmus amk
            grpc_scheme='http'
        )



        ## burayi editleyecegim indexlere gore, for loop gerekebilir
        result = (
            client.query.get("dist-data", ["blabla", "blabla"])
            .with_near_vector({"vector": query_vector})
            .do()
        )

        client.close()

        ## burasi da degisecek siteden aldigim gibi kaldi
        if "data" in result and "Get" in result["data"] and "blabla" in result["data"]["Get"]:
            top_results = result["data"]["Get"]["blabla"]

            rag_text = []
            certainity = []
            
        else:
            rag_text = []
            certainity = []

        ## ilk 5'i sec certainity'lere gore, ve ona gore gonder

        return rag_texts, certainities


search_weaviate_udf = F.udf(search_weaviate, output_schema)

In [17]:
query = "hello there"

df = spark.createDataFrame([services[service] for service in services if service != "grpc"], input_schema)  ## spark df yarat
df = df.withColumn("query", F.col("query").cast(StringType())).withColumn("query", F.lit(query))            ## lit ile query'yi ekle
df = df.withColumn("query_embedding", generate_embedding_udf(F.col("query")))                               ## udf ile embeddingleri al

df.show()

                                                                                

+------------------+--------------+----+-----------+--------------------+
|              name|            ip|port|      query|     query_embedding|
+------------------+--------------+----+-----------+--------------------+
|weaviate-cluster-1|  10.101.40.35|8080|hello there|[-0.09443893, 0.0...|
|weaviate-cluster-2| 10.96.100.105|8080|hello there|[-0.09443893, 0.0...|
|weaviate-cluster-3| 10.111.91.205|8080|hello there|[-0.09443893, 0.0...|
|weaviate-cluster-4|  10.108.170.1|8080|hello there|[-0.09443893, 0.0...|
|weaviate-cluster-5|10.100.180.155|8080|hello there|[-0.09443893, 0.0...|
+------------------+--------------+----+-----------+--------------------+



In [21]:
df.withColumn("rag_text", search_weaviate_udf(F.col("name"), F.col("ip"), F.col("port"), F.col("query_embedding"))) \
.show()
    

+------------------+--------------+----+-----------+--------------------+--------------------+
|              name|            ip|port|      query|     query_embedding|            rag_text|
+------------------+--------------+----+-----------+--------------------+--------------------+
|weaviate-cluster-1|  10.101.40.35|8080|hello there|[-0.09443893, 0.0...|{NULL, 10.101.40....|
|weaviate-cluster-2| 10.96.100.105|8080|hello there|[-0.09443893, 0.0...|{NULL, 10.96.100....|
|weaviate-cluster-3| 10.111.91.205|8080|hello there|[-0.09443893, 0.0...|{NULL, 10.111.91....|
|weaviate-cluster-4|  10.108.170.1|8080|hello there|[-0.09443893, 0.0...|{NULL, 10.108.170...|
|weaviate-cluster-5|10.100.180.155|8080|hello there|[-0.09443893, 0.0...|{NULL, 10.100.180...|
+------------------+--------------+----+-----------+--------------------+--------------------+

