Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into tomeu/os-2108
Browse files Browse the repository at this point in the history
  • Loading branch information
tomtobac committed Jun 19, 2024
2 parents 91a24a8 + c55dcb5 commit a84ee9b
Show file tree
Hide file tree
Showing 85 changed files with 1,561 additions and 1,009 deletions.
4 changes: 2 additions & 2 deletions docs/custom_tool_guides/tool_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Remember, you can also access your tools via the API.
- List tools:

```bash
curl --location --request GET 'http://localhost:8000/tools' \
curl --location --request GET 'http://localhost:8000/v1/tools' \
--header 'User-Id: me' \
--header 'Content-Type: application/json' \
--data '{}'
Expand All @@ -188,7 +188,7 @@ curl --location --request GET 'http://localhost:8000/tools' \
- Chat turns with tools:

```bash
curl --location 'http://localhost:8000/chat-stream' \
curl --location 'http://localhost:8000/v1/chat-stream' \
--header 'User-Id: me' \
--header 'Content-Type: application/json' \
--data '{
Expand Down
8 changes: 4 additions & 4 deletions docs/how_to_guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ To add your own deployment:

## How to call the backend as an API

It is possible to just run the backend service, and call it in the same manner as the Cohere API. Note streaming and non streaming endpoints are split into 'http://localhost:8000/chat-stream' and 'http://localhost:8000/chat' compared to the API. For example, to stream:
It is possible to just run the backend service, and call it in the same manner as the Cohere API. Note streaming and non streaming endpoints are split into 'http://localhost:8000/v1/chat-stream' and 'http://localhost:8000/v1/chat' compared to the API. For example, to stream:

