<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/agent_design_patterns_demo_magistral.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install crewai and langchain-openai
!pip install crewai langchain-openai -q
!pip install pydantic  -q # Ensure pydantic is installed
!pip install 'crewai[tools]' -q

In [None]:
!pip install colab-env --quiet
!pip install mistralai --quiet

In [3]:
import os
from crewai import Agent, Task, Crew, Process
from crewai.tools import BaseTool # Import BaseTool
from pydantic import BaseModel, Field # Import BaseModel and Field for schema definition
from litellm import litellm # Import litellm directly
from google.colab import userdata # For Google Colab environment

In [None]:
from google.colab import userdata
from litellm import completion
import os
import openai
import colab_env

from mistralai import Mistral
api_key = os.environ["MISTRAL_API_KEY"]
client = Mistral(api_key=api_key)

client = Mistral(api_key=api_key)  # Use MistralClient instead of Mistral
model_list = client.models.list()

#for model in model_list.data:
    #print(model.id)
    #print(model.created)
    #print(model.owned_by)

model=model_list.data[45].id

print(f"Model name: {model}")

messages = [{"role": "user", "content": "What is the best Canadian poet?"}]

chat_response = client.chat.complete(  # Use client.chat.complete
    model=model,
    messages=messages,
)
print(chat_response.choices[0].message.content)

In [None]:
# Install necessary libraries
!pip install crewai langchain-openai pydantic litellm mistralai -q
!pip install 'crewai [tools]' -q # For SerperDevTool and WebsiteReadTool

In [None]:
!pip install crewai langchain-openai pydantic mistralai -q

In [30]:
import os
import warnings
from crewai import Agent, Task, Crew, Process
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
# REMOVED: from litellm import litellm # <--- THIS LINE IS REMOVED
from google.colab import userdata

# Import Mistral AI specific components
from mistralai import Mistral # As per your explicit request to use 'Mistral' client

# Langchain components for custom LLM integration
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatResult, Generation
from typing import Any, List, Dict, Optional, Type

# Suppress warnings that are not critical to code logic
warnings.filterwarnings("ignore", category=DeprecationWarning, module="httpx")
warnings.filterwarnings ("ignore", category=DeprecationWarning, module="ipywidgets")

