<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/FULL_FlightDM_OPENAI_AGENT_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/openai/openai-python

https://openai.github.io/openai-agents-python/



In [None]:
!pip install openai -q
!pip install colab_env -q
!pip install openai-agents -q

In [2]:
import openai
import colab_env
import os
import openai
openai.api_key = os.environ['OPENAI_API_KEY']

## crewagent


In [3]:
import sqlite3
import datetime
from agents import Agent, Runner  # Corrected import - already present

In [4]:
class CrewAgent(Agent):
    def __init__(
        self,
        name="CrewAgent",
        instructions="Manage crew schedules and reassignments during flight disruptions.",
        db_name="flight_data.db",
    ):
        super().__init__(name=name, instructions=instructions)
        self.db_name = db_name

    def get_crew_schedule(self, flight_id):
        """Retrieves the crew schedule for a given flight."""
        conn = None
        cursor = None
        try:
            conn = sqlite3.connect(self.db_name)
            cursor = conn.cursor()
            query = f"SELECT CAP, FO FROM Flights WHERE AC_Tail = '{flight_id}'"
            cursor.execute(query)
            crew_info = cursor.fetchone()
            return crew_info
        except sqlite3.Error as e:
            print(f"Database error: {e}")
            return None
        finally:
            if cursor:
                cursor.close()
            if conn:
                conn.close()

    def assess_crew_availability(self, disruption_time, affected_airport):
        """Assesses crew availability considering disruption impacts and regulations."""
        conn = None
        cursor = None
        try:
            conn = sqlite3.connect(self.db_name)
            cursor = conn.cursor()

            # Simplified crew availability check (replace with actual logic)
            # This example assumes crew availability is based on not having a flight
            # scheduled within a certain window.
            availability_window_start = disruption_time - 300  # 5 hours before
            availability_window_end = disruption_time + 300  # 5 hours after

            query = f"""
                SELECT DISTINCT CAP, FO
                FROM Flights
                WHERE NOT (
                    (Departure_Time >= {availability_window_start} AND Departure_Time <= {availability_window_end})
                    OR (Arrival_Time >= {availability_window_start} AND Arrival_Time <= {availability_window_end})
                )
            """
            cursor.execute(query)
            available_crew = cursor.fetchall()
            return available_crew
        except sqlite3.Error as e:
            print(f"Database error: {e}")
            return
        finally:
            if cursor:
                cursor.close()
            if conn:
                conn.close()

    def suggest_crew_reassignment(self, flight_id, disruption_time, affected_airport):
        """Suggests crew reassignments based on availability and regulations."""
        crew_info = self.get_crew_schedule(flight_id)
        available_crew = self.assess_crew_availability(disruption_time, affected_airport)

        if crew_info is None:
            return None  # Handle case where crew info is not found

        # Simple logic: Assign the first available crew
        suggested_reassignment = available_crew[0] if available_crew else None
        return suggested_reassignment

    def execute_crew_reassignment(self, flight_id, new_crew):
        """Updates the flight schedule with the new crew assignment."""
        conn = None

## passenger Agent


In [5]:
class PassengerAgent(Agent):
    def __init__(self, name="PassengerAgent", instructions="Manage passenger communication, rebooking, and satisfaction during flight disruptions.", db_name="flight_data.db"):
        super().__init__(name=name, instructions=instructions)
        self.db_name = db_name
        # Remove persistent connection
        #self.conn = sqlite3.connect(self.db_name)
        #self.cursor = self.conn.cursor()

    def get_passenger_manifest(self, flight_id):
        """Retrieves the passenger list for a given flight."""
        # Establish a new connection for each database interaction
        conn = sqlite3.connect(self.db_name)
        cursor = conn.cursor()

        query = f"SELECT Number_of_Passengers FROM Flights WHERE AC_Tail = '{flight_id}'"
        cursor.execute(query)
        passenger_count = cursor.fetchone()[0]
        conn.close()  # Close the connection
        return passenger_count

    def notify_passengers(self, flight_id, message):
        """Notifies passengers of a flight about a disruption."""
        passenger_count = self.get_passenger_manifest(flight_id)
        print(f"Notifying {passenger_count} passengers of flight {flight_id}: {message}")

    def suggest_rebooking_options(self, flight_id, disruption_time, affected_airport):
        """Suggests alternative flight options for rebooking passengers."""
        rebooking_options = ["Alternative Flight 1", "Alternative Flight 2"]  # Simplified options
        return rebooking_options

    def execute_rebooking(self, passenger_id, new_flight):
        """Rebooks a passenger onto a new flight."""
        print(f"Passenger rebooked to {new_flight}")