```bash
curl --location 'http://localhost:8000/chat-stream' \
curl --location 'http://localhost:8000/v1/chat-stream' \
--header 'User-Id: me' \
--header 'Content-Type: application/json' \
--data '{
Expand Down Expand Up @@ -79,12 +79,12 @@ Python interpreter and Tavily Internet search are provided in the toolkit by def

Example API call:
```bash
curl --location 'http://localhost:8000/langchain-chat' \
curl --location 'http://localhost:8000/v1/langchain-chat' \
--header 'User-Id: me' \
--header 'Content-Type: application/json' \
--data '{
"message": "Tell me about the aya model",
"tools": [{"name": "Python_Interpreter"},{"name": "Internet Search"},]
"tools": [{"name": "Python_Interpreter"},{"name": "Internet_Search"}]
}'
```

Expand Down
231 changes: 156 additions & 75 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ python-dotenv = "^1.0.1"
pytest-dotenv = "^0.5.2"
alembic = "^1.13.1"
psycopg2 = "^2.9.9"
psycopg2-binary = "^2.9.9"
python-multipart = "^0.0.9"
sse-starlette = "^2.0.0"
boto3 = "^1.0.0"
Expand Down
43 changes: 43 additions & 0 deletions src/backend/alembic/versions/982bbef24559_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""empty message
Revision ID: 982bbef24559
Revises: 3f207ae41477
Create Date: 2024-06-18 13:56:50.044706
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "982bbef24559"
down_revision: Union[str, None] = "3f207ae41477"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("conversations", sa.Column("agent_id", sa.String(), nullable=True))
op.drop_index("conversation_user_id", table_name="conversations")
op.create_index(
"conversation_user_agent_index",
"conversations",
["user_id", "agent_id"],
unique=False,
)
op.create_foreign_key(
None, "conversations", "agents", ["agent_id"], ["id"], ondelete="CASCADE"
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "conversations", type_="foreignkey")
op.drop_index("conversation_user_agent_index", table_name="conversations")
op.create_index("conversation_user_id", "conversations", ["user_id"], unique=False)
op.drop_column("conversations", "agent_id")
# ### end Alembic commands ###
4 changes: 2 additions & 2 deletions src/backend/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def show_examples():
bcolors.OKCYAN,
)
print_styled(
"""\tcurl --location 'http://localhost:8000/chat-stream' --header 'User-Id: test-user' --header 'Content-Type: application/json' --data '{"message": "hey"}'""",
"""\tcurl --location 'http://localhost:8000/v1/chat-stream' --header 'User-Id: test-user' --header 'Content-Type: application/json' --data '{"message": "hey"}'""",
bcolors.OKCYAN,
)

Expand All @@ -198,7 +198,7 @@ def show_examples():
bcolors.OKCYAN,
)
print_styled(
"""\tcurl --location 'http://localhost:8000/chat-stream' --header 'User-Id: test-user' --header 'Deployment-Name: SageMaker' --header 'Content-Type: application/json' --data '{"message": "hey"}'""",
"""\tcurl --location 'http://localhost:8000/v1/chat-stream' --header 'User-Id: test-user' --header 'Deployment-Name: SageMaker' --header 'Content-Type: application/json' --data '{"message": "hey"}'""",
bcolors.OKCYAN,
)

Expand Down
16 changes: 15 additions & 1 deletion src/backend/crud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_agent(db: Session, agent: Agent) -> Agent:
return agent


def get_agent(db: Session, agent_id: str) -> Agent:
def get_agent_by_id(db: Session, agent_id: str) -> Agent:
"""
Get an agent by its ID.
Expand All @@ -37,6 +37,20 @@ def get_agent(db: Session, agent_id: str) -> Agent:
return db.query(Agent).filter(Agent.id == agent_id).first()


def get_agent_by_name(db: Session, agent_name: str) -> Agent:
"""
Get an agent by its name.
Args:
db (Session): Database session.
agent_name (str): Agent name.
Returns:
Agent: Agent with the given name.
"""
return db.query(Agent).filter(Agent.name == agent_name).first()


def get_agents(db: Session, offset: int = 0, limit: int = 100) -> list[Agent]:
"""
Get all agents for a user.
Expand Down
21 changes: 12 additions & 9 deletions src/backend/crud/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def get_conversation(


def get_conversations(
db: Session, user_id: str, offset: int = 0, limit: int = 100
db: Session,
user_id: str,
offset: int = 0,
limit: int = 100,
agent_id: str | None = None,
) -> list[Conversation]:
"""
List all conversations.
Expand All @@ -57,14 +61,13 @@ def get_conversations(
Returns:
list[Conversation]: List of conversations.
"""
return (
db.query(Conversation)
.filter(Conversation.user_id == user_id)
.order_by(Conversation.updated_at.desc())
.offset(offset)
.limit(limit)
.all()
)
query = db.query(Conversation).filter(Conversation.user_id == user_id)
if agent_id is not None:
query = query.filter(Conversation.agent_id == agent_id)

query = query.offset(offset).limit(limit)

return query.all()


def update_conversation(
Expand Down
8 changes: 6 additions & 2 deletions src/backend/database_models/conversation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List

from sqlalchemy import Index, String
from sqlalchemy import ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from backend.database_models.agent import Agent
from backend.database_models.base import Base
from backend.database_models.file import File
from backend.database_models.message import Message
Expand All @@ -18,9 +19,12 @@ class Conversation(Base):

text_messages: Mapped[List["Message"]] = relationship()
files: Mapped[List["File"]] = relationship()
agent_id: Mapped[str] = mapped_column(
ForeignKey("agents.id", ondelete="CASCADE"), nullable=True
)

@property
def messages(self):
return sorted(self.text_messages, key=lambda x: x.position)

__table_args__ = (Index("conversation_user_id", user_id),)
__table_args__ = (Index("conversation_user_agent_index", user_id, agent_id),)
35 changes: 26 additions & 9 deletions src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Depends(validate_create_agent_request),
],
)
def create_agent(session: DBSessionDep, agent: CreateAgent, request: Request):
def create_agent(session: DBSessionDep, agent: CreateAgent, request: Request) -> Agent:
user_id = get_header_user_id(request)

agent_data = AgentModel(
Expand All @@ -40,7 +40,10 @@ def create_agent(session: DBSessionDep, agent: CreateAgent, request: Request):
tools=agent.tools,
)

return agent_crud.create_agent(session, agent_data)
try:
return agent_crud.create_agent(session, agent_data)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.get("", response_model=list[Agent])
Expand All @@ -59,11 +62,16 @@ async def list_agents(
Returns:
list[Agent]: List of agents.
"""
return agent_crud.get_agents(session, offset=offset, limit=limit)
try:
return agent_crud.get_agents(session, offset=offset, limit=limit)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.get("/{agent_id}", response_model=Agent)
async def get_agent(agent_id: str, session: DBSessionDep, request: Request) -> Agent:
async def get_agent_by_id(
agent_id: str, session: DBSessionDep, request: Request
) -> Agent:
"""
Args:
agent_id (str): Agent ID.
Expand All @@ -75,7 +83,10 @@ async def get_agent(agent_id: str, session: DBSessionDep, request: Request) -> A
Raises:
HTTPException: If the agent with the given ID is not found.
"""
agent = agent_crud.get_agent(session, agent_id)
try:
agent = agent_crud.get_agent_by_id(session, agent_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if not agent:
raise HTTPException(
Expand Down Expand Up @@ -115,15 +126,18 @@ async def update_agent(
Raises:
HTTPException: If the agent with the given ID is not found.
"""
agent = agent_crud.get_agent(session, agent_id)
agent = agent_crud.get_agent_by_id(session, agent_id)

if not agent:
raise HTTPException(
status_code=404,
detail=f"Agent with ID: {agent_id} not found.",
)

agent = agent_crud.update_agent(session, agent, new_agent)
try:
agent = agent_crud.update_agent(session, agent, new_agent)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return agent

Expand All @@ -146,14 +160,17 @@ async def delete_agent(
Raises:
HTTPException: If the agent with the given ID is not found.
"""
agent = agent_crud.get_agent(session, agent_id)
agent = agent_crud.get_agent_by_id(session, agent_id)

if not agent:
raise HTTPException(
status_code=404,
detail=f"Agent with ID: {agent_id} not found.",
)

agent_crud.delete_agent(session, agent_id)
try:
agent_crud.delete_agent(session, agent_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

return DeleteAgent()
11 changes: 9 additions & 2 deletions src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,30 @@ async def get_conversation(

@router.get("", response_model=list[ConversationWithoutMessages])
async def list_conversations(
*, offset: int = 0, limit: int = 100, session: DBSessionDep, request: Request
*,
offset: int = 0,
limit: int = 100,
agent_id: str = None,
session: DBSessionDep,
request: Request,
) -> list[ConversationWithoutMessages]:
"""
List all conversations.
Args:
offset (int): Offset to start the list.
limit (int): Limit of conversations to be listed.
agent_id (str): Query parameter for agent ID to optionally filter conversations by agent.
session (DBSessionDep): Database session.
request (Request): Request object.
Returns:
list[ConversationWithoutMessages]: List of conversations.
"""
user_id = get_header_user_id(request)

return conversation_crud.get_conversations(
session, offset=offset, limit=limit, user_id=user_id
session, offset=offset, limit=limit, user_id=user_id, agent_id=agent_id
)


Expand Down
2 changes: 1 addition & 1 deletion src/backend/routers/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def list_tools(session: DBSessionDep, agent_id: str | None = None) -> list[Manag
"""
if agent_id:
agent_tools = []
agent = agent_crud.get_agent(session, agent_id)
agent = agent_crud.get_agent_by_id(session, agent_id)

if not agent:
raise HTTPException(
Expand Down
1 change: 1 addition & 0 deletions src/backend/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Conversation(ConversationBase):
messages: List[Message]
files: List[File]
description: Optional[str]
agent_id: Optional[str]

@computed_field(return_type=int)
def total_file_size(self):
Expand Down
12 changes: 11 additions & 1 deletion src/backend/services/request_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS
from backend.config.tools import AVAILABLE_TOOLS
from backend.crud import agent as agent_crud
from backend.database_models.database import DBSessionDep


def validate_user_header(request: Request):
Expand Down Expand Up @@ -112,7 +114,7 @@ async def validate_env_vars(request: Request):
)


async def validate_create_agent_request(request: Request):
async def validate_create_agent_request(session: DBSessionDep, request: Request):
"""
Validate that the create agent request has valid tools, deployments, and compatible models.
Expand All @@ -124,6 +126,14 @@ async def validate_create_agent_request(request: Request):
"""
body = await request.json()

# TODO @scott-cohere: for now we disregard versions and assume agents have unique names, enforce versioning later
agent_name = body.get("name")
agent = agent_crud.get_agent_by_name(session, agent_name)
if agent:
raise HTTPException(
status_code=400, detail=f"Agent {agent_name} already exists."
)

# Validate tools
tools = body.get("tools")
if tools:
Expand Down
Loading

0 comments on commit a84ee9b

Please sign in to comment.