# Basic Completion and Embedding Examples

## Completion


In [2]:
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

import os
from collections.abc import AsyncIterator, Iterator

from dotenv import load_dotenv
from graphrag_llm.completion import LLMCompletion, create_completion
from graphrag_llm.config import AuthMethod, ModelConfig
from graphrag_llm.types import LLMCompletionChunk, LLMCompletionResponse
from graphrag_llm.utils import (
    gather_completion_response,
    gather_completion_response_async,
)

load_dotenv()

api_key = os.getenv("GRAPHRAG_API_KEY")
model_config = ModelConfig(
    model_provider="azure",
    model=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    azure_deployment_name=os.getenv("GRAPHRAG_MODEL", "gpt-4o"),
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version=os.getenv("GRAPHRAG_API_VERSION", "2025-04-01-preview"),
    api_key=api_key,
    auth_method=AuthMethod.AzureManagedIdentity if not api_key else AuthMethod.ApiKey,
)
llm_completion: LLMCompletion = create_completion(model_config)

response: LLMCompletionResponse | Iterator[LLMCompletionChunk] = (
    llm_completion.completion(
        messages="What is the capital of France?",
    )
)

if isinstance(response, Iterator):
    # Streaming response
    for chunk in response:
        print(chunk.choices[0].delta.content or "", end="", flush=True)
else:
    # Non-streaming response
    print(response.choices[0].message.content)

# Alternatively, you can use the utility function to gather the full response
# The following is equivalent to the above logic. If all you care about is
# the first choice response then you can use the gather_completion_response
# utility function.
response_text = gather_completion_response(response)
print(response_text)

The capital of France is Paris.
The capital of France is Paris.


## Async Completion


In [3]:
response = await llm_completion.completion_async(
    messages="What is the capital of France?",
)

response_text = await gather_completion_response_async(response)
print(response_text)

The capital of France is Paris.


## Streaming Completion


In [4]:
response = llm_completion.completion(
    messages="What is the capital of France?",
    stream=True,
)

if isinstance(response, Iterator):
    # Streaming response
    for chunk in response:
        print(chunk.choices[0].delta.content or "", end="", flush=True)

# If you don't actually care about streaming and just want the full response
# you can use the utility function to gather the full response
# response_text = gather_completion_response(response)  # noqa: ERA001
# print(response_text)  # noqa: ERA001

The capital of France is Paris.

## Async Streaming Completion


In [5]:
response = await llm_completion.completion_async(
    messages="What is the capital of France?",
    stream=True,
)

if isinstance(response, AsyncIterator):
    # Streaming response
    async for chunk in response:
        print(chunk.choices[0].delta.content or "", end="", flush=True)

# If you don't actually care about streaming and just want the full response
# you can use the utility function to gather the full response
# response_text = await gather_completion_response_async(response)  # noqa: ERA001
# print(response_text)  # noqa: ERA001

The capital of France is Paris.

## Completion Arguments

The completion API adheres to litellm completion API and thus the OpanAI SDK API. The `messages` parameter can be one of the following:

- `str`: Raw string for the prompt.
- `list[dict[str, Any]]`: A list of dicts in the form `{"role": "user|system|...", "content": "..."}`
- `list[ChatCompletionMessageParam]`: A list of OpenAI `ChatCompletionMessageParam`. `graphrag_llm.utils` provides a `ChatCompletionMessageParamBuilder` to help construct these objects. See the message builder notebook for more details on using `ChatCompletionMessageParamBuilder`.


In [6]:
from graphrag_llm.utils import (
    CompletionMessagesBuilder,
)

# raw string input
response1 = llm_completion.completion(messages="What is the capital of France?")
print(gather_completion_response(response1))

# list of message dicts input
response2 = llm_completion.completion(
    messages=[{"role": "user", "content": "What is the capital of France?"}]
)
print(gather_completion_response(response2))

# using the builder to create complex message
messages = (
    CompletionMessagesBuilder()
    .add_system_message(
        "You are a helpful assistant that likes to talk like a pirate. Respond as if you are a pirate using pirate speak."
    )
    .add_user_message("Is pluto a planet? Respond with a yes or no.")
    .add_assistant_message("Aye, matey! Pluto be a planet in me book.")
    .add_user_message("Are you sure? I want the truth. Can you elaborate?")
    .build()
)

response3 = llm_completion.completion(messages=messages)
print(gather_completion_response(response3))

The capital of France is Paris.
The capital of France is Paris.
Argh, ye caught me, matey! The truth be that in 2006, them landlubbers at the International Astronomical Union decided Pluto be reclassified as a "dwarf planet." So, by their reckonin', it ain't considered a full-fledged planet no more. But to this ol' sea dog, she'll always be a planet at heart!


## Embedding


In [7]:
from graphrag_llm.embedding import LLMEmbedding, create_embedding
from graphrag_llm.types import LLMEmbeddingResponse
from graphrag_llm.utils import gather_embeddings

embedding_config = ModelConfig(
    model_provider="azure",
    model=os.getenv("GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small"),
    azure_deployment_name=os.getenv(
        "GRAPHRAG_LLM_EMBEDDING_MODEL", "text-embedding-3-small"
    ),
    api_base=os.getenv("GRAPHRAG_API_BASE"),
    api_version=os.getenv("GRAPHRAG_API_VERSION", "2025-04-01-preview"),
    api_key=api_key,
    auth_method=AuthMethod.AzureManagedIdentity if not api_key else AuthMethod.ApiKey,
)

llm_embedding: LLMEmbedding = create_embedding(embedding_config)

embeddings_batch: LLMEmbeddingResponse = llm_embedding.embedding(
    input=["Hello world", "How are you?"]
)
for data in embeddings_batch.data:
    print(data.embedding[0:3])

# OR
batch = gather_embeddings(embeddings_batch)
for embedding in batch:
    print(embedding[0:3])

[-0.002078542485833168, -0.04908587411046028, 0.020946789532899857]
[0.027567066252231598, -0.026544300839304924, -0.027091361582279205]
[-0.002078542485833168, -0.04908587411046028, 0.020946789532899857]
[0.027567066252231598, -0.026544300839304924, -0.027091361582279205]


## Async Embedding


In [8]:
embeddings_batch = await llm_embedding.embedding_async(
    input=["Hello world", "How are you?"]
)

for data in embeddings_batch.data:
    print(data.embedding[0:3])

[-0.002078542485833168, -0.04908587411046028, 0.020946789532899857]
[0.027567066252231598, -0.026544300839304924, -0.027091361582279205]