## ENGINE INTEGRATION

In [6]:
class FleetAgent(Agent):
    def __init__(self, name="FleetAgent", instructions="Manage and recover from flight disruptions by coordinating with crew and passenger agents.", db_name="flight_data.db"):
        super().__init__(name=name, instructions=instructions)
        self.db_name = db_name
        # Instead of creating a persistent connection, establish a new connection for each database interaction
        # Initialize CrewAgent and PassengerAgent instances
        self.crew_agent = CrewAgent(db_name=self.db_name)
        self.passenger_agent = PassengerAgent(db_name=self.db_name)


    def strategize_recovery(self, impact, disrupted_flights, disruption_time, affected_airport):
        """Develops and executes a comprehensive recovery strategy, coordinating with crew and passenger agents."""
        if impact["delayed_count"] > 0:
            print("\n Strategy: Comprehensive Disruption Recovery")
            for flight in disrupted_flights:
                flight_id = flight[0]  # Assuming AC_Tail is the flight ID
                print(f"\n-- Handling flight {flight_id} --")

                # 1. Crew Reassignment
                new_crew = self.crew_agent.suggest_crew_reassignment(flight_id, disruption_time, affected_airport)
                if new_crew:
                    self.crew_agent.execute_crew_reassignment(flight_id, new_crew)
                else:
                    print("  No suitable crew available for reassignment.")

                # 2. Passenger Communication and Rebooking
                self.passenger_agent.notify_passengers(flight_id, "Your flight has been delayed. We are working on recovery options.")
                rebooking_options = self.passenger_agent.suggest_rebooking_options(flight_id, disruption_time, affected_airport)
                if rebooking_options:
                    print(f"  Rebooking options: {rebooking_options}")
                    #  [PLACEHOLDER LOGIC]
                    # Logic to automatically rebook or provide options to passengers
                    # For simplicity, we'll just print a message
                    print("  Passengers are being rebooked.")
                else:
                    print("  No rebooking options available at this time.")

                # 3. Reschedule Flight (if necessary after crew/passenger actions)
                self.retime_flight(flight, disruption_time)  # Reschedule the flight

        else:
            print("\n Strategy: No immediate action needed.")

    def retime_flight(self, flight, disruption_time):
        """Reschedules a delayed flight (more flexible implementation)."""
        ac_tail, origin, destination, departure_time = flight[0], flight[1], flight[2], flight[3]
        delay_duration = 100  #  minutes

        new_departure_time = max(departure_time, disruption_time) + delay_duration
        update_query = f"""
            UPDATE Flights
            SET Departure_Time = {new_departure_time}, Status = 'Rescheduled'
            WHERE AC_Tail = '{ac_tail}' AND Origin = '{origin}' AND Destination = '{destination}'
        """
        # Establish a new connection for each database operation
        conn = sqlite3.connect(self.db_name)
        cursor = conn.cursor()
        cursor.execute(update_query)
        conn.commit() # Commit within this connection's context
        conn.close() # Close the connection explicitly
        print(f"  Flight {ac_tail} rescheduled to {new_departure_time}.")

    def analyze_state(self):
        """Retrieves current flight information from the database."""
        # Establish a new connection
        conn = sqlite3.connect(self.db_name)
        cursor = conn.cursor()
        query = "SELECT AC_Tail, Origin, Destination, Departure_Time, Arrival_Time, Status FROM Flights"
        cursor.execute(query)
        flight_info = cursor.fetchall()
        conn.close() # Close the connection
        return flight_info

    def assess_impact(self, flight_info):
        """Analyzes flight data to identify delays or other issues and estimates the impact."""
        impact = {"delayed_count": 0, "disruption_cost": 0}
        disrupted_flights = []
        for flight in flight_info:
            if flight[5] == "Delayed":
                impact["delayed_count"] += 1
                impact["disruption_cost"] += self.estimate_delay_cost(flight)
                disrupted_flights.append(flight)  # Collect disrupted flight details
        return impact, disrupted_flights

    def estimate_delay_cost(self, flight):
        """Estimates the cost of a flight delay (placeholder for a more complex calculation)."""
        #  [PLACEHOLDER LOGIC]
        delay_cost = 1000  # Assume a fixed cost of $1000 per delayed flight
        return delay_cost

