In [None]:
import os
import re
from abc import ABC, abstractmethod

from camel.agents import RolePlaying
from camel.messages import ChatMessage
from camel.typing import TaskType, ModelType
from chatdev.chat_env import ChatEnv
from chatdev.statistics import get_info
from chatdev.utils import log_visualize, log_arguments




In [None]:
class Phase(ABC):

    def __init__(self,
                 assistant_role_name,
                 user_role_name,
                 phase_prompt,
                 role_prompts,
                 phase_name,
                 model_type,
                 log_filepath):


        self.seminar_conclusion = None
        self.assistant_role_name = assistant_role_name
        self.user_role_name = user_role_name
        self.phase_prompt = phase_prompt
        self.phase_env = dict()
        self.phase_name = phase_name
        self.assistant_role_prompt = role_prompts[assistant_role_name]
        self.user_role_prompt = role_prompts[user_role_name]
        self.ceo_prompt = role_prompts["Chief Executive Officer"]
        self.counselor_prompt = role_prompts["Counselor"]
        self.max_retries = 3
        self.reflection_prompt = """Here is a conversation between two roles: {conversations} {question}"""
        self.model_type = model_type
        self.log_filepath = log_filepath

    @log_arguments
    def chatting(
            self,
            chat_env,
            task_prompt: str,
            assistant_role_name: str,
            user_role_name: str,
            phase_prompt: str,
            phase_name: str,
            assistant_role_prompt: str,
            user_role_prompt: str,
            task_type=TaskType.CHATDEV,
            need_reflect=False,
            with_task_specify=False,
            model_type=ModelType.GPT_3_5_TURBO,
            memory=None,
            placeholders=None,
            chat_turn_limit=10
    ) -> str:




        if placeholders is None:
            placeholders = {}
        assert 1 <= chat_turn_limit <= 100

        if not chat_env.exist_employee(assistant_role_name):
            raise ValueError(f"{assistant_role_name} not recruited in ChatEnv.")
        if not chat_env.exist_employee(user_role_name):
            raise ValueError(f"{user_role_name} not recruited in ChatEnv.")

        # init role play
        role_play_session = RolePlaying(
            assistant_role_name=assistant_role_name,
            user_role_name=user_role_name,
            assistant_role_prompt=assistant_role_prompt,
            user_role_prompt=user_role_prompt,
            task_prompt=task_prompt,
            task_type=task_type,
            with_task_specify=with_task_specify,
            memory=memory,
            model_type=model_type,
            background_prompt=chat_env.config.background_prompt
        )

        # log_visualize("System", role_play_session.assistant_sys_msg)
        # log_visualize("System", role_play_session.user_sys_msg)

        # start the chat
        _, input_user_msg = role_play_session.init_chat(None, placeholders, phase_prompt)
        seminar_conclusion = None


        for i in range(chat_turn_limit):

            assistant_response, user_response = role_play_session.step(input_user_msg, chat_turn_limit == 1)

            conversation_meta = "**" + assistant_role_name + "<->" + user_role_name + " on : " + str(
                phase_name) + ", turn " + str(i) + "**\n\n"

            # TODO: max_tokens_exceeded errors here
            if isinstance(assistant_response.msg, ChatMessage):
                # we log the second interaction here
                log_visualize(role_play_session.assistant_agent.role_name,
                              conversation_meta + "[" + role_play_session.user_agent.system_message.content + "]\n\n" + assistant_response.msg.content)
                if role_play_session.assistant_agent.info:
                    seminar_conclusion = assistant_response.msg.content
                    break
                if assistant_response.terminated:
                    break

            if isinstance(user_response.msg, ChatMessage):
                # here is the result of the second interaction, which may be used to start the next chat turn
                log_visualize(role_play_session.user_agent.role_name,
                              conversation_meta + "[" + role_play_session.assistant_agent.system_message.content + "]\n\n" + user_response.msg.content)
                if role_play_session.user_agent.info:
                    seminar_conclusion = user_response.msg.content
                    break
                if user_response.terminated:
                    break

            # continue the chat
            if chat_turn_limit > 1 and isinstance(user_response.msg, ChatMessage):
                input_user_msg = user_response.msg
            else:
                break

        # conduct self reflection
        if need_reflect:
            if seminar_conclusion in [None, ""]:
                seminar_conclusion = "<INFO> " + self.self_reflection(task_prompt, role_play_session, phase_name,
                                                                      chat_env)
            if "recruiting" in phase_name:
                if "Yes".lower() not in seminar_conclusion.lower() and "No".lower() not in seminar_conclusion.lower():
                    seminar_conclusion = "<INFO> " + self.self_reflection(task_prompt, role_play_session,
                                                                          phase_name,
                                                                          chat_env)
            elif seminar_conclusion in [None, ""]:
                seminar_conclusion = "<INFO> " + self.self_reflection(task_prompt, role_play_session, phase_name,
                                                                      chat_env)
        else:
            seminar_conclusion = assistant_response.msg.content

        log_visualize("**[Seminar Conclusion]**:\n\n {}".format(seminar_conclusion))
        seminar_conclusion = seminar_conclusion.split("<INFO>")[-1]
        return seminar_conclusion


    @abstractmethod
    def update_phase_env(self, chat_env):
        pass

    @abstractmethod
    def update_chat_env(self, chat_env) -> ChatEnv:

        pass

    def execute(self, chat_env, chat_turn_limit, need_reflect) -> ChatEnv:

        self.update_phase_env(chat_env)
        self.seminar_conclusion = \
            self.chatting(chat_env=chat_env,
                          task_prompt=chat_env.env_dict['task_prompt'],
                          need_reflect=need_reflect,
                          assistant_role_name=self.assistant_role_name,
                          user_role_name=self.user_role_name,
                          phase_prompt=self.phase_prompt,
                          phase_name=self.phase_name,
                          assistant_role_prompt=self.assistant_role_prompt,
                          user_role_prompt=self.user_role_prompt,
                          chat_turn_limit=chat_turn_limit,
                          placeholders=self.phase_env,
                          memory=chat_env.memory,
                          model_type=self.model_type)
        chat_env = self.update_chat_env(chat_env)
        return chat_env


