In [None]:
# imports
import os
from dotenv import load_dotenv
import gradio as gr
import ollama
import sqlite3
import requests
from openai import OpenAI

In [None]:
# constants
MODEL_LLAMA = 'llama3.2'
MODEL_GPT = 'gpt-4o-mini'
DB = "houses.db"
API_BASE_URL = "https://wizard-world-api.herokuapp.com/Houses"
OLLAMA_LOCAL = "Ollama (local)"

HOUSES_MASTER = {
    "0367baf3-1cb6-4baf-bede-48e17e1cd005": "gryffindor",
    "805fd37a-65ae-4fe5-b336-d767b8b7c73a": "ravenclaw",
    "85af6295-fd01-4170-a10b-963dd51dce14": "hufflepuff",
    "a9704c47-f92e-40a4-8771-ed1899c9b9c1": "slytherin"
}

SYSTEM_PROMPT = """
You're "the Sorting Hat" from the Harry Potter novels. First of all, introduce yourself in a simple sentence.
And then, you only have three questions to ask to find out which house I'll be assigned to.
Wait the user's answers to one of your questions before making the next one.
"""

INITIAL_MESSAGES = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": "Hello, I'm ready to be sorted!"}
]

In [None]:
# Model Initialization
load_dotenv(override=True)

openai_api_key = os.getenv('OPENAI_API_KEY')
if not openai_api_key:
    print("OpenAI API Key not set")
    
openai = OpenAI()

In [None]:
# Database initialization
def init_database():
    with sqlite3.connect(DB) as conn:
        cursor = conn.cursor()
        cursor.execute('CREATE TABLE IF NOT EXISTS houses (id TEXT PRIMARY KEY, name TEXT)')
        for house_id, name in HOUSES_MASTER.items():
            cursor.execute('INSERT OR REPLACE INTO houses (id, name) VALUES (?, ?)', (house_id, name))
        conn.commit()

init_database()

In [None]:
# Global variables
selected_model = None

# Database functions
def get_house_id(name):
    with sqlite3.connect(DB) as conn:
        cursor = conn.cursor()
        cursor.execute('SELECT id FROM houses WHERE name = ?', (name.lower(),))
        result = cursor.fetchone()
        return result[0] if result else None

def get_house_traits(house_id):
    try:
        response = requests.get(f"{API_BASE_URL}/{house_id}")
        if response.status_code == 200:
            house_data = response.json()
            return [trait['name'] for trait in house_data.get('traits', [])]
    except Exception as e:
        print(f"API Error: {e}")
    return []

# AI model functions
def call_ai_model(messages, model_type):
    if model_type == OLLAMA_LOCAL:
        response = ollama.chat(model=MODEL_LLAMA, messages=messages)
        return response['message']['content']
    else:
        response = openai.chat.completions.create(model=MODEL_GPT, messages=messages)
        return response.choices[0].message.content

def on_model_change(value):
    global selected_model
    selected_model = value
    
    if value == OLLAMA_LOCAL:
        try:
            ollama.pull(MODEL_LLAMA)
        except Exception as e:
            print(f"Error connecting to Ollama: {e}")
    
    if value:
        try:
            initial_message = call_ai_model(INITIAL_MESSAGES, value)
        except:
            initial_message = "Hi there! I'm the Sorting Hat. Let's find out which house you belong to!"
        return gr.update(visible=True), [[None, initial_message]]
    return gr.update(visible=False), []

def build_conversation_history(history, message):
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for h in history:
        if h[0]: messages.append({"role": "user", "content": h[0]})
        if h[1]: messages.append({"role": "assistant", "content": h[1]})
    messages.append({"role": "user", "content": message})
    return messages

def add_house_details(response, house_names, user_history):
    for house in house_names:
        if house.lower() in response.lower():
            house_id = get_house_id(house)
            if house_id:
                traits = get_house_traits(house_id)
                if traits:
                    user_answers = [h[0] for h in user_history if h[0]]
                    justification_prompt = f"""Analyze why this person belongs in {house.title()} based on their answers and traits.
User answers: {' | '.join(user_answers)}
{house.title()} traits: {', '.join(traits)}
Write a brief justification connecting specific answers to traits."""
                    
                    try:
                        justification = call_ai_model([
                            {"role": "system", "content": "You analyze personality traits and connect answers to character qualities."},
                            {"role": "user", "content": justification_prompt}
                        ], selected_model)
                        return f"{response}\n\n{justification}"
                    except:
                        return f"{response}\n\nYou possess the traits of {house.title()}: {', '.join(traits)}."
            break
    return response

def sorting_hat_chat(message, history):
    try:
        messages = build_conversation_history(history, message)
        response = call_ai_model(messages, selected_model)
        
        user_responses = len([h for h in history if h[0]]) + 1
        if user_responses == 3:
            full_history = history + [[message, response]]
            response = add_house_details(response, list(HOUSES_MASTER.values()), full_history)
        
        history.append([message, response])
        return history, ""
    except Exception as e:
        print(f"Error: {e}")
        history.append([message, "Hi there! I'm the Sorting Hat. Let's find out which house you belong to!"])
        return history, ""

In [None]:
# UI definition with Gradio
with gr.Blocks() as ui:
    gr.Markdown("# Sorting Hat LLM Experience")
    
    with gr.Row():
        model_selector = gr.Dropdown(["GPT", "Ollama (local)"], label="First Select model", value=None)
    
    with gr.Row(visible=False) as chat_container:
        with gr.Column():
            gr.Markdown("### Sorting Hat Chat")
            gr.Markdown("Chat with the Sorting Hat to discover your Hogwarts house!")
            chatbot = gr.Chatbot()
            msg = gr.Textbox(label="Your message", placeholder="Type your message here...")
            clear = gr.Button("Clear")
    
    # Connect events
    model_selector.change(fn=on_model_change, inputs=model_selector, outputs=[chat_container, chatbot])
    msg.submit(sorting_hat_chat, [msg, chatbot], [chatbot, msg])
    clear.click(lambda: ([], ""), outputs=[chatbot, msg])

ui.launch()