diff --git a/evadb/third_party/databases/slack/slack_handler.py b/evadb/third_party/databases/slack/slack_handler.py index 4022efb619..475b12d04b 100644 --- a/evadb/third_party/databases/slack/slack_handler.py +++ b/evadb/third_party/databases/slack/slack_handler.py @@ -12,12 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pandas as pd +import certifi +import ssl from slack_sdk import WebClient from slack_sdk.errors import SlackApiError -from evadb.third_party.types import DBHandler, DBHandlerResponse, DBHandlerStatus +from evadb.third_party.databases.types import ( + DBHandler, + DBHandlerResponse, + DBHandlerStatus, +) class SlackHandler(DBHandler): @@ -28,16 +33,15 @@ def __init__(self, name: str, **kwargs): def connect(self): try: - self.client = WebClient(token=self.token) + ssl_context = ssl.create_default_context(cafile=certifi.where()) + self.client = WebClient(token=self.token, ssl=ssl_context) return DBHandlerStatus(status=True) except Exception as e: return DBHandlerStatus(status=False, error=str(e)) def disconnect(self): - """ - TODO: integrate code for disconnecting from slack - """ - raise NotImplementedError() + + pass def check_connection(self) -> DBHandlerStatus: try: @@ -46,6 +50,18 @@ def check_connection(self) -> DBHandlerStatus: assert e.response["ok"] is False return False return True + + def channel_list(self): + channels_list = {} + if self.client: + channels = self.client.conversations_list( + types="public_channel,private_channel" + )["channels"] + # self.channel_names = [c["name"] for c in channels] + for c in channels: + channels_list[c["name"]] = c["id"] + return channels_list + def get_tables(self) -> DBHandlerResponse: if self.client: @@ -55,24 +71,23 @@ def get_tables(self) -> DBHandlerResponse: self.channel_names = [c["name"] for c in channels] tables_df = pd.DataFrame(self.channel_names, columns=["table_name"]) return DBHandlerResponse(data=tables_df) - + def get_columns(self, table_name: str) -> DBHandlerResponse: columns = [ - "ts", - "text", - "message_created_at", - "user", - "channel", - "reactions", - "attachments", - "thread_ts", - "reply_count", - "reply_users_count", - "latest_reply", - "subtype", - "hidden", + ["ts", str], + ["text", str], + ["user",str], + ["channel", str], + ["reactions", str], + ["attachments", str], + ["thread_ts", str], + ["reply_count", str], + ["reply_users_count", str], + ["latest_reply", str], + ["subtype", str], + ["hidden", str], ] - columns_df = pd.DataFrame(columns, columns=["column_name"]) + columns_df = pd.DataFrame(columns, columns=["name", "dtype"]) return DBHandlerResponse(data=columns_df) def post_message(self, message) -> DBHandlerResponse: @@ -122,8 +137,53 @@ def del_message(self, timestamp) -> DBHandlerResponse: assert e.response["error"] return DBHandlerResponse(data=None, error=e.response["error"]) - def execute_native_query(self, query_string: str) -> DBHandlerResponse: + def supported_table(self, channel): + def slack_generator(): + channel_list = self.channel_list() + for message in self.client.conversations_history(channel=channel_list[channel])["messages"]: + if message["type"] == "message": + yield { + "user": message.get("user", "N/A"), + "ts": message.get("ts", "N/A"), + "text": message.get("text", "N/A"), + "channel": message.get("channel", "N/A"), + "reactions": message.get("reactions", "N/A"), + "attachments": message.get("attachments", "N/A"), + "thread_ts": message.get("thread_ts", "N/A"), + "reply_count": message.get("reply_count", "N/A"), + "reply_users_count": message.get("reply_users_count", "N/A"), + "latest_reply": message.get("latest_reply", "N/A"), + "subtype": message.get("subtype", "N/A"), + "hidden": message.get("hidden", "N/A"), + } + + mapping = { + "random": { + "columns": ["user", "ts", "text", "channel", "reactions", "attachments", "thread_ts", + "reply_count", "reply_users_count", "latest_reply", "subtype", "hidden"], + "generator": slack_generator(), + }, + } + return mapping + + def select(self, table_name: str) -> DBHandlerResponse: """ - TODO: integrate code for executing query on slack + Returns a generator that yields the data from the given table. + Args: + table_name (str): name of the table whose data is to be retrieved. + Returns: + DBHandlerResponse """ - raise NotImplementedError() + + + if not self.client: + return DBHandlerResponse(data=None, error="Not connected to the database.") + try: + table = self.supported_table(channel=table_name) + return DBHandlerResponse( + data=None, + data_generator=table["random"]["generator"], + ) + except Exception as e: + return DBHandlerResponse(data=None, error=str(e)) +