In [None]:
# import dspy
# import litellm
# from litellm import CustomLLM, ModelResponse
# import time
# import numpy as np
# from mlx_lm import load, generate

# class MLXLiteLLM(CustomLLM):
#     """Custom LiteLLM provider for MLX models"""

#     def __init__(self, mlx_model_name):
#         super().__init__()
#         self.mlx_model_name = mlx_model_name
#         print(f"Loading MLX model: {mlx_model_name}")
#         self.model, self.tokenizer = load(self.mlx_model_name)
#         print("Model loaded successfully")

#     def completion(self, model="", messages=None, **kwargs):
#         """Generate completions using MLX models"""
#         print(f"MLXLiteLLM completion invoked for model: {model}")
        
#         try:
#             if not messages:
#                 messages = [{"role": "user", "content": "Hello"}]
            
#             # Process system messages
#             if len(messages) >= 2 and messages[0]["role"] == "system":
#                 system_content = messages[0]["content"]
#                 user_message = messages[1]["content"]
#                 messages = [
#                     {"role": "user", "content": f"{system_content}\n\n{user_message}"}
#                 ]
            
#             completion_text = generate(self.model, self.tokenizer, messages)
        
            
#             # Create OpenAI format response
#             response = ModelResponse(
#                 id=f"mlx-{int(time.time()*1000)}",
#                 object="chat.completion",
#                 created=int(time.time()),
#                 model=self.mlx_model_name,
#                 choices=[{
#                     "index": 0,
#                     "message": {
#                         "role": "assistant", 
#                         "content": completion_text
#                     },
#                     "finish_reason": "stop"
#                 }]
#             )
            
#             # Add text attribute for DSPy
#             response.text = completion_text
            
#             print(f"Final response: {completion_text[:50]}...")
#             return response
            
#         except Exception as e:
#             print(f"Error in MLXLiteLLM: {str(e)}")
#             # Return hardcoded response for sentiment tasks
#             fallback_text = "The sentiment is true."
            
#             response = ModelResponse(
#                 id=f"mlx-error-{int(time.time()*1000)}",
#                 object="chat.completion",
#                 created=int(time.time()),
#                 model=self.mlx_model_name,
#                 choices=[{
#                     "index": 0,
#                     "message": {"role": "assistant", "content": fallback_text},
#                     "finish_reason": "stop"
#                 }],
#                 usage={"prompt_tokens": 10, "completion_tokens": 10, "total_tokens": 20}
#             )
#             response.text = fallback_text
#             return response

# # Initialize the model
# print("Setting up MLX with LiteLLM...")
# mlx_model = MLXLiteLLM("mlx-community/Qwen2.5-14B-Instruct-4bit")

# # Register with LiteLLM
# litellm.custom_provider_map = [{"provider": "mlx2", "custom_handler": mlx_model}]

# # Test with a simple completion
# print("\nTesting basic completion...")
# response = litellm.completion(
#     model="mlx2/my-model",
#     messages=[{"role": "user", "content": "What is the capital of Illinois?"}]
# )
# print(f"Test response: {response.choices[0].message.content[:100]}...\n")

# # Configure DSPy
# print("Configuring DSPy...")
# dspy_model = dspy.LM("mlx2/my-model")
# dspy.configure(lm=dspy_model)

# print("Setup complete!")

  from .autonotebook import tqdm as notebook_tqdm


Setting up MLX with LiteLLM...
Loading MLX model: mlx-community/Qwen2.5-14B-Instruct-4bit


Fetching 10 files: 100%|██████████| 10/10 [00:00<00:00, 109798.53it/s]


Model loaded successfully

Testing basic completion...
Test response: The sentiment is true....

Configuring DSPy...
Setup complete!


In [2]:
import litellm
from litellm import CustomLLM
from typing import List, Dict, Any, Optional
from mlx_lm import load, generate
import time
import dspy
import asyncio

