# Custom Model Clients

This guide shows how to create custom model clients by subclassing the {py:class}`~autogen_core.models.ChatCompletionClient` base class. Custom model clients allow you to integrate any LLM or AI service that isn't directly supported by AutoGen's built-in model clients.

In this tutorial, we'll create a custom model client using HuggingFace Transformers as the underlying model API.

## Understanding the ChatCompletionClient Interface

All model clients in AutoGen implement the {py:class}`~autogen_core.models.ChatCompletionClient` protocol. This abstract base class defines two key methods that you must implement:

- `create()`: Creates a single response from the model
- `create_stream()`: Creates a stream of string chunks from the model ending with a CreateResult

Let's examine the interface:

In [None]:
from autogen_core.models import ChatCompletionClient
import inspect

# View the abstract methods that need to be implemented
abstract_methods = [method for method in dir(ChatCompletionClient) 
                   if getattr(getattr(ChatCompletionClient, method, None), '__isabstractmethod__', False)]
print("Abstract methods to implement:")
for method in abstract_methods:
    print(f"- {method}")

# Show the signature of the create method
print("\nCreate method signature:")
print(inspect.signature(ChatCompletionClient.create))

## Required Imports and Setup

Let's import all the necessary modules for creating our custom model client:

In [None]:
import asyncio
import warnings
from typing import AsyncGenerator, List, Literal, Mapping, Optional, Sequence, Union, Any

from pydantic import BaseModel
from transformers import pipeline, Pipeline
import torch

from autogen_core import CancellationToken, Component
from autogen_core.models import (
    ChatCompletionClient,
    CreateResult,
    LLMMessage,
    RequestUsage,
    SystemMessage,
    UserMessage,
    AssistantMessage,
    ModelCapabilities,
    ModelInfo,
)
from autogen_core.tools import Tool, ToolSchema

## Creating a HuggingFace Custom Model Client

Now let's implement our custom model client. This example uses HuggingFace's text generation pipeline to create a simple chat completion client:

In [None]:
class HuggingFaceConfig(BaseModel):
    """Configuration for HuggingFace model client."""
    model_name: str = "microsoft/DialoGPT-small"
    max_new_tokens: int = 100
    temperature: float = 0.7
    do_sample: bool = True
    device: Optional[str] = None  # Let HF auto-detect if None


