# This is a poc notebook for how to have @AI listen for slack messages, send them to the sagemaker endpoint, and stream the response back.


In [None]:
import io
class LineIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [None]:
# stream https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
# run !gimme-aws-creds, then aws_switch sofi-bedrock-bdp-production
import boto3
import io
import json 
import sagemaker

# session = sagemaker.Session()

# session = boto3.Session(profile_name='sofi-bedrock-bdp-production')
session = boto3.Session()
smr = session.client('sagemaker-runtime', 'us-west-2')
#AssumeBedrockPOC

def sagemaker_prompt(question, on_text_received):
    print(f"sagemaker question: {question}")
    system_message = "You are a helpful assistant that does not use superfluous pleasantries. Avoiding ending your reply with questions.  If a question does not make sense, call out why it doesn't make sense, and don't attempt to answer. If you don't know the answer to a question, do not make up an answer."
    prompt=f'''[INST] <<SYS>>
        {system_message}
        <</SYS>>
        {question} [/INST]'''
    
    # hyperparameters for llm
    payload = {
      "inputs": prompt,
      "parameters": {
        "do_sample": True,
        "top_p": 0.7,
        "temperature": 0.7,
        "top_k": 10,
        "max_new_tokens": 500,
        "repetition_penalty": 1.03,
        "stop": ["<|endoftext|>"]
      },
      "stream": True  
    }
    stop_token = "<|endoftext|>"
    # llm.deserializer=StreamDeserializer()
    endpoint_name = "Llama-2-13B-Chat-fp16-v7-2023-12-21-17-51-07-902"
    resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(payload), ContentType='application/json')
    event_stream = resp['Body']
    start_json = b'{'
    for line in LineIterator(event_stream):
        if line != b'' and start_json in line:
            data = json.loads(line[line.find(start_json):].decode('utf-8'))
            if data['token']['text'] != stop_token:
                on_text_received(data['token']['text'])
                # print(data['token']['text'],end='')

def handle_text_received(text):
    print(text, end='')

sagemaker_prompt("What is 5 + 5?", handle_text_received)       

In [None]:
import asyncio
import nest_asyncio
import threading
nest_asyncio.apply()
from slack_sdk.rtm_v2 import RTMClient

# Load your Slack API token
with open('slack-api-token.txt', 'r') as file:
    slack_token = file.read().strip()

# Initialize the RTMClient
rtm = RTMClient(token=slack_token)

# Define the event handler
@rtm.on("message")
def handle(client: RTMClient, event: dict):
    # print(f"event: {event}")
    message_received = event.get('text', '')
    # channel_id = event['channel']
    # thread_ts = event['ts']
    # user = event['user']  # User ID
    ai_bot_user = "<@U06B7U1478S>"
    
    # Post a message to the channel
    # Check if the message is addressed to @ai
    if message_received.startswith(ai_bot_user):
        # print('ai_bot_user')
        
        def handle_ai_bot_message(client: RTMClient, event: dict):
            # print(f'handle ai bot message: {event}')
            message_received = event.get('text', '')
            channel_id = event['channel']
            thread_ts = event['ts']
            
            prompt = message_received.replace(ai_bot_user, '')
            cumulative_text = f"{prompt}\n"
            # print(cumulative_text)
            # Initial message sent to the channel
            initial_response = client.web_client.chat_postMessage(
                channel=channel_id,
                text=cumulative_text,
                thread_ts=thread_ts
            )
    
            # Retrieve timestamp of the initial response
            initial_response_ts = initial_response['ts']
            last_update_length = len(cumulative_text)
            
            def handle_text_received(text):
                nonlocal cumulative_text, last_update_length
                cumulative_text += text
                # print(text, end='')
                # Edit the existing message with the new text
                if len(cumulative_text) - last_update_length >= 20 or "</s>" in cumulative_text:
                    client.web_client.chat_update(
                        channel=channel_id,
                        ts=initial_response_ts,  # Timestamp of the message to update
                        text=cumulative_text,
                        thread_ts=thread_ts
                    )
                    last_update_length = len(cumulative_text)
    
            sagemaker_prompt(prompt, handle_text_received)

        thread = threading.Thread(
            target=handle_ai_bot_message,
            args=(client, event)
        )
        thread.start()
        # handle_ai_bot_message(client, event)
    
# Running the event loop
# asyncio.ensure_future(rtm.start())
# asyncio.get_event_loop().run_forever()

loop = asyncio.get_event_loop()
try:
    asyncio.ensure_future(rtm.start())
    loop.run_forever()
except KeyboardInterrupt:
    # Handle the interrupt gracefully
    print("Interrupted by user, shutting down.")
finally:
    # Perform any cleanup here if necessary
    loop.stop()
    # loop.close()