In [1]:
# be name khoda

In [2]:
import sqlite3
import os 

class ImageDB:
    """
    Manages interactions with the SQLite image database.
    Assumes a table named 'results' with 'filename' and 'description' columns.
    """
    def __init__(self, db_path: str = 'database.db'):
        self.db_path = db_path
        # Ensure the database file exists and has the 'results' table
        if not os.path.exists(self.db_path):
            self.create_database_schema()

    def create_database_schema(self):
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    CREATE TABLE IF NOT EXISTS results (
                        filename TEXT PRIMARY KEY UNIQUE,
                        description TEXT NOT NULL
                    )
                """)
                conn.commit()
                print(f"Database '{self.db_path}' and table 'results' created.")

        except sqlite3.Error as e:
            print(f"Error creating database schema: {e}")

    def add_image_description(self, filename: str, description: str):
        # it should be completed
        pass

    def retrieve_descriptions_with_query (self, sqlite_query: str)-> list[tuple[str, str]]:
        if not sqlite_query:
            return []
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                cursor.execute(sqlite_query)
                return cursor.fetchall()
        except sqlite3.Error as e:
            print(f"Error retrieving descriptions: {e}")
            return []
        
        
    def retrieve_descriptions_with_keywords (self, keywords: list[str]) -> list[tuple[str, str]]:
        """
        Retrieves image filenames and descriptions based on keywords using LIKE.
        """
        if not keywords:
            return []

        # Construct the WHERE clause dynamically
        like_clauses = [f"description LIKE '%{keyword}%'" for keyword in keywords]
        where_clause = " OR ".join(like_clauses)
        query = f"SELECT filename, description FROM results WHERE {where_clause}"

        return self.retrieve_descriptions_with_query(query)

In [3]:
import abc

# abstract class
class PromptEngineer(abc.ABC):
    def generate_retrieval_keywords_and_sqlitequery_prompt(self, user_query: str) -> str:
        return f"""You are an Agent to retrieve images from a gallery.
                A user has this request (prompt): `{user_query}`. 
                You should first create a keyword list with 10 main keywords which can describe and specific to this prompt.(don't write plural form )
                And then write an SQLite SELECT query in Python that finds all rows in a table named 'results'
                where the column 'description' contains ANY of the keywords (case-insensitive). 
                Format the output similar to the example below.
                DO NOT include any other text or explanation. do not include any code explanaion, do not include any extra detail. Just SQLITE query that returned rows with ANY mathcing keywork is wanted as output and keywords.
    
                Example:
                Keywords: [person, two, people, individuals, couple, pair, human, subjects, faces, interaction]
                
                ```sqlite
                SELECT * FROM results WHERE description LIKE '%person%' OR description LIKE '%two%' OR description LIKE '%people%' OR description LIKE '%individuals%' OR description LIKE '%couple%' OR description LIKE '%pair%' OR description LIKE '%human%' OR description LIKE '%subjects%' OR description LIKE '%faces%' OR description LIKE '%interaction%'
                ```
                """

    def generate_refinement_prompt(self, image_description: str, user_query: str)-> str:
        return f""" You are an agent and you have a list containing image file names and their description: {image_description}.
                User gives you a promt: {user_query}. 
                identify and return images with descriptions that match the prompt.
                return file names of the relevant images in a list format.
                for example: filenames = ["IMG_0001.JPG", "IMG_1111.JPG",...]
                Do not include any explanaion, do not include any extra detail."""


class OllamaPromptEngineer(PromptEngineer):
    pass


class GeminiPromptEngineer(PromptEngineer):
    pass

In [4]:
import abc
import httpx

# abstract class
class APICaller(abc.ABC):
    def __init__(self, base_url:str):
        self.base_url = base_url

        
    @abc.abstractclassmethod
    def call_api(self, prompt: str) ->str:
        pass


class OllamaAPICaller(APICaller):
    def __init__ (self, base_url= 'http://gemma:11434/api/generate', model= 'gemma3:4b'):
        super().__init__(base_url) 
        self.model = model

    def call_api(self, prompt):
        headers = {"Content-Type": "application/json"}
        payload = {
            "model": self.model,
            "prompt": prompt,
            "stream": False
        }

        try:
            with httpx.Client() as client:
                response = client.post(self.base_url, headers=headers, json=payload, timeout=None) 
                response.raise_for_status()
                response_data = response.json()
                return response_data.get("response", "No response found from Ollama.")
        except httpx.RequestError as exc:
            return f"An error occurred while requesting {exc.request.url!r}: {exc}"
        except httpx.HTTPStatusError as exc:
            return f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}"
        except Exception as e:
            return f"An unexpected error occurred: {e}"      

        

class GeminiAPICaller(APICaller):
    def __init__(self,  model = "gemini-2.5-flash-preview-05-20", api_key: str = None ):
        base_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
        super().__init__(base_url) 
        self.api_key = api_key
        self.model = model
        
    def call_api(self, prompt: str) -> str: # No 'async' here
        headers = {
            "Content-Type": "application/json",
        }
        payload = {
            "contents": [
                {
                    "parts": [
                        {"text":prompt}
                    ]
                }
            ]
        }

        try:
            with httpx.Client() as client:
                response = client.post(self.base_url, headers=headers, json=payload, timeout=100.0) 
                response.raise_for_status()
                response_data = response.json()
                if 'candidates' in response_data and response_data['candidates']:
                    first_candidate = response_data['candidates'][0]
                    if 'content' in first_candidate and 'parts' in first_candidate['content']:
                        for part in first_candidate['content']['parts']:
                            if 'text' in part:
                                return part['text']
                return "No text response found from Gemini."
        except httpx.RequestError as exc:
            return f"An error occurred while requesting {exc.request.url!r}: {exc}"
        except httpx.HTTPStatusError as exc:
            return f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}"
        except Exception as e:
            return f"An unexpected error occurred: {e}"


In [5]:
import re 
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import math
import shutil

class ImageRefinementAgent():
    def __init__(self, image_db: ImageDB, api_caller: APICaller, prompt_engineer: PromptEngineer):
        self.prompt_engineer = prompt_engineer
        self.api_caller = api_caller
        self.image_db = image_db

    def parse_sql_query_from_llm_response(self, llm_response: str) ->str:
        # The regex pattern looks for:
        # ```sqlite       - the literal string '```sqlite'
        # (.*?)          - a non-greedy match for any characters (this is our target SQL)
        # ```            - the literal string '```' to close the code block
        pattern = r"```sqlite\n(.*?)\n```"
        try:
            match = re.search(pattern, llm_response, re.DOTALL)
            if match:
                extracted_sql = match.group(1).strip()
                return extracted_sql
            else:
                print("No SQL query found in the specified format.")
                return ""
        except Exception as e:
            print(f"Regex error: {e}")
            return ""
            
    def parse_sql_keywords_from_llm_response(self, llm_response: str) ->list[str]:
        # this fuction should be checked in future
        """
        Parses the LLM's response to extract only the keywords from the SQL-like string.
        Assumes the LLM returns the format: ```sqlite\nSELECT ... LIKE '%keyword1%' OR ...\n```
        """
        pattern = r"%'([^'%]+?)%'"
        keywords = re.findall(pattern, llm_response)
        return keywords

    def parse_filename_from_llm_response(self,llm_response:str) -> list[str]:
        # Pattern explanation:
        # "           - Matches a literal double quote (start of the filename string)
        # (           - Start of capturing group 1 (this will be our filename)
        # [^"]+       - Matches one or more characters that are NOT a double quote (the filename itself)
        # \.          - Matches a literal dot
        # (PNG|JPG)   - Matches either "PNG" or "JPG"
        # )           - End of capturing group 1
        # "           - Matches a literal double quote (end of the filename string)
        # re.IGNORECASE - Flag to make PNG/JPG matching case-insensitive (e.g., .png, .jpg)
        
        pattern = r'"([^"]+\.(PNG|JPG|jpeg|jpg|png))"'
        try:
            found_filenames = re.findall(pattern, llm_response)
            extracted_filenames = [item[0] for item in found_filenames]
            if not extracted_filenames:
                print("no photo was found")
            else:
                return extracted_filenames
        
        except Exception as e:
            print(f"Regex error: {e}")
            return []
            

    def copy_images(self,source_path, destination_path, filenames):
        try:
            if os.path.exists(destination_path):
                # Clear the contents of the directory
                shutil.rmtree(destination_path)
            # Create a clean directory
            os.makedirs(destination_path)

            for fn in filenames:
                # Check if the file exists in the source directory
                src_path = os.path.join(source_path, fn)
                if not os.path.exists(src_path):
                    print(f"Skipping {fn} as file does not exist")
                    continue
                
                dest_path = os.path.join(destination_path, fn)
                # Create the destination directory if it doesn't exist
                os.makedirs(os.path.dirname(dest_path), exist_ok=True)   
                
                # Copy the file
                with open(src_path, 'rb') as f:
                    with open(dest_path, 'wb') as dest_f:
                        dest_f.write(f.read())   
                        
        except Exception as e:
            print(f"Copy error:{e}")

    def show_images(self, source_path, filenames):
        if not filenames:
            print("No image files found to show")
        else:
            # Grid configuration
            num_images = len(filenames)
            cols = 3
            rows = math.ceil(num_images / cols) 
            
            # Set figure size (adjust as needed)
            plt.figure(figsize=(10, 3 * rows))
            
            for i, fn in enumerate(filenames):
                image_path = os.path.join(source_path, fn)
                img = mpimg.imread(image_path)
                
                plt.subplot(rows, cols, i + 1)
                plt.imshow(img)
                plt.axis('off')  # Hide axes
                plt.title(fn)
            
            plt.tight_layout()
            plt.show()       

    
    def run_agent(self, user_query:str, source_path, destination_path):
        
        # Step 1: Generate retrieval keywords and sqlite query and check the output
        print("Generating retrieval keywords and sqlite query...")
        retrieval_keyword_prompt = self.prompt_engineer.generate_retrieval_keywords_and_sqlitequery_prompt(user_query)
        llm_response = self.api_caller.call_api(retrieval_keyword_prompt)
        
        # Check for API errors directly from the response string (as per API client's error format)
        if "API Error" in llm_response:
            print(f"Error getting keywords and sqlite query from LLM: {llm_response}")
            return []
        else:
            print(f'\nllmm response for retrival keywords:{llm_response}')

        selected_query = self.parse_sql_query_from_llm_response (llm_response)
        if not selected_query:
            print("No sqlite query extracted from LLM response. Response was:")
            print(selected_query)
            return []
        print(f"\nExtracted sqlite_query: {selected_query}")

        # Step 2: Retrieve initial image descriptions from DB
        print("\nRetrieving initial image descriptions from database...")
        image_description = self.image_db.retrieve_descriptions_with_query(selected_query)
        if not image_description:
            print("No initial matches found in the database for the extracted keywords.")
            return []
        print(f"Found {len(image_description)} initial matches filenames:{[images[0] for images in image_description]}")

        # Step 3: Refine each retrieved description
        print("\nRefining image descriptions...")
        refinement_prompt = self.prompt_engineer.generate_refinement_prompt(image_description,user_query)
        llm_response_final_filename = self.api_caller.call_api(refinement_prompt)
        
        if "API Error" in llm_response_final_filename:
            print(f"Error refining description from llm: {llm_response_refinement}")
            return []
        else:
            print(f'\nllmm response for final filenames:{llm_response_final_filename}')

        final_filenames = self.parse_filename_from_llm_response(llm_response_final_filename)
        
        # Step 4: Rcopy final images to destination directory
        self.copy_images(source_path, destination_path, final_filenames)


        # step 5: shwowing images
        self.show_images(source_path, final_filenames)
        print("\nAgent workflow completed")
        return final_filenames

In [6]:
# agent factory

class AgentFactory():
    """
    Factory class responsible for creating and configuring the ImageRefinementAgent.
    Centralizes the logic for instantiating all necessary components.
    """

    @staticmethod
    def create_default_ollama_agent(db_path: str = 'database.db') -> ImageRefinementAgent:
        """
        Creates an ImageRefinementAgent configured to use Ollama for both retrieval and refinement.
        """
        image_db_instance = ImageDB(db_path)
        ollama_base_url = "http://gemma:11434/api/generate"
        ollama_model_name = "gemma3:4b"
        ollama_api_caller = OllamaAPICaller(base_url=ollama_base_url, model=ollama_model_name)
        ollama_prompt_engineer = OllamaPromptEngineer()

        agent = ImageRefinementAgent(
            image_db=image_db_instance,
            api_caller=ollama_api_caller,
            prompt_engineer=ollama_prompt_engineer
        )
        return agent

    @staticmethod
    def create_default_gemini_agent(db_path:str = "database.db") -> ImageRefinementAgent:
        """
        Creates an ImageRefinementAgent configured to use Gemini for both retrieval and refinement.
        """        
        image_db_instance = ImageDB(db_path)
        gemini_api_key = os.getenv("GEMINI_API_KEY")
        assert gemini_api_key is not None
        gemini_model_name = "gemini-2.5-flash-preview-05-20" 
        gemini_api_caller = GeminiAPICaller(model=gemini_model_name,api_key=gemini_api_key)
        gemini_prompt_engineer = GeminiPromptEngineer()

        agent = ImageRefinementAgent(
            image_db=image_db_instance, 
            api_caller=gemini_api_caller,
            prompt_engineer=gemini_prompt_engineer
        )
        return agent

In [7]:
from ipywidgets import widgets
from IPython.display import display, clear_output

if __name__ == '__main__':


    # Example 1: Create an agent using Ollama for both steps
    ollama_agent = AgentFactory.create_default_ollama_agent()

    # Example 2: Create an agent using Gemini for both steps
    gemini_agent = AgentFactory.create_default_gemini_agent() # Make sure GEMINI_API_KEY is set!
    
    # Create Jupyter widgets for interaction
    query_input = widgets.Textarea(
        value="Show me pictures of a couple enjoying food.",
        placeholder="Enter your image search query...",
        description="Search Query:",
        rows=3,
        layout=widgets.Layout(width='auto')
    )
    
    ollama_output_area = widgets.Output()
    gemini_output_area = widgets.Output()

    ollama_run_button = widgets.Button(description="Run Ollama Agent")
    gemini_run_button = widgets.Button(description="Run Gemini Agent")

    def ollama_on_run_button_clicked(b):
        with ollama_output_area:
            clear_output(wait=True)
            results = ollama_agent.run_agent(query_input.value,source_path="./resized_images", destination_path="./output_images")
            
    def gemini_on_run_button_clicked(b):
        with gemini_output_area:
            clear_output(wait=True)
            results = gemini_agent.run_agent(query_input.value,source_path="./resized_images", destination_path="./output_images")

    ollama_run_button.on_click(ollama_on_run_button_clicked)
    gemini_run_button.on_click(gemini_on_run_button_clicked)

    display(
        widgets.VBox([
            query_input,
            # run_button,
            widgets.HBox([ollama_run_button, gemini_run_button]),
            # widgets.HTML("<h2>Agent Output:</h2>"),
            # output_area
            widgets.HTML("<h2>Ollama Agent Output:</h2>"),
            ollama_output_area,
            widgets.HTML("<h2>Gemini Agent Output:</h2>"),
            gemini_output_area
        ])
    )

VBox(children=(Textarea(value='Show me pictures of a couple enjoying food.', description='Search Query:', layo…