class HuggingFaceChatCompletionClient(
    ChatCompletionClient, Component[HuggingFaceConfig]
):
    """A custom model client that uses HuggingFace transformers.
    
    This example demonstrates how to create a custom model client by:
    1. Subclassing ChatCompletionClient
    2. Implementing the required abstract methods
    3. Handling message conversion and response formatting
    """
    
    component_type = "HuggingFaceChatCompletionClient"
    component_config_schema = HuggingFaceConfig
    
    def __init__(self, **kwargs):
        # Initialize the Component base class
        super().__init__()
        
        # Parse configuration
        config = HuggingFaceConfig(**kwargs)
        
        # Initialize the HuggingFace pipeline
        self._pipeline = pipeline(
            "text-generation",
            model=config.model_name,
            device=0 if torch.cuda.is_available() and config.device != "cpu" else -1,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        )
        
        # Store configuration
        self._config = config
        
        # Initialize usage tracking
        self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
        
        # Model capabilities - adjust based on your model's capabilities
        self._model_info = ModelInfo(
            vision=False,
            function_calling=False,
            json_output=False,
        )
    
    @property
    def capabilities(self) -> ModelCapabilities:
        """Return the capabilities of this model."""
        return self._model_info
    
    @property
    def model_info(self) -> ModelInfo:
        """Return the model information."""
        return self._model_info
    
    @property
    def total_usage(self) -> RequestUsage:
        """Return the total token usage for this client."""
        return self._total_usage
    
    @property
    def actual_usage(self) -> RequestUsage:
        """Return the actual usage (same as total for this implementation)."""
        return self._total_usage
    
    async def close(self) -> None:
        """Close the client and clean up resources."""
        # For this implementation, we don't need to close anything
        # but in a real implementation you might need to close connections, etc.
        pass
    
    def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
        """Count the number of tokens in the messages."""
        prompt = self._messages_to_prompt(messages)
        return self._estimate_tokens(prompt)
    
    def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
        """Return the number of remaining tokens for the context window.
        
        For this simple implementation, we'll assume a context window of 2048 tokens.
        In a real implementation, you'd use the actual model's context window size.
        """
        used_tokens = self.count_tokens(messages, tools=tools)
        context_window = 2048  # Assumed context window size
        return max(0, context_window - used_tokens)
    
    def _messages_to_prompt(self, messages: Sequence[LLMMessage]) -> str:
        """Convert a sequence of messages to a prompt string.
        
        This is a simple implementation that concatenates messages.
        For production use, you might want to use the model's specific
        chat template if available.
        """
        prompt_parts = []
        
        for message in messages:
            if isinstance(message, SystemMessage):
                prompt_parts.append(f"System: {message.content}")
            elif isinstance(message, UserMessage):
                # Handle both string and list content
                if isinstance(message.content, str):
                    prompt_parts.append(f"User: {message.content}")
                else:
                    # For complex content (images, etc.), extract text only
                    text_content = ""
                    for item in message.content:
                        if hasattr(item, 'text'):
                            text_content += item.text
                    prompt_parts.append(f"User: {text_content}")
            elif isinstance(message, AssistantMessage):
                if isinstance(message.content, str):
                    prompt_parts.append(f"Assistant: {message.content}")
                # Note: This simple implementation doesn't handle function calls
        
        prompt = "\n".join(prompt_parts)
        prompt += "\nAssistant:"  # Prompt for the next assistant response
        return prompt
    
    def _estimate_tokens(self, text: str) -> int:
        """Simple token estimation. For production, use proper tokenization."""
        # Very rough estimation: assume ~4 characters per token
        return len(text) // 4
    
    async def create(
        self,
        messages: Sequence[LLMMessage],
        *,
        tools: Sequence[Tool | ToolSchema] = [],
        tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
        json_output: Optional[bool | type[BaseModel]] = None,
        extra_create_args: Mapping[str, Any] = {},
        cancellation_token: Optional[CancellationToken] = None,
    ) -> CreateResult:
        """Create a single response from the model."""
        
        # Check for unsupported features
        if tools:
            raise ValueError("Tool calling is not supported by this model client")
        if json_output:
            raise ValueError("JSON output is not supported by this model client")
        
        # Convert messages to prompt
        prompt = self._messages_to_prompt(messages)
        
        # Estimate input tokens
        prompt_tokens = self._estimate_tokens(prompt)
        
        # Prepare generation arguments
        generation_args = {
            "max_new_tokens": self._config.max_new_tokens,
            "temperature": self._config.temperature,
            "do_sample": self._config.do_sample,
            "return_full_text": False,  # Only return the generated part
            "pad_token_id": self._pipeline.tokenizer.eos_token_id,
        }
        generation_args.update(extra_create_args)
        
        # Generate response using asyncio to avoid blocking
        def _generate():
            return self._pipeline(prompt, **generation_args)
        
        # Run in thread pool to avoid blocking the event loop
        loop = asyncio.get_event_loop()
        result = await loop.run_in_executor(None, _generate)
        
        # Extract the generated text
        if isinstance(result, list) and len(result) > 0:
            generated_text = result[0]["generated_text"].strip()
        else:
            generated_text = ""
        
        # Estimate completion tokens
        completion_tokens = self._estimate_tokens(generated_text)
        
        # Create usage info
        usage = RequestUsage(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
        )
        
        # Update total usage
        self._total_usage = RequestUsage(
            prompt_tokens=self._total_usage.prompt_tokens + usage.prompt_tokens,
            completion_tokens=self._total_usage.completion_tokens + usage.completion_tokens,
        )
        
        # Return the result
        return CreateResult(
            finish_reason="stop",
            content=generated_text,
            usage=usage,
            cached=False,
        )
    
    async def create_stream(
        self,
        messages: Sequence[LLMMessage],
        *,
        tools: Sequence[Tool | ToolSchema] = [],
        tool_choice: Tool | Literal["auto", "required", "none"] = "auto",
        json_output: Optional[bool | type[BaseModel]] = None,
        extra_create_args: Mapping[str, Any] = {},
        cancellation_token: Optional[CancellationToken] = None,
    ) -> AsyncGenerator[Union[str, CreateResult], None]:
        """Create a stream of string chunks from the model.
        
        Note: This is a simplified implementation that simulates streaming
        by generating the full response and then yielding it in chunks.
        For true streaming, you would need to use a streaming-capable model
        or API.
        """
        
        # Generate the full response first
        result = await self.create(
            messages,
            tools=tools,
            tool_choice=tool_choice,
            json_output=json_output,
            extra_create_args=extra_create_args,
            cancellation_token=cancellation_token,
        )
        
        # Simulate streaming by yielding chunks
        if isinstance(result.content, str):
            # Split into words and yield each word as a chunk
            words = result.content.split()
            for i, word in enumerate(words):
                chunk = word + (" " if i < len(words) - 1 else "")
                yield chunk
                # Small delay to simulate streaming
                await asyncio.sleep(0.01)
        
        # Yield the final result
        yield result

