In [None]:
%pip install mlx-lm numpy transformers torch

In [6]:
model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

In [None]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
from mlx_lm import load, generate
import json
import numpy as np
from typing import List, Dict, Any

# 1. 加载模型和tokenizer
model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model, tokenizer = load(model_path)

# 将模型转移到MPS设备
mx.set_default_device(mx.gpu)

# 2. 定义function calling相关的函数描述
functions = [
    {
        "name": "get_weather",
        "description": "Get the weather for a specific location and date",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {"type": "string", "description": "The city name"},
                "date": {"type": "string", "description": "The date in YYYY-MM-DD format"}
            },
            "required": ["location"]
        }
    },
    {
        "name": "send_email",
        "description": "Send an email to a recipient",
        "parameters": {
            "type": "object",
            "properties": {
                "recipient": {"type": "string", "description": "The email address of the recipient"},
                "subject": {"type": "string", "description": "The subject of the email"},
                "body": {"type": "string", "description": "The content of the email"}
            },
            "required": ["recipient", "subject"]
        }
    }
]

# 3. 准备训练数据
def prepare_training_data():
    examples = [
        {
            "messages": [
                {"role": "user", "content": "What's the weather in Beijing tomorrow?"},
                {"role": "assistant", "content": None, "function_call": {
                    "name": "get_weather",
                    "arguments": json.dumps({"location": "Beijing", "date": "2023-11-21"})
                }}
            ]
        },
        {
            "messages": [
                {"role": "user", "content": "Send an email to John about the meeting"},
                {"role": "assistant", "content": None, "function_call": {
                    "name": "send_email",
                    "arguments": json.dumps({
                        "recipient": "john@example.com",
                        "subject": "Meeting Reminder",
                        "body": "Hi John, just a reminder about our meeting tomorrow."
                    })
                }}
            ]
        }
        # 可以添加更多训练样本...
    ]
    
    # 将数据转换为MLX可用的格式
    input_texts = []
    target_outputs = []
    
    for example in examples:
        messages = example["messages"]
        user_msg = next(m["content"] for m in messages if m["role"] == "user")
        function_call = next(m["function_call"] for m in messages if m["role"] == "assistant")
        
        input_text = f"<|user|>\n{user_msg}\n<|assistant|>\n"
        target_output = json.dumps({
            "function": function_call["name"],
            "arguments": function_call["arguments"]
        })
        
        input_texts.append(input_text)
        target_outputs.append(target_output)
    
    return input_texts, target_outputs

input_texts, target_outputs = prepare_training_data()

# 4. 数据预处理
def preprocess_data(texts: List[str], targets: List[str], tokenizer, max_length=512):
    input_ids = []
    attention_masks = []
    labels = []
    
    for text, target in zip(texts, targets):
        # 编码输入
        encoded_input = tokenizer(
            text,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="np"
        )
        
        # 编码目标
        encoded_target = tokenizer(
            target,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="np"
        )
        
        input_ids.append(encoded_input["input_ids"][0])
        attention_masks.append(encoded_input["attention_mask"][0])
        labels.append(encoded_target["input_ids"][0])
    
    return {
        "input_ids": mx.array(np.stack(input_ids)),
        "attention_mask": mx.array(np.stack(attention_masks)),
        "labels": mx.array(np.stack(labels))
    }

train_data = preprocess_data(input_texts, target_outputs, tokenizer)

# 5. 定义训练循环
def train_model(model, train_data, epochs=3, learning_rate=1e-5):
    optimizer = optim.Adam(learning_rate=learning_rate)
    loss_fn = nn.losses.cross_entropy
    
    # 获取模型的可训练参数
    model.train()
    model.update(tree_unflatten(list(model.parameters().items())))
    
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        
        # 前向传播
        outputs = model(
            input_ids=train_data["input_ids"],
            attention_mask=train_data["attention_mask"],
            labels=train_data["labels"]
        )
        
        # 计算损失
        loss = loss_fn(outputs.logits, train_data["labels"])
        
        # 反向传播
        grad_fn = mx.value_and_grad(model, loss_fn)
        loss, grads = grad_fn(
            model,
            train_data["input_ids"],
            train_data["attention_mask"],
            train_data["labels"]
        )
        
        # 更新参数
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        
        print(f"Loss: {loss.item():.4f}")

# 6. 开始训练
print("Starting training...")
train_model(model, train_data, epochs=3, learning_rate=1e-5)

# 7. 保存微调后的模型
model.save_pretrained("./fine_tuned_deepseek_r1")
tokenizer.save_pretrained("./fine_tuned_deepseek_r1")

# 8. 测试function calling
def test_function_calling(model, tokenizer, prompt):
    # 生成响应
    response = generate(
        model,
        tokenizer,
        prompt=prompt,
        max_tokens=200,
        temp=0.7
    )
    
    # 尝试解析function call
    try:
        # 查找JSON格式的function call
        start_idx = response.find("{")
        end_idx = response.rfind("}") + 1
        json_str = response[start_idx:end_idx]
        function_call = json.loads(json_str)
        
        print("\nDetected function call:")
        print(f"Function: {function_call.get('function')}")
        print(f"Arguments: {function_call.get('arguments')}")
        return function_call
    except (ValueError, AttributeError):
        print("\nNo valid function call detected in response:")
        print(response)
        return None

# 测试示例
print("\nTesting function calling...")
test_prompts = [
    "What's the weather like in Shanghai next Monday?",
    "Send an email to Alice about the project deadline"
]

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    test_function_calling(model, tokenizer, f"<|user|>\n{prompt}\n<|assistant|>\n")