From 47718f87723fab18504ebef3cc8af94141e401e3 Mon Sep 17 00:00:00 2001 From: Esteban Galvis Date: Thu, 4 Jul 2024 14:46:46 -0500 Subject: [PATCH] feat: :sparkles: Enable/disable menuflow bots --- menuflow/api/client.py | 23 +++++++++++++++++++++++ menuflow/db/client.py | 14 +++++++++----- menuflow/db/migrations.py | 5 +++++ menuflow/menu.py | 7 +++++++ 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/menuflow/api/client.py b/menuflow/api/client.py index ce3437a..ea1bfa3 100644 --- a/menuflow/api/client.py +++ b/menuflow/api/client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from json import JSONDecodeError from logging import Logger, getLogger from typing import Dict, Optional @@ -193,3 +194,25 @@ async def reload_client_flow(request: web.Request) -> web.Response: await _reload_flow(client) return resp.ok(client.to_dict()) + + +@routes.patch("/client/{mxid}/{action}") +async def enable_disable_client(request: web.Request) -> web.Response: + mxid = request.match_info["mxid"] + action = request.match_info["action"] + client: MenuClient = await MenuClient.get(mxid) + if not client: + return resp.client_not_found(mxid) + + if action == "enable": + client.enabled = True + await client.start() + elif action == "disable": + client.enabled = False + asyncio.create_task(client.leave_rooms(), name=f"{mxid}-leave_rooms") + await client.stop() + else: + return resp.bad_request("Invalid action provided") + + await client.update() + return resp.ok({"detail": {"message": f"Client {action}d successfully"}}) diff --git a/menuflow/db/client.py b/menuflow/db/client.py index d1c83ee..90d3cb3 100644 --- a/menuflow/db/client.py +++ b/menuflow/db/client.py @@ -24,6 +24,7 @@ class Client(SyncStore): filter_id: FilterID autojoin: bool + enabled: bool flow: int | None = None @classmethod @@ -32,7 +33,9 @@ def _from_row(cls, row: Record | None) -> Client | None: return None return cls(**row) - _columns = "id, homeserver, access_token, device_id, next_batch, filter_id, autojoin, flow" + _columns = ( + "id, homeserver, access_token, device_id, next_batch, filter_id, autojoin, enabled, flow" + ) @property def _values(self): @@ -44,12 +47,13 @@ def _values(self): self.next_batch, self.filter_id, self.autojoin, + self.enabled, self.flow, ) @classmethod async def all(cls) -> list[Client]: - rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client") + rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client WHERE enabled IS TRUE") return [cls._from_row(row) for row in rows] @classmethod @@ -59,14 +63,14 @@ async def get(cls, id: str) -> Client | None: @classmethod async def get_by_flow_id(cls, flow_id: int) -> list[Client]: - q = f"SELECT {cls._columns} FROM client WHERE flow=$1" + q = f"SELECT {cls._columns} FROM client WHERE flow=$1 AND enabled IS TRUE" rows = await cls.db.fetch(q, flow_id) return [cls._from_row(row) for row in rows] async def insert(self) -> None: q = ( "INSERT INTO client (id, homeserver, access_token, device_id, next_batch, filter_id, " - "autojoin, flow) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + "autojoin, enabled, flow) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" ) await self.db.execute(q, *self._values) @@ -80,7 +84,7 @@ async def get_next_batch(self) -> SyncToken: async def update(self) -> None: q = ( "UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, next_batch=$5, " - "filter_id=$6, autojoin=$7, flow=$8 WHERE id=$1" + "filter_id=$6, autojoin=$7, enabled=$8, flow=$9 WHERE id=$1" ) await self.db.execute(q, *self._values) diff --git a/menuflow/db/migrations.py b/menuflow/db/migrations.py index 2c2f31d..3991544 100644 --- a/menuflow/db/migrations.py +++ b/menuflow/db/migrations.py @@ -86,3 +86,8 @@ async def upgrade_v4(conn: Connection) -> None: await conn.execute( "ALTER TABLE client ADD CONSTRAINT FK_flow_client FOREIGN KEY (flow) references flow (id)" ) + + +@upgrade_table.register(description="Add enable column to client table") +async def upgrade_v5(conn: Connection) -> None: + await conn.execute("ALTER TABLE client ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE") diff --git a/menuflow/menu.py b/menuflow/menu.py index 51b8708..b21d96b 100644 --- a/menuflow/menu.py +++ b/menuflow/menu.py @@ -52,6 +52,7 @@ def __init__( next_batch: SyncToken = "", filter_id: FilterID = "", autojoin: bool = True, + enabled: bool = True, flow: int | None = None, ) -> None: super().__init__( @@ -62,6 +63,7 @@ def __init__( next_batch=next_batch, filter_id=filter_id, autojoin=bool(autojoin), + enabled=bool(enabled), flow=flow, ) self._postinited = False @@ -239,6 +241,11 @@ async def delete(self) -> None: pass await super().delete() + async def leave_rooms(self) -> None: + rooms = await self.matrix_handler.get_joined_rooms() + for room_id in rooms: + await self.matrix_handler.leave_room(room_id) + @classmethod async def all(cls) -> AsyncGenerator[MenuClient, None]: users = await super().all()