## Using the Custom Model Client

Now let's test our custom model client. We'll create an instance and use it to generate responses:

In [None]:
# Create an instance of our custom model client
# Note: This will download the model on first use
model_client = HuggingFaceChatCompletionClient(
    model_name="microsoft/DialoGPT-small",  # Small model for quick testing
    max_new_tokens=50,
    temperature=0.7
)

print("Custom model client created successfully!")
print(f"Model capabilities: {model_client.capabilities}")

### Basic Chat Completion

Let's test the basic `create` method:

In [None]:
# Create some test messages
messages = [
    SystemMessage(content="You are a helpful assistant.", source="system"),
    UserMessage(content="Hello! How are you today?", source="user")
]

# Generate a response
response = await model_client.create(messages)

print(f"Response: {response.content}")
print(f"Finish reason: {response.finish_reason}")
print(f"Usage: {response.usage}")
print(f"Total usage so far: {model_client.total_usage}")

### Streaming Response

Now let's test the streaming functionality:

In [None]:
# Test streaming with a different conversation
messages = [
    SystemMessage(content="You are a creative storyteller.", source="system"),
    UserMessage(content="Tell me a short story about a robot.", source="user")
]

print("Streaming response:")
print("---")

final_result = None
async for chunk in model_client.create_stream(messages):
    if isinstance(chunk, str):
        print(chunk, end="", flush=True)
    else:
        # This is the final CreateResult
        final_result = chunk

print("\n---")
print(f"\nFinal result usage: {final_result.usage if final_result else 'N/A'}")
print(f"Total usage: {model_client.total_usage}")

## Using with AutoGen Agents

Your custom model client can be used with AutoGen agents just like any built-in client. Here's a simple example:

In [None]:
from autogen_core import RoutedAgent, MessageContext, message_handler
from dataclasses import dataclass

@dataclass
class ChatMessage:
    content: str

class CustomModelAgent(RoutedAgent):
    """A simple agent that uses our custom model client."""
    
    def __init__(self, model_client: HuggingFaceChatCompletionClient):
        super().__init__("Custom agent using HuggingFace model")
        self._model_client = model_client
        self._system_message = SystemMessage(
            content="You are a helpful AI assistant powered by a custom HuggingFace model.",
            source="system"
        )
    
    @message_handler
    async def handle_chat_message(self, message: ChatMessage, ctx: MessageContext) -> str:
        # Create messages for the model
        messages = [
            self._system_message,
            UserMessage(content=message.content, source="user")
        ]
        
        # Generate response using our custom model client
        response = await self._model_client.create(messages)
        
        return response.content

# Create an agent with our custom model client
agent = CustomModelAgent(model_client)
print("Custom agent created successfully!")
print(f"Agent created successfully")

## Best Practices and Considerations

When creating custom model clients, consider the following:

### 1. Error Handling
Implement proper error handling for network issues, model failures, and invalid inputs:

In [None]:
# Example of adding error handling to our client
class RobustHuggingFaceChatCompletionClient(HuggingFaceChatCompletionClient):
    """Enhanced version with better error handling."""
    
    async def create(
        self,
        messages: Sequence[LLMMessage],
        **kwargs
    ) -> CreateResult:
        try:
            return await super().create(messages, **kwargs)
        except Exception as e:
            # Log the error and provide a graceful fallback
            print(f"Error in model generation: {e}")
            return CreateResult(
                finish_reason="error",
                content="I apologize, but I encountered an error while processing your request.",
                usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
                cached=False,
            )

print("Enhanced client with error handling defined.")

### 2. Configuration Management
Use Pydantic models for configuration to ensure type safety and validation:

