In [None]:
import bittensor as bt
import pydantic
from starlette.types import Send
from functools import partial
from typing import Callable, Awaitable, List
import asyncio

bt.trace()


# This is a subclass of StreamingSynapse for prompting network functionality
class StreamPrompting(bt.StreamingSynapse):
    """
    StreamPrompting is a subclass of StreamingSynapse that is specifically designed for prompting network functionality.
    It overrides abstract methods from the parent class to provide concrete implementations for processing streaming responses,
    deserializing the response, and extracting JSON data.

    Attributes:
        roles: List of roles associated with the prompt.
        messages: List of messages to be processed.
        completion: A string to store the completion result.
    """

    roles: List[str] = pydantic.Field(..., allow_mutation=False)
    messages: List[str] = pydantic.Field(..., allow_mutation=False)
    completion: str = None

    async def process_streaming_response(self, response):
        """
        Processes the streaming response by iterating through the content and decoding tokens.
        Concatenates the decoded tokens into the completion attribute.

        Args:
            response: The response object containing the content to be processed.
        """
        if self.completion is None:
            self.completion = ""
        async for chunk in response.content.iter_any():
            tokens = chunk.decode('utf-8').split('\n')
            for token in tokens:
                if token:
                    self.completion += token

    def deserialize(self):
        """
        Deserializes the response by returning the completion attribute.

        Returns:
            str: The completion result.
        """
        return self.completion

    def extract_response_json(self, response):
        """
        Extracts JSON data from the response, including headers and specific information related to dendrite and axon.

        Args:
            response: The response object from which to extract JSON data.

        Returns:
            dict: A dictionary containing extracted JSON data.
        """
        headers = {k.decode('utf-8'): v.decode('utf-8') for k, v in response.__dict__["_raw_headers"]}

        def extract_info(prefix):
            return {key.split('_')[-1]: value for key, value in headers.items() if key.startswith(prefix)}

        return {
            "name": headers.get('name', ''),
            "timeout": float(headers.get('timeout', 0)),
            "total_size": int(headers.get('total_size', 0)),
            "header_size": int(headers.get('header_size', 0)),
            "dendrite": extract_info('bt_header_dendrite'),
            "axon": extract_info('bt_header_axon'),
            "roles": self.roles,
            "messages": self.messages,
            "completion": self.completion,
        }

# This should encapsulate all the logic for generating a streaming response
def sforward(synapse: StreamPrompting) -> StreamPrompting:
    """
    Encapsulates the logic for generating a streaming response. It defines the tokenizer, model, and prompt functions,
    and creates a streaming response using the provided synapse.

    Args:
        synapse: A StreamPrompting instance containing the messages to be processed.

    Returns:
        StreamPrompting: The streaming response object.
    """
    def tokenizer(text):
        return (ord(char) for char in text)

    def model(ids):
        return (chr(id) for id in ids)

    async def prompt(text: str, send: Send):
        # Simulate model inference
        input_ids = tokenizer(text)
        for token in model(input_ids):
            await send({"type": "http.response.body", "body": (token + '\n').encode('utf-8'), "more_body": True})
            bt.logging.trace(f"Streamed token: {token}")
            # Sleep to show the streaming effect
            await asyncio.sleep(0.5)

    message = synapse.messages[0]
    token_streamer = partial(prompt, message)
    return synapse.create_streaming_response(token_streamer)

def blacklist(synapse: StreamPrompting) -> bool:
    """
    Determines whether the synapse should be blacklisted.

    Args:
        synapse: A StreamPrompting instance.

    Returns:
        bool: Always returns False, indicating that the synapse should not be blacklisted.
    """
    return False

def priority(synapse: StreamPrompting) -> float:
    """
    Determines the priority of the synapse.

    Args:
        synapse: A StreamPrompting instance.

    Returns:
        float: Always returns 0.0, indicating the default priority.
    """
    return 0.0


In [None]:
# Create an Axon instance on port 8099.
axon = bt.axon(port=8099)

# Attach the forward, blacklist, and priority functions to the Axon.
# forward_fn: The function to handle forwarding logic.
# blacklist_fn: The function to determine if a request should be blacklisted.
# priority_fn: The function to determine the priority of the request.
axon.attach(
    forward_fn=sforward,
    blacklist_fn=blacklist,
    priority_fn=priority
)

# Start the Axon to begin listening for requests.
axon.start()

# Create a Dendrite instance to handle client-side communication.
d = bt.dendrite()

# Send a request to the Axon using the Dendrite, passing in a StreamPrompting instance with roles and messages.
# The response is awaited, as the Dendrite communicates asynchronously with the Axon.
resp = await d(
    [axon],
    StreamPrompting(roles=["user"], messages=["hello this is a test of streaming."])
)

# The response object contains the result of the streaming operation.
resp
