Skip to content

Commit

Permalink
Organizations initial commit - review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneLightsOn committed Jun 19, 2024
1 parent 882f875 commit 292d899
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""empty message
Revision ID: 19127a0eeefb
Revision ID: 8bc604e45f2d
Revises: 982bbef24559
Create Date: 2024-06-18 19:22:26.209490
Create Date: 2024-06-19 16:15:20.386321
"""

Expand All @@ -12,7 +12,7 @@
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "19127a0eeefb"
revision: str = "8bc604e45f2d"
down_revision: Union[str, None] = "982bbef24559"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
Expand All @@ -32,11 +32,14 @@ def upgrade() -> None:
"user_organization",
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
sa.PrimaryKeyConstraint("user_id", "organization_id", "id"),
)
op.add_column("agents", sa.Column("organization_id", sa.String(), nullable=True))
op.create_foreign_key(
Expand Down
39 changes: 15 additions & 24 deletions src/backend/crud/organization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session

from backend.database_models.organization import Organization
from backend.database_models.user import User, user_organization_association
from backend.database_models.user import User, UserOrganizationAssociation
from backend.schemas.organization import UpdateOrganization


Expand Down Expand Up @@ -71,10 +71,10 @@ def get_organizations_by_user_id(
return (
db.query(Organization)
.join(
user_organization_association,
Organization.id == user_organization_association.c.organization_id,
UserOrganizationAssociation,
Organization.id == UserOrganizationAssociation.organization_id,
)
.filter(user_organization_association.c.user_id == user_id)
.filter(UserOrganizationAssociation.user_id == user_id)
.limit(limit)
.offset(offset)
.all()
Expand Down Expand Up @@ -124,11 +124,10 @@ def add_user_to_organization(db: Session, user_id: str, organization_id: str) ->
user_id (str): User ID.
organization_id (str): Organization ID.
"""
db.execute(
user_organization_association.insert().values(
user_id=user_id, organization_id=organization_id
)
user_organization_association = UserOrganizationAssociation(
user_id=user_id, organization_id=organization_id
)
db.add(user_organization_association)
db.commit()


Expand All @@ -143,21 +142,13 @@ def remove_user_from_organization(
user_id (str): User ID.
organization_id (str): Organization ID.
"""
record = (
db.query(user_organization_association)
.filter(
user_organization_association.c.user_id == user_id,
user_organization_association.c.organization_id == organization_id,
)
user_organization_association = (
db.query(UserOrganizationAssociation)
.filter(user_id == user_id, organization_id == organization_id)
.first()
)
if record:
db.execute(
user_organization_association.delete().where(
user_organization_association.c.user_id == user_id,
user_organization_association.c.organization_id == organization_id,
)
)
if user_organization_association:
db.delete(user_organization_association)
db.commit()


Expand All @@ -179,10 +170,10 @@ def get_users_by_organization_id(
return (
db.query(User)
.join(
user_organization_association,
User.id == user_organization_association.c.user_id,
UserOrganizationAssociation,
User.id == UserOrganizationAssociation.user_id,
)
.filter(user_organization_association.c.organization_id == organization_id)
.filter(UserOrganizationAssociation.organization_id == organization_id)
.limit(limit)
.offset(offset)
.all()
Expand Down
4 changes: 2 additions & 2 deletions src/backend/database_models/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from backend.database_models.agent import Agent
from backend.database_models.base import Base
from backend.database_models.conversation import Conversation
from backend.database_models.user import User, user_organization_association
from backend.database_models.user import User, UserOrganizationAssociation


class Organization(Base):
__tablename__ = "organizations"

name: Mapped[str] = mapped_column()
users: Mapped[List["User"]] = relationship(
secondary=user_organization_association, backref="organizations"
"User", secondary="user_organization", backref="organizations"
)
conversations: Mapped[List["Conversation"]] = relationship(backref="organization")
agents: Mapped[List["Agent"]] = relationship(backref="organization")
31 changes: 21 additions & 10 deletions src/backend/database_models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@

from backend.database_models.base import Base

user_organization_association = Table(
"user_organization",
Base.metadata,
Column("user_id", ForeignKey("users.id", ondelete="CASCADE"), primary_key=True),
Column(
"organization_id",
ForeignKey("organizations.id", ondelete="CASCADE"),
primary_key=True,
),
)
# user_organization_association = Table(
# "user_organization",
# Base.metadata,
# Column("user_id", ForeignKey("users.id", ondelete="CASCADE"), primary_key=True),
# Column(
# "organization_id",
# ForeignKey("organizations.id", ondelete="CASCADE"),
# primary_key=True,
# ),
# )


class UserOrganizationAssociation(Base):
__tablename__ = "user_organization"

user_id: Mapped[str] = mapped_column(
ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
organization_id: Mapped[str] = mapped_column(
ForeignKey("organizations.id", ondelete="CASCADE"), primary_key=True
)


class User(Base):
Expand Down
46 changes: 46 additions & 0 deletions src/backend/tests/crud/test_organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,49 @@ def test_add_user_to_non_existent_organization(session):
user = get_factory("User", session).create()
with pytest.raises(Exception):
organization_crud.add_user_to_organization(session, user.id, "123")


def test_user_organization_association(session):
organization = get_factory("Organization", session).create(name="Test Organization")
user = get_factory("User", session).create(fullname="John Doe")
user.organizations.append(organization)
assert user.organizations[0].name == "Test Organization"


def test_user_organization_association_reverse(session):
organization = get_factory("Organization", session).create(name="Test Organization")
user = get_factory("User", session).create(fullname="John Doe")
organization.users.append(user)
assert organization.users[0].fullname == "John Doe"


def test_agent_organization_association(session):
organization = get_factory("Organization", session).create(name="Test Organization")
agent = get_factory("Agent", session).create(name="Test Agent")
agent.organization = organization
assert agent.organization.name == "Test Organization"


def test_agent_organization_association_reverse(session):
organization = get_factory("Organization", session).create(name="Test Organization")
agent = get_factory("Agent", session).create(name="Test Agent")
organization.agents.append(agent)
assert organization.agents[0].name == "Test Agent"


def test_conversation_organization_association(session):
organization = get_factory("Organization", session).create(name="Test Organization")
conversation = get_factory("Conversation", session).create(
title="Test Conversation"
)
conversation.organization = organization
assert conversation.organization.name == "Test Organization"


def test_conversation_organization_association_reverse(session):
organization = get_factory("Organization", session).create(name="Test Organization")
conversation = get_factory("Conversation", session).create(
title="Test Conversation"
)
organization.conversations.append(conversation)
assert organization.conversations[0].title == "Test Conversation"

0 comments on commit 292d899

Please sign in to comment.