# Customize Speaker Selection

In GroupChat, we can also customize the speaker selection by passing in a function to `speaker_selection_method`:
```python
def custom_speaker_selection_func(
    last_speaker: Agent, 
    groupchat: GroupChat
) -> Union[Agent, Literal['auto', 'manual', 'random' 'round_robin'], None]:

    """Define a customized speaker selection function.
    A recommended way is to define a transition for each speaker in the groupchat.

    Parameters:
        - last_speaker: Agent
            The last speaker in the group chat.
        - groupchat: GroupChat
            The GroupChat object
    Return:
        Return one of the following:
        1. an `Agent` class, it must be one of the agents in the group chat.
        2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use.
        3. None, which indicates the chat should be terminated.
    """
    pass

groupchat = autogen.GroupChat(
    speaker_selection_method=custom_speaker_selection_func,
    ...,
)
```
The last speaker and the groupchat object are passed to the function. 
Commonly used variables from groupchat are `groupchat.messages` and `groupchat.agents`, which is the message history and the agents in the group chat respectively. You can access other attributes of the groupchat, such as `groupchat.allowed_speaker_transitions_dict` for pre-defined `allowed_speaker_transitions_dict`.

Heres is a simple example to build workflow for research with customized speaker selection.


```{=mdx}
![group_chat](../../../blog/2024-02-29-StateFlow/img/sf_example_1.png)
```

We define the following agents:

- Initializer: Start the workflow by sending a task.
- Coder: Retrieve papers from the internet by writing code.
- Executor: Execute the code.
- Scientist: Read the papers and write a summary.

In the Figure, we define a simple workflow for research with 4 states: Init, Retrieve, Research and End. Within each state, we will call different agents to perform the tasks.

Init: We use the initializer to start the workflow.
Retrieve: We will first call the coder to write code and then call the executor to execute the code.
Research: We will call the scientist to read the papers and write a summary.
End: We will end the workflow.

In [16]:
import os
import json
import autogen
import faiss
import numpy as np
import csv
from typing import Annotated, Literal
from autogen import register_function
import requests
from dotenv import load_dotenv

# 加载环境变量
load_dotenv()

# 读取环境变量中的 API 密钥
api_key = os.getenv("OPENAI_API_KEY")

config_list = [{"model": "gpt-4", "api_key": api_key}]

def get_embedding(text, model="text-embedding-ada-002"):
        url = "https://api.openai.com/v1/embeddings"
        headers = {
            "Content-Type": "application/json",
            "Authorization": "Bearer "+api_key
        }
        data = {
            "input": text,
            "model": model
        }
        response = requests.post(url, headers=headers, data=json.dumps(data))
        response_data = response.json()
        embedding = response_data['data'][0]['embedding']
        return embedding

# 加载 Faiss 索引
index_file = "faiss_index.index"

index = faiss.read_index(index_file)

# 加载 QA 数据
qa_data_file = "qa_data.csv"


# 定义函数根据当前问题查找答案
def search_answer(question, index, questions, answers):
    question_embedding = get_embedding(question)
    question_embedding = np.array([question_embedding]).astype('float32')  # 将查询向量转换为二维数组

    # 在 Faiss 索引中搜索最相似的嵌入向量
    top_k = 1  # 返回前1个最相似的结果
    _, indices = index.search(question_embedding, top_k)

    # 根据索引获取相似的答案
    result = []
    for idx in indices[0]:
        question = questions[idx]
        answer = answers[idx]
        result.append((question, answer))

    return result

# 读取 QA 数据并存储问题和答案
questions = []
answers = []

with open(qa_data_file, "r") as f:
    reader = csv.reader(f)
    next(reader)  # 跳过标题行
    for row in reader:
        question = row[0]
        answer = row[1]
        questions.append(question)
        answers.append(answer)

gpt4_config = {
    "cache_seed": 42,  # change the cache_seed for different trials
    "temperature": 0.6,
    "config_list": config_list,
    "timeout": 120,
}

# 设置环境变量AUTOGEN_USE_DOCKER为0
os.environ['AUTOGEN_USE_DOCKER'] = '0'

