In [None]:
import os
import json
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def load_world(filename):
    with open(filename, 'r') as f:
        return json.load(f)

def get_game_state(inventory={}):
    world = load_world('../shared_data/Kyropeia.json')
    kingdom = world['kingdoms']['Eldrida']
    town = kingdom['towns']["Luminaria"]
    character = town['npcs']['Elwyn Stormbringer']
    start = world['start']

    game_state = {
        "world": world['description'],
        "kingdom": kingdom['description'],
        "town": town['description'],
        "character": character['description'],
        "start": start,
        "inventory": inventory
    }
    return game_state

class LocalLLM:
    def __init__(self, model_path="meta-llama/Llama-3-8b-instruct-hf"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path, 
            torch_dtype=torch.float16, 
            device_map="auto"
        )

    def generate(self, messages, max_new_tokens=200):
        # Llama 3.1 prompt format
        prompt = "<|begin_of_text|>"
        
        # Add system message
        prompt += "<|start_header_id|>system<|end_header_id|>\n"
        prompt += "You are an AI Game master creating an interactive story.\n<|eot_id|>"
        
        # Add conversation history and current message
        for msg in messages:
            role = msg['role']
            content = msg['content']
            prompt += f"<|start_header_id|>{role}<|end_header_id|>\n{content}\n<|eot_id|>"
        
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        # Generate response
        outputs = self.model.generate(
            **inputs, 
            max_new_tokens=max_new_tokens, 
            do_sample=True, 
            temperature=0.7
        )
        
        # Decode response and remove special tokens
        response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        return response.strip()

def run_action(message, history, game_state):
    if message == 'start game':
        return game_state['start']
    
    world_info = f"""
World: {game_state['world']}
Kingdom: {game_state['kingdom']}
Town: {game_state['town']}
Your Character: {game_state['character']}
Current Inventory: {json.dumps(game_state['inventory'])}"""

    messages = [
        {"role": "user", "content": world_info + "\n" + message}
    ]

    # Use local LLM
    llm = LocalLLM()
    result = llm.generate(messages)
    return result

def detect_inventory_changes(game_state, output):
    messages = [
        {"role": "system", "content": "Detect inventory changes from the story. Return JSON with item updates."},
        {"role": "user", "content": f'Current Inventory: {str(game_state["inventory"])}\nRecent Story: {output}'}
    ]

    # Use local LLM
    llm = LocalLLM()
    response = llm.generate(messages)
    
    # Parse response (add error handling)
    try:
        result = json.loads(response)
        return result.get('itemUpdates', [])
    except:
        return []

def update_inventory(inventory, item_updates):
    update_msg = ''
    
    for update in item_updates:
        name = update['name']
        change_amount = update['change_amount']
        
        if change_amount > 0:
            if name not in inventory:
                inventory[name] = change_amount
            else:
                inventory[name] += change_amount
            update_msg += f'\nInventory: {name} +{change_amount}'
        elif name in inventory and change_amount < 0:
            inventory[name] += change_amount
            update_msg += f'\nInventory: {name} {change_amount}'
            
        if name in inventory and inventory[name] < 0:
            del inventory[name]
            
    return update_msg

def start_game(main_loop, share=False):
    demo = gr.ChatInterface(
        main_loop,
        chatbot=gr.Chatbot(height=250, placeholder="Type 'start game' to begin"),
        textbox=gr.Textbox(placeholder="What do you do next?", container=False, scale=7),
        title="Local Llama AI RPG",
        theme="soft",
        examples=["Look around", "Continue the story"],
        cache_examples=False,
        retry_btn="Retry",
        undo_btn="Undo",
        clear_btn="Clear",
    )
    demo.launch(share=share, server_name="0.0.0.0")

def main_loop(message, history):
    game_state = get_game_state(inventory={
        "cloth pants": 1,
        "cloth shirt": 1,
        "goggles": 1,
        "leather bound journal": 1,
        "gold": 5
    })

    output = run_action(message, history, game_state)

    item_updates = detect_inventory_changes(game_state, output)
    update_msg = update_inventory(
        game_state['inventory'], 
        item_updates
    )
    output += update_msg

    return output

# Uncomment to launch
start_game(main_loop, True)

Running on local URL:  http://0.0.0.0:7860
Running on public URL: https://9d492b137915d5a655.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/gradio/blocks.py", line 1935, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/gradio/blocks.py", line 1518, in call_function
    prediction = await fn(*processed_input)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Version