# Vertex AI Search for Retail: Product Catalog Generation

This notebook demonstrates how to generate a synthetic product catalog for a retailer using a Vertex AI generative model (Gemini) and load it into a BigQuery table. The process is designed to be scalable and robust, handling potentially large datasets and ensuring data quality.

## 1. Setup and Configuration

In [None]:
%pip install --upgrade --quiet google-cloud-aiplatform pandas db-dtypes

In [None]:
import os
import json
import math
import concurrent.futures
from pathlib import Path

import vertexai
from vertexai.generative_models import GenerativeModel, Part
from google.cloud import bigquery, storage
from google.colab import auth as colab_auth

In [None]:
# Authenticate with Google Cloud
colab_auth.authenticate_user()

### Configuration
Set the project-specific variables below. You can either set the environment variables or replace the `os.environ.get(...)` calls with your static values.

In [None]:
# --- Core Configuration ---
RETAILER = "wayfair"
MODEL_NAME = "gemini-1.5-flash-001"
NUMBER_OF_PRODUCTS = 1000
MAX_OUTPUT_TOKENS = 8192  # Max for gemini-1.5-flash
MAX_PRODUCTS_PER_API_CALL = 25 # Adjust based on token usage and model capacity

# --- GCP Configuration ---
PROJECT_ID = os.environ.get("GOOGLE_CLOUD_PROJECT")
GCS_BUCKET_NAME = os.environ.get("GOOGLE_CLOUD_BUCKET")
LOCATION = "us-central1"

# --- BigQuery Configuration ---
BQ_DATASET_ID = "retail"
BQ_TABLE_ID = f"products-{RETAILER.lower()}"

# --- File Paths ---
CONFIG_DIR = Path("config")
OUTPUT_DIR = Path("output")
OUTPUT_FILE = OUTPUT_DIR / f"{RETAILER}_catalog.jsonl"

# Create directories if they don't exist
CONFIG_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)

# Initialize Vertex AI
vertexai.init(project=PROJECT_ID, location=LOCATION)

print(f"Project ID: {PROJECT_ID}")
print(f"GCS Bucket: {GCS_BUCKET_NAME}")

## 2. Load Configuration Files
This step loads the schema, requirements, categories, and the prompt template from the `config/` directory.

In [None]:
def load_config_file(path):
    """Loads content from a file."""
    with open(path, 'r') as f:
        return f.read()

try:
    schema_str = load_config_file(CONFIG_DIR / 'schema.json')
    requirements_str = load_config_file(CONFIG_DIR / 'requirements.txt')
    categories_list = load_config_file(CONFIG_DIR / 'categories.txt').strip().split('\n')
    prompt_template = load_config_file(CONFIG_DIR / 'prompt.txt')
    
    # Parse the schema for BigQuery client
    bq_schema = [bigquery.SchemaField.from_api_repr(field) for field in json.loads(schema_str)]
    
    print("Configuration files loaded successfully.")
    print(f"Found {len(categories_list)} categories.")
except FileNotFoundError as e:
    print(f"Error: {e}. Please make sure all configuration files exist in the '{CONFIG_DIR}' directory.")

## 3. Product Generation

This section defines the functions to generate product data using the Gemini model. It includes a function to handle individual API calls and a parallel execution framework to scale the generation process.

In [None]:
def clean_and_parse_jsonl(raw_text):
    """Cleans and parses raw model output to extract valid JSONL lines."""
    valid_lines = []
    for line in raw_text.strip().split('\n'):
        clean_line = line.strip()
        if not clean_line:
            continue
        # The model sometimes wraps the output in ```jsonl ... ```, remove it.
        if clean_line.startswith('```'):
            continue
        try:
            json.loads(clean_line)  # Validate JSON
            valid_lines.append(clean_line)
        except json.JSONDecodeError:
            print(f"Warning: Skipping invalid JSON line: {clean_line}")
    return valid_lines