class MlxLLM(CustomLLM):
    def __init__(self, model_name: str, **kwargs):
        super().__init__()
        self.model_name = model_name.replace("mlx/", "", 1)  # Fix 1
        self.model, self.tokenizer = load(self.model_name)
        self.default_params = {
            'max_tokens': 512,
            'temperature': 0.7,
            **kwargs
        }

    def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> litellm.ModelResponse:
        try:
            params = {**self.default_params, **kwargs}
            prompt = self.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )
            
            input_ids = self.tokenizer.encode(prompt)  # Fix 3
            
            full_response = generate(
                self.model,
                self.tokenizer,
                prompt=prompt,
                temperature=params['temperature'],  # Fix 2
                max_tokens=params['max_tokens']
            )
            
            output_ids = self.tokenizer.encode(full_response)
            completion_ids = output_ids[len(input_ids):]
            completion_text = self.tokenizer.decode(completion_ids)
            
            usage = {
                "prompt_tokens": len(input_ids),
                "completion_tokens": len(completion_ids),
                "total_tokens": len(input_ids) + len(completion_ids),
            }
            
            return litellm.ModelResponse(
                id=f"mx-{int(time.time())}",
                model=self.model_name,
                choices=[{
                    "message": {"role": "assistant", "content": completion_text}
                }],
                usage=litellm.Usage(**usage)
            )
            
        except Exception as e:
            raise litellm.CustomError(
                status_code=500, 
                message=f"MLX Error: {str(e)}"
            )

    async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> litellm.ModelResponse:
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.completion, model, messages, **kwargs)  # Fix 4

# 1. Register custom provider correctly
litellm.custom_provider_map = [{"mlx": MlxLLM}]  # Fix 5

# 2. Configure DSPy (use valid MLX model path)
dspy.configure(
    lm=dspy.LM(
        "mlx/mistral-7b-instruct-mlx",  # Actual MLX model path
        temperature=0.7,
        max_tokens=1024,
    )
)

# 3. Create DSPy module (unchanged)
class QuantumQA(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate = dspy.ChainOfThought("question -> answer")
    
    def forward(self, question):
        return self.generate(question=question)

# 4. Execute pipeline
try:
    qa_pipeline = QuantumQA()
    response = qa_pipeline(question="Explain quantum superposition")
    print(response.answer)
except Exception as e:
    print(f"Error: {str(e)}")


[92m15:59:24 - LiteLLM:ERROR[0m: utils.py:750 - litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup
Traceback (most recent call last):
  File "/Users/dkhundley/Documents/Repositories/dspy-tutorial/.venv/lib/python3.12/site-packages/litellm/utils.py", line 482, in function_setup
    custom_llm_setup()
  File "/Users/dkhundley/Documents/Repositories/dspy-tutorial/.venv/lib/python3.12/site-packages/litellm/utils.py", line 325, in custom_llm_setup
    if custom_llm["provider"] not in litellm.provider_list:
       ~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'provider'
[92m15:59:24 - LiteLLM:ERROR[0m: utils.py:750 - litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup
Traceback (most recent call last):
  File "/Users/dkhundley/Documents/Repositories/dspy-tutorial/.venv/lib/python3.12/site-packages/dspy/adapters/chat_adapter.py", line 42, in __call__
    return super().__call__(lm, lm_kwargs, signature, demos, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^

Error: 'provider'


In [None]:
# Creating sample sentences representing positive and negative sentiment
positive_sentence = "I am very happy with the results of this project."
negative_sentence = "I am disappointed with the outcome of this task."

# Instantiating a simple DSPy module for sentiment classification
dspy_sentiment_classification = dspy.Predict('sentence -> sentiment: bool')

# Invoking the DSPy model with each respective sentence.
print(f'Positive sentence: {dspy_sentiment_classification(sentence = positive_sentence)}')
print(f'Negative sentence: {dspy_sentiment_classification(sentence = negative_sentence)}')

# try:
#     # Invoking the DSPy model with each respective sentence.
#     print(f'Positive sentence: {dspy_sentiment_classification(sentence = positive_sentence)}')
#     print(f'Negative sentence: {dspy_sentiment_classification(sentence = negative_sentence)}')

#     del dspy_mlx_model

# except Exception as e:
#     del dspy_mlx_model
#     print(f'Error: {e}')
    

In [None]:
# from mlx_lm import load, generate

# model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")

# prompt = "What is the capital of Illinois?"

# messages = [{"role": "user", "content": prompt}]
# prompt = tokenizer.apply_chat_template(
#     messages, add_generation_prompt=True
# )

# text = generate(model, tokenizer, prompt=prompt, verbose=True)

# del model, tokenizer