In [None]:
from pydantic import Field, validator

class AdvancedHuggingFaceConfig(BaseModel):
    """More comprehensive configuration with validation."""
    model_name: str = Field(description="HuggingFace model name or path")
    max_new_tokens: int = Field(default=100, ge=1, le=2048, description="Maximum tokens to generate")
    temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature")
    top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Nucleus sampling parameter")
    device: Optional[str] = Field(default=None, description="Device to run the model on")
    
    @validator('model_name')
    def validate_model_name(cls, v):
        if not v.strip():
            raise ValueError("Model name cannot be empty")
        return v

# Example usage
config = AdvancedHuggingFaceConfig(
    model_name="microsoft/DialoGPT-small",
    max_new_tokens=100,
    temperature=0.8
)
print(f"Configuration: {config}")

### 3. Token Usage Tracking
Implement accurate token counting for cost tracking and rate limiting:

In [None]:
# Example of more accurate token counting
def accurate_token_count(text: str, tokenizer) -> int:
    """More accurate token counting using the model's actual tokenizer."""
    if hasattr(tokenizer, 'encode'):
        return len(tokenizer.encode(text))
    else:
        # Fallback to rough estimation
        return len(text) // 4

print("Token counting utility defined.")

### 4. Model Capabilities Declaration
Accurately declare your model's capabilities to ensure proper integration:

In [None]:
# Example of different capability configurations
def get_model_capabilities(model_name: str) -> ModelInfo:
    """Return capabilities based on the specific model."""
    
    # This is a simplified example - in practice, you'd have a more
    # comprehensive mapping of model capabilities
    if "gpt" in model_name.lower():
        return ModelInfo(
            vision=False,
            function_calling=True,  # Many GPT models support function calling
            json_output=True,
        )
    elif "vision" in model_name.lower():
        return ModelInfo(
            vision=True,
            function_calling=False,
            json_output=False,
        )
    else:
        # Conservative defaults
        return ModelInfo(
            vision=False,
            function_calling=False,
            json_output=False,
        )

print("Model capability detection utility defined.")

## Advanced Features

### Supporting Function Calling
If your model supports function calling, implement tool handling:

In [None]:
# Placeholder for function calling implementation
# This would require significant additional code to handle tools properly
def handle_tools(tools: Sequence[Tool | ToolSchema], response_text: str):
    """Extract and handle function calls from the model response.
    
    This is a placeholder - actual implementation would depend on
    how your model formats function calls in its output.
    """
    # Parse the response for function calls
    # Convert to FunctionCall objects
    # Return appropriate content type
    pass

print("Function calling placeholder defined.")

### Supporting JSON Output
For structured output, implement JSON mode:

In [None]:
import json

def handle_json_output(json_output: bool | type[BaseModel], prompt: str) -> str:
    """Modify the prompt to encourage JSON output.
    
    This is a simple approach - more sophisticated implementations
    might use guided generation or constrained decoding.
    """
    if json_output is True:
        return prompt + "\n\nPlease respond with valid JSON only."
    elif isinstance(json_output, type) and issubclass(json_output, BaseModel):
        schema = json_output.model_json_schema()
        return f"{prompt}\n\nPlease respond with JSON that matches this schema: {json.dumps(schema)}"
    return prompt

print("JSON output handling utility defined.")

## Summary

In this guide, we've covered:

1. **Understanding the ChatCompletionClient interface** - The abstract methods you need to implement
2. **Creating a basic custom client** - Using HuggingFace Transformers as the backend
3. **Implementing required methods** - Both `create()` and `create_stream()`
4. **Integration with AutoGen** - How to use your custom client with agents
5. **Best practices** - Error handling, configuration, and capability declaration
6. **Advanced features** - Function calling and JSON output support

### Key Takeaways

- Custom model clients must implement the `ChatCompletionClient` abstract interface
- Use Pydantic models for configuration and type safety
- Properly handle async operations to avoid blocking the event loop
- Accurately declare model capabilities to ensure proper integration
- Implement proper error handling for production use
- Consider token usage tracking for cost management

### Next Steps

- Adapt this example to use your specific model or API
- Implement additional features like function calling or vision support
- Add comprehensive error handling and logging
- Create tests for your custom model client
- Consider packaging your client as a reusable component

For more examples and patterns, see the existing model client implementations in the `autogen_ext.models` package.