# Testing notebook for Granite on Llama Stack

Setup instructions:
1. Install a fork of Llama Stack with Granite support
1. Install the latest version of the `llama_stack_client` package using `pip`
1. Download the `Llama3.2-11B-Vision-Instruct` Llama model (using the "llama" command) and the `granite-3.0-8b-instruct-r241014a` 
1. Configure a Llama Stack server with the `meta-reference` inference implementation serving the `Llama3.2-11B-Vision-Instruct` model
1. Add a `remote::granite` inference provider to your copy of Llama stack. You can configure this provider during the manual, step-by-step configuration, or you can just change the part of your server's YAML configuration file that reads:
    ```
    providers:
      inference:
      - provider_id: inline::meta-reference
        provider_type: inline::meta-reference
        config:
          model: Llama3.2-11B-Vision
          torch_seed: null
          max_seq_len: 4096
          max_batch_size: 1
          create_distributed_process_group: true
          checkpoint_dir: null
    ```
    ...so that it looks like this:
    ```
    providers:
      inference:
      - provider_id: inline::meta-reference
        provider_type: inline::meta-reference
        config:
          model: Llama3.2-11B-Vision
          torch_seed: null
          max_seq_len: 4096
          max_batch_size: 1
          create_distributed_process_group: true
          checkpoint_dir: null
      - provider_id: inline::granite
        provider_type: inline::granite
        config:
          modeldir: dmf_models
          backend: vllm
          preload_model_name: granite-3.0-8b-instruct-r241014a
    ```
    The `modeldir` parameter should point to the parent directory containing your local copy of `granite-3.0-8b-instruct-r241014a`. The Granite model should be in a directory with the same name as the model, because Granite model names are currently stored in name of the directory.
1. Add an entry in the `models` section of your YAML configuration file, something like this:
    ```
    models: 
    - model_id: Llama3.2-11B-Vision-Instruct
      provider_id: inline::meta-reference
      provider_model_id: Llama3.2-11B-Vision-Instruct
    - model_id: granite-3.0-8b-instruct-r241014a
      provider_id: inline::granite
      provider_model_id: granite-3.0-8b-instruct-r241014a
    ```
1. Start the server on `localhost` at port 5000. Server startup takes about a minute due to model loading overheads.

In [1]:
# Boilerplate goes here
import json
import termcolor
import textwrap
import pydantic

host = "localhost"
port = 5000
base_url = f"http://{host}:{port}"
_WRAP_CHARS = 80

# Import the Python client for Llama Stack. This client code breaks frequently
# due to breaking API changes on the server and a chaotic release schedule.
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url=base_url)
models_client = client.models
inference_client = client.inference
agents_client = client.agents

# The latest batch of breaking API changes broke the Llama Stack client code
# in the server.
# Alternate code reverse-ingineered from Llama Stack source
# import llama_stack.apis.inference.client
# inference_client = llama_stack.apis.inference.client.InferenceClient(base_url=base_url)
# import llama_stack.apis.agents.client
# agents_client = llama_stack.apis.agents.client.AgentsClient(base_url)
# import llama_stack.apis.models.client
# models_client = llama_stack.apis.models.client.ModelsClient(base_url=base_url)
# import llama_stack.apis.inference


In [2]:
# Define functions for pretty-printing results

async def print_result(result_future_future):
    """Common code for printing model outputs to stdout"""
    # Inference APIs currently return a corutine that returns a corutine
    # when called in non-streaming mode.
    result_future = await result_future_future
    result = await result_future
    
    # Result is of type ChatCompletionResponse
    #print(f"Raw result: {result}")

    role = result.completion_message.role
    content = result.completion_message.content
    tool_calls = result.completion_message.tool_calls
    
    if len(content) > 0:
        content_lines = content.split("\n")
        indent_str = (" " * (len(role) + 2))
        first_line = textwrap.fill(content_lines[0],
                                   subsequent_indent=indent_str,
                                   width=_WRAP_CHARS)
        remaining_lines = [textwrap.fill(l, width=_WRAP_CHARS) 
                           for l in content_lines[1:]]
        pretty_role = termcolor.colored(role, color="red")
        print(f"{pretty_role}: {first_line}")
        print("\n".join(remaining_lines))
    if len(tool_calls) > 0:
        print("Tool calls:")
        for t in tool_calls:
            print(f"   {t}")
   

