In [1]:
from datasets import Dataset
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig
from swanlab.integration.transformers import SwanLabCallback
import swanlab

In [2]:
# 将JSON文件转换为CSV文件
train_df = pd.read_json('./dataset/input.json')[5:]
train_ds = Dataset.from_pandas(train_df)
test_df = pd.read_json('./dataset/input.json')[:2]
test_ds = Dataset.from_pandas(test_df)

In [3]:
def process_func(example):
    MAX_LENGTH = 8000    # Llama分词器会将一个中文字切分为多个token，因此需要放开一些最大长度，保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<|im_start|>user\n{example['instruction'] }<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

tokenizer = AutoTokenizer.from_pretrained('./Qwen/Qwen2.5-Coder-0.5B-Instruct', use_fast=False, trust_remote_code=True)
tokenized_id = train_ds.map(process_func, remove_columns=train_ds.column_names)
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

def predict(messages, model, tokenizer):
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
    generated_ids = [
        output_ids[len(input_ids) :]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return response

Map:   0%|          | 0/503 [00:00<?, ? examples/s]

# lora

In [4]:
from peft import LoraConfig, TaskType, get_peft_model

model = AutoModelForCausalLM.from_pretrained('./Qwen/Qwen2.5-Coder-0.5B-Instruct/', device_map="auto",torch_dtype=torch.bfloat16).train()
model.enable_input_require_grads() # 开启梯度检查点时，要执行该方法

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False, # 训练模式
    r=4, # Lora 秩
    lora_alpha=32, # Lora alaph，具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()

trainable params: 2,199,552 || all params: 496,232,320 || trainable%: 0.4433


# 配置训练参数

In [5]:
args = TrainingArguments(
    output_dir="./output/Qwen2.5-Coder-0.5B-Instruct-Lora-SwanLab",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=10,
    logging_steps=2, # 每多少个步骤记录一次训练日志
    num_train_epochs=4,
    save_steps=10, # 每多少个步骤保存一次模型检查点
    learning_rate=1e-3,
    save_on_each_node=False, # 如果使用分布式训练，每个节点都会保存自己的模型检查点。对于单机多卡训练，通常不需要设置为 True。
    gradient_checkpointing=True,
    report_to="none")

class CodeSwanLabCallback(SwanLabCallback):   
    # def on_train_begin(self, args, state, control, model=None, **kwargs):
    #     if not self._initialized:
    #         self.setup(args, state, model, **kwargs)
            
    #     print("训练开始")
    #     print("未开始微调，先取3条主观评测：")
    #     test_text_list = []
    #     for index, row in test_df[:3].iterrows():
    #         instruction = row["instruction"]
    #         # input_value = row["input"]

    #         messages = [
    #             # {"role": "system", "content": f"{instruction}"},
    #             {"role": "user", "content": f"{instruction}"},
    #         ]

    #         response = predict(messages, peft_model, tokenizer)
    #         messages.append({"role": "assistant", "content": f"{response}"})
                
    #         result_text = f"【Q】{messages[1]['content']}\n【LLM】{messages[2]['content']}\n"
    #         print(result_text)
            
    #         test_text_list.append(swanlab.Text(result_text, caption=response))

    #     swanlab.log({"Prediction": test_text_list}, step=0)
    
    def on_epoch_end(self, args, state, control, **kwargs):
        # ===================测试阶段======================
        test_text_list = []
        for index, row in test_df.iterrows():
            instruction = row["instruction"]
            # input_value = row["input"]
            ground_truth = row["output"]

            messages = [
                # {"role": "system", "content": f"{instruction}"},
                {"role": "user", "content": f"{instruction}"},
            ]

            response = predict(messages, peft_model, tokenizer)
            messages.append({"role": "assistant", "content": f"{response}"})
            
            if index == 0:
                print("epoch", round(state.epoch), "主观评测：")
                
            result_text = f"【Q】{messages[0]['content']}\n【LLM】{messages[1]['content']}\n【GT】 {ground_truth}"
            print(result_text)
            
            test_text_list.append(swanlab.Text(result_text, caption=response))

        swanlab.log({"Prediction": test_text_list}, step=round(state.epoch))
        
        
swanlab_callback = CodeSwanLabCallback(
    project="Qwen2.5-Coder-0.5B-Instruct-LoRA",
    experiment_name="0.5b",
    config={
        "lora_rank": 4,
        "lora_alpha": 32,
        "lora_dropout": 0.1,
    },
)


In [6]:
trainer = Trainer(
    model=peft_model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

trainer.train()

swanlab.finish()

[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.4.3                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/Users/liuchengzhuo/fine-tuneing /fine-tune-Qwen-coder/swanlog/run-20250120_115325-a3b1799d[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mLiuchengzhuo[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33m0.5b[0m to the cloud
[1m[34mswanlab[0m[0m: 🌟 Run `[1mswanlab watch /Users/liuchengzhuo/fine-tuneing /fine-tune-Qwen-coder/swanlog[0m` to view SwanLab Experiment Dashboard locally
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@Liuchengzhuo/Qwen2.5-Coder-0.5B-Instruct-LoRA[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@Liuchengzhuo/Qwen2.5-Coder-0.5B-Instruct-LoRA/runs/d77a6lrv4a231a24h6d6i[0m[0m


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
2,17.0677
4,14.8402
6,13.1118
8,12.6617
10,13.6087
12,12.069
14,11.915
16,13.9715
18,11.4193
20,11.2848


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


epoch 1 主观评测：
【Q】实现一个基于AJX.PureComponent的纯组件基类，提供基础组件功能。定义FOXPureComponent类，继承自AJX.PureComponent，实现基础的组件构造和渲染功能。该组件作为其他业务组件的基类，用于公共交通相关页面组件的开发
【LLM】import * as AJX from '@framework/FOXRax.js';
import FOXPureComponent from '@framework/FOXPureComponent.jsx';

export default class AGroupBase extends FOXPureComponent {...}
【GT】 import * as AJX from 'FOXRax.js';

export default class FOXPureComponent extends AJX.PureComponent {
    constructor(spec) {
        super(spec);
    }

    render() {
        return [];
    }
}
【Q】公交详情列表容器组件，负责渲染和管理公交路线详情列表，包括实时信息、碳排放数据等内容展示。管理公交详情列表的核心容器组件，处理列表渲染、滚动定位、实时数据更新等功能。集成了紧急情况提示、碳排放信息展示、列表滚动定位等特性，支持自定义公交、火车等多种交通方式的展示。
【LLM】export default class BusDetailListContainer extends FOXPureComponent {...}
【GT】 export default class BusDetailListContainer extends FOXPureComponent {
    constructor(props) {
        super(props);
        this.iconAnimation = [];
        this.listNode = this.listNode.bind(this);
    }
    componentWillMount() {...}
    render() {...}
 



epoch 2 主观评测：
【Q】实现一个基于AJX.PureComponent的纯组件基类，提供基础组件功能。定义FOXPureComponent类，继承自AJX.PureComponent，实现基础的组件构造和渲染功能。该组件作为其他业务组件的基类，用于公共交通相关页面组件的开发
【LLM】import * as AJX from '@framework/FOXRax.js';
import FOXPureComponent from '@framework/FOXPureComponent.jsx';

export default class FOXPureComponent extends FOXPureComponent {
    constructor(spec) {
        super(spec);
    }

    render() {
        return [];
    }
}
【GT】 import * as AJX from 'FOXRax.js';

export default class FOXPureComponent extends AJX.PureComponent {
    constructor(spec) {
        super(spec);
    }

    render() {
        return [];
    }
}
【Q】公交详情列表容器组件，负责渲染和管理公交路线详情列表，包括实时信息、碳排放数据等内容展示。管理公交详情列表的核心容器组件，处理列表渲染、滚动定位、实时数据更新等功能。集成了紧急情况提示、碳排放信息展示、列表滚动定位等特性，支持自定义公交、火车等多种交通方式的展示。
【LLM】class BusDetailListItemContainer extends FOXPureComponent {
    constructor(spec) {
        super(spec);
        this.emergencyTips = BusCloudUtil.getEmergencyTips();
        this.headerHeight = BusCloudUtil.getBusDetailHeaderHeight();
      



epoch 3 主观评测：
【Q】实现一个基于AJX.PureComponent的纯组件基类，提供基础组件功能。定义FOXPureComponent类，继承自AJX.PureComponent，实现基础的组件构造和渲染功能。该组件作为其他业务组件的基类，用于公共交通相关页面组件的开发
【LLM】import * as AJX from 'FOXRax.js';

export default class FOXPureComponent extends AJX.PureComponent {
    constructor(spec) {
        super(spec);
    }

    render() {
        return [];
    }
}
【GT】 import * as AJX from 'FOXRax.js';

export default class FOXPureComponent extends AJX.PureComponent {
    constructor(spec) {
        super(spec);
    }

    render() {
        return [];
    }
}
【Q】公交详情列表容器组件，负责渲染和管理公交路线详情列表，包括实时信息、碳排放数据等内容展示。管理公交详情列表的核心容器组件，处理列表渲染、滚动定位、实时数据更新等功能。集成了紧急情况提示、碳排放信息展示、列表滚动定位等特性，支持自定义公交、火车等多种交通方式的展示。
【LLM】export default class BusDetailListContainer extends FOXPureComponent {
    constructor(spec) {
        super(spec);
        this.isBusDetailListOpen = false;
        this.state = {
            isShowCarbonText: false,
            isShowEmergencyText: false,
            listHeight: 0,
            listWidth: 0,
     



epoch 4 主观评测：
【Q】实现一个基于AJX.PureComponent的纯组件基类，提供基础组件功能。定义FOXPureComponent类，继承自AJX.PureComponent，实现基础的组件构造和渲染功能。该组件作为其他业务组件的基类，用于公共交通相关页面组件的开发
【LLM】import * as AJX from 'FOXRax.js';

export default class FOXPureComponent extends AJX.PureComponent {
    constructor(spec) {
        super(spec);
    }

    render() {
        return [];
    }
}
【GT】 import * as AJX from 'FOXRax.js';

export default class FOXPureComponent extends AJX.PureComponent {
    constructor(spec) {
        super(spec);
    }

    render() {
        return [];
    }
}
【Q】公交详情列表容器组件，负责渲染和管理公交路线详情列表，包括实时信息、碳排放数据等内容展示。管理公交详情列表的核心容器组件，处理列表渲染、滚动定位、实时数据更新等功能。集成了紧急情况提示、碳排放信息展示、列表滚动定位等特性，支持自定义公交、火车等多种交通方式的展示。
【LLM】class BusDetailListContainer extends FOXPureComponent {
    constructor(spec) {
        super(spec);
        this.isBusPathEmergency = BusCloudUtil.getBusDetailListShowBus();
        this.isBusPathEmergencyDialogOpen = false;
        this.isCarbonSummaryShow = BusCloudUtil.getBusDetailListCarbonSummaryShow();
     

# 测试模型效果

In [7]:
# 推理部分
def generate_response(is_fineturn, prompt):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import PeftModel
    import torch


    mode_path = 'Qwen/Qwen2.5-Coder-0.5B-Instruct/'
    adapter_path = 'output/Qwen2.5-Coder-0.5B-Instruct-Lora-SwanLab/checkpoint-200' # 这里改称你的 lora 输出对应 checkpoint 地址

    # 加载tokenizer
    tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)

    # 加载基础模型
    model = AutoModelForCausalLM.from_pretrained(
        mode_path, 
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    ).eval()
    if is_fineturn:
        # 加载adapter权重
        model = PeftModel.from_pretrained(model, adapter_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "mps")
    # 生成测试
    inputs = tokenizer.apply_chat_template(
        [
        # {"role":"system","content":"你是一个专业的代码专家，熟悉AJX语言，熟悉高德地图前端开发，熟悉公共交通业务，能够根据高德地图前端开发需求，开发代码。"},
        {"role": "user", "content": prompt}],
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True
    ).to(device)

    gen_config = {"max_length": 4096, "do_sample": True, "top_k": 1, "temperature": 0.1}
    
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_config)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [8]:
# 测试生成
test_prompt = "在高德地图前端公共交通业务中，如何计算和记录用户在页面上的停留时间，并支持暂停/恢复计时功能？"
# response1 = generate_response(False, test_prompt)
# print("原模型回答：\n",response1)
response2 = generate_response(True,test_prompt)
print("微调后模型回答:\n",response2)

微调后模型回答:
 export default class TripLogUtil {
    static setPageStayTime(pageId, pageName, action, location) {
        if (!pageId || !pageName || !action) {
            return;
        }
        const logParams = {
            page_id: parseInt(pageId, 10),
            page_name: pageName,
            action: action,
            location: JSON.stringify(location),
        };
        ajx.log.print(`TripLogUtil >> setPageStayTime >> logParams: ${JSON.stringify(logParams)}`);
        natives.amap_trip.setPageStayTime(JSON.stringify(logParams), (res) => {
            ajx.log.print(`TripLogUtil >> setPageStayTime >> res: ${JSON.stringify(res)}`);
        });
    }

    static resumePageStayTime(pageId, pageName, action) {
        if (!pageId || !pageName || !action) {
            return;
        }
        ajx.log.debug(
            `TripLogUtil >> resumePageStayTime >> logParams: ${JSON.stringify({
                page_id: parseInt(pageId, 10),
                page_name: pageName,
         