def generate_product_batch(category, num_products):
    """Generates a batch of product data for a given category."""
    print(f"Generating {num_products} products for category: {category}...")
    model = GenerativeModel(MODEL_NAME)
    
    prompt = prompt_template.format(
        RETAILER=RETAILER,
        CATEGORY=category,
        NUMBER_OF_PRODUCTS=num_products,
        SCHEMA=schema_str,
        REQUIREMENTS=requirements_str
    )
    
    try:
        response = model.generate_content(
            [prompt],
            generation_config={"max_output_tokens": MAX_OUTPUT_TOKENS, "temperature": 0.8}
        )
        cleaned_lines = clean_and_parse_jsonl(response.text)
        print(f"Successfully generated and cleaned {len(cleaned_lines)} products for {category}.")
        return cleaned_lines
    except Exception as e:
        print(f"An error occurred while generating data for {category}: {e}")
        return []

In [None]:
def generate_catalog_in_parallel(total_products, categories):
    """Generates the full catalog by running batches in parallel."""
    all_products = []
    products_per_category = math.ceil(total_products / len(categories))
    
    # Create batches of tasks
    tasks = []
    for category in categories:
        remaining_for_cat = products_per_category
        while remaining_for_cat > 0:
            batch_size = min(remaining_for_cat, MAX_PRODUCTS_PER_API_CALL)
            tasks.append((category, batch_size))
            remaining_for_cat -= batch_size

    with concurrent.futures.ThreadPoolExecutor() as executor:
        future_to_task = {executor.submit(generate_product_batch, cat, num): (cat, num) for cat, num in tasks}
        
        for future in concurrent.futures.as_completed(future_to_task):
            task = future_to_task[future]
            try:
                product_lines = future.result()
                all_products.extend(product_lines)
            except Exception as exc:
                print(f'{task} generated an exception: {exc}')
    
    # Trim to the exact number of products requested
    return all_products[:total_products]

In [None]:
print("Starting product catalog generation...")
generated_products = generate_catalog_in_parallel(NUMBER_OF_PRODUCTS, categories_list)

if generated_products:
    with open(OUTPUT_FILE, 'w') as f:
        for line in generated_products:
            f.write(line + '\n')
    print(f"\nSuccessfully generated {len(generated_products)} products.")
    print(f"Catalog saved to {OUTPUT_FILE}")
else:
    print("\nNo products were generated. Please check the logs for errors.")

## 4. Load Data into BigQuery

This section handles uploading the generated JSONL file to Google Cloud Storage (GCS) and then loading it into a BigQuery table.

In [None]:
def upload_to_gcs(bucket_name, source_file_name, destination_blob_name):
    """Uploads a file to the bucket."""
    storage_client = storage.Client(project=PROJECT_ID)
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)

    blob.upload_from_filename(source_file_name)
    print(f"File {source_file_name} uploaded to {destination_blob_name}.")
    return f"gs://{bucket_name}/{destination_blob_name}"

def load_gcs_to_bigquery(gcs_uri, dataset_id, table_id, schema):
    """Loads data from GCS to a BigQuery table."""
    client = bigquery.Client(project=PROJECT_ID)

    # Create dataset if it doesn't exist
    dataset_ref = client.dataset(dataset_id)
    try:
        client.get_dataset(dataset_ref)
        print(f"Dataset {dataset_id} already exists.")
    except Exception:
        print(f"Creating dataset {dataset_id}...")
        client.create_dataset(dataset_ref)

    job_config = bigquery.LoadJobConfig(
        schema=schema,
        source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, # Overwrite table if it exists
    )

    table_ref = dataset_ref.table(table_id)
    load_job = client.load_table_from_uri(gcs_uri, table_ref, job_config=job_config)
    print(f"Starting job {load_job.job_id} to load data into {dataset_id}.{table_id}")

    load_job.result()  # Wait for the job to complete
    print(f"Job finished. Loaded {load_job.output_rows} rows.")

In [None]:
if generated_products:
    if not GCS_BUCKET_NAME:
        print("GCS_BUCKET_NAME is not set. Skipping GCS upload and BigQuery load.")
    else:
        gcs_destination_blob = f"product_catalogs/{RETAILER}/{OUTPUT_FILE.name}"
        try:
            gcs_uri = upload_to_gcs(GCS_BUCKET_NAME, OUTPUT_FILE, gcs_destination_blob)
            load_gcs_to_bigquery(gcs_uri, BQ_DATASET_ID, BQ_TABLE_ID, bq_schema)
        except Exception as e:
            print(f"An error occurred during the GCS/BigQuery process: {e}")
else:
    print("Skipping GCS and BigQuery steps as no products were generated.")