In [5]:
from groq import Groq
from PIL import Image
import pandas as pd
import os
import time
from random import uniform
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import base64

# Function to encode the image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

# Configure multiple API keys
API_KEYS = [
    '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'
]

current_api_key_index = 0
api_request_count = 0
MAX_REQUESTS_PER_KEY = 1450

def switch_to_next_api_key():
    """Switch to the next available API key"""
    global current_api_key_index, api_request_count
    current_api_key_index = (current_api_key_index + 1) % len(API_KEYS)
    api_request_count = 0
    print(f"\nSwitching to API key {current_api_key_index + 1}")
    time.sleep(2)  # Brief pause when switching keys

# Initialize with first API key
client = Groq(api_key=API_KEYS[current_api_key_index])

@retry(
    wait=wait_exponential(multiplier=1, min=4, max=60),
    stop=stop_after_attempt(3),
    retry=retry_if_exception_type(Exception)
)
def generate_qa_groq(image_path, product_info):
    """Generates a question-answer pair using GROQ Vision API with rate limiting."""
    global api_request_count
    
    try:
        # Check if we need to switch API keys
        if api_request_count >= MAX_REQUESTS_PER_KEY:
            if current_api_key_index == len(API_KEYS) - 1:
                print("\nAll API keys exhausted. Saving progress...")
                return None, None, True  # Third parameter indicates API exhaustion
            switch_to_next_api_key()
        
        # Add random delay between requests (2-3 seconds)
        time.sleep(uniform(2, 3))
        
        # Encode the image to base64
        base64_image = encode_image(image_path)
        
        prompt = f"""Analyze this product image carefully. The image resolution is 256x256 pixels.

Product Information:
{product_info}

Task:
1. Look at the product and identify:
   - The main visible object/product category
   - Prominent colors
   - Materials or textures
   - Basic shape or form
   - Notable features or attributes

Requirements:
1. Generate ONE question that:
   - Must be clearly answerable from the image
   - Must have a single-word answer only
   - Must focus on obvious visual features
2. Avoid questions about:
   - Small text or details
   - Measurements or dimensions
   - Subjective qualities
   - Brand names unless clearly visible

Output Format (STRICT):
Question: <single clear question about visible feature>
Answer: <single word only>

Example Good Outputs:
Question: What color is this shirt?
Answer: blue

Question: What material is this table made of?
Answer: wood"""

        # GROQ API call
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{base64_image}",
                            },
                        },
                    ],
                }
            ],
            model="meta-llama/llama-4-maverick-17b-128e-instruct",
        )
        
        # Increment request count
        api_request_count += 1
        
        # Parse the response
        response_text = chat_completion.choices[0].message.content
        question = response_text.split("Question:")[1].split("Answer:")[0].strip()
        answer = response_text.split("Answer:")[1].strip()
        
        return question, answer, False  # False indicates API not exhausted
    
    except Exception as e:
        if "quota exceeded" in str(e).lower():
            print(f"\nQuota exceeded for API key {current_api_key_index + 1}")
            if current_api_key_index == len(API_KEYS) - 1:
                print("All API keys exhausted. Saving progress...")
                return None, None, True
            switch_to_next_api_key()
            return None, None, False
        else:
            print(f"GROQ Error for {image_path}: {e}")
            return None, None, False

# Try to load existing checkpoint
try:
    last_checkpoint = pd.read_csv("vqa_dataset_groq_final_1.csv")
    print(f"Loaded existing checkpoint with {len(last_checkpoint)} entries")
    paths = last_checkpoint['path'].tolist()
    questions = last_checkpoint['generated_question'].tolist()
    answers = last_checkpoint['generated_answer'].tolist()
    processed_count = len(last_checkpoint)
    checkpoint_num = 2
except FileNotFoundError:
    print("No existing checkpoint found. Starting fresh.")
    paths = []
    questions = []
    answers = []
    processed_count = 0
    checkpoint_num = 0

print(processed_count)

# Ensure clean_df is properly initialized
try:
    clean_df = pd.read_csv("sampled_metadata_stratified_1.csv")  # Replace with your actual data file path
    print(f"Loaded data: {len(clean_df)} rows")
except FileNotFoundError:
    print("Data file not found. Please provide the correct path to your data.")
    clean_df = pd.DataFrame()  # Initialize as an empty DataFrame in case of failure