def print_result_stream(result_future):
    """Common code for printing model outputs to stdout when the model
    is running in streaming mode.
    """
    
    label_str = None  # "assistant: " or "tool call: "
    
    cur_line_len = 0
    
    #result_generator = await result_future
    
    #async for chunk in result_generator:
    for chunk in result_future:
        # Result chunks are of type ChatCompletionResponseStreamChunk.
        # Note that this type is quite different from the 
        # ChatCompletionResponse object returned in non-streaming mode.
        event = chunk.event
    
        #if event.event_type is llama_stack.apis.inference.inference.ChatCompletionResponseEventType.progress:
        if event.event_type == "progress":
            # API requires us to discern tool calls from agent text
            # by checking Python types
            if isinstance(event.delta, str):
                is_tool_call = False
                delta_text = event.delta
            # elif isinstance(event.delta, llama_stack.apis.inference.inference.ToolCallDelta):
            #     is_tool_call = True
            #     # Default JSON serialization has no pretty-printing
            #     delta_text = json.dumps(
            #         json.loads(event.delta.model_dump_json()),
            #         indent=4
            #     )
            else:
                raise TypeError(f"Unexpected event delta type '{type(event.delta)}'")
            
            if label_str is None:
                label_str = "tool call: " if is_tool_call else "assistant: "
                pretty_label = termcolor.colored(label_str, color="red")
                print(pretty_label, end="", flush=True)
                cur_line_len = len(label_str)
            
            # Add carriage returns as needed
            while "\n" in delta_text:
                first_line, delta_text = delta_text.split("\n", 1)
                print(first_line)
                cur_line_len = 0
            
            if cur_line_len + len(delta_text) >= _WRAP_CHARS and delta_text.startswith(" "):
                print(f" \\\n{delta_text[1:]}", end="", flush=True)
                cur_line_len = len(delta_text) - 1
            else:
                print(delta_text, end="", flush=True)
                cur_line_len += len(delta_text)
            
        else:
            pass
            #print(f"Skipping event {event}")
                
    print()

# Model registry APIs

With the Granite connector installed, APIs for listing and registering Granite models are somewhat functional.


In [3]:
# This cell no longer works with the latest batch of breaking API changes

# List all available models. Should show both Llama and Granite models
#await models_client.list_models()

In [4]:
# This cell no longer works with the latest batch of breaking API changes

#await models_client.get_model('granite-8b-code-instruct-128k-r241015a')

In [5]:
# This cell no longer works with the latest batch of breaking API changes

# We can also register Granite models. This call should be a no-op, since the
# model gets registered when the connector starts up.
# await models_client.register_model(
#     llama_stack.apis.models.ModelDefWithProvider(
#         # Name used in URLs
#         identifier="granite-3.0-8b-instruct-r241014a",
        
#         # Name used by the inference provider
#         llama_model="granite-3.0-8b-instruct-r241014a",
        
#         # Should match field with same name in server's YAML config file
#         provider_id="remote::granite" 
#     )
# )

## Single prompt

Here we run the simplest type of chat completion request, first with a Llama model, then with a Granite model. Both requests are identical except for the model name.

In [6]:
result_future = inference_client.chat_completion(
    model_id="Llama3.2-11B-Vision-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Write a short movie trailer voiceover about the Cauchy–Schwarz inequality"
        }
    ],
    stream=True,
)
print_result_stream(result_future)

