In [None]:
import redis
import pickle
import time
import pandas as pd
import os
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

# ---------------------- 配置参数 ----------------------
REDIS_CONFIG_PATH = "redis.conf"
TASK_QUEUE = "function_calls"
DAILY_K_PATH = r"D:\workspace\xiaoyao\data\stock_daily_price.parquet"

# 优化的队列控制参数
MAX_QUEUE_SIZE = 20000      # 队列最大缓存任务数
BATCH_SIZE = 2000           # 每批发送任务数
SEND_INTERVAL = 0.1         # 批次发送间隔
CHECK_INTERVAL_WHEN_FULL = 2 # 队列满时的检查间隔
METADATA_BATCH_SIZE = 500   # 元数据批量写入大小
MAX_THREADS = 4             # 元数据写入线程数

# ---------------------- 工具函数 ----------------------
def load_redis_config(config_path):
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Redis配置文件不存在：{config_path}")
    
    host = "localhost"
    port = 6379
    password = ""
    
    with open(config_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            if line.startswith('host='):
                host = line.split('=', 1)[1].strip()
            elif line.startswith('port='):
                try:
                    port = int(line.split('=', 1)[1].strip())
                except ValueError:
                    print(f"警告：port配置格式错误，使用默认值{port}")
            elif line.startswith('password='):
                password = line.split('=', 1)[1].strip()
    
    return {
        "host": host,
        "port": port,
        "password": password,
        "decode_responses": False,
        "socket_timeout": 30,
        "socket_keepalive": True
    }

def load_valid_daily_data(parquet_path):
    if not os.path.exists(parquet_path):
        raise FileNotFoundError(f"日K线文件不存在：{parquet_path}")
    
    print(f"开始加载日K线数据：{parquet_path}")
    df = pd.read_parquet(
        parquet_path,
        columns=['date', 'stock_code', 'paused']
    )
    
    valid_df = df[df['paused'] == 0].copy()
    valid_df = valid_df[valid_df['date'] >= '2025-01-01']
    valid_df['date'] = valid_df['date'].dt.strftime('%Y%m%d')
    # 获取20250101 ~ 20250130的valid_df
    valid_df = valid_df[valid_df['date'] >= '20250101']
    valid_df = valid_df[valid_df['date'] <= '20250131']

    valid_tasks = valid_df.groupby(['date', 'stock_code']).size().reset_index()
    valid_tasks = valid_tasks[['date', 'stock_code']]
    
    print(f"日K线数据加载完成，有效任务数：{len(valid_tasks)}")
    return valid_tasks

# ---------------------- 任务发送类 ----------------------
class OptimizedTaskSender:
    def __init__(self, redis_config):
        self.redis = redis.Redis(** redis_config)
        self._test_connection()
        self.redis_metadata_key = "task_metadata"
        
        # 初始化线程池用于并行写入元数据
        self.metadata_pool = ThreadPoolExecutor(max_workers=MAX_THREADS)
        self.metadata_futures = []
        
        # 检测Redis客户端版本，确定使用hmset还是hset
        self.use_hmset = self._check_redis_version()

    def _test_connection(self):
        try:
            self.redis.ping()
            print("✅ Redis连接成功")
        except Exception as e:
            print(f"❌ Redis连接失败：{e}")
            raise SystemExit(1)

    def _check_redis_version(self):
        """检查Redis客户端版本，决定使用hmset还是hset"""
        try:
            import redis.version
            # redis-py 3.0+ 推荐使用hset(mapping=)，但为兼容旧版本保留hmset选项
            return getattr(redis.version, 'VERSION', (0,0,0)) < (3,0,0)
        except:
            return True  # 无法检测版本时默认使用hmset

    def _get_queue_length(self):
        try:
            return self.redis.llen(TASK_QUEUE)
        except Exception as e:
            print(f"获取队列长度失败：{e}")
            return MAX_QUEUE_SIZE

    def _batch_write_metadata(self, metadata_dict):
        """修复：批量写入元数据，兼容新旧版本Redis客户端"""
        try:
            if not metadata_dict:
                return 0
            
            # 根据版本选择合适的批量写入方法
            if self.use_hmset:
                # 旧版本：使用hmset
                self.redis.hmset(self.redis_metadata_key, metadata_dict)
            else:
                # 新版本：使用hset + mapping参数
                self.redis.hset(self.redis_metadata_key, mapping=metadata_dict)
            
            return len(metadata_dict)
        except Exception as e:
            print(f"❌ 元数据批量写入失败：{str(e)}")
            print(f"  - 待写入数量：{len(metadata_dict)}")
            # 只打印第一个键值对作为示例，避免信息过多
            if metadata_dict:
                first_key = next(iter(metadata_dict.keys()))
                print(f"  - 示例键：{first_key}, 值长度：{len(metadata_dict[first_key])}")
            return 0

    def send_tasks(self, valid_tasks):
        total_tasks = len(valid_tasks)
        sent_count = 0
        print(f"开始发送任务，总有效任务数：{total_tasks}")
        
        # 元数据缓冲区
        metadata_buffer = {}

        while sent_count < total_tasks:
            current_len = self._get_queue_length()
            if current_len >= MAX_QUEUE_SIZE:
                print(f"⚠️ 队列已满（当前{current_len}/{MAX_QUEUE_SIZE}），暂停{CHECK_INTERVAL_WHEN_FULL}秒...")
                time.sleep(CHECK_INTERVAL_WHEN_FULL)
                continue

            # 计算本次可发送数量
            remaining = total_tasks - sent_count
            available = MAX_QUEUE_SIZE - current_len
            batch_count = min(remaining, available, BATCH_SIZE)

            # 提取批次任务
            batch = valid_tasks.iloc[sent_count:sent_count + batch_count]
            task_bytes_list = []
            
            for idx in range(len(batch)):
                task_id = f"task_{sent_count + idx}"
                trade_date = batch.iloc[idx]['date']
                stock_code = batch.iloc[idx]['stock_code']
                
                # 构建任务
                task = {
                    "func_name": "fetch_minute_stock_data",
                    "args": (trade_date, [stock_code]),
                    "kwargs": {},
                    "task_id": task_id
                }
                task_bytes_list.append(pickle.dumps(task))
                
                # 暂存元数据到缓冲区
                metadata_buffer[task_id] = pickle.dumps((trade_date, stock_code))
                
                # 当缓冲区达到阈值时，提交到线程池批量写入
                if len(metadata_buffer) >= METADATA_BATCH_SIZE:
                    self.metadata_futures.append(
                        self.metadata_pool.submit(self._batch_write_metadata, metadata_buffer.copy())
                    )
                    metadata_buffer.clear()

            # 批量发送任务
            try:
                self.redis.rpush(TASK_QUEUE, *task_bytes_list)
                sent_count += batch_count
                
                # 进度打印：每10000任务或最后一批打印一次
                if sent_count % 10000 == 0 or sent_count == total_tasks:
                    progress = (sent_count / total_tasks) * 100
                    print(f"📤 已发送 {sent_count}/{total_tasks} ({progress:.2f}%)，队列当前长度：{self._get_queue_length()}")
                
                if sent_count < total_tasks:
                    time.sleep(SEND_INTERVAL)

            except Exception as e:
                print(f"❌ 批量发送失败：{e}，将重试当前批次")
                time.sleep(2)

        # 处理剩余的元数据
        if metadata_buffer:
            self.metadata_futures.append(
                self.metadata_pool.submit(self._batch_write_metadata, metadata_buffer)
            )
        
        # 等待所有元数据写入完成
        print("等待剩余元数据写入...")
        for future in self.metadata_futures:
            future.result()
        
        # 关闭线程池
        self.metadata_pool.shutdown()
        
        print("✅ 所有任务发送完成")

# ---------------------- 主函数 ----------------------
if __name__ == "__main__":
    try:
        redis_config = load_redis_config(REDIS_CONFIG_PATH)
        print(f"已加载Redis配置：host={redis_config['host']}, port={redis_config['port']}")
        
        valid_tasks = load_valid_daily_data(DAILY_K_PATH)
        
        sender = OptimizedTaskSender(redis_config)
        sender.send_tasks(valid_tasks)
        
    except Exception as e:
        print(f"程序执行失败：{e}")
        raise SystemExit(1)
    

已加载Redis配置：host=220.203.1.124, port=6379
开始加载日K线数据：D:\workspace\xiaoyao\data\stock_daily_price.parquet
日K线数据加载完成，有效任务数：949049
✅ Redis连接成功
开始发送任务，总有效任务数：949049


  self.redis.hmset(self.redis_metadata_key, metadata_dict)


📤 已发送 10000/949049 (1.05%)，队列当前长度：7186
📤 已发送 20000/949049 (2.11%)，队列当前长度：14160


KeyboardInterrupt: 