统计每条数据的token，找到最大token，计算每条数据平均token

### 统计

In [None]:
# 安装依赖包
!pip install tiktoken

In [None]:
import json
import tiktoken
from typing import List, Tuple, Dict, Optional

def load_json_data(file_path: str) -> List[dict]:
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def calculate_field_tokens(
    data: List[dict],
    target_fields: Optional[List[str]] = None,
    role_filter: Optional[List[str]] = None
) -> Tuple[List[Dict[str, int]], List[dict]]:
    enc = tiktoken.get_encoding("o200k_base")
    results = []

    for idx, entry in enumerate(data):
        field_stats = {}
        total_tokens = 0

        # 处理目标字段中的普通字段（如uid、Name等）
        for field_path in (target_fields or []):
            values = _get_nested_values(entry, field_path)
            for value in values:
                if isinstance(value, (dict, list)):
                    text = json.dumps(value, ensure_ascii=False)
                else:
                    text = str(value)
                tokens = enc.encode(text)
                field_stats[field_path] = field_stats.get(field_path, 0) + len(tokens)
                total_tokens += len(tokens)

        # 显式处理messages数组中的角色内容
        if "messages" in entry:
            for msg in entry["messages"]:
                role = msg.get("role")
                # 应用角色过滤
                if role_filter and role not in role_filter:
                    continue
                content = msg.get("content", "")
                tokens_content = enc.encode(content)
                # 构造按角色分类的字段名
                role_content_field = f"messages.{role}.content"
                field_stats[role_content_field] = field_stats.get(role_content_field, 0) + len(tokens_content)
                total_tokens += len(tokens_content)

        # 记录结果
        results.append({
            "index": idx,
            "total_tokens": total_tokens,
            "field_breakdown": field_stats
        })
        # print(f"条目 {idx} 的token统计：{json.dumps(field_stats, ensure_ascii=False)}")

    return results, data

def _get_nested_values(obj: dict, path: str) -> list:
    parts = path.split('.')
    results = [obj]
    for part in parts:
        new_results = []
        for r in results:
            if part == '*':
                if isinstance(r, list):
                    new_results.extend(r)
            elif isinstance(r, dict):
                if part in r:
                    new_results.append(r[part])
            elif isinstance(r, list):
                for item in r:
                    if isinstance(item, dict) and part in item:
                        new_results.append(item[part])
        results = new_results
    return [item for item in results if item is not None]

def analyze_results(results: List[Dict[str, int]], data: List[dict]):
    # 汇总统计
    summary = {
        "total": {
            "max": max(r["total_tokens"] for r in results),
            "min": min(r["total_tokens"] for r in results),
            "avg": sum(r["total_tokens"] for r in results) / len(results)
        },
        "fields": {}
    }

    # 字段级统计
    field_stats = {}
    for r in results:
        for field, count in r["field_breakdown"].items():
            if field not in field_stats:
                field_stats[field] = []
            field_stats[field].append(count)

    for field, counts in field_stats.items():
        summary["fields"][field] = {
            "max": max(counts),
            "min": min(counts),
            "avg": sum(counts) / len(counts)
        }

    # 输出结果
    print("\n全局统计：")
    print(f"总token数 | 最大: {summary['total']['max']} 最小: {summary['total']['min']} 平均: {summary['total']['avg']:.1f}")

    print("\n字段级统计：")
    for field, stats in summary["fields"].items():
        print(f"{field.ljust(20)} | 最大: {stats['max']} 最小: {stats['min']} 平均: {stats['avg']:.1f}")

    # 输出最大条目详情
    max_entry = max(results, key=lambda x: x["total_tokens"])
    print(f"\n最大token条目（索引 {max_entry['index']}，共 {max_entry['total_tokens']} tokens）")
    # print(json.dumps(data[max_entry["index"]], indent=4, ensure_ascii=False))

def main():
    FILE_PATH = ''
    TARGET_FIELDS = ["uid", "Object_ID", "Name"]  # 不再包含messages路径
    # TARGET_FIELDS = []
    ROLE_FILTER = ["system", "user", "assistant"]

    data = load_json_data(FILE_PATH)
    results, data = calculate_field_tokens(
        data,
        target_fields=TARGET_FIELDS,
        role_filter=ROLE_FILTER
    )
    analyze_results(results, data)

if __name__ == "__main__":
    main()

某一条数据的token长度超过给定的值，则删掉这条数据

In [None]:
import json
import tiktoken
from typing import List, Tuple, Dict, Optional

def load_json_data(file_path: str) -> List[dict]:
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json_data(data: List[dict], file_path: str):
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def filter_entries_by_token_limit(
    data: List[dict],
    role_filter: Optional[List[str]] = None,
    token_limit: int = 500
) -> Tuple[List[dict], List[dict]]:
    """
    过滤数据条目：删除指定角色 content 超过 token 限制的条目
    返回 (保留的条目列表, 被删除的条目列表)
    """
    enc = tiktoken.get_encoding("o200k_base")
    kept_data = []
    removed_data = []

    for entry in data:
        should_remove = False

        # 检查消息内容
        if "messages" in entry:
            for msg in entry["messages"]:
                role = msg.get("role")
                content = msg.get("content", "")
                
                # 只处理指定角色
                if role_filter and role not in role_filter:
                    continue
                
                # 计算 token
                tokens = enc.encode(content)
                if len(tokens) > token_limit:
                    should_remove = True
                    break  # 发现超限条目，立即终止检查

        # 分类存储结果
        if should_remove:
            removed_data.append(entry)
        else:
            kept_data.append(entry)

    return kept_data, removed_data