[31massistant: [0m(Ominous music starts playing)

Narrator (in a deep, dramatic voice): "In a world where vectors collide, one \
truth stands above the rest. A boundary that separates the possible from the \
impossible."

(Visuals of vectors intersecting, with a red line marking the limit)

Narrator: "It's the Cauchy-Schwarz inequality, the ultimate test of power and \
precision."

(Visuals of mathematicians working, with equations flashing on screen)

Narrator: "A mathematical law that governs the inner workings of our universe, \
from the smallest particles to the vast expanse of space."

(Visuals of a dot product, with the numbers increasing)

Narrator: "It's a delicate balance, a dance of numbers and variables, where the \
slightest misstep can lead to disaster."

(Visuals of a vector going beyond the limit, with a dramatic explosion)

Narrator: "But what happens when the limits are pushed to the extreme? When the \
vectors collide, and the inequality is broken?"

(Visuals of a m

In [7]:
result_future = inference_client.chat_completion(
    model_id="granite-3.0-8b-instruct-r241014a",
    messages=[
        {
            "role": "user",
            "content": "Write a short movie trailer voiceover about the Cauchy–Schwarz inequality"
        }
    ],
    stream=True,
)
await print_result_stream(result_future)

[31massistant: [0m���

"In a world where numbers dance and equations sing, there exists a powerful theorem, \
a silent guardian of mathematical harmony.

Introducing the Cauchy-Schwarz Inequality, a mathematical marvel that whispers, \
'The square of the sum of two numbers is less than or equal to the product of \
their sums.'

Witness as it weaves its magic, binding vectors and matrices in a dance of elegance \
and precision.

Experience the thrill as it unravels the secrets of inner products and norms, \
revealing the true potential of mathematical relationships.

The Cauchy-Schwarz Inequality: Where numbers meet their match, and harmony is \
the only law.

Coming soon to a theorem near you. ���"


TypeError: object NoneType can't be used in 'await' expression

In [9]:
# Same API call as above, in non-streaming mode
result_future = inference_client.chat_completion(
    model_id="granite-3.0-8b-instruct-r241014a",
    messages=[
        {
            "role": "user",
            "content": "Write a short movie trailer voiceover about the Cauchy–Schwarz inequality"
        }
    ],
    stream=False,
)
print_result(result_future)


InternalServerError: Error code: 500 - {'detail': 'Internal server error: An unexpected error occurred.'}

## Two turns plus system prompt

Here we test using a system prompt to alter how the assistant responds to subsequent user turns.
The system prompt here instructs the assistant to be rude instead of its default polite behavior.

In [None]:
result_generator = inference_client.chat_completion(
    model="Llama3.2-11B-Vision-Instruct",
    messages=[
        {
            "role": "system",
            "content": "You are a bitter and unhelpful customer service representative for Acme Communications. All of your responses are dripping with sarcasm."
        },
        {
            "role": "assistant",
            "content": "Welcome to Acme Communications, why should I bother assisting you?",
            "stop_reason": "end_of_turn"
        },
        {
            "role": "user",
            "content": "Hi, my cell phone caught fire and now it doesn't work."
        }
    ],
    stream=True,
)
await print_result_stream(result_generator)

In [None]:
result_generator = inference_client.chat_completion(
    model="granite-3.0-8b-instruct-r241014a",
    messages=[
        {
            "role": "system",
            "content": "You are a bitter and unhelpful customer service representative for Acme Communications. All of your responses are dripping with sarcasm."
        },
        {
            "role": "assistant",
            "content": "Welcome to Acme Communications, why should I bother assisting you?",
            "stop_reason": "end_of_turn"
        },
        {
            "role": "user",
            "content": "Hi, my cell phone caught fire and now it doesn't work."
        }
    ],
    stream=True,
)
await print_result_stream(result_generator)

## Multiple turns, no system prompt

Test whether the model uses context from earlier turns of the conversation when responding to the current user turn.

In [None]:
# Machine-generated Granite test data
original_messages = [
    {"from": "user", "value": "How can I change my Zoom virtual background?"},
    {"from": "assistant", "value": "You can change your Zoom virtual background by signing into the Zoom desktop client, clicking your profile picture, and selecting Settings. From there, click the Virtual Background tab and upload a new image or video to use as your background. You can also adjust the settings as needed before saving your changes."}, 
    {"from": "user", "value": "What is a green screen and how do I use it?"}, 
    {"from": "assistant", "value": "A green screen is a technique used in video production where a subject is filmed in front of a solid-colored background, usually green or blue. The background color is then replaced with a different image or video during post-production. To use a green screen in Zoom, you'll need to have a physical green screen behind you and check the I have a green screen checkbox in the Virtual Background tab."}, 
    {"from": "user", "value": "What if I don't have a green screen?"}, {"from": "assistant", "value": "If you don't have a green screen, you can still use a virtual background in Zoom without one. Just follow the steps to change your virtual background and select an image or video to use as your background. However, without a green screen, the virtual background may not be as effective in blocking out your actual background."}, 
    {"from": "user", "value": "What if I want to use a video as my virtual background?"}, 
    #{"from": "assistant", "value": "You can use a video as your virtual background in Zoom by following the same steps to change your virtual background and selecting a video file instead of an image file. Make sure the video file is in a supported format and that it meets the size and length requirements."}
]

# Convert to Llama Stack format
formatted_messages = [
    {"role": m["from"], "content": m["value"], "stop_reason": "end_of_turn"} 
    if m["from"] == "assistant"
    else {"role": m["from"], "content": m["value"]}
    for m in original_messages 
]
formatted_messages

In [None]:
result_generator =  inference_client.chat_completion(
    model="Llama3.2-11B-Vision-Instruct",
    messages=formatted_messages,
    stream=True,
)
await print_result_stream(result_generator)

In [None]:
result_generator = inference_client.chat_completion(
    model="granite-3.0-8b-instruct-r241014a",
    messages=formatted_messages,
    stream=True,
)
await print_result_stream(result_generator)

## Single prompt plus tool catalog

In [None]:
TOOLS_FROM_EXAMPLE_CODE = [
    {
        "name": "get_current_weather",
        "description": "Get the current weather",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA"
                }
            },
            "required": ["location"]
        }
    },
    {
        "name": "get_stock_price",
        "description": "Retrieves the current stock price for a given ticker symbol. The ticker symbol must be a valid symbol for a publicly traded company on a major US stock exchange like NYSE or NASDAQ. The tool will return the latest trade price in USD. It should be used when the user asks about the current or most recent price of a specific stock. It will not provide any other information about the stock or company.",
        "parameters": {
            "type": "object",
            "properties": {
                "ticker": {
                    "type": "string",
                    "description": "The stock ticker symbol, e.g. AAPL for Apple Inc."
                }
            },
            "required": ["ticker"]
        }
    }
]


In [None]:
# Tool definitions from the example code
get_current_weather_tool_def = llama_stack.apis.inference.ToolDefinition(
    tool_name="get_current_weather",
    description="Get the current weather",
    parameters={
        "location": llama_stack.apis.inference.ToolParamDefinition(
            param_type="string",
            description="The city and state, e.g. San Francisco, CA",
            required=True
        )
    }
)
get_stock_price_tool_def = llama_stack.apis.inference.ToolDefinition(
    tool_name="get_stock_price",
    description="Retrieves the current stock price for a given ticker symbol. The ticker symbol must be a valid symbol for a publicly traded company on a major US stock exchange like NYSE or NASDAQ. The tool will return the latest trade price in USD. It should be used when the user asks about the current or most recent price of a specific stock. It will not provide any other information about the stock or company.",
    parameters={
        "ticker": llama_stack.apis.inference.ToolParamDefinition(
            param_type="string",
            description="The stock ticker symbol, e.g. AAPL for Apple Inc.",
            required=True
        )
    }
)
tools_list = [get_current_weather_tool_def, get_stock_price_tool_def]

In [None]:
result_generator = inference_client.chat_completion(
    model="Llama3.2-11B-Vision-Instruct",
    tools=tools_list,
    messages=[
        {
            "role": "user",
            "content": "What's the weather today?"
        }
    ],
    stream=True,
)
await print_result_stream(result_generator)

In [None]:
result_generator = inference_client.chat_completion(
    model="granite-3.0-8b-instruct-r241014a",
    tools=tools_list,
    messages=[
        {
            "role": "user",
            #"content": "What's the weather today?"  # Missing location
            #"content": "What's the weather in Springfield today?"  # Could be any of 67 locations
            "content": "What's the weather in Springfield? The one in in New Hampshire."
        }
    ],
    stream=True,
)
await print_result_stream(result_generator)

## Single prompt with RAG data

RAG documents are supposed to be passed in via the undocumented `context` element of `UserMessage`. 

There is an undocumented set of special delimiters that are supposed to be used when passing RAG documents to Llama models.

The format of the example inputs below is reverse-engineered from the code in [`agent_instance.py`](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/impls/meta_reference/agents/agent_instance.py)

Here we use two short snippets containing some jargon that is unlikely to appear in either model's training data. This input emulates a scenario where the model answers questions about technical documents from a vertical domain such as jet engine repair.

In [None]:
# RAG document data in the form that Llama 3.2 apparently expects to
# receive, assuming that the Llama Stack developers knew what they 
# were doing.
RAG_DOCS_STRING = """Here are the retrieved documents for relevant context:
=== START-RETRIEVED-CONTEXT ===

id:585e0e26-16ac-42a0-a26b-cd46fce1e53b; content:The right way to smurgulate a brown floopydoodle is to deconfabulate its flipflop.
id:94e8c7a8-0657-4ce1-aef9-aef581917118; content:If you want to smurgulate a green floopydoodle, you should augment its deblogulator.

=== END-RETRIEVED-CONTEXT ===
"""

In [None]:
result_generator = inference_client.chat_completion(
    model="Llama3.2-11B-Vision-Instruct",
    messages=[
        {
            "role": "user",
            "content": ("Hi, I would like to know how to smurgulate my floopydoodle. "
                        "The floopydoodle is brown."),
            "context": RAG_DOCS_STRING
        }
    ],
    stream=True,
)
await print_result_stream(result_generator)

In [None]:
result_generator = inference_client.chat_completion(
    model="granite-3.0-8b-instruct-r241014a",
    messages=[
        {
            "role": "user",
            "content": ("Hi, I would like to know how to smurgulate my floopydoodle. "
                        "The floopydoodle is brown."),
            "context": RAG_DOCS_STRING
        }
    ],
    stream=True,
)
await print_result_stream(result_generator)

## Basic structured output

Structured output for JSON schemas is hooked up end-to-end. As of this writing,
Granite models are not fine-tuned for this type of constrained decoding.

In [None]:
class FormatThatTheModelIsSuppoedToProduce(pydantic.BaseModel):
    wrong_answer: str
    correct_answer: str
    city: str
    county: str
    state: str
    country: str
    continent: str

response_format = llama_stack.apis.inference.JsonSchemaResponseFormat(
    json_schema=FormatThatTheModelIsSuppoedToProduce.model_json_schema(),
)

In [None]:
result_generator = inference_client.chat_completion(
    model="Llama3.2-11B-Vision-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Where is the world's largest ball of twine?"
        }
    ],
    response_format=llama_stack.apis.inference.JsonSchemaResponseFormat(
        json_schema=FormatThatTheModelIsSuppoedToProduce.model_json_schema(),
    ),
    stream=True,
)
await print_result_stream(result_generator)

