Skip to content

Commit

Permalink
Merge branch 'main' into pia/tlk-682-toolkiteno-update-agent
Browse files Browse the repository at this point in the history
  • Loading branch information
misspia-cohere committed Jun 20, 2024
2 parents 627883f + ed8eb80 commit c2774e9
Show file tree
Hide file tree
Showing 30 changed files with 843 additions and 54 deletions.
77 changes: 77 additions & 0 deletions src/backend/alembic/versions/8bc604e45f2d_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""empty message
Revision ID: 8bc604e45f2d
Revises: 982bbef24559
Create Date: 2024-06-19 16:15:20.386321
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "8bc604e45f2d"
down_revision: Union[str, None] = "982bbef24559"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"organizations",
sa.Column("name", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"user_organization",
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "organization_id", "id"),
)
op.add_column("agents", sa.Column("organization_id", sa.String(), nullable=True))
op.create_foreign_key(
"agents_organization_id_fkey",
"agents",
"organizations",
["organization_id"],
["id"],
ondelete="CASCADE",
)
op.add_column(
"conversations", sa.Column("organization_id", sa.String(), nullable=True)
)
op.create_foreign_key(
"conversations_organization_id_fkey",
"conversations",
"organizations",
["organization_id"],
["id"],
ondelete="CASCADE",
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(
"conversations_organization_id_fkey", "conversations", type_="foreignkey"
)
op.drop_column("conversations", "organization_id")
op.drop_constraint("agents_organization_id_fkey", "agents", type_="foreignkey")
op.drop_column("agents", "organization_id")
op.drop_table("user_organization")
op.drop_table("organizations")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion src/backend/config/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class RouterName(StrEnum):
RouterName.CHAT: {
"default": [
Depends(get_session),
Depends(validate_chat_request),
Depends(validate_user_header),
Depends(validate_chat_request),
],
"auth": [
Depends(get_session),
Expand Down
14 changes: 10 additions & 4 deletions src/backend/config/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@


class ToolName(StrEnum):
Wiki_Retriever_LangChain = "Wikipedia"
Wiki_Retriever_LangChain = "wikipedia"
Search_File = "search_file"
Read_File = "read_document"
Python_Interpreter = "Python_Interpreter"
Calculator = "Calculator"
Tavily_Internet_Search = "Internet_Search"
Python_Interpreter = "python_interpreter"
Calculator = "calculator"
Tavily_Internet_Search = "internet_search"


ALL_TOOLS = {
ToolName.Wiki_Retriever_LangChain: ManagedTool(
name=ToolName.Wiki_Retriever_LangChain,
display_name="Wikipedia",
implementation=LangChainWikiRetriever,
parameter_definitions={
"query": {
Expand All @@ -54,6 +55,7 @@ class ToolName(StrEnum):
),
ToolName.Search_File: ManagedTool(
name=ToolName.Search_File,
display_name="Search File",
implementation=SearchFileTool,
parameter_definitions={
"search_query": {
Expand All @@ -75,6 +77,7 @@ class ToolName(StrEnum):
),
ToolName.Read_File: ManagedTool(
name=ToolName.Read_File,
display_name="Read Document",
implementation=ReadFileTool,
parameter_definitions={
"filename": {
Expand All @@ -91,6 +94,7 @@ class ToolName(StrEnum):
),
ToolName.Python_Interpreter: ManagedTool(
name=ToolName.Python_Interpreter,
display_name="Python Interpreter",
implementation=PythonInterpreter,
parameter_definitions={
"code": {
Expand All @@ -107,6 +111,7 @@ class ToolName(StrEnum):
),
ToolName.Calculator: ManagedTool(
name=ToolName.Calculator,
display_name="Calculator",
implementation=Calculator,
parameter_definitions={
"code": {
Expand All @@ -123,6 +128,7 @@ class ToolName(StrEnum):
),
ToolName.Tavily_Internet_Search: ManagedTool(
name=ToolName.Tavily_Internet_Search,
display_name="Web Search",
implementation=TavilyInternetSearch,
parameter_definitions={
"query": {
Expand Down
14 changes: 12 additions & 2 deletions src/backend/crud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,29 @@ def get_agent_by_name(db: Session, agent_name: str) -> Agent:
return db.query(Agent).filter(Agent.name == agent_name).first()


def get_agents(db: Session, offset: int = 0, limit: int = 100) -> list[Agent]:
def get_agents(
db: Session,
offset: int = 0,
limit: int = 100,
organization_id: str = None,
) -> list[Agent]:
"""
Get all agents for a user.
Args:
db (Session): Database session.
offset (int): Offset of the results.
limit (int): Limit of the results.
organization_id (str): Organization ID.
Returns:
list[Agent]: List of agents.
"""
return db.query(Agent).offset(offset).limit(limit).all()
query = db.query(Agent)
if organization_id is not None:
query = query.filter(Agent.organization_id == organization_id)
query = query.offset(offset).limit(limit)
return query.all()


def update_agent(db: Session, agent: Agent, new_agent: UpdateAgent) -> Agent:
Expand Down
8 changes: 6 additions & 2 deletions src/backend/crud/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def get_conversations(
offset: int = 0,
limit: int = 100,
agent_id: str | None = None,
organization_id: str | None = None,
) -> list[Conversation]:
"""
List all conversations.
Args:
db (Session): Database session.
user_id (str): User ID.
organization_id (str): Organization ID.
agent_id (str): Agent ID.
offset (int): Offset to start the list.
limit (int): Limit of conversations to be listed.
Expand All @@ -64,8 +67,9 @@ def get_conversations(
query = db.query(Conversation).filter(Conversation.user_id == user_id)
if agent_id is not None:
query = query.filter(Conversation.agent_id == agent_id)

query = query.offset(offset).limit(limit)
if organization_id is not None:
query = query.filter(Conversation.organization_id == organization_id)
query = query.order_by(Conversation.updated_at.desc()).offset(offset).limit(limit)

return query.all()

Expand Down
180 changes: 180 additions & 0 deletions src/backend/crud/organization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from sqlalchemy.orm import Session

from backend.database_models.organization import Organization
from backend.database_models.user import User, UserOrganizationAssociation
from backend.schemas.organization import UpdateOrganization


def create_organization(db: Session, organization: Organization) -> Organization:
""" "
Create a new organization.
Args:
db (Session): Database session.
organization (Organization): Organization data to be created.
Returns:
Organization: Created organization.
"""
db.add(organization)
db.commit()
db.refresh(organization)
return organization


def get_organization(db: Session, organization_id: str) -> Organization:
"""
Get a organization by ID.
Args:
db (Session): Database session.
organization_id (str): Organization ID.
Returns:
Organization: Organization with the given ID.
"""
return db.query(Organization).filter(Organization.id == organization_id).first()


def get_organizations(
db: Session, offset: int = 0, limit: int = 100
) -> list[Organization]:
"""
List all organizations.
Args:
db (Session): Database session.
offset (int): Offset to start the list.
limit (int): Limit of organizations to be listed.
Returns:
list[Organization]: List of organizations.
"""
return db.query(Organization).offset(offset).limit(limit).all()


def get_organizations_by_user_id(
db: Session, user_id: str, offset: int = 0, limit: int = 100
) -> list[Organization]:
"""
List all organizations by user id
Args:
db (Session): Database session.
user_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of organizations to be listed.
Returns:
list[Organization]: List of organizations.
"""
return (
db.query(Organization)
.join(
UserOrganizationAssociation,
Organization.id == UserOrganizationAssociation.organization_id,
)
.filter(UserOrganizationAssociation.user_id == user_id)
.limit(limit)
.offset(offset)
.all()
)


def update_organization(
db: Session, organization: Organization, new_organization: UpdateOrganization
) -> Organization:
"""
Update a organization by ID.
Args:
db (Session): Database session.
organization (Organization): Organization to be updated.
new_organization (Organization): New organization data.
Returns:
Organization: Updated organization.
"""
for attr, value in new_organization.model_dump(exclude_none=True).items():
setattr(organization, attr, value)
db.commit()
db.refresh(organization)
return organization


def delete_organization(db: Session, organization_id: str) -> None:
"""
Delete a organization by ID.
Args:
db (Session): Database session.
organization_id (str): Organization ID.
"""
organization = db.query(Organization).filter(Organization.id == organization_id)
organization.delete()
db.commit()


def add_user_to_organization(db: Session, user_id: str, organization_id: str) -> None:
"""
Add a user to an organization.
Args:
db (Session): Database session.
user_id (str): User ID.
organization_id (str): Organization ID.
"""
user_organization_association = UserOrganizationAssociation(
user_id=user_id, organization_id=organization_id
)
db.add(user_organization_association)
db.commit()


def remove_user_from_organization(
db: Session, user_id: str, organization_id: str
) -> None:
"""
Remove a user from an organization.
Args:
db (Session): Database session.
user_id (str): User ID.
organization_id (str): Organization ID.
"""
user_organization_association = (
db.query(UserOrganizationAssociation)
.filter(user_id == user_id, organization_id == organization_id)
.first()
)
if user_organization_association:
db.delete(user_organization_association)
db.commit()


def get_users_by_organization_id(
db: Session, organization_id: str, offset: int = 0, limit: int = 100
) -> list[User]:
"""
List all users by organization ID.
Args:
db (Session): Database session.
organization_id (str): Organization ID.
offset (int): Offset to start the list.
limit (int): Limit of users to be listed.
Returns:
list[User]: List of users.
"""
return (
db.query(User)
.join(
UserOrganizationAssociation,
User.id == UserOrganizationAssociation.user_id,
)
.filter(UserOrganizationAssociation.organization_id == organization_id)
.limit(limit)
.offset(offset)
.all()
)
1 change: 1 addition & 0 deletions src/backend/database_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from backend.database_models.document import *
from backend.database_models.file import *
from backend.database_models.message import *
from backend.database_models.organization import *
from backend.database_models.user import *
Loading

0 comments on commit c2774e9

Please sign in to comment.