In [3]:
# data preparation
import pandas as pd
import os
import json

In [4]:
# --- Configuration ---
JSON_FILE_PATH = f"/Users/dhirendrachoudhary/Desktop/Workstation/Research/APIGenie/data/scikit-learn-api-reference.json"
CHROMA_DB_PATH = "./chroma_db" # Path to store ChromaDB files
CHROMA_COLLECTION_NAME = "sklearn_apis"
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' # or 'all-mpnet-base-v2'

In [5]:
from src.prepare_data import load_and_flatten_data

flat_data = load_and_flatten_data(JSON_FILE_PATH)

In [6]:
flat_data

[{'id': '0',
  'api_full_name': 'sklearn.config_context',
  'module_name': 'sklearn',
  'class_name': 'config_context',
  'link': 'https://scikit-learn.org/stable/modules/generated/sklearn.config_context.html',
  'class_signature': 'class sklearn.config_context(*, assume_finite=None, working_memory=None, print_changed_only=None, display=None, pairwise_dist_chunk_size=None, enable_cython_pairwise_dist=None, array_api_dispatch=None, transform_output=None, enable_metadata_routing=None, skip_parameter_validation=None)',
  'example_code': "import sklearn\nfrom sklearn.utils.validation import assert_all_finite\nwith sklearn.config_context(assume_finite=True):\n    assert_all_finite([float('nan')])\nwith sklearn.config_context(assume_finite=True):\n    with sklearn.config_context(assume_finite=False):\n        assert_all_finite([float('nan')])\nTraceback (most recent call last):\n...\nValueError: Input contains NaN...",
  'text_for_embedding': "API Name: config_context. Belongs to module: skl

In [19]:
# flatten the json data
def load_and_flatten_data(json_file_path):
    """Loads data from JSON and flattens it into a list of API documents."""
    flattened_apis = []
    try:
        with open(json_file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: JSON file not found at {json_file_path}")
        return []
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {json_file_path}")
        return []

    doc_id_counter = 0
    for module_name, module_data in data.items():
        for class_name, class_details in module_data.get("subsections", {}).items():
            api_full_name = f"{module_name}.{class_name}"

            # Construct text for embedding
            # Consider adding a brief scraped description here if possible in the future
            text_for_embedding = f"API Name: {class_name}. Belongs to module: {module_name}. "
            text_for_embedding += f"Signature: {class_details.get('class_signature', '')}. "
            # Example code can be long; consider truncating or summarizing for embedding if performance issues arise
            text_for_embedding += f"Example Usage: {class_details.get('example_code', '')}"

            api_doc = {
                "id": str(doc_id_counter), # ChromaDB requires string IDs
                "api_full_name": api_full_name,
                "module_name": module_name,
                "class_name": class_name,
                "link": class_details.get("link", ""),
                "class_signature": class_details.get("class_signature", ""),
                "example_code": class_details.get("example_code", ""),
                "text_for_embedding": text_for_embedding
            }
            flattened_apis.append(api_doc)
            doc_id_counter += 1
    return flattened_apis

In [20]:
# # save flattened data to a json
# with open("data/flattened_apis.json", "w") as f:
#     json.dump(filtered_apis, f, indent=4)

In [21]:
# create and inti to vectorize the data
from sentence_transformers import SentenceTransformer
import chromadb
import numpy as np

def initialize_embedding_model(model_name):
    """Initializes and returns the Sentence Transformer model."""
    print(f"Loading embedding model: {model_name}...")
    model = SentenceTransformer(model_name)
    print("Embedding model loaded.")
    return model

def create_and_populate_vector_db(apis, model, db_path, collection_name):
    """Creates embeddings and populates ChromaDB."""
    if not apis:
        print("No APIs to process for vector DB.")
        return None

    print("Initializing ChromaDB client...")
    client = chromadb.PersistentClient(path=db_path)

    # Get or create collection
    try:
        collection = client.get_collection(name=collection_name)
        print(f"Using existing collection: {collection_name}")
        # Optional: Clear collection if you want to re-ingest every time
        # client.delete_collection(name=collection_name)
        # collection = client.create_collection(name=collection_name)
        # print(f"Re-created collection: {collection_name}")
    except: # Simple catch-all for this example, refine for production
        print(f"Creating new collection: {collection_name}")
        collection = client.create_collection(name=collection_name)

    print(f"Generating embeddings for {len(apis)} API documents...")
    texts_to_embed = [doc['text_for_embedding'] for doc in apis]
    embeddings = model.encode(texts_to_embed, show_progress_bar=True)

    # Prepare data for ChromaDB
    documents_for_chroma = [doc['text_for_embedding'] for doc in apis] # What text Chroma stores
    metadatas_for_chroma = [
        {
            "api_full_name": doc["api_full_name"],
            "module_name": doc["module_name"],
            "class_name": doc["class_name"],
            "link": doc["link"],
            "signature": doc["class_signature"]
            # Exclude full example_code and text_for_embedding from metadata to keep it lean
            # We already have the ID to fetch the full original doc if needed
        } for doc in apis
    ]
    ids_for_chroma = [doc['id'] for doc in apis]

    # Check if documents already exist to avoid duplicates or decide on update strategy
    # For simplicity here, we'll assume we add if not exists, or re-add if collection was cleared.
    # A more robust way is to check existing IDs.
    existing_docs = collection.get(ids=ids_for_chroma)
    new_ids_for_chroma = []
    new_embeddings = []
    new_documents_for_chroma = []
    new_metadatas_for_chroma = []

    for i, doc_id in enumerate(ids_for_chroma):
        if doc_id not in existing_docs['ids']:
            new_ids_for_chroma.append(doc_id)
            new_embeddings.append(embeddings[i])
            new_documents_for_chroma.append(documents_for_chroma[i])
            new_metadatas_for_chroma.append(metadatas_for_chroma[i])

    if new_ids_for_chroma:
        print(f"Adding {len(new_ids_for_chroma)} new documents to ChromaDB...")
        collection.add(
            embeddings=np.array(new_embeddings).tolist(), # Ensure it's a list of lists/np.array
            documents=new_documents_for_chroma,
            metadatas=new_metadatas_for_chroma,
            ids=new_ids_for_chroma
        )
        print(f"{len(new_ids_for_chroma)} documents added/updated in collection '{collection_name}'.")
    else:
        print("No new documents to add. All documents might already exist.")

    return collection

In [22]:
# retrieve relevant APIs
def retrieve_relevant_apis(query_text, model, collection, n_results=5):
    """Embeds the query and retrieves relevant APIs from ChromaDB."""
    print(f"\nUser Query: '{query_text}'")
    query_embedding = model.encode([query_text])[0] # Get the first (and only) embedding

    results = collection.query(
        query_embeddings=[query_embedding.tolist()], # Chroma expects a list of embeddings
        n_results=n_results,
        include=['metadatas', 'documents', 'distances'] # documents are the 'text_for_embedding'
    )
    return results

In [24]:
if __name__ == "__main__":
    # 1. Load and flatten API data
    api_documents = load_and_flatten_data(JSON_FILE_PATH)
    if not api_documents:
        print("Exiting due to data loading issues.")
        exit()

    # 2. Initialize embedding model
    embedding_model = initialize_embedding_model(EMBEDDING_MODEL_NAME)

    # 3. Create/Populate Vector DB
    api_collection = create_and_populate_vector_db(
        api_documents, embedding_model, CHROMA_DB_PATH, CHROMA_COLLECTION_NAME
    )

    if not api_collection:
        print("Exiting due to Vector DB initialization issues.")
        exit()

    # 4. Example User Query and Retrieval
    # Test queries:
    # query1 = "How to perform clustering on high dimensional data?"
    # query2 = "I need a classifier for multi-class text data."
    # query3 = "Scale numerical features before training a model."
    # query4 = "Reduce dimensions of my dataset"
    # query5 = "How to combine multiple estimators into one?"


    test_queries = [
        "Build a classifier for multi-class text data, data is sparse",
        "I need to preprocess numerical features that have different scales, preparing for an SVM.",
        "Find a clustering algorithm suitable for a large number of samples and features.",
        "How to perform feature selection to improve my regression model?",
        "Combine preprocessing and a classification model into a single unit."
    ]

    for user_query in test_queries:
        retrieved_apis_info = retrieve_relevant_apis(user_query, embedding_model, api_collection, n_results=50)

        print(f"\n--- Results for Query: '{user_query}' ---")
        if retrieved_apis_info and retrieved_apis_info['ids'][0]:
            for i in range(len(retrieved_apis_info['ids'][0])):
                api_name = retrieved_apis_info['metadatas'][0][i].get('api_full_name', 'N/A')
                distance = retrieved_apis_info['distances'][0][i]
                # doc_content = retrieved_apis_info['documents'][0][i][:200] + "..." # Snippet
                print(f"  - API: {api_name} (Distance: {distance:.4f})")
                # print(f"    Doc: {doc_content}") # Optionally print part of the matched document

            print("\n--- Next Steps: LLM-based Pipeline Planning ---")
            print("The retrieved APIs would now be passed to an LLM with the original query.")
            print("Example prompt structure for LLM:")
            print(f"""
            User Goal: "{user_query}"
            Potentially Relevant Scikit-learn APIs (with their metadata):
            """)
            for i in range(len(retrieved_apis_info['ids'][0])):
                metadata = retrieved_apis_info['metadatas'][0][i]
                print(f"  - {metadata.get('api_full_name')}: {metadata.get('signature')}")
            print("""
            Task: Propose a conceptual scikit-learn pipeline to achieve the user's goal.
            Explain each step and why the chosen (or an alternative) API is suitable.
            Provide example Python code structure if possible using sklearn.pipeline.Pipeline.
            """)
            print("-" * 50)

        else:
            print("  No relevant APIs found.")

Loading embedding model: all-MiniLM-L6-v2...
Embedding model loaded.
Initializing ChromaDB client...
Using existing collection: sklearn_apis
Generating embeddings for 575 API documents...


Batches: 100%|██████████| 18/18 [00:02<00:00,  6.41it/s]


No new documents to add. All documents might already exist.

User Query: 'Build a classifier for multi-class text data, data is sparse'

--- Results for Query: 'Build a classifier for multi-class text data, data is sparse' ---
  - API: sklearn.decomposition.LatentDirichletAllocation (Distance: 1.1976)
  - API: sklearn.feature_extraction.TfidfTransformer (Distance: 1.2764)
  - API: sklearn.datasets.make_multilabel_classification (Distance: 1.3092)
  - API: sklearn.feature_extraction.CountVectorizer (Distance: 1.3337)
  - API: sklearn.multioutput.MultiOutputClassifier (Distance: 1.3378)
  - API: sklearn.utils.ClassifierTags (Distance: 1.3572)
  - API: sklearn.tree.export_text (Distance: 1.3659)
  - API: sklearn.feature_extraction.TfidfVectorizer (Distance: 1.3787)
  - API: sklearn.datasets.make_sparse_spd_matrix (Distance: 1.3881)
  - API: sklearn.datasets.make_sparse_coded_signal (Distance: 1.4488)
  - API: sklearn.metrics.classification_report (Distance: 1.4550)
  - API: sklearn.featur