In [None]:
result_generator = inference_client.chat_completion(
    model="granite-3.0-8b-instruct-r241014a",
    messages=[
        {
            "role": "user",
            "content": "Where is the Washington Monument?"
            #"content": "Where is the world's largest ball of twine?"
        }
    ],
    response_format=llama_stack.apis.inference.JsonSchemaResponseFormat(
        json_schema=FormatThatTheModelIsSuppoedToProduce.model_json_schema(),
    ),
    sampling_params={
        "temperature": 1.0
    },
    stream=True,
)
await print_result_stream(result_generator)

## Basic interaction via Agent API

We start by repeating the previous RAG example, using attachments on the 
last message to pass in the documents.

In [None]:
#import llama_stack.apis.agents.agents
from llama_models.datatypes import SamplingParams

# Agent initialization arguments shared across different models
COMMON_ARGS = {
    # Instructions string from the Llama Stack example code. 
    # No documentation on what else we could put here.
    "instructions": "You are a helpful assistant.",
    
    # Haven't tested what this does yet.
    "enable_session_persistence": False,
    
    # Attachments will trigger an Error 500 unless you passed a
    # "memory tool definition" to the agent API on initialization.
    "tools": [
        llama_stack.apis.agents.agents.MemoryToolDefinition(
            max_tokens_in_context=2048,
            memory_bank_configs=[],
        ),
    ]
}

