In [None]:
import bittensor as bt
from starlette.responses import StreamingResponse as _StreamingResponse
from starlette.types import Send
from functools import partial
from typing import Callable, Awaitable, List
import pydantic
import asyncio

bt.debug()

class BTStreamingResponse(_StreamingResponse):
    def __init__(self, token_streamer: Callable[[Send], Awaitable[None]], **kwargs) -> None:
        super().__init__(content=iter(()), **kwargs)
        self.token_streamer = token_streamer

    async def stream_response(self, send: Send) -> None:
        headers = [(b"content-type", b"text/event-stream")] + self.raw_headers
        await send({"type": "http.response.start", "status": 200, "headers": headers})
        await self.token_streamer(send)
        await send({"type": "http.response.body", "body": b"", "more_body": False})

    async def __call__(self, scope, receive, send):
        await self.stream_response(send)

class StreamPrompting(bt.Synapse):
    class Config: validate_assignment = True
    def deserialize(self): return self.completion

    roles: List[str] = pydantic.Field(..., allow_mutation=False)
    messages: List[str] = pydantic.Field(..., allow_mutation=False)
    completion: str = None

    async def process_streaming_response(self, response):
        if self.completion is None:
            self.completion = ""
        async for chunk in response.content.iter_any():
            tokens = chunk.decode('utf-8').split('\n')
            for token in tokens:
                if token:
                    self.completion += token

    def extract_response_json(self, response):
        headers = {k.decode('utf-8'): v.decode('utf-8') for k, v in response.__dict__["_raw_headers"]}

        def extract_info(prefix):
            return {key.split('_')[-1]: value for key, value in headers.items() if key.startswith(prefix)}

        return {
            "name": headers.get('name', ''),
            "timeout": float(headers.get('timeout', 0)),
            "total_size": int(headers.get('total_size', 0)),
            "header_size": int(headers.get('header_size', 0)),
            "dendrite": extract_info('bt_header_dendrite'),
            "axon": extract_info('bt_header_axon'),
            "roles": self.roles,
            "messages": self.messages,
            "completion": self.completion,
        }

def tokenizer(text):
    return (ord(char) for char in text)

def model(ids):
    return (chr(id) for id in ids)

async def prompt(text: str, send: Send):
    input_ids = tokenizer(text)
    for token in model(input_ids):
        await send({"type": "http.response.body", "body": (token + '\n').encode('utf-8'), "more_body": True})

def generate_streaming_response(synapse: StreamPrompting) -> BTStreamingResponse:
    message = synapse.messages[0]
    token_streamer = partial(prompt, message)
    return BTStreamingResponse(token_streamer)

def blacklist( synapse: StreamPrompting ) -> bool:
    return False

def priority( synapse: StreamPrompting ) -> float:
    return 0.0

In [None]:

axon = bt.axon(port=8099)
axon.attach(
    forward_fn = generate_streaming_response,
    blacklist_fn = blacklist,
    priority_fn = priority
)
axon.start()

d = bt.dendrite()
resp = await d(
    [ axon ],
    StreamPrompting(roles=["user"], messages=["hello this is a test of streaming."])
)

In [None]:
resp