In [1]:
import bittensor as bt
import pydantic
from starlette.types import Send
from starlette.responses import Response, StreamingResponse
from functools import partial
from typing import Callable, Awaitable, List, Tuple
import asyncio
from transformers import GPT2Tokenizer

bt.debug()


# This is a subclass of StreamingSynapse for prompting network functionality
class StreamPrompting(bt.StreamingSynapse):
    """
    StreamPrompting is a subclass of StreamingSynapse that is specifically designed for prompting network functionality.
    It overrides abstract methods from the parent class to provide concrete implementations for processing streaming responses,
    deserializing the response, and extracting JSON data.

    Attributes:
        roles: List of roles associated with the prompt.
        messages: List of messages to be processed.
        completion: A string to store the completion result.
    """

    roles: List[str] = pydantic.Field(
        ...,
        title="Roles",
        description="A list of roles in the Prompting scenario. Immuatable.",
        allow_mutation=False,
    )

    messages: List[str] = pydantic.Field(
        ...,
        title="Messages",
        description="A list of messages in the Prompting scenario. Immutable.",
        allow_mutation=False,
    )

    completion: str = pydantic.Field(
        "",
        title="Completion",
        description="Completion status of the current Prompting object. This attribute is mutable and can be updated.",
    )

    async def process_streaming_response(self, response):
        """
        Asynchronously processes chunks of a streaming response, decoding the chunks from utf-8 to strings 
        and appending them to the `completion` attribute. The primary goal of this method is to accumulate the 
        content from the streaming response in a sequential manner.

        This method is particularly vital when the streaming response from the server is broken down into multiple 
        chunks, and a comprehensive result needs to be constructed from these individual chunks.

        Args:
            response: The response object from which the streamed content is fetched. This content typically 
                    contains chunks of string data that are being streamed from the server.

        Raises:
            ValueError: If there is an issue decoding the streamed chunks.

        Note:
            This method is designed for utf-8 encoded strings. If the streamed content has a different encoding, 
            it may need to be adjusted accordingly.
        """
        if self.completion is None:
            self.completion = ""
        async for chunk in response.content.iter_any():
            tokens = chunk.decode('utf-8').split('\n')
            for token in tokens:
                if token:
                    self.completion += token

    def deserialize(self):
        """
        Deserializes the response by returning the completion attribute.

        Returns:
            str: The completion result.
        """
        return self.completion

    def extract_response_json(self, response):
        """
        Extracts various components of the response object, including headers and specific information related 
        to dendrite and axon, into a structured JSON format. This method aids in simplifying the raw response 
        object into a format that's easier to read and interpret.

        The method is particularly useful for extracting specific metadata from the response headers which 
        provide insights about the response or the server's configurations. Moreover, details about dendrite 
        and axon extracted from headers can provide information about the neural network layers that were 
        involved in the request-response cycle.

        Args:
            response: The response object, typically an instance of an HTTP response, containing the headers 
                    and the content that needs to be extracted.

        Returns:
            dict: A dictionary containing the structured data extracted from the response object. This includes 
                data such as the server's name, timeout details, data sizes, and information about dendrite 
                and axon among others.

        Raises:
            KeyError: If expected headers or response components are missing.

        Note:
            This method assumes a certain structure and naming convention for the headers. If the server 
            changes its header naming convention or structure, this method may need adjustments.
        """
        headers = {k.decode('utf-8'): v.decode('utf-8') for k, v in response.__dict__["_raw_headers"]}

        def extract_info(prefix):
            return {key.split('_')[-1]: value for key, value in headers.items() if key.startswith(prefix)}

        return {
            "name": headers.get('name', ''),
            "timeout": float(headers.get('timeout', 0)),
            "total_size": int(headers.get('total_size', 0)),
            "header_size": int(headers.get('header_size', 0)),
            "dendrite": extract_info('bt_header_dendrite'),
            "axon": extract_info('bt_header_axon'),
            "roles": self.roles,
            "messages": self.messages,
            "completion": self.completion,
        }

