In [None]:
import os
import re
import csv
import pandas as pd
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from itertools import cycle
from dotenv import load_dotenv
import google.generativeai as genai

SRC_PATH        = '../data/curated.csv'
IMAGE_DEST_DIR  = '../data/curated_images'
DEST_PATH       = '../data/vqa.csv'

# Load curated metadata
df = pd.read_csv(SRC_PATH)
df.set_index('filename', inplace=True)

# Load API keys
load_dotenv()
api_keys_str = os.getenv("GOOGLE_API_KEYS", "")
API_KEYS = [k.strip() for k in api_keys_str.split(",") if k.strip()]

# If API keys not in .env, you need your own API key list
# API_KEYS = [key1, key2, ...]

if not API_KEYS:
    raise ValueError("\nNo API keys found.")
api_key_cycle = cycle(API_KEYS)

# Gemini‐1.5‐flash config (experiment with this)
generation_config = {
    "temperature": 1.2,
    "top_p": 0.8,
    "top_k": 100,
    "max_output_tokens": 1000
}

def configure_model(api_key):
    genai.configure(api_key=api_key)
    return genai.GenerativeModel(
        model_name="gemini-1.5-flash",
        generation_config=generation_config
    )

# Initialize model
current_api_key = next(api_key_cycle)
model = configure_model(current_api_key)

# Regex for parsing outputs
QA_REGEX = re.compile(
    r"s*Question\s*\d+:\s*(.*?)\s*Answer\s*\d+:\s*(.+)",
    re.IGNORECASE
)

def generate_qa(filename):
    global model, current_api_key

    pairs = []
    img_path = Path(IMAGE_DEST_DIR) / filename

    try:
        img = Image.open(img_path)
    except FileNotFoundError:
        print(f"\nError: Image not found: {img_path}")
        return pairs
    except Exception as e:
        print(f"\nError opening {img_path}: {e}")
        return pairs

    row = df.loc[filename]
    name         = row.get('name', 'N/A')
    product_type = row.get('product_type', 'N/A')
    color        = row.get('color', 'N/A')
    keywords     = row.get('keywords', 'N/A')

    prompt = f"""
        You are given an image, some metadata about that image, and a set of instructions. Follow the instructions exactly.
        Image: {img}
        Metadata:
        Name: {name}
        Product Type: {product_type}
        Color: {color}
        Keywords: {keywords}
        Instructions:
        1. Generate 5 distinct questions.
        2. The questions can be answered by looking at the image or can be inferred by thinking.
        3. Difficult questions are preferred.
        4. Do not use quotation marks anywhere.
        5. The answer should exactly be 1 word.
        6. Provide the output strictly in the format given below.

        Question 1:
        Answer 1:
        Question 2:
        Answer 2:
        Question 3:
        Answer 3:
        Question 4:
        Answer 4:
        Question 5:
        Answer 5:
        """

    for _ in range(len(API_KEYS)):
        try:
            response = model.generate_content(prompt)
            text = response.text.strip()
            matches = QA_REGEX.findall(text)
            for q, a in matches:
                question = q.strip().rstrip('?.!')
                answer   = a.strip().lower()
                if question and answer:
                    pairs.append((question, answer))
            break

        except Exception as e:
            if "429" in str(e):
                current_api_key = next(api_key_cycle)
                model = configure_model(current_api_key)
                continue
            else:
                print(f"\nAPI error on {filename}: {e}")
                break

    return pairs

def main():
    processed = set()
    if not os.path.exists(DEST_PATH):
        with open(DEST_PATH, 'w', newline='', encoding='utf-8') as f:
            csv.writer(f).writerow(['filename','question','answer'])
    else:
        try:
            chk = pd.read_csv(DEST_PATH)
            processed = set(chk['filename'])
        except pd.errors.EmptyDataError:
            pass

    all_files = list(df.index)
    to_process = [fn for fn in all_files if fn not in processed]

    for filename in tqdm(to_process, desc="VQA generation"):
        qa = generate_qa(filename)
        if len(qa) == 5:
            with open(DEST_PATH, 'a', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                for q,a in qa:
                    writer.writerow([filename, q, a])
        else:
            print(f"\nSkipped {filename}: {len(qa)} Q&A pairs")

    out_df = pd.read_csv(DEST_PATH).sort_values(['filename','question']).reset_index(drop=True)
    out_df.to_csv(DEST_PATH, index=False)
    print(f"\nDone. Total Q&A rows: {len(out_df)}")

if __name__ == "__main__":
    main()