In [1]:
import json

from transformers import pipeline

from deb_ai_tor.environments import DebateSandwichEnv

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
agents = ["agent_0", "agent_1"]

env = DebateSandwichEnv(agents=agents)
obs = env.reset(options={
    "question": "What is a debate sandwich?",
    "answers": ["Arguments with answers in the middle", "A sandwich eaten while debating", "Answers with arguments in the middle"],
    "num_rounds": 3,
})

print(json.dumps(obs, indent=4))

class LocalLLMAgent:
    def __init__(self, agent_name, llm_pipeline):
        self.agent_name = agent_name
        self.llm = llm_pipeline
        self.context = ""

    def get_action(self, obs):
        prompt = (
            f"Agent: {self.agent_name}\n"
            f"Previous context:\n{self.context}\n"
            f"Observation: {json.dumps(obs, indent=2)}\n"
            f"What is your action?"
        )
        response = self.llm(prompt, max_new_tokens=32)[0]['generated_text']
        self.context += f"\nObservation: {json.dumps(obs, indent=2)}\nAction: {response.strip()}"
        return response.strip()

# Shared LLM pipeline for all agents
llm_pipeline = pipeline('text-generation', model='mistralai/Mistral-7B-Instruct-v0.2', device_map='auto')

# Create agent wrappers
agent_wrappers = {name: LocalLLMAgent(name, llm_pipeline) for name in agents}

{
    "agent_0": {
        "message": "Which of the following answers do you think is best?\n",
        "info": "Question: What is a debate sandwich?\nAnswers: {'A': 'Arguments with answers in the middle', 'B': 'A sandwich eaten while debating', 'C': 'Answers with arguments in the middle'}\n"
    },
    "agent_1": {
        "message": "Which of the following answers do you think is best?\n",
        "info": "Question: What is a debate sandwich?\nAnswers: {'A': 'Arguments with answers in the middle', 'B': 'A sandwich eaten while debating', 'C': 'Answers with arguments in the middle'}\n"
    }
}


Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# Load a local LLM (change model name as needed)
llm = pipeline('text-generation', model='mistralai/Mistral-7B-Instruct-v0.2', device_map='auto')

def get_llm_action(agent, obs):
    prompt = f"Agent: {agent}\nObservation: {json.dumps(obs[agent], indent=2)}\nWhat is your action?"
    response = llm(prompt, max_new_tokens=32)[0]['generated_text']
    # You may need to parse/clean the response to fit your env's expected action format
    return response.strip()

# Run a round using LLM agents
actions = {name: agent_wrappers[name].get_action(obs[name]) for name in agents}
obs, rewards, terminations, truncations, infos = env.step(actions)
print(json.dumps(obs, indent=4))

{
    "agent_0": {
        "message": "Debate your choice of answer\n",
        "info": {
            "agent_0": "A",
            "agent_1": "C"
        }
    },
    "agent_1": {
        "message": "Debate your choice of answer\n",
        "info": {
            "agent_0": "A",
            "agent_1": "C"
        }
    }
}