## data managenment

In [7]:
import sqlite3
import datetime
from agents import Agent, Runner  # Corrected import - already present

#1. Flight Data
flight_data = [
    ("N12345", "B737", "SYD", "MEL", 1535, 1700, 160, "219", "213"),
    ("N12345", "B737", "MEL", "ADL", 1755, 1900, 150, "264", "245"),
    ("N12345", "B737", "ADL", "SYD", 2120, 2230, 140, "275", "295"),
    ("C67890", "A320", "SYD", "MEL", 1740, 1845, 170, "363", "211"),
    ("A90123", "E190", "MEL", "SYD", 1720, 1830, 95, "263", "245"),
    ("A90123", "E190", "SYD", "00L", 1955, 2100, 85, "224", "280"),
    ("B34567", "B787", "SYD", "MEL", 1530, 1705, 250, "221", "210"),
    ("B34567", "B787", "MEL", "PER", 1850, 2110, 240, "286", "300"),
    ("D89012", "A330", "SYD", "MEL", 1500, 1615, 300, "307", "186"),
    ("D89012", "A330", "MEL", "ADL", 1750, 1900, 280, "298", "255"),
    ("D89012", "A330", "ADL", "MEL", 1955, 2100, 260, "268", "170"),
    ("E23456", "DH8D", "HTI", "BNE", 1520, 1555, 70, "383", "157"),
    ("E23456", "DH8D", "BNE", "SYD", 1725, 1800, 65, "349", "180"),
    ("E23456", "DH8D", "SYD", "DRW", 1935, 40, 60, "402", "188"),
    ("F67890", "A320", "SYD", "MEL", 1735, 1830, 175, "220", "165"),
    ("G90123", "E190", "CBR", "MEL", 1620, 1725, 105, "327", "145"),
    ("G90123", "E190", "MEL", "SYD", 1800, 1925, 100, "328", "166"),
    ("G90123", "E190", "SYD", "CBR", 2000, 2100, 90, "329", "57"),
    ("H34567", "B737", "SYD", "MEL", 1500, 1625, 155, "218", "180"),
    ("H34567", "B737", "MEL", "ADL", 1730, 1905, 145, "225", "136"),
    ("H34567", "B737", "ADL", "PER", 1940, 2035, 135, "266", "141"),
    ("H34567", "B737", "PER", "ADL", 2115, 2315, 125, "282", "175"),
    ("189012", "A330", "PER", "MEL", 1230, 1800, 290, "286", "177"),
    ("189012", "A330", "MEL", "BNE", 1910, 2120, 270, "348", "167"),
    ("J23456", "B737", "00L", "MEL", 1410, 1630, 165, "337", "180"),
    ("J23456", "B737", "MEL", "SYD", 1745, 1910, 155, "309", "160"),
    ("J23456", "B737", "SYD", "MTCE", 2000, 2135, 145, "229", "123")
]

