In [12]:
import vertexai
import streamlit as st
from vertexai.preview import generative_models
from vertexai.preview.generative_models import GenerativeModel, Tool, Part, Content, ChatSession


In [13]:
project = "trans-array-427509-h2"
vertexai.init(project = project)

In [14]:
# Declare Tool
get_search_flights = generative_models.FunctionDeclaration(
    name="get_search_flights",
    description="Tool for searching a flight with origin, destination, and departure date",
    parameters={
        "type": "object",
        "properties": {
            "origin": {
                "type": "string",
                "description": "The airport of departure for the flight given in airport code such as LAX, SFO, BOS, etc."
            },
            "destination": {
                "type": "string",
                "description": "The airport of destination for the flight given in airport code such as LAX, SFO, BOS, etc."
            },
            "departure_date": {
                "type": "string",
                "format": "date",
                "description": "The date of departure for the flight in YYYY-MM-DD format"
            },
        },
        "required": [
            "origin",
            "destination",
            "departure_date"
        ]
    },
)

# Instantiate tool and model with tools
search_tool = generative_models.Tool(
    function_declarations=[get_search_flights],
)

In [15]:
config = generative_models.GenerationConfig(temperature=0.4)
# Load model with config
model = GenerativeModel(
    "gemini-pro",
    tools = [search_tool],
    generation_config = config
)

In [1]:
# Function to extract relevant information for flight_id
def extract_flight_info(conversation_history):
    flight_number = None
    departure_date = None

    for message in reversed(conversation_history):
        if not flight_number and "flight number" in message["content"].lower():
            # Attempt to extract flight number (e.g., "flight number is 123")
            parts = message["content"].lower().split("flight number")
            if len(parts) > 1:
                flight_number = parts[1].strip()
                try:
                    flight_number = int(flight_number)  # Convert to integer
                except ValueError:
                    flight_number = None  # Reset if not a valid number

        if not departure_date and "departure date" in message["content"].lower():
            # Attempt to extract departure date (e.g., "departure date is 2024-05-01")
            parts = message["content"].lower().split("departure date")
            if len(parts) > 1:
                departure_date = parts[1].strip()

        if flight_number and departure_date:
            break  # Found both, stop iterating

    return flight_number, departure_date


In [2]:
from services.flight_manager import find_flight_id,book_flight,search_flights,update_flight_booking,remove_flight_booking,find_customer_id,find_booking_id
def handle_response(response, chat_history):  # Pass chat history
    st.write(response)
    output = []
    st.write(response.candidates[0].content.parts)
    parts = response.candidates[0].content.parts
    for part in parts:
        if part.function_call.args:
            function_name = part.function_call.name
            function_args = part.function_call.args

            function_params = {}
            for key in function_args:
                value = function_args[key]
                function_params[key] = value

            results = ""

            function_map = {
                "book_flights": book_flight,
                "get_search_flights": search_flights,
                "remove_flight_booking": remove_flight_booking,
                "update_flight_booking": update_flight_booking,
                "find_flight_id": find_flight_id,
                "find_customer_id": find_customer_id,
                "find_booking_id": find_booking_id,
            }

            function_to_call = function_map.get(function_name)

            # Handle flight_id extraction if book_flights is called
            if 'flight_id' in function_params :
                flight_number, departure_date = extract_flight_info(chat_history)  # Extract from chat history
                if flight_number and departure_date:
                        function_params['flight_id'] = find_flight_id(
                            flight_number=flight_number, departure_date=departure_date)
                        print("Found flight_id:", function_params['flight_id'])
                   
            if function_to_call:
                try:
                    print(function_to_call)
                    print(function_params)
                    results = function_to_call(**function_params)
                    print(results)
                    # ... (rest of your result handling code)
                except Exception as e:
                    return f"Error executing function '{function_name}': {str(e)}"
            else:
                return f"Error: Unknown function '{function_name}'"

            if results != "":
                intermediate_response = chat.send_message(
                    Part.from_function_response(
                        name=function_name,
                        response=results
                    )
                )
                st.write(intermediate_response)
                output.append(intermediate_response.candidates[0].content.parts[0].text)


            else:
                return "Search Failed"
        else:
            output.append(part.text)
    return output

In [3]:
response  = """candidates {
  content {
    role: "model"
    parts {
      function_call {
        name: "book_flights"
        args {
          fields {
            key: "seat_type"
            value {
              string_value: "economy"
            }
          }
          fields {
            key: "phone_number"
            value {
              number_value: 222
            }
          }
          fields {
            key: "num_seats"
            value {
              number_value: 1
            }
          }
          fields {
            key: "last_name"
            value {
              string_value: "De Silva"
            }
          }
          fields {
            key: "flight_id"
            value {
              string_value: "unknown"
            }
          }
          fields {
            key: "first_name"
            value {
              string_value: "Sanuthi"
            }
          }
          fields {
            key: "email"
            value {
              string_value: "sanuthi@gmail.com"
            }
          }
          fields {
            key: "date_of_birth"
            value {
              string_value: "2008-01-01"
            }
          }
          fields {
            key: "booking_date"
            value {
              string_value: "2024-08-21"
            }
          }
        }
      }
    }
  }
  finish_reason: STOP
  safety_ratings {
    category: HARM_CATEGORY_HATE_SPEECH
    probability: NEGLIGIBLE
    probability_score: 0.172851562
    severity: HARM_SEVERITY_LOW
    severity_score: 0.2109375
  }
  safety_ratings {
    category: HARM_CATEGORY_DANGEROUS_CONTENT
    probability: NEGLIGIBLE
    probability_score: 0.353515625
    severity: HARM_SEVERITY_MEDIUM
    severity_score: 0.400390625
  }
  safety_ratings {
    category: HARM_CATEGORY_HARASSMENT
    probability: NEGLIGIBLE
    probability_score: 0.271484375
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.13671875
  }
  safety_ratings {
    category: HARM_CATEGORY_SEXUALLY_EXPLICIT
    probability: NEGLIGIBLE
    probability_score: 0.142578125
    severity: HARM_SEVERITY_LOW
    severity_score: 0.241210938
  }
  avg_logprobs: -0.0099325869232416153
}
usage_metadata {
  prompt_token_count: 2278
  candidates_token_count: 64
  total_token_count: 2342
}"""

output = handle_response(response)
print(output)


TypeError: handle_response() missing 1 required positional argument: 'chat_history'