In [2]:
from typing import List, Dict, Literal, Type
from pydantic import BaseModel
from pydantic import ValidationError, BaseModel

In [None]:

# Individual message schemas

class UniveralMessage(BaseModel):
    role: Literal['user', 'assistant']
    content: str

class OpenAIMessage(BaseModel):
    role: Literal['user', 'assistant']
    message: str

class AnthropicMessage(BaseModel):
    role: Literal['user', 'assistant']
    message: str

class CohereMessage(BaseModel):
    role: Literal['SYSTEM', 'USER', 'CHATBOT'] 
    message: str

# Combined message schemas
class UniversalMessages(BaseModel):
    has_system_message: bool = True
    messages: List[UniveralMessage]

class OpenAIMessages(BaseModel):
    has_system_message: bool = True
    messages: List[OpenAIMessage]

class AnthropicMessages(BaseModel):
    has_system_message: bool = False
    messages: List[AnthropicMessage]

class CohereMessages(BaseModel):
    has_system_message: bool = False
    messages: List[CohereMessage]

In [None]:
class MessageConverter:
    _schemas: Dict[str, Type[BaseModel]] = {} 

    @staticmethod
    def register_schema(schema_name: str, schema_class: Type[BaseModel]): MessageConverter._schemas[schema_name] = schema_class
    @staticmethod
    def convert_to_universal(messages: BaseModel) -> UniversalMessages:
        schema_name = type(messages).__name__.lower().replace("messages", "")
        if schema_name not in MessageConverter._schemas:
            raise ValueError(f"Unknown schema: {schema_name}")

        universal_messages = []
        for message in messages.messages:
            if isinstance(message, MessageConverter._schemas[schema_name]):
                role = message.role.lower() if schema_name != 'cohere' else 'assistant' if message.role == 'CHATBOT' else message.role.lower()
                universal_messages.append(UniveralMessage(role=role, content=message.message))
        return UniversalMessages(has_system_message=messages.has_system_message, messages=universal_messages)

    @staticmethod
    def convert_to_target(universal_messages: UniversalMessages, target_schema: str) -> BaseModel:
        if target_schema not in MessageConverter._schemas:
            raise ValueError(f"Unknown target schema: {target_schema}")

        target_messages = []
        for message in universal_messages.messages:
            role = message.role if target_schema != 'cohere' else message.role.upper() if message.role != 'assistant' else 'CHATBOT'
            target_messages.append(MessageConverter._schemas[target_schema](role=role, message=message.content))
        return MessageConverter._schemas[target_schema + "s"](has_system_message=universal_messages.has_system_message, messages=target_messages)

    @staticmethod
    def process_messages(messages: BaseModel, target_schema: str = None) -> BaseModel:
        try:
            if isinstance(messages, UniversalMessages):
                if target_schema:
                    return MessageConverter.convert_to_target(messages, target_schema)
                else:
                    return messages
            else:
                universal_messages = MessageConverter.convert_to_universal(messages)
                if target_schema:
                    return MessageConverter.convert_to_target(universal_messages, target_schema)
                else:
                    return universal_messages
        except ValidationError as e:
            raise ValueError(f"Invalid message format: {str(e)}")