llama_agent_config = llama_stack.apis.agents.agents.AgentConfig(
    model="Llama3.2-11B-Vision-Instruct",
    **COMMON_ARGS
)
granite_agent_config = llama_stack.apis.agents.agents.AgentConfig(
    model="granite-3.0-8b-instruct-r241014a",
    **COMMON_ARGS
)
llama_agent_id = (await agents_client.create_agent(llama_agent_config)).agent_id
granite_agent_id = (await agents_client.create_agent(granite_agent_config)).agent_id
llama_agent_id, granite_agent_id

In [None]:
agent_id = llama_agent_id
session_id = (await agents_client.create_agent_session(
    agent_id=agent_id, session_name="session"
)).session_id
session_id

attachments = [
    llama_stack.apis.agents.agents.Attachment(
        content="The right way to smurgulate a brown floopydoodle is to deconfabulate its flipflop.",
        mime_type="text/plain"
    ),
    llama_stack.apis.agents.agents.Attachment(
        content="If you want to smurgulate a green floopydoodle, you should augment its deblogulator.",
        mime_type="text/plain"
    ),
]


result_generator = await agents_client.create_agent_turn(
    llama_stack.apis.agents.agents.AgentTurnCreateRequest(
        agent_id=agent_id,
        session_id=session_id,
        messages=[
            llama_stack.apis.agents.agents.UserMessage(
                content=("Hi, I would like to know how to smurgulate my floopydoodle. "
                         "The floopydoodle is brown.")
            )
        ],
        attachments=attachments,
        stream=True
    )
)

async for result in result_generator:
    print(
        json.dumps(
            json.loads(result.model_dump_json()),
            indent=4
        )
    )

In [None]:
agent_id = granite_agent_id
session_id = (await agents_client.create_agent_session(
    agent_id=agent_id, session_name="session"
)).session_id
session_id

attachments = [
    llama_stack.apis.agents.agents.Attachment(
        content="The right way to smurgulate a brown floopydoodle is to deconfabulate its flipflop.",
        mime_type="text/plain"
    ),
    llama_stack.apis.agents.agents.Attachment(
        content="If you want to smurgulate a green floopydoodle, you should augment its deblogulator.",
        mime_type="text/plain"
    ),
]


result_generator = await agents_client.create_agent_turn(
    llama_stack.apis.agents.agents.AgentTurnCreateRequest(
        agent_id=agent_id,
        session_id=session_id,
        messages=[
            llama_stack.apis.agents.agents.UserMessage(
                content=("Hi, I would like to know how to smurgulate my floopydoodle. "
                         "The floopydoodle is brown.")
            )
        ],
        attachments=attachments,
        stream=True
    )
)

async for result in result_generator:
    print(
        json.dumps(
            json.loads(result.model_dump_json()),
            indent=4
        )
    )