total_images = len(clean_df)
save_interval = 100  # Save more frequently
batch_size = 50

def save_progress(paths, questions, answers, is_final=False):
    """Save current progress to CSV"""
    if len(paths) > 0:
        temp_df = pd.DataFrame({
            'path': paths,
            'generated_question': questions,
            'generated_answer': answers
        })
        
        if is_final:
            filename = "vqa_dataset_groq_final_1.csv"
        else:
            filename = f"vqa_dataset_groq_checkpoint_{checkpoint_num}.csv"
        
        temp_df.to_csv(filename, index=False)
        print(f"\nProgress saved: {len(paths)} Q&A pairs written to {filename}")

# Skip already processed images
clean_df = clean_df[processed_count:]

try:
    # Iterate through the DataFrame in batches
    for batch_start in range(0, len(clean_df), batch_size):
        current_batch = (batch_start + processed_count) // batch_size + 1
        total_batches = (len(clean_df) + batch_size - 1) // batch_size
        print(f"\nProcessing batch {current_batch}/{total_batches}")
        
        # Add a pause between batches
        if batch_start > 0:
            print("Pausing between batches...")
            time.sleep(30)  # 30-second pause between batches
        
        # Get current batch
        batch = clean_df[batch_start:batch_start + batch_size]
        
        for index, row in batch.iterrows():
            image_path = os.path.join("small", row['path'])
            
            # Construct product info
            product_info_parts = []
            if pd.notna(row['model_name']):
                product_info_parts.append(f"Model: {row['model_name']}")
            if pd.notna(row['color']):
                product_info_parts.append(f"Color: {row['color']}")
            if pd.notna(row['product_type']):
                product_info_parts.append(f"Type: {row['product_type']}")
            if pd.notna(row['material']):
                product_info_parts.append(f"Material: {row['material']}")
            if pd.notna(row['style']):
                product_info_parts.append(f"Style: {row['style']}")
            if pd.notna(row['pattern']):
                product_info_parts.append(f"Pattern: {row['pattern']}")
            if pd.notna(row['item_shape']):
                product_info_parts.append(f"Shape: {row['item_shape']}")
            
            product_info = ", ".join(product_info_parts)

            question, answer, apis_exhausted = generate_qa_groq(image_path, product_info)

            if apis_exhausted:
                # Save progress and exit
                save_progress(paths, questions, answers, is_final=True)
                print(f"\nProcessing stopped at {processed_count}/{total_images} images")
                print(f"Successfully generated {len(paths)} Q&A pairs")
                break
            
            if question and answer:
                paths.append(row['path'])
                questions.append(question)
                answers.append(answer)

                # Save more frequently
                if len(paths) % save_interval == 0:
                    checkpoint_num += 1
                    save_progress(paths, questions, answers)

            processed_count += 1
            if processed_count % 10 == 0:
                print(f"Processed {processed_count}/{total_images} images ({(processed_count/total_images)*100:.2f}%)")
        
        if apis_exhausted:
            break

except KeyboardInterrupt:
    print("\nProcess interrupted by user. Saving progress...")
    save_progress(paths, questions, answers, is_final=True)
    print(f"\nProcessing stopped at {processed_count}/{total_images} images")
    print(f"Successfully generated {len(paths)} Q&A pairs")
except Exception as e:
    print(f"\nAn error occurred: {e}")
    save_progress(paths, questions, answers, is_final=True)
    print(f"\nProcessing stopped at {processed_count}/{total_images} images")
    print(f"Successfully generated {len(paths)} Q&A pairs")

# Save final progress if not already saved
if len(paths) > 0:
    save_progress(paths, questions, answers, is_final=True)
    print(f"\nFinal processing complete: {processed_count}/{total_images} images processed")
    print(f"Successfully generated {len(paths)} Q&A pairs")


Loaded existing checkpoint with 3993 entries
3993
Loaded data: 4000 rows

Processing batch 80/1

Progress saved: 4000 Q&A pairs written to vqa_dataset_groq_checkpoint_3.csv
Processed 4000/4000 images (100.00%)

Progress saved: 4000 Q&A pairs written to vqa_dataset_groq_final_1.csv

Final processing complete: 4000/4000 images processed
Successfully generated 4000 Q&A pairs