# This should encapsulate all the logic for generating a streaming response
def prompt(synapse: StreamPrompting) -> StreamPrompting:
    """
    Generates a streaming response based on the given StreamPrompting synapse.

    This function integrates tokenization, model inference, and streaming functionality to 
    generate a response for the Bittensor network. It defines the tokenizer based on the GPT-2 model, 
    the decoding mechanism for the model, and the actual streaming mechanism to stream tokens.

    The streaming process mimics a model inference on the provided message, encodes the output tokens,
    and sends them in a streaming fashion to the client. To simulate the behavior of streaming 
    responses, there's an artificial delay introduced between sending tokens.

    Args:
        synapse (StreamPrompting): A StreamPrompting instance containing the messages that need to be processed. 
                                   This object contains details like roles, messages, and other prompting 
                                   details which guide the streaming response's behavior.

    Returns:
        StreamPrompting: A modified StreamPrompting instance with the streaming response set up. This response 
                         is ready to be sent to the client as a continuous stream of tokens based on the input 
                         message and the model inference.

    Note:
        The function assumes the use of the GPT-2 tokenizer and simulates model inference. In a real-world 
        scenario, it can be integrated with an actual model inference mechanism.
    """
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

    def model(ids):
        return (tokenizer.decode(id) for id in ids)

    async def _prompt(text: str, send: Send):
        """
        Asynchronously simulates model inference and streams decoded tokens to the client.

        This function takes a given input text, tokenizes it, and simulates the behavior of model inference
        by decoding the tokenized input. Decoded tokens are sent to the client in a streaming manner. An artificial
        delay is introduced between sending tokens to mimic the real-world behavior of streaming responses.

        Args:
            text (str): The input message text to be tokenized and used for simulated model inference.
                        This typically represents a part of the client's request or a message from the StreamPrompting synapse.
            
            send (Send): A callable provided by the ASGI server, responsible for sending the response to the client.
                        This function is used to stream the decoded tokens to the client in real-time.

        Note:
            The function uses a simulated model inference mechanism and an artificial delay to showcase 
            the streaming effect. In a real-world scenario, this could be integrated with an actual 
            deep learning model for generating response tokens.
        """
        # Simulate model inference
        input_ids = tokenizer(text, return_tensors="pt").input_ids.squeeze()
        for token in model(input_ids):
            await send({"type": "http.response.body", "body": (token + '\n').encode('utf-8'), "more_body": True})
            bt.logging.debug(f"Streamed token: {token}")
            # Sleep to show the streaming effect
            await asyncio.sleep(1)

    message = synapse.messages[0]
    token_streamer = partial(_prompt, message)
    return synapse.create_streaming_response(token_streamer)

def blacklist(synapse: StreamPrompting) -> Tuple[bool, str]:
    """
    Determines whether the synapse should be blacklisted.

    Args:
        synapse: A StreamPrompting instance.

    Returns:
        Tuple[bool, str]: Always returns False, indicating that the synapse should not be blacklisted.
    """
    return False, ""

def priority(synapse: StreamPrompting) -> float:
    """
    Determines the priority of the synapse.

    Args:
        synapse: A StreamPrompting instance.

    Returns:
        float: Always returns 0.0, indicating the default priority.
    """
    return 0.0


In [2]:
# Create an Axon instance on port 8099.
axon = bt.axon(port=8099)

# Attach the forward, blacklist, and priority functions to the Axon.
# forward_fn: The function to handle forwarding logic.
# blacklist_fn: The function to determine if a request should be blacklisted.
# priority_fn: The function to determine the priority of the request.
axon.attach(
    forward_fn=prompt,
    blacklist_fn=blacklist,
    priority_fn=priority
)

# Start the Axon to begin listening for requests.
axon.start()

# Create a Dendrite instance to handle client-side communication.
d = bt.dendrite()

# Send a request to the Axon using the Dendrite, passing in a StreamPrompting instance with roles and messages.
# The response is awaited, as the Dendrite communicates asynchronously with the Axon.
resp = await d(
    [axon],
    StreamPrompting(roles=["user"], messages=["hello this is a test of a streaming response."])
)

# The response object contains the result of the streaming operation.
resp


[34m2023-09-15 20:57:00.086[0m | [34m[1m     DEBUG      [0m | dendrite | --> | 3468 B | StreamPrompting | 5C86aJ2uQawR6P6veaJQXNK9HaWh6NMbUhTiLs65kq4ZW3NH | 216.153.62.113:8099 | 0 | Success
[34m2023-09-15 20:57:00.097[0m | [34m[1m     DEBUG      [0m | axon     | <-- | 843 B | StreamPrompting | 5C86aJ2uQawR6P6veaJQXNK9HaWh6NMbUhTiLs65kq4ZW3NH | 127.0.0.1:46428 | 200 | Success 


[34m2023-09-15 20:57:00.465[0m | [34m[1m     DEBUG      [0m | axon     | --> | -1 B | StreamPrompting | 5C86aJ2uQawR6P6veaJQXNK9HaWh6NMbUhTiLs65kq4ZW3NH | 127.0.0.1:46428  | 200 | Success
[34m2023-09-15 20:57:00.467[0m | [34m[1m     DEBUG      [0m | Streamed token: hello         
[34m2023-09-15 20:57:01.468[0m | [34m[1m     DEBUG      [0m | Streamed token:  this         
[34m2023-09-15 20:57:02.470[0m | [34m[1m     DEBUG      [0m | Streamed token:  is           
[34m2023-09-15 20:57:03.470[0m | [34m[1m     DEBUG      [0m | Streamed token:  a            
[34m2023-09-15 20:57:04.472[0m | [34m[1m     DEBUG      [0m | Streamed token:  test         
[34m2023-09-15 20:57:05.473[0m | [34m[1m     DEBUG      [0m | Streamed token:  of           
[34m2023-09-15 20:57:06.473[0m | [34m[1m     DEBUG      [0m | Streamed token:  a            
[34m2023-09-15 20:57:07.475[0m | [34m[1m     DEBUG      [0m | Streamed token:  streaming    
[34m2023-09-15 20:57:0

['hello this is a test of a streaming response.']