# 1. Custom MistralChatModel for CrewAI
class MistralChatModel(BaseChatModel):
    """
    A custom ChatModel for Langchain (and thus CrewAI) that uses Mistral AI's client (Mistral class).
    """
    model_name: str
    api_key: str
    temperature: float = 0.7
    max_tokens: Optional[int] = None
    client: Mistral = None

    def __init__(self, model_name: str, api_key: str, temperature: float = 0.7, max_tokens: Optional[int] = None, **kwargs):
        super().__init__(model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens, **kwargs)
        self.model_name = model_name
        self.api_key = api_key
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.client = Mistral(api_key=api_key) # Initialized with Mistral class

    def _generate(
        self,
        messages: List[Any],
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> ChatResult:
        mistral_messages = []
        for message in messages:
            if isinstance(message, HumanMessage):
                mistral_messages.append({"role": "user", "content": message.content})
            elif isinstance(message, AIMessage):
                mistral_messages.append({"role": "assistant", "content": message.content})
            elif isinstance(message, SystemMessage):
                mistral_messages.append({"role": "system", "content": message.content})
            else:
                raise ValueError(f"Unsupported message type: {type(message)}")

        try:
            # The chat.complete method is called on the client attribute from mistralai
            chat_response = self.client.chat.complete(
                model=self.model_name,
                messages=mistral_messages,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                stop=stop,
                **kwargs
            )
            response_content = chat_response.choices[0].message.content
            ai_message = AIMessage(content=response_content)
            generation = Generation(text=response_content, message=ai_message)
            return ChatResult(generations=[generation])
        except Exception as e:
            print(f"Mistral AI call failed: {e}")
            raise e

    @property
    def _llm_type(self) -> str:
        # A unique identifier for this custom LLM integration
        return "mistral_native_chat_model_v2" # Changed this slightly to be more distinct

    def _get_parameters(self) -> Dict[str, Any]:
        return {
            "model_name": self.model_name,
            "api_key": "SKIPPED_FOR_SECURITY",
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
        }

    def _stream(self, messages: List[Any], stop: Optional[List[str]] = None, **kwargs: Any) -> Any:
        raise NotImplementedError("MistralChatModel does not support streaming yet in this wrapper.")


# 2. Configuration for Mistral LLM
mistral_api_key = userdata.get('MISTRAL_API_KEY')

if not mistral_api_key:
    try:
        mistral_api_key = os.environ["MISTRAL_API_KEY"]
    except KeyError:
        print("MISTRAL_API_KEY not found in Colab userdata or environment variables.")
        print("Please set your MISTRAL_API_KEY in Colab secrets or as an environment variable.")
        mistral_api_key = "YOUR_MISTRAL_API_KEY_HERE"

# Initialize Mistral client (Mistral class) to get model list
mistral_client_temp = Mistral(api_key=mistral_api_key)
model_list = mistral_client_temp.models.list()

try:
    # Using index 45 as in your reference, but safer to use explicit model name
    mistral_model_name = model_list.data[45].id
except IndexError:
    print("Warning: Model at index 45 not found. Falling back to 'mistral-large-latest'.")
    mistral_model_name = "mistral-large-latest"

print(f"Using Mistral Model: {mistral_model_name}")

# This is the instance of your custom MistralChatModel
# (If you are forced to use LiteLLM for Mistral due to persistent errors)
# Re-introduce litellm import
from litellm import litellm

class MistralViaLiteLLMChatModel(BaseChatModel):
    """
    A custom ChatModel for Langchain (and thus CrewAI) that uses LiteLLM to call Mistral.
    """
    model_name: str # This will be like "mistral/magistral-medium-latest"
    api_key: str
    temperature: float = 0.7
    max_tokens: Optional[int] = None

    def __init__(self, model_name: str, api_key: str, temperature: float = 0.7, max_tokens: Optional[int] = None, **kwargs):
        super().__init__(model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens, **kwargs)
        self.model_name = model_name
        self.api_key = api_key
        self.temperature = temperature
        self.max_tokens = max_tokens

    def _generate(
        self,
        messages: List[Any],
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> ChatResult:
        litellm_messages = []
        for message in messages:
            if isinstance(message, HumanMessage):
                litellm_messages.append({"role": "user", "content": message.content})
            elif isinstance(message, AIMessage):
                litellm_messages.append({"role": "assistant", "content": message.content})
            elif isinstance(message, SystemMessage):
                litellm_messages.append({"role": "system", "content": message.content})
            else:
                raise ValueError(f"Unsupported message type: {type(message)}")

        try:
            response = litellm.completion(
                model=self.model_name, # Now expects "mistral/MODEL_NAME"
                messages=litellm_messages,
                api_key=self.api_key, # LiteLLM uses this for the specific provider
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                stop=stop,
                **kwargs
            )
            response_content = response.choices[0].message.content
            ai_message = AIMessage(content=response_content)
            generation = Generation(text=response_content, message=ai_message)
            return ChatResult(generations=[generation])
        except Exception as e:
            print(f"LiteLLM call failed: {e}") # This error is now expected if misconfigured
            raise e

    @property
    def _llm_type(self) -> str:
        return "mistral_via_litellm_chat_model"

    def _get_parameters(self) -> Dict[str, Any]:
        return {
            "model_name": self.model_name,
            "api_key": "SKIPPED_FOR_SECURITY",
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
        }

    def _stream(self, messages: List[Any], stop: Optional[List[str]] = None, **kwargs: Any) -> Any:
        raise NotImplementedError("MistralViaLiteLLMChatModel does not support streaming yet.")


# And then configure it like this:
my_mistral_llm = MistralViaLiteLLMChatModel(
    model_name=f"mistral/{mistral_model_name}", # IMPORTANT: Add "mistral/" prefix
    api_key=mistral_api_key,
    temperature=0.5,
    max_tokens=4096,
)

# 3. Define Pydantic Models for Tool Inputs
class FlightInput(BaseModel):
    """Input for FlightAPIQueryTool."""
    origin: str = Field(description="The departure city or airport code.")
    destination: str = Field(description="The arrival city or airport code.")
    trip_dates: str = Field(description="The dates of the trip, e.g., 'next month', 'July 15-22'.")
    passengers: int = Field(description="Number of passengers.", default=1)

class HotelInput(BaseModel):
    """Input for HotelAPIQueryTool."""
    location: str = Field(description="The city or area for the hotel search.")
    check_in_date: str = Field(description="The check-in date, e.g., 'July 15'.")
    check_out_date: str = Field(description="The check-out date, e.g., 'July 22'.")
    preferences: Optional[str] = Field(description="Any specific preferences like 'luxury', 'budget'.", default=None)

# 4. Refactor Custom Tools using BaseTool and Pydantic
class FlightAPIQueryTool(BaseTool):
    name: str = "Flight API Query Tool"
    description: str = (
        """Queries a simulated flight booking API for flight details.
        Use this tool when the user asks for flight information including origin, destination, and dates."""
    )
    args_schema: Type[BaseModel] = FlightInput

    def _run(self, origin: str, destination: str, trip_dates: str, passengers: int = 1) -> str:
        print(f"\n--- Using {self.name} with origin='{origin}', destination='{destination}', dates='{trip_dates}', passengers='{passengers}'")
        if "London" in origin and "New York" in destination:
            return f"Found direct flights: BA123 ({origin}-{destination}), AA456 ({origin}-{destination}). Prices are around $800-1200."
        elif "Montreal" in origin and "Paris" in destination:
            return f"Found flights with layovers: AC870 ({origin}-{destination} via Toronto). Prices are around $900-1500 per person. Example: Air Canada direct from YUL to CDG on July 15, 2025 at 20:55, arriving July 16, 2025 at 09:45+1. Return on July 22, 2025 from CDG at 11:30, arriving YUL at 13:35. Duration 7h 50m (outbound) / 7h 05m (return). Baggage: 1 checked bag included."
        else:
            return "No direct flights found for the specified route. Suggesting alternatives."

class HotelAPIQueryTool(BaseTool):
    name: str = "Hotel API Query Tool"
    description: str = (
        """Queries a simulated hotel booking API for hotel options.
        Use this tool when the user asks for hotel information including location, check-in/out dates, and preferences."""
    )
    args_schema: Type[BaseModel] = HotelInput

    def _run(self, location: str, check_in_date: str, check_out_date: str, preferences: Optional[str] = None) -> str:
        print(f"\n--- Using {self.name} with location='{location}', check-in='{check_in_date}', check-out='{check_out_date}', preferences='{preferences}'")
        if "New York" in location and "luxury" in preferences:
            return f"Luxury hotels in NYC for {check_in_date}-{check_out_date}: The Plaza, Mandarin Oriental. Rates starting at $600/night."
        elif "Paris" in location and "mid-range" in preferences:
            return f"Mid-range hotels in Paris for {check_in_date}-{check_out_date}: Ibis Styles Paris Gare de l'Est Château Landon (approx. €210/night). Features: 5-min walk to Gare de l'Est metro/RER, breakfast included, 24-hour reception, direct RER line to CDG (25 mins)."
        else:
            return "No hotels matching criteria."

# Initialize the primary tools
flight_api_tool = FlightAPIQueryTool()
hotel_api_tool = HotelAPIQueryTool()

# 5. REMOVED: Optional Real Web Search Tools (SerperDevTool, WebsiteReadTool)
# Because you don't want to use SERPER_API_KEY, these tools are no longer initialized or used.
general_search_tools = [] # Keep as empty list so it can be added to tools without error

print("General web search tools (SerperDevTool, WebsiteReadTool) are NOT initialized as per user request.")

# 6. Define Agents
class FlightPlanningAgents:
    def __init__(self, llm_model: BaseChatModel):
        self.llm = llm_model

    def flight_researcher(self):
        return Agent(
            role='Flight Information Researcher',
            goal='Find the best flight options based on user query, including cheapest, fastest, and most convenient routes.',
            backstory="""You are an expert travel agent specializing in flight research. You have access to flight data.""", # Adjusted backstory
            verbose=False,
            allow_delegation=False,
            tools=[flight_api_tool], # Removed general_search_tools
            llm=self.llm
        )

    def hotel_researcher(self):
        return Agent(
            role='Hotel and Accommodation Researcher',
            goal='Identify suitable hotel options that align with user preferences (e.g., budget, luxury, family-friendly) and provide key details.',
            backstory="""You are a seasoned hospitality expert with an encyclopedic knowledge of hotels worldwide. You can find the perfect stay for any traveler.""",
            verbose=False,
            allow_delegation=False,
            tools=[hotel_api_tool], # Removed general_search_tools
            llm=self.llm
        )

    def itinerary_builder(self):
        return Agent(
            role='Personalized Itinerary Creator',
            goal='Synthesize flight and hotel information into a coherent, personalized travel itinerary, adding popular attractions and practical tips.',
            backstory="""You are a master of travel logistics, capable of weaving together disparate trip components into a seamless and enjoyable journey.""",
            verbose=False,
            allow_delegation=True,
            llm=self.llm
        )

    def travel_router(self):
        return Agent(
            role='Travel Query Router',
            goal='Determine the primary intent of a user\'s travel query and route it to the appropriate specialized agent (flight, hotel, or itinerary).',
            backstory="""You are the first point of contact for all travel inquiries. Your sharp analytical skills allow you to quickly understand user needs and direct them to the best resource.""",
            verbose=False,
            allow_delegation=False,
            llm=self.llm
        )

# 7. Define Tasks
class FlightPlanningTasks:
    def __init__(self):
        pass

    def research_flight_options(self, agent: Agent, query: str):
        return Task(
            description=f"""Thoroughly research flight options for the following request: '{query}'.
            Identify the origin, destination, and dates from the query to accurately use the flight API tool.
            Focus on finding direct flights, reasonable layovers if direct isn't possible,
            and provide a range of prices.""", # Removed mention of web search
            expected_output="A concise summary of top 2-3 flight options including airlines, dates, times, and approximate costs.",
            agent=agent,
            tools=[flight_api_tool] # Removed general_search_tools
        )

    def research_hotel_options(self, agent: Agent, query: str):
        return Task(
            description=f"""Investigate hotel options relevant to: '{query}'.
            Extract the location, check-in/out dates, and any preferences from the query to
            identify top 2-3 hotel suggestions with approximate costs and key features.""", # Removed mention of web search
            expected_output="A list of 2-3 suitable hotels with brief descriptions, estimated nightly rates, and key amenities.",
            agent=agent,
            tools=[hotel_api_tool] # Removed general_search_tools
        )

    def create_travel_itinerary(self, agent: Agent, original_query: str, context_tasks: List[Task]):
        return Task(
            description=f"""Compile a comprehensive travel itinerary based on the original query: '{original_query}'.
            Access the flight information from the context of previous tasks.
            Access the hotel information from the context of previous tasks.
            Synthesize this information. Structure the itinerary clearly, perhaps day-by-day, suggesting attractions and practical tips (e.g., transport, dining, safety).""",
            expected_output="A well-structured, detailed travel itinerary in markdown format, combining flight details, hotel information, daily activities, and practical travel tips.",
            agent=agent,
            context=context_tasks
        )

    def route_travel_query(self, agent: Agent, query: str):
        return Task(
            description=f"""Analyze the user's travel query: '{query}'.
            Determine if it's primarily a flight search, a hotel search, or a request for a full itinerary.
            Output the identified intent (e.g., 'flight_search', 'hotel_search', 'full_itinerary')
            and any key details extracted from the query, specifically,
            for flight search: 'origin', 'destination', 'trip_dates', 'passengers'.
            For hotel search: 'location', 'check_in_date', 'check_out_date', 'preferences'.
            For full itinerary: 'destination', 'duration', and whether flights/hotels are explicitly mentioned as needed.""",
            expected_output="A JSON string indicating the intent and extracted parameters, e.g., {'intent': 'flight_search', 'args': {'origin': 'New York', 'destination': 'London', 'trip_dates': 'next month', 'passengers': 2}}",
            agent=agent
        )

# 8. Orchestrate with CrewAI
class FlightPlanningCrew:
    def __init__(self, llm_model: BaseChatModel):
        self.agents = FlightPlanningAgents(llm_model)
        self.tasks = FlightPlanningTasks()
        self.llm_model = llm_model

    def create_sequential_flight_planner_crew(self):
        flight_researcher = self.agents.flight_researcher()
        hotel_researcher = self.agents.hotel_researcher()
        itinerary_builder = self.agents.itinerary_builder()

        task_flight_research = self.tasks.research_flight_options(
            agent=flight_researcher,
            query="I need a flight from Montreal (YUL) to Paris (CDG) for July 15-22, 2025 for 2 people."
        )

        task_hotel_research = self.tasks.research_hotel_options(
            agent=hotel_researcher,
            query="I need a mid-range hotel in Paris for July 15-22, 2025."
        )

        task_build_itinerary = self.tasks.create_travel_itinerary(
            agent=itinerary_builder,
            original_query="Montreal to Paris 7-day trip",
            context_tasks=[task_flight_research, task_hotel_research]
        )

        crew = Crew(
            agents=[flight_researcher, hotel_researcher, itinerary_builder],
            tasks=[task_flight_research, task_hotel_research, task_build_itinerary],
            process=Process.sequential,
            verbose=False
        )
        return crew

    def create_router_based_flight_planner_crew(self, user_query: str):
        travel_router = self.agents.travel_router()
        flight_researcher = self.agents.flight_researcher()
        hotel_researcher = self.agents.hotel_researcher()
        itinerary_builder = self.agents.itinerary_builder()

        task_route_query = self.tasks.route_travel_query(
            agent=travel_router,
            query=user_query
        )

        crew = Crew(
            agents=[travel_router, flight_researcher, hotel_researcher, itinerary_builder],
            tasks=[task_route_query],
            process=Process.hierarchical,
            manager_llm=self.llm_model,
            verbose=False
        )
        return crew

# 9. Run the Crews
if __name__ == "__main__":
    print("Welcome to your Mistral-powered AI agent for flight planning!")

    crew_builder = FlightPlanningCrew(my_mistral_llm)

    #--- Demo 1: Sequential Flight Planning
    print("\n" + "="*28 + " Demo 1: Sequential Flight Planning " + "="*20)
    sequential_crew = crew_builder.create_sequential_flight_planner_crew()

    print("\nStarting Sequential Flight Planning Crew...")
    try:
        result_sequential = sequential_crew.kickoff()
        print("\n## Sequential Flight Planning Result:")
        print(result_sequential)
    except Exception as e:
        print(f"\nError running sequential crew: {e}")

    print("\n" + "="*60 + "\n")

    #--- Demo 2: Router-based Flight Planning
    print("\n" + "="*28 + " Demo 2: Router-based Flight Planning " + "="*20)
    flight_query = "Find me a cheap flight from London Heathrow (LHR) to New York (JFK) for next month for 1 person."
    router_crew_flight = crew_builder.create_router_based_flight_planner_crew(flight_query)

    print(f"\nStarting Router-based Flight Planning Crew for: '{flight_query}'")
    try:
        result_router_flight = router_crew_flight.kickoff()
        print("\n## Router-based Flight Planning Result (Flight Query):")
        print(result_router_flight)
    except Exception as e:
        print(f"\nError running router crew (flight query): {e}")

Using Mistral Model: magistral-medium-latest
General web search tools (SerperDevTool, WebsiteReadTool) are NOT initialized as per user request.
Welcome to your Mistral-powered AI agent for flight planning!


Starting Sequential Flight Planning Crew...

--- Using Flight API Query Tool with origin='YUL', destination='CDG', dates='July 15-22, 2025', passengers='2'

--- Using Flight API Query Tool with origin='YUL', destination='CDG', dates='July 14-21, 2025', passengers='2'

--- Using Flight API Query Tool with origin='YUL', destination='CDG', dates='July 16-23, 2025', passengers='2'

--- Using Hotel API Query Tool with location='Paris', check-in='July 15, 2025', check-out='July 22, 2025', preferences='mid-range'

--- Using Flight API Query Tool with origin='YUL', destination='CDG', dates='July 15-22, 2025', passengers='1'

--- Using Hotel API Query Tool with location='Paris', check-in='July 15', check-out='July 22', preferences='mid-range'

## Sequential Flight Planning Result:
```markdo