#2. Database Operations
def create_and_populate_db(db_name="flight_data.db", data=flight_data):
    """
    Creates an SQLite database and populates it with flight data.

    Args:
        db_name (str, optional): The name of the database file.
                              Defaults to "flight_data.db".
        data (list, optional): The flight data to insert into the database.
                              Defaults to the flight_data list defined above.
    """
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()

    # Drop the existing table if it exists to ensure a fresh start
    cursor.execute("DROP TABLE IF EXISTS Flights")

    # Create the table with a unique constraint
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS Flights (
            AC_Tail TEXT,
            AC_Type TEXT,
            Origin TEXT,
            Destination TEXT,
            Departure_Time INTEGER,
            Arrival_Time INTEGER,
            Number_of_Passengers INTEGER,
            CAP TEXT,
            FO TEXT,
            Flight_Type TEXT,
            Status TEXT,  -- Added for disruption management
            UNIQUE(AC_Tail, Origin, Destination)
        )
    """)
    print("Table created successfully")

    # Insert data using a loop and individual execute statements
    for row in data:
        try:
            # Calculate FTI
            departure_time = row[4]
            arrival_time = row[5]
            if arrival_time > departure_time:
                flight_time_interval = arrival_time - departure_time
            else:
                flight_time_interval = (arrival_time + 2400) - departure_time

            flight_type = "Long-haul" if flight_time_interval >= 300 else "Short-haul"

            # Initialize flight status as "Scheduled"
            status = "Scheduled"

            print(f"Inserting data: {row[0]}, {row[1]}, {row[2]}, {row[3]}")
            cursor.execute(f"""
                INSERT INTO Flights (AC_Tail, AC_Type, Origin, Destination,
                Departure_Time, Arrival_Time, Number_of_Passengers, CAP, FO, Flight_Type, Status)
                VALUES ('{row[0]}', '{row[1]}', '{row[2]}', '{row[3]}', {row[4]},
                {row[5]}, {row[6]}, '{row[7]}', '{row[8]}', '{flight_type}', '{status}')
            """)
        except sqlite3.IntegrityError:
            # Ignore duplicate rows
            pass

    conn.commit()
    conn.close()
    print(f"Database '{db_name}' created and data inserted successfully.")

#3. Query Execution
def execute_query(db_name="flight_data.db", query="SELECT * FROM Flights"):
    """
    Executes an SQL query against the database and prints the results.

    Args:
        db_name (str, optional): The name of the database file.
                              Defaults to "flight_data.db".
        query (str, optional): The SQL query to execute.
                             Defaults to a simple SELECT query.
    """
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()

    print(f"Executing query: {query}")
    cursor.execute(query)
    results = cursor.fetchall()
    print(f"Query results: {results}")  # Print the results

    print("\nQuery Results:")
    for row in results:
        print(row)

    conn.close()

#4. Example Queries (Analytics)
def run_analytics(db_name="flight_data.db"):
    """
    Executes and prints the results of several analytical SQL queries.

    Args:
        db_name (str, optional): The name of the database file.
                              Defaults to "flight_data.db".
    """
    print("\n--- Running Analytics ---")

    queries = {
        "Count of Flights by Aircraft Type":
            """
            SELECT AC_Type, COUNT(*) AS Flight_Count
            FROM Flights
            GROUP BY AC_Type;
            """,
        "Flights Originating from SYD":
            """
            SELECT * FROM Flights
            WHERE Origin = 'SYD';
            """,
        "Avg. Passengers by Aircraft Type":
            """
            SELECT AC_Type, AVG(Number_of_Passengers) AS AveragePassengers
            FROM Flights
            GROUP BY AC_Type;
            """,
        "Earliest Departure Time":
            """
            SELECT * FROM Flights
            ORDER BY Departure_Time ASC
            LIMIT 1;
            """,
        "Latest Departure Time":
            """
            SELECT * FROM Flights
            ORDER BY Departure_Time DESC
            LIMIT 1;
            """,
        "CAP/FO Details":
            """
            SELECT AC_Tail, Origin, Destination, CAP, FO, Flight_Type, Status
            FROM Flights;
            """,
        "Number of Unique Aircraft":
            """
            SELECT DISTINCT AC_Tail, AC_Type AS UniqueAircraftCount
            FROM Flights;
            """
    }

    for analysis, query in queries.items():
        print(f"\n--- {analysis} ---")
        execute_query(db_name, query)

#5. Disruption Simulation
def simulate_disruption(db_name="flight_data.db", disruption_start=1700, disruption_end=1900, affected_airport="MEL"):
    """
    Simulates a weather disruption at a given airport and adjusts flight statuses.

    Args:
        db_name (str, optional): The name of the database file.
                              Defaults to "flight_data.db".
        disruption_start (int, optional): The start time of the disruption (e.g., 1700 for 5 PM).
                                   Defaults to 1700.
        disruption_end (int, optional): The end time of the disruption (e.g., 1900 for 7 PM).
                                 Defaults to 1900.
        affected_airport (str, optional): The airport affected by the disruption.
                                   Defaults to "MEL".
    """
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()

    print(f"\n--- Simulating Disruption at {affected_airport} from {disruption_start} to {disruption_end} ---")

    # Identify affected flights
    affected_flights_query = f"""
        SELECT AC_Tail, Origin, Destination, Departure_Time, Arrival_Time
        FROM Flights
        WHERE (Origin = '{affected_airport}' AND Departure_Time >= {disruption_start} AND Departure_Time <= {disruption_end})
        OR (Destination = '{affected_airport}' AND Arrival_Time >= {disruption_start} AND Arrival_Time <= {disruption_end})
    """
    cursor.execute(affected_flights_query)
    affected_flights = cursor.fetchall()

    print("\nAffected Flights:")
    for flight in affected_flights:
        print(flight)

    # Update flight statuses
    for flight in affected_flights:
        ac_tail = flight[0]
        origin = flight[1]
        destination = flight[2]

        # For simplicity, let's just mark these flights as "Delayed"
        update_status_query = f"""
            UPDATE Flights
            SET Status = 'Delayed'
            WHERE AC_Tail = '{ac_tail}' AND Origin = '{origin}' AND Destination = '{destination}'
        """
        cursor.execute(update_status_query)
        print(f"    Updated status for flight {ac_tail} from {origin} to {destination} to Delayed")

    conn.commit()  # Ensure changes are committed
    conn.close()

    # Re-establish connection for the final query (important!)
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()

    print("\nDisruption simulation completed.")

## main execution

In [8]:
import nest_asyncio
nest_asyncio.apply()

if __name__ == "__main__":
    create_and_populate_db()
    run_analytics()

    disruption_start_time = 1700
    disruption_end_time = 1900
    affected_airport = "MEL"
    simulate_disruption(db_name="flight_data.db", disruption_start=disruption_start_time,
                         disruption_end=disruption_end_time, affected_airport=affected_airport)

    print("\n-- Flight Data After Disruption ---")
    execute_query(query="SELECT AC_Tail, Origin, Destination, Departure_Time, Arrival_Time, Status FROM Flights")

    print("\n--- FleetAgent in Action (with Runner) --")
    fleet_agent = FleetAgent()
    result = Runner.run_sync(fleet_agent, "Analyze the current flight state and strategize recovery.")
    print(result.final_output) if result.final_output else None

    flight_info = fleet_agent.analyze_state()
    print("\nFlight Information (from Agent):")
    for flight in flight_info:
        print(flight)

    impact, disrupted_flights = fleet_agent.assess_impact(flight_info)  # Get disrupted flights
    print("\nDisruption Impact (from Agent):")
    print(impact)

    total_cost = impact["disruption_cost"]
    print(f"\nTotal Disruption Cost: {total_cost}")

    fleet_agent.strategize_recovery(impact, disrupted_flights, disruption_start_time, affected_airport)  # Pass disrupted flights and disruption time

    print("\n--- Flight Data After Agent Recovery ---")
    execute_query(query="SELECT AC_Tail, Origin, Destination, Departure_Time, Arrival_Time, Status FROM Flights")

Table created successfully
Inserting data: N12345, B737, SYD, MEL
Inserting data: N12345, B737, MEL, ADL
Inserting data: N12345, B737, ADL, SYD
Inserting data: C67890, A320, SYD, MEL
Inserting data: A90123, E190, MEL, SYD
Inserting data: A90123, E190, SYD, 00L
Inserting data: B34567, B787, SYD, MEL
Inserting data: B34567, B787, MEL, PER
Inserting data: D89012, A330, SYD, MEL
Inserting data: D89012, A330, MEL, ADL
Inserting data: D89012, A330, ADL, MEL
Inserting data: E23456, DH8D, HTI, BNE
Inserting data: E23456, DH8D, BNE, SYD
Inserting data: E23456, DH8D, SYD, DRW
Inserting data: F67890, A320, SYD, MEL
Inserting data: G90123, E190, CBR, MEL
Inserting data: G90123, E190, MEL, SYD
Inserting data: G90123, E190, SYD, CBR
Inserting data: H34567, B737, SYD, MEL
Inserting data: H34567, B737, MEL, ADL
Inserting data: H34567, B737, ADL, PER
Inserting data: H34567, B737, PER, ADL
Inserting data: 189012, A330, PER, MEL
Inserting data: 189012, A330, MEL, BNE
Inserting data: J23456, B737, 00L, ME