initializer = autogen.UserProxyAgent(
    name="Init",
)

# 创建包含长字符串的字典
tool_dict = {
    "get_predict": """get_predict——
    Predicting the mean variance and range of plant utilization rates based on generation type and generation output values
    from get_predict import PredictFd
    await PredictFd(now_date,predict_date,original_data,power_type)
    Args:
        now_date (str): date now,format is'yyyymmdd'
        predict_date (str): forecast date ,format is'yyyymmdd'
        original_data (str): generation output values
        power_type (str):  Type of power generation,include '供电', '光伏', '风力'

    Returns:
        predict_data (dict): {"predict":(float)}。
    """,
    "get_cylpredict": """get_cylpredict——
    Predicting the mean variance and range of plant utilization rates based on generation type and generation output values
    from get_cylpredict import PredictCyl
    await PredictCyl(value,power_type)
    Args:
        value (str): value of generation output
        power_type (str):  Type of power generation,include 'fd'(Wind power), 'gf'(Hydropower), 'hdrm'(Coal-fired Thermal Power) , 'sdyg'(Hydropower)

    Returns:
        predict_data (dict): 
        {
        "mean": "",(mean value of plant utilization rate)
        "std": "",( variance of plant utilization rate)
        "interval_lower": "",(lower limit of plant utilization rate interval)
        "interval_upper": ""(upper limit of plant utilization rate interval)
        }
    """
    
}

msg = {'content': ""}

# 定义一个方法来打印msg内容
def print_message(received_msg):
    global msg
    try:
        strdata = received_msg.get('content',"")
        data_dict = json.loads(strdata)
        tools_des = ""
        for tool in data_dict["tools"]:
            tools_des += tool_dict[tool]
        newcontent ="tools description:\n" + tools_des + "Please consider importing the above tools to write code"
        results = search_answer(str(received_msg.get('question',"")), index, questions, answers)
        msg['content'] = newcontent + "\nThe following questions and corresponding code examples are provided for reference:\n" +str(results)
    except Exception as e:
        print("No Recommend")
    return False

retriever = autogen.UserProxyAgent(
    name="ToolSearcher",
    llm_config=False,
    is_termination_msg=lambda msg: print_message(msg),
    default_auto_reply= msg,
    human_input_mode="NEVER",
)

selector = autogen.AssistantAgent(
    name="ToolSelector",
    llm_config=gpt4_config,
    system_message="""
    get_cylpredict:Predicting the mean variance and range of plant utilization rates based on generation type and generation output values
    get_predict:get power generation forecasts based on current and forecast dates and generation types original_data
    Select the appropriate tool to answer the question based on the description of the above two tools and the user's question,
    and output it only as a json containing and the refinement of your understanding of the user's problem makes it clearer and an array of tool names {"question":"","tools":[]}""",
)

coder = autogen.AssistantAgent(
    name="Coder",
    llm_config=gpt4_config,
    system_message="""
    You are the Coder.Write code according to the tool description and relevant examples No mocking unless requested by the user.
You write python/shell code to solve tasks. Wrap the code in a code block that specifies the script type. The user can't modify your code. So do not suggest incomplete code which requires others to modify. Don't use a code block if it's not intended to be executed by the executor.
Don't include multiple code blocks in one response. Do not ask others to copy and paste the result. Check the execution result returned by the executor.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
""",
)

executor = autogen.UserProxyAgent(
    name="ExecuteCode",
    system_message="Executor. Execute the code written by the Coder and report the result.",
    human_input_mode="NEVER",
    code_execution_config={
        "last_n_messages": 3,
        "work_dir": "grid",
        "use_docker": False,
    },  # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly.
)
checker = autogen.AssistantAgent(
    name="ResultChecker",
    llm_config=gpt4_config,
    system_message="""You are the checker, determine whether the execution result meets the user's requirements, if so then output Yes and tell the me the answer the user needs, and if there is a file saved state where the file is saved otherwise output No. And give suggestions for changes to make it perfect.""",
)


