# Retreival Augemented Generation

A few steps required to make this work:

- Create some callable functions that can query the database for related keywords
- Give these callable functions to an OpenAI assistant.
- Build in recursive functions that allow the bot to query the database multiple times to pull together an outfit.
- Create a system whereby the user comminicates with a chatbot which can have long form discussions.


In [1]:
import numpy as np
import hnswlib
import clip
import torch
import os
import sys
import psutil
from PIL import Image
from IPython.display import display
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()
GPT_API_KEY = os.getenv('PERSONAL_OPENAI_KEY')

from openai import OpenAI
import json

client = OpenAI(api_key=GPT_API_KEY)

# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Load HNSWlib index and image ids
hnsw_index_path = '../data/processed/hnsw_index.bin'
image_ids_path = '../data/processed/image_ids.npy'
image_folder_path = '../data/raw/images'

# Load the index
hnsw_index = hnswlib.Index(space='l2', dim=512)
hnsw_index.load_index(hnsw_index_path)
image_ids = np.load(image_ids_path)

sys.path.append('C:/projects/python/GPT-Image-Reccomender')

from src.vector_search import generate_text_embedding, search_similar_images, display_images, monitor_resources, query_images


In [2]:
def define_tools():
    tools = []

    tools.append({
        "type": "function",
        "function": {
            "name": "query_database",
            "description": "This is a database that contains images of +50000 items of clothing. This function allows you to send an unstructured text query to the database to get a list of appropriate items. Every query should include colour, style, size, gender, and item",
            "parameters": {
                "type": "object",
                "properties": {
                    "item_name": {
                        "type": "string",
                        "description": "The name of the clothing item (e.g., 'brown sandal', 'white top')."
                    },
                    "query": {
                        "type": "string",
                        "description": "Unstructured text query describing the clothing items (e.g., 'brown sandal and white top')."
                    }
                },
                "required": ["item_name", "query"]
            },
        }
    })

    return tools

def handle_function_call(function_name, function_args, top_k):
    if function_name == "query_database":
        item_name = function_args['item_name']
        query = function_args['query']
        print(f"Querying database for: {query}")
        return query_images(query, top_k=top_k)
    return {}

In [3]:
def run_conversation(user_query, introduction, top_k, max_depth, session_messages=None):
    if session_messages is None:
        session_messages = []

    if not session_messages:
        session_messages.append({"role": "system", "content": introduction})

    tools = define_tools()

    all_image_paths_dict = {}
    all_responses = []

    session_messages.append({"role": "user", "content": user_query})
    depth = 0

    while depth < max_depth:
        response = client.chat.completions.create(
            model="gpt-4",
            messages=session_messages,
            tools=tools,
            tool_choice="auto"
        )

        response_content = response.choices[0].message.content

        print(response)
        
        if response_content:
            all_responses.append(response_content)
            session_messages.append({"role": "assistant", "content": response_content})

        finish_reason = response.choices[0].finish_reason
        if finish_reason == "stop":
            break
        elif finish_reason == "tool_calls":
            tool_calls = response.choices[0].message.tool_calls
            for tool_call in tool_calls:
                function_name = tool_call.function.name
                function_args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
                print('functions')
                print(function_name)
                print(function_args)
                function_response = handle_function_call(function_name, function_args, top_k)
                session_messages.append({"role": "system", "name": function_name, "content": json.dumps(function_response)})

                if function_response:
                    session_messages.append({"role": "system", "name": function_name, "content": json.dumps(function_response)})

                    if function_name == "query_database":
                        for item_name, paths in function_response.items():
                            if item_name not in all_image_paths_dict:
                                all_image_paths_dict[item_name] = []
                            all_image_paths_dict[item_name].extend(paths)
            
                
            depth += 1
        else:
            print(f'Unhandled finish reason: {finish_reason}')
            break

    return all_responses, session_messages, all_image_paths_dict


In [None]:
# Example usage
top_k = 5
user_query = "I'm a woman going to a festival this weekend, could you please suggest a full outfit for me??"
introduction = """You are a shopping assistant giving suggestions to a customer. 
You can only make one call to the database per item, if the if you suggest three different items of clothing you must make three separate calls.
If the user asked for suggestions you must query the database with options you think would work. 
Be specific and then tell the user why you made that decision.
"""

responses, session_messages, all_image_paths_dict = run_conversation(user_query, introduction, top_k, max_depth=5)

ChatCompletion(id='chatcmpl-9cDFu3zcMfExohB39LItzizl6iTQi', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_gdAXitgWeRFbIx9kAOZcCdPG', function=Function(arguments='{\n"item_name": "boho dress",\n"query": "colourful boho maxi dress for a woman for a festival"\n}', name='query_database'), type='function')]))], created=1718894446, model='gpt-4-0613', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=34, prompt_tokens=231, total_tokens=265))
functions
query_database
{'item_name': 'boho dress', 'query': 'colourful boho maxi dress for a woman for a festival'}
Querying database for: colourful boho maxi dress for a woman for a festival


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


ChatCompletion(id='chatcmpl-9cDFyow6HD2PN2un03pSozwILYWlp', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content="Sure, for a festival, I recommend something vibrant and easy-going. I found this lovely colourful boho maxi dress that would look great. Festivals are all about having fun and letting loose, so this free-flowing and eye-catching dress would be a perfect fit. Here's the image for your reference: [![Dress](../data/raw/images\\\\59897.jpg)](../data/raw/images\\\\59897.jpg)\n\nNow let's find a matching pair of boots and a bag.", role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_zsayLNsskkMydRLXJa3kcLHz', function=Function(arguments='{\n"item_name": "comfortable boots for women",\n"query": "comfortable yet stylish boots suitable for a festival for women"\n}', name='query_database'), type='function')]))], created=1718894450, model='gpt-4-0613', object='chat.completion', system_fingerprint=

In [None]:
print(session_messages)

In [None]:
from IPython.display import display, Markdown

def generate_markdown_from_session(session_messages):
    markdown_content = ""
    for message in session_messages:
        role = message.get('role')
        content = message.get('content')
        name = message.get('name', '')

        if role == 'system':
            if 'query_database' in name:
                content_dict = json.loads(content)
                for item_name, paths in content_dict.items():
                    markdown_content += f"### {item_name}\n"
                    for i, path in enumerate(paths):
                        markdown_content += f"![Option {i+1}]({path})\n"
            pass
        elif role == 'user':
            markdown_content += f"**User Query:** {content}\n\n"
            pass
        elif role == 'assistant':
            markdown_content += f"\n{content}\n\n"
    
    return markdown_content
markdown_content = generate_markdown_from_session(session_messages)
display(Markdown(markdown_content))
