In [1]:
# This section is to read .env file in the current directory. You need to set GOOGLE_API_KEY in the file. Alternatively just set hardcode google_api_key to the API Key value. Do not checkin the key to git
from dotenv import load_dotenv
import os
load_dotenv()  # Automatically looks for a `.env` file in current dir
#print(os.environ['GOOGLE_API_KEY'])

True

In [8]:
import faiss
import google.generativeai as genai
import numpy as np
from pyspark.sql import SparkSession
import os

google_api_key = os.environ.get('GOOGLE_API_KEY')
MODEL = "models/embedding-001"
BATCH_SIZE_FOR_GOOGLE_API = 1000
dimension = 768

spark = SparkSession.builder \
    .appName("ParallelEmbeddingGeneration") \
    .master("local[*]") \
    .getOrCreate()

def configure_genai():
    genai.configure(api_key=google_api_key)

def generate_embeddings_in_batches(batch):
    configure_genai()
    response = genai.embed_content(
        model=MODEL,
        content=batch,
        task_type="retrieval_document"
    )

    #print (response)

    embeddings = response["embedding"]
    return embeddings  # List[List[float]]

def gen_standard_embeddings(standard_diagnosis_list):
    # Create RDD and process partitions in Spark
    rdd = spark.sparkContext.parallelize(standard_diagnosis_list, numSlices=6)

    def process_partition(partition):
        configure_genai()
        partition = list(partition)
        all_embeddings = []
        for i in range(0, len(partition), BATCH_SIZE_FOR_GOOGLE_API):
            batch = partition[i:i + BATCH_SIZE_FOR_GOOGLE_API]
            batch_embeddings = generate_embeddings_in_batches(batch)
            all_embeddings.extend(batch_embeddings)
        return all_embeddings

    # ✅ Step 1: Collect all embeddings back to driver
    all_embeddings = rdd.mapPartitions(process_partition).collect()

    # ✅ Step 2: Convert to NumPy array and build FAISS index on driver
    embeddings_np = np.array(all_embeddings, dtype='float32')
    faiss.normalize_L2(embeddings_np)

    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings_np)

    # ✅ Step 3: Save index
    faiss.write_index(index, "faiss_standard_strings_embeddings.index")
    print("Saved FAISS index with", index.ntotal, "vectors.")


# standard_diagnosis = [
#         "Hypertension, primary",
#         "Type 2 diabetes mellitus",
#         "Acute upper respiratory infection",
#         "Major depressive disorder",
#     ]

# #df = pd.read_csv("symptom_descriptions_top100.csv")
# #first_column = df.iloc[:, 0].astype(str).tolist()
# gen_standard_embeddings(standard_diagnosis)

In [9]:
import pandas as pd
# Read the CSV file and extract the first column
df = pd.read_csv("symptom_descriptions_top10K.csv")
# Extract the first column as an array of strings
first_column = df.iloc[:, 0].astype(str).tolist()


# Generate embeddings
gen_standard_embeddings(first_column)
#print("Generated embeddings shape:", standard_embeddings.shape)

                                                                                

Saved FAISS index with 9999 vectors.