def calculate_field_tokens(
    data: List[dict],
    target_fields: Optional[List[str]] = None
) -> List[Dict[str, int]]:
    # 计算字段 token 数量
    enc = tiktoken.get_encoding("o200k_base")
    results = []

    for idx, entry in enumerate(data):
        field_stats = {}
        total_tokens = 0

        # 处理普通字段
        for field_path in (target_fields or []):
            values = _get_nested_values(entry, field_path)
            for value in values:
                if isinstance(value, (dict, list)):
                    text = json.dumps(value, ensure_ascii=False)
                else:
                    text = str(value)
                tokens = enc.encode(text)
                field_stats[field_path] = field_stats.get(field_path, 0) + len(tokens)
                total_tokens += len(tokens)

        # 处理消息内容
        if "messages" in entry:
            for msg in entry["messages"]:
                content = msg.get("content", "")
                tokens = enc.encode(content)
                role = msg.get("role")
                role_field = f"messages.{role}.content"
                field_stats[role_field] = field_stats.get(role_field, 0) + len(tokens)
                total_tokens += len(tokens)

        results.append({
            "index": idx,
            "total_tokens": total_tokens,
            "field_breakdown": field_stats
        })

    return results

def _get_nested_values(obj: dict, path: str) -> list:
    """
    使用点号语法获取嵌套字段值
    支持通配符 * 表示遍历数组
    """
    parts = path.split('.')
    results = [obj]
    
    for part in parts:
        new_results = []
        for r in results:
            if part == '*':
                if isinstance(r, list):
                    new_results.extend(r)
            elif isinstance(r, dict):
                if part in r:
                    new_results.append(r[part])
            elif isinstance(r, list):
                for item in r:
                    if isinstance(item, dict) and part in item:
                        new_results.append(item[part])
        results = new_results
    
    return [item for item in results if item is not None]

def analyze_results(results: List[Dict[str, int]], data: List[dict]):
    """分析并输出详细统计结果"""
    # 汇总统计
    summary = {
        "total": {
            "max": max(r["total_tokens"] for r in results),
            "min": min(r["total_tokens"] for r in results),
            "avg": sum(r["total_tokens"] for r in results) / len(results)
        },
        "fields": {}
    }

    # 字段级统计
    field_stats = {}
    for r in results:
        for field, count in r["field_breakdown"].items():
            if field not in field_stats:
                field_stats[field] = []
            field_stats[field].append(count)

    for field, counts in field_stats.items():
        summary["fields"][field] = {
            "max": max(counts),
            "min": min(counts),
            "avg": sum(counts) / len(counts)
        }

    # 输出结果
    print("\n全局统计：")
    print(f"总token数 | 最大: {summary['total']['max']} 最小: {summary['total']['min']} 平均: {summary['total']['avg']:.1f}")

    print("\n字段级统计：")
    for field, stats in summary["fields"].items():
        print(f"{field.ljust(20)} | 最大: {stats['max']} 最小: {stats['min']} 平均: {stats['avg']:.1f}")

    # 输出最大条目详情
    if data:
        max_entry = max(results, key=lambda x: x["total_tokens"])
        print(f"\n最大token条目（索引 {max_entry['index']}，共 {max_entry['total_tokens']} tokens）")
        # print(json.dumps(data[max_entry["index"]], indent=4, ensure_ascii=False))

def main():
    # 配置参数
    INPUT_FILE = ''
    OUTPUT_FILE = ''
    TARGET_FIELDS = ["uid", "Object_ID", "Name"]
    ROLE_FILTER = ["assistant"]
    TOKEN_LIMIT = 2000

    # 处理流程
    # 1. 加载原始数据
    original_data = load_json_data(INPUT_FILE)
    
    # 2. 执行数据过滤
    filtered_data, removed_data = filter_entries_by_token_limit(
        original_data,
        role_filter=ROLE_FILTER,
        token_limit=TOKEN_LIMIT
    )
    
    # 3. 保存过滤后的数据
    save_json_data(filtered_data, OUTPUT_FILE)
    
    # 4. 输出过滤结果
    print(f"\n=== 过滤结果 ===")
    print(f"原始条目数: {len(original_data)}")
    print(f"保留条目数: {len(filtered_data)}")
    print(f"删除条目数: {len(removed_data)}")
    print(f"已保存过滤后数据到: {OUTPUT_FILE}")

    # 5. 分析过滤后数据
    if filtered_data:
        print("\n=== 过滤后数据统计 ===")
        results = calculate_field_tokens(filtered_data, TARGET_FIELDS)
        analyze_results(results, filtered_data)
    else:
        print("\n警告：过滤后数据为空！")

if __name__ == "__main__":
    main()