Skip to content

Commit

Permalink
feature: add pending session policy and impl predicate (#1226)
Browse files Browse the repository at this point in the history
Co-authored-by: Joongi Kim <joongi@lablup.com>
  • Loading branch information
fregataa and achimnol committed Feb 28, 2024
1 parent b6c4181 commit 59f546c
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 0 deletions.
1 change: 1 addition & 0 deletions changes/1226.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a policy to predicate to limit the number and resources of concurrent pending sessions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""add_max_pending_session_count
Revision ID: 3f47af213b05
Revises: 41f332243bf9
Create Date: 2023-09-27 15:09:00.419228
"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "3f47af213b05"
down_revision = "41f332243bf9"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"keypair_resource_policies",
sa.Column("max_pending_session_count", sa.Integer(), nullable=True),
)
op.add_column(
"keypair_resource_policies",
sa.Column(
"max_pending_session_resource_slots",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("keypair_resource_policies", "max_pending_session_resource_slots")
op.drop_column("keypair_resource_policies", "max_pending_session_count")
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/resource_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
sa.Column("total_resource_slots", ResourceSlotColumn(), nullable=False),
sa.Column("max_session_lifetime", sa.Integer(), nullable=False, server_default=sa.text("0")),
sa.Column("max_concurrent_sessions", sa.Integer(), nullable=False),
sa.Column("max_pending_session_count", sa.Integer(), nullable=True),
sa.Column("max_pending_session_resource_slots", ResourceSlotColumn(), nullable=True),
sa.Column(
"max_concurrent_sftp_sessions", sa.Integer(), nullable=False, server_default=sa.text("1")
),
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
check_domain_resource_limit,
check_group_resource_limit,
check_keypair_resource_limit,
check_pending_session_count_limit,
check_pending_session_resource_limit,
check_reserved_batch_session,
check_user_resource_limit,
)
Expand Down Expand Up @@ -478,6 +480,14 @@ async def _check_predicates() -> List[Tuple[str, Union[Exception, PredicateResul
]
if not sess_ctx.is_private:
predicates += [
(
"pending_session_resource_limit",
check_pending_session_resource_limit(db_sess, sched_ctx, sess_ctx),
),
(
"pending_session_count_limit",
check_pending_session_count_limit(db_sess, sched_ctx, sess_ctx),
),
(
"keypair_resource_limit",
check_keypair_resource_limit(db_sess, sched_ctx, sess_ctx),
Expand Down
126 changes: 126 additions & 0 deletions src/ai/backend/manager/scheduler/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sqlalchemy as sa
from dateutil.tz import tzutc
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, noload

from ai.backend.common import redis_helper
from ai.backend.common.logging import BraceStyleAdapter
Expand Down Expand Up @@ -294,3 +295,128 @@ async def check_domain_resource_limit(
),
)
return PredicateResult(True)


async def check_pending_session_count_limit(
db_sess: SASession,
sched_ctx: SchedulingContext,
sess_ctx: SessionRow,
) -> PredicateResult:
result = True
failure_msgs = []

query = (
sa.select(SessionRow)
.where(
(SessionRow.access_key == sess_ctx.access_key)
& (SessionRow.status == SessionStatus.PENDING)
)
.options(noload("*"), load_only(SessionRow.requested_slots))
)
pending_sessions: list[SessionRow] = (await db_sess.scalars(query)).all()

# TODO: replace keypair resource policies with user resource policies
j = sa.join(
KeyPairResourcePolicyRow,
KeyPairRow,
KeyPairResourcePolicyRow.name == KeyPairRow.resource_policy,
)
policy_stmt = (
sa.select(KeyPairResourcePolicyRow)
.select_from(j)
.where(KeyPairRow.access_key == sess_ctx.access_key)
.options(
noload("*"),
load_only(
KeyPairResourcePolicyRow.max_pending_session_count,
),
)
)
policy: KeyPairResourcePolicyRow = (await db_sess.scalars(policy_stmt)).first()

pending_count_limit: int | None = policy.max_pending_session_count
if pending_count_limit is not None:
if len(pending_sessions) >= pending_count_limit:
result = False
failure_msgs.append(
f"You cannot create more than {pending_count_limit} pending session(s)."
)

log.debug(
"access key:{} number of pending sessions: {} / {}",
sess_ctx.access_key,
len(pending_sessions),
pending_count_limit,
)
if not result:
return PredicateResult(False, "\n".join(failure_msgs))
return PredicateResult(True)


async def check_pending_session_resource_limit(
db_sess: SASession,
sched_ctx: SchedulingContext,
sess_ctx: SessionRow,
) -> PredicateResult:
result = True
failure_msgs = []

query = (
sa.select(SessionRow)
.where(
(SessionRow.access_key == sess_ctx.access_key)
& (SessionRow.status == SessionStatus.PENDING)
)
.options(noload("*"), load_only(SessionRow.requested_slots))
)
pending_sessions: list[SessionRow] = (await db_sess.scalars(query)).all()

# TODO: replace keypair resource policies with user resource policies
j = sa.join(
KeyPairResourcePolicyRow,
KeyPairRow,
KeyPairResourcePolicyRow.name == KeyPairRow.resource_policy,
)
policy_stmt = (
sa.select(KeyPairResourcePolicyRow)
.select_from(j)
.where(KeyPairRow.access_key == sess_ctx.access_key)
.options(
noload("*"),
load_only(
KeyPairResourcePolicyRow.max_pending_session_resource_slots,
),
)
)
policy: KeyPairResourcePolicyRow = (await db_sess.scalars(policy_stmt)).first()

pending_resource_limit: ResourceSlot | None = policy.max_pending_session_resource_slots
if pending_resource_limit is not None and pending_resource_limit:
current_pending_session_slots: ResourceSlot = sum(
[session.requested_slots for session in pending_sessions], start=ResourceSlot()
)
if current_pending_session_slots >= pending_resource_limit:
result = False
msg = "Your pending session quota is exceeded. ({})".format(
" ".join(
f"{k}={v}"
for k, v in current_pending_session_slots.to_humanized(
sched_ctx.known_slot_types
).items()
)
)
failure_msgs.append(msg)

log.debug(
"access key:{} current-occupancy of pending sessions: {}",
sess_ctx.access_key,
current_pending_session_slots,
)
log.debug(
"access key:{} total-allowed of pending sessions: {}",
sess_ctx.access_key,
pending_resource_limit,
)
if not result:
return PredicateResult(False, "\n".join(failure_msgs))
return PredicateResult(True)

0 comments on commit 59f546c

Please sign in to comment.