In [None]:
import json

from src.config import RAW_DATA_DIR
from src.api.config.text_embedding_3_small_config import BATCH_SIZE

In [None]:
franchises_data_dir = RAW_DATA_DIR / "franserve"
franchises_data_files = list(franchises_data_dir.glob("*.json"))

columns = [
    "franchise_name",
    "primary_category",
    "sub_categories",
    "why_franchise_summary",
    "ideal_candidate_profile_text",
    "description_text",
]
offset = 0

franchises_batch = []
for file in franchises_data_files[offset : offset + BATCH_SIZE]:
    with open(file, "r", encoding="utf-8") as f:
        franchise_data = json.load(f)["franchise_data"]
        franchise_data = {k: v for k, v in franchise_data.items() if k in columns}
    franchises_batch.append(franchise_data)

In [None]:
franchises_batch

In [None]:
import ast
import json
from typing import Any, Dict, List

from google import genai
from google.genai import types
from loguru import logger
import pandas as pd

from src.api.config.gemini_config import (
    CLIENT,
    MODEL_FLASH,
    get_generate_content_config_keywords,
    get_thinking_config,
    get_tools,
)
from src.api.google_gemini import generate
from src.config import CONFIG_DIR, RAW_DATA_DIR

# --- Step 1: Create the Summary from Structured Data ---


def create_summary_for_keyword_extraction(franchise_data: Dict[str, Any]) -> str:
    """
    Creates a clean, text-based summary from the structured franchise data
    to be used as input for keyword extraction.

    Args:
        franchise_data: The dictionary of data extracted by the first LLM call.

    Returns:
        A single string summarizing the franchise.
    """
    summary_parts = []

    # Helper to add a field to the summary if it exists
    def add_to_summary(key: str, label: str):
        value = franchise_data.get(key)
        if value:
            # If value is a list (like from JSON), join it
            if isinstance(value, list):
                value_str = ", ".join(value)
                summary_parts.append(f"{label}: {value_str}")
            else:
                summary_parts.append(f"{label}: {value}")

    add_to_summary("franchise_name", "Franchise Name")
    add_to_summary("description_text", "Description")
    add_to_summary("why_franchise_summary", "Key Benefits")
    add_to_summary("ideal_candidate_profile_text", "Ideal Candidate")

    # Create a composite investment summary
    min_inv = franchise_data.get("total_investment_min_usd")
    max_inv = franchise_data.get("total_investment_max_usd")
    if min_inv and max_inv:
        summary_parts.append(f"Total Investment: ${min_inv:,} - ${max_inv:,}")

    # Create a composite ownership model summary
    ownership_model = []
    if franchise_data.get("is_home_based"):
        ownership_model.append("Home-Based")
    if franchise_data.get("allows_absentee"):
        ownership_model.append("Absentee Ownership")
    elif franchise_data.get("allows_semi_absentee"):
        ownership_model.append("Semi-Absentee Ownership")
    if franchise_data.get("e2_visa_friendly"):
        ownership_model.append("E2 Visa Friendly")

    if ownership_model:
        summary_parts.append(f"Business Model: {', '.join(ownership_model)}")

    return "\n".join(summary_parts)


# --- Step 2: Call the LLM and Extract Keywords ---


def extract_keywords_with_llm(
    client: genai.Client,
    model: str,
    prompt: str,
    franchise_summary: str,
) -> List[str] | None:
    """
    Calls the Gemini API to extract keywords from a franchise summary.

    Args:
        client: The initialized genai.Client.
        model: The name of the model to use (e.g., 'gemini-1.5-pro-latest').
        prompt: The keyword extraction prompt.
        franchise_summary: The text summary of the franchise.

    Returns:
        A list of keywords, or None if an error occurs.
    """
    parts = [
        types.Part(text=prompt),
        types.Part(text="\n--- FRANCHISE SUMMARY ---\n"),
        types.Part(text=franchise_summary),
    ]

    generate_content_config = get_generate_content_config_keywords(
        thinking_config=get_thinking_config(thinking_budget=-1),
        tools=get_tools(google_search=True, url_context=True),
    )

    response = generate(
        client=client,
        model=model,
        parts=parts,
        generate_content_config=generate_content_config,
    )

    # Print token usage for cost estimation
    if hasattr(response, "usage_metadata") and response.usage_metadata:
        input_tokens = response.usage_metadata.prompt_token_count
        output_tokens = response.usage_metadata.candidates_token_count
        logger.info(f"Token usage - Input: {input_tokens}, Output: {output_tokens}")
    else:
        logger.warning("Token usage information not available in response")

    return response

In [None]:
def main():
    """
    Main function to run the keyword extraction pipeline with upsert support.
    """

    # Load prompt
    prompt_path = CONFIG_DIR / "franserve" / "keywords_prompt.txt"
    with open(prompt_path, "r", encoding="utf-8") as file:
        prompt_keywords = file.read()

    # Process each franchise
    franchise_data_paths = list((RAW_DATA_DIR / "franserve").glob("*.json"))
    logger.info(f"Found {len(franchise_data_paths)} franchise files.")

    for franchise_data_path in franchise_data_paths:
        logger.debug(f"Processing {franchise_data_path.name}")

        with open(franchise_data_path, "r", encoding="utf-8") as file:
            raw_data = json.load(file)
            franchise_data = raw_data["franchise_data"]

        # Generate inputs
        franchise_summary = create_summary_for_keyword_extraction(franchise_data)
        response = extract_keywords_with_llm(
            client=CLIENT,
            model=MODEL_FLASH,
            prompt=prompt_keywords,
            franchise_summary=franchise_summary,
        )

        logger.success(f"Processed {franchise_data.get('source_id')}")

        break
    return response

In [None]:
response = main()

In [None]:
response

In [None]:
from src.data.nlp.genai_keywords_batch import check_keywords_batch_results

In [None]:
check_keywords_batch_results("batches/uh6g31hv1o06t7kslrzy738lkipqw5qjt9pg")