# Streaming

Streaming enables users apply functions and datasets locally during remote model execution. This allows users to stream results for immediate consumption (i.e., seeing tokens as they are generated) or applying non-whitelisted functions such as model tokenizers, large local datasets, and more!

*   `nnsight.local()` context sends values immediately to user's local machine from server
*   Intervention graph is executed locally on downstream nodes
*   Exiting local context uploads data back to server
*   `@nnsight.trace` function decorator enables custom functions to be added to intervention graph when using `nnsight.local()`

## Setup

In [1]:
# if running in Google Colab, install nnsight
try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight

## `nnsight.local()`

You may sometimes want to locally access and manipulate values during remote execution. Using `.local()` on a proxy, you can send remote content to your local machine and apply local functions. The intervention graph is then executed locally on downstream nodes (until you send execution back to the remote server by exiting the `.local()` context).

There are a few use cases for streaming with `.local()`, including live chat generation and applying large datasets or non-whitelisted local functions to the intervention graph.

Now let's explore how streaming works. We'll start by grabbing some hidden states of the model and printing their value using `tracer.log()`. Without calling `nnsight.local()`, these operations will all occur remotely.

In [6]:
from nnsight import CONFIG
from IPython.display import clear_output

if is_colab:
    # include your HuggingFace Token and NNsight API key on Colab secrets
    from google.colab import userdata
    NDIF_API = userdata.get('NDIF_API')
    HF_TOKEN = userdata.get('HF_TOKEN')

    CONFIG.set_default_api_key(NDIF_API)
    !huggingface-cli login -token HF_TOKEN

clear_output()

In [7]:
from nnsight import LanguageModel
llama = LanguageModel("meta-llama/Meta-Llama-3.1-70B")

In [None]:
# This will give you a remote LOG response because it's coming from the remote server
with llama.trace("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    tracer.log(hs[0,0,0])

    out =  llama.lm_head.output.save()

print(out)

In [None]:
import nnsight
# This will print locally because it's already local
with llama.trace("hello", remote=True) as tracer:

    with nnsight.local():
        hs = llama.model.layers[-1].output[0]
        tracer.log(hs[0,0,0])

    out =  llama.lm_head.output.save()

print(out)

## `@nnsight.trace` function decorator

We can also use function decorators to create custom functions to be used during `.local` calls. This is a handy way to enable live streaming of a chat or to train probing classifiers on model hidden states.

Let's try out `@nnsight.trace` and `nnsight.local()` to access a custom function during remote execution.

In [None]:
# first, let's define our function
@nnsight.trace # decorator that enables this function to be added to the intervention graph
def my_local_fn(value):
    return value * 0

# We use a local function to ablate some hidden states
# This downloads the data for the .local context, and then uploads it back to set the value.
with llama.generate("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    with nnsight.local():

        hs = my_local_fn(hs)

    llama.model.layers[-1].output[0][:] = hs

    out =  llama.lm_head.output.save()

Note that without calling `.local`, the remote API does not know about `my_local_fn` and will throw a whitelist error. A whitelist error occurs because you are being allowed access to the function.

In [None]:
with llama.trace("hello", remote=True) as tracer:

    hs = llama.model.layers[-1].output[0]

    hs = my_local_fn(hs) # no .local - will cause an error

    llama.model.layers[-1].output[0][:] = hs * 2

    out =  llama.lm_head.output.save()

print(out)

## Example: Live-streaming remote chat

Now that we can access data within the tracing context on our local computer, we can apply non-whitelisted functions, such as the model's tokenizer, within our tracing context.

Let's build a decoding function that will decode tokens into words and print the result.

In [None]:
@nnsight.trace
def my_decoding_function(tokens, model, max_length=80, state=None):
    # Initialize state if not provided
    if state is None:
        state = {'current_line': '', 'current_line_length': 0}

    token = tokens[-1] # only use last token

    # Decode the token
    decoded_token = llama.tokenizer.decode(token).encode("unicode_escape").decode()

    if decoded_token == '\\n':  # Handle explicit newline tokens
        # Print the current line and reset state
        print('',flush=True)
        state['current_line'] = ''
        state['current_line_length'] = 0
    else:
        # Check if adding the token would exceed the max length
        if state['current_line_length'] + len(decoded_token) > max_length:
            print('',flush=True)
            state['current_line'] = decoded_token  # Start a new line with the current token
            state['current_line_length'] = len(decoded_token)
            print(state['current_line'], flush=True, end="")  # Print the current line
        else:
            # Add a space if the line isn't empty and append the token
            if state['current_line']:
                state['current_line'] += decoded_token
            else:
                state['current_line'] = decoded_token
            state['current_line_length'] += len(decoded_token)
            print(state['current_line'], flush=True, end="")  # Print the current line

    return state

Now we can decode and print our model outputs throughout token generation by accessing our decoding function through `nnsight.local()`.

In [None]:
import torch

nnsight.CONFIG.APP.REMOTE_LOGGING = False

prompt = "A press release is an official statement delivered to members of the news media for the purpose of"
# prompt = "Your favorite board game is"

print("Prompt: ",prompt,'\n', end ="")

# Initialize the state for decoding
state = {'current_line': '', 'current_line_length': 0}

with llama.generate(prompt, remote=True, max_new_tokens = 50) as generator:
    # Call .all() to apply to each new token
    llama.all()

    all_tokens = nnsight.list().save()

    # Access model output
    out = llama.lm_head.output.save()

    # Apply softmax to obtain probabilities and save the result
    probs = torch.nn.functional.softmax(out, dim=-1)
    max_probs = torch.max(probs, dim=-1)
    tokens = max_probs.indices.cpu().tolist()
    all_tokens.append(tokens[0]).save()

    with nnsight.local():
        state = my_decoding_function(tokens[0], llama, max_length=20, state=state)