def state_transition(last_speaker, groupchat):
    messages = groupchat.messages
    
    if last_speaker is initializer:
        #用户提问——推荐工具
        return selector
    elif last_speaker is selector:
        return retriever 
    elif last_speaker is retriever:
        return coder
        
    elif last_speaker is coder:
        #写代码——执行
        return executor
    elif last_speaker is executor:
        last_message_content = messages[-1].get('content',"")
        if last_message_content == "exitcode: 1":
            print("Error，recode")
            return coder
        else:
            print("Success，check")
            return checker
    elif last_speaker is checker:
        last_message_content = messages[-1].get("content", "").lower()
        string_to_match = "yes"
        if string_to_match in last_message_content:
            print("Checker Accept")
            return None
        else:
            return coder
    elif last_speaker is executor:
        # research -> end
        return None


groupchat = autogen.GroupChat(
    agents=[initializer,selector,checker,retriever ,coder, executor],
    messages=[{'content': '{"tools": ["get_predict"]}', 'role': 'user', 'name': 'ToolSelector'}, {'content': 'tools description:\nget_predict——\n    Predicting the mean variance and range of plant utilization rates based on generation type and generation output values\n    from get_predict import PredictFd\n    PredictFd(now_date,predict_date,original_data,power_type)\n    Args:\n        now_date (str): date now,format is\'yyyymmdd\'\n        predict_date (str): forecast date ,format is\'yyyymmdd\'\n        original_data (str): generation output values\n        power_type (str):  Type of power generation,include \'供电\', \'光伏\', \'风力\'\n\n    Returns:\n        predict_data (dict): {"predict":(float)}。\n    Please consider importing the above tools to write code', 'role': 'assistant', 'name': 'ToolSelector'}],
    max_round=20,
    speaker_selection_method=state_transition,
)
# groupchat.append({
#                 "content": "你需要在代码里表明作者是gzm",
#                 "role": "assistant",
#                 },initializer)
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=gpt4_config)


initializer.initiate_chat(
    manager, message="画个图，展示未来3天内光伏与风电发电量预测,保存到本地"
)

[33mInit[0m (to chat_manager):

画个图，展示未来3天内光伏与风电发电量预测,保存到本地

--------------------------------------------------------------------------------
[33mToolSelector[0m (to chat_manager):

{"question":"Draw a graph to show the forecast of photovoltaic and wind power generation in the next 3 days, and save it locally","tools":["get_predict"]}

--------------------------------------------------------------------------------
[33mToolSearcher[0m (to chat_manager):

tools description:
get_predict——
    Predicting the mean variance and range of plant utilization rates based on generation type and generation output values
    from get_predict import PredictFd
    await PredictFd(now_date,predict_date,original_data,power_type)
    Args:
        now_date (str): date now,format is'yyyymmdd'
        predict_date (str): forecast date ,format is'yyyymmdd'
        original_data (str): generation output values
        power_type (str):  Type of power generation,include '供电', '光伏', '风力'

    Returns:
 

ChatResult(chat_id=None, chat_history=[{'content': '画个图，展示未来3天内光伏与风电发电量预测,保存到本地', 'role': 'assistant'}, {'content': '{"question":"Draw a graph to show the forecast of photovoltaic and wind power generation in the next 3 days, and save it locally","tools":["get_predict"]}', 'name': 'ToolSelector', 'role': 'user'}, {'content': 'tools description:\nget_predict——\n    Predicting the mean variance and range of plant utilization rates based on generation type and generation output values\n    from get_predict import PredictFd\n    await PredictFd(now_date,predict_date,original_data,power_type)\n    Args:\n        now_date (str): date now,format is\'yyyymmdd\'\n        predict_date (str): forecast date ,format is\'yyyymmdd\'\n        original_data (str): generation output values\n        power_type (str):  Type of power generation,include \'供电\', \'光伏\', \'风力\'\n\n    Returns:\n        predict_data (dict): {"predict":(float)}。\n    Please consider importing the above tools to write code\nThe 

In [15]:
!pip install python-dotenv

Looking in indexes: https://mirrors.aliyun.com/pypi/simple/