In [None]:
class DemandAnalysis():
  def __init__(self, **kwargs):
        super().__init__(**kwargs)

  def update_phase_env(self, chat_env):
        pass

  def update_chat_env(self, chat_env) -> ChatEnv:
        if len(self.seminar_conclusion) > 0:
            chat_env.env_dict['analysis'] = self.seminar_conclusion.split("<INFO>")[-1].lower().replace(".", "").strip()
        return chat_env

In [None]:
class FacultyDecision():
  def __init__(self, **kwargs):
        super().__init__(**kwargs)

  def update_phase_env():
    self.phase_env.update({"task": chat_env.env_dict['task_prompt'],
                           "description": chat_env.env_dict['task_description'],
                           "analysis":chat_env.env_dict['analysis']})


  def update_chat_env():
        if len(self.seminar_conclusion) > 0:
            chat_env.env_dict['faculty'] = self.seminar_conclusion.split("<INFO>")[-1].lower().replace(".", "").strip()
        return chat_env



In [None]:
class StudentDecision():
  def __init__(self, **kwargs):
        super().__init__(**kwargs)

  def update_phase_env():
    self.phase_env.update({"task": chat_env.env_dict['task_prompt'],
                           "description": chat_env.env_dict['task_description'],
                           "analysis": chat_env.env_dict['analysis'],
                           "faculty": chat_env.env_dict['faculty']})



  def update_chat_env():
        if len(self.seminar_conclusion) > 0:
            chat_env.env_dict['student'] = self.seminar_conclusion.split("<INFO>")[-1].lower().replace(".", "").strip()
        return chat_env

In [None]:
class DeanDecision():
  def __init__(self, **kwargs):
        super().__init__(**kwargs)

  def update_phase_env():
    self.phase_env.update({"task": chat_env.env_dict['task_prompt'],
                           "description": chat_env.env_dict['task_description'],
                           "analysis": chat_env.env_dict['analysis'],
                           "faculty": chat_env.env_dict['faculty'],
                           "student": chat_env.env_dict['student']})



  def update_chat_env():
        if len(self.seminar_conclusion) > 0:
            chat_env.env_dict['dean'] = self.seminar_conclusion.split("<INFO>")[-1].lower().replace(".", "").strip()
        return chat_env

In [None]:
class Review():
  def __init__(self, **kwargs):
        super().__init__(**kwargs)

  def update_phase_env():
    self.phase_env.update({"task": chat_env.env_dict['task_prompt'],
                           "description": chat_env.env_dict['task_description'],
                           "analysis": chat_env.env_dict['analysis'],
                           "faculty": chat_env.env_dict['faculty'],
                           "student": chat_env.env_dict['student'],
                           "dean": chat_env.env_dict['dean']})

  def update_chat_env():
        if len(self.seminar_conclusion) > 0:
            chat_env.env_dict['final'] = self.seminar_conclusion.split("<INFO>")[-1].lower().replace(".", "").strip()
        return chat_env