In [1]:
import pickle 
import pandas as pd
from sqlalchemy import create_engine, text
def insert_pred_to_sql():
    # 读取数据库连接字符串
    with open('./datasets/sql_token.pkl', 'rb') as f:
        DB_URI = pickle.load(f)

    # 创建数据库引擎
    engine = create_engine(DB_URI)

    # SQL 查询语句
    sql = text("""
        SELECT fund_code, forecast_date, pre_data, model_version, create_time, update_time
        FROM b_fund_forecast_new
        WHERE fund_code IN :codes
        ORDER BY forecast_date
    """)

    # 执行查询，传入参数（注意 tuple 中只有一个元素时加逗号）
    df = pd.read_sql_query(
        sql.bindparams(codes=tuple(['005626'])),  # 或 codes=('005626',)
        engine
    )
    
    df.to_sql('my_table', engine, if_exists='replace', index=False)
    return True

In [2]:
# import numpy as np 
# for i in range(len(df)):
#     pred_str = df['pre_data'][i]
#     str_start = pred_str.find('[') + 1
#     str_end = pred_str.find(']') 
#     pred = pred_str[str_start:str_end].split(', ')
#     pred = np.array(pred, dtype=np.float32)
#     # print(pred.shape)

In [3]:
from datetime import datetime, timedelta

def get_start_date(end_date: str, window_size: int) -> str:
    """
    给定结束日期和历史窗口长度，返回窗口开始日期（字符串格式）。

    参数：
    - end_date (str): 结束日期，格式 'YYYY-MM-DD'
    - window_size (int): 历史窗口长度（天数）

    返回：
    - start_date (str): 开始日期，格式 'YYYY-MM-DD'
    """
    end_dt = datetime.strptime(end_date, "%Y-%m-%d")
    start_dt = end_dt - timedelta(days=window_size)
    return start_dt.strftime("%Y-%m-%d")

In [5]:
# import pandas as pd
# from datetime import datetime

# def convert_array_to_df(array):
#     """
#     将数组转换为包含标准日期列的 DataFrame
#     """
#     df = pd.DataFrame(array, columns=[
#         'fund_code', 'year', 'month', 'day', 'weekday',
#         'value1', 'value2', 'value3'
#     ])
#     df['year'] = df['year'].astype(int)
#     df['month'] = df['month'].astype(int)
#     df['day'] = df['day'].astype(int)
#     df['date'] = pd.to_datetime(df[['year', 'month', 'day']])
#     return df

# def find_missing_dates_from_array_list(array_list, start_date, end_date, freq='B'):
#     """
#     参数：
#     - array_list: List[np.ndarray]，每个元素是一个基金的历史数据数组
#     - start_date, end_date: 'YYYY-MM-DD'
#     - freq: 'B' 表示工作日（默认）

#     返回：
#     - dict: {fund_code: [缺失的日期列表]}
#     """
#     full_range = pd.date_range(start=start_date, end=end_date, freq=freq)
#     missing_map = {}

#     for arr in array_list:
#         df = convert_array_to_df(arr)
#         if df.empty:
#             continue

#         fund_code = df.iloc[0]['fund_code']
#         date_series = df['date']
#         missing = full_range.difference(date_series)

#         if not missing.empty:
#             missing_map[fund_code] = missing.strftime('%Y-%m-%d').tolist()

#     return missing_map

# missing_map = find_missing_dates_from_array_list(df, start_date, end_date)

# for code, dates in missing_map.items():
#     print(f"基金 {code} 缺失日期：{dates}")

In [104]:
import torch
from exp.exp_model import Model
from data_provider.generate_financial import process_date_columns, query_fund_data
from data_provider.get_financial import get_group_idx
from utils.exp_config import get_config
from sqlalchemy import create_engine, text
import numpy as np
import pandas as pd
import pickle
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError

def get_history_data(get_group_idx, current_date, config):
    all_history_input = []
    for i in range(len(get_group_idx)):
        start_date = get_start_date(current_date, window_size=64)
        df = query_fund_data(get_group_idx[i], start_date, current_date)
        df = process_date_columns(df)
        all_history_input.append(df)
    data = all_history_input
    return data

def check_input(all_history_input):
    data = np.stack(all_history_input, axis=0)
    data = data.transpose(1, 0, 2)
    # 只取符合模型的历史天数
    data = data[-config.seq_len:, :, :]
    return data

def get_pretrained_model(config):
    model = Model(config)
    model.load_state_dict(torch.load('./checkpoints/ours/Model_ours_Dataset_financial_Multi_round_0.pt', weights_only=False))
    return model 

def predict_torch_model(model, history_input, config):
    # 因为我加了时间戳特征
    x = history_input[:, :, -3:]
    # unsqueeze 代表 batch size = 1
    x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0)
    pred_value = model(x, None, None).squeeze(0).detach().numpy()
    # fund_code = history_input[-config.pred_len:, :, 0]
    fund_code = history_input
    # print(history_input.shape, x.shape, pred_value.shape, fund_code.shape)
    # pred_value = np.stack([fund_code, pred_value], axis=0)
    return pred_value

def get_sql_format_data(pred_value, cleaned_input):
    from datetime import datetime
    now_df = []
    for i in range(pred_value.shape[0]):
        for j in range(pred_value.shape[1]):
            idx = np.random.randint(0, 10)  # 生成一个 0 到 9（包含 0，不包含 10）之间的整数
            fund_code = cleaned_input[i][j][0]
            forcast_date = current_date
            pred = '{"pre": [' + ', '.join(f'{item:.6f}' for item in pred_value[:, j]) + ']}'
            model_version = 'v2025'
            create_date = update_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            now_df.append([idx, fund_code, forcast_date, pred, model_version, create_date, update_date])
            # break
        # break
    # now_df
    now_df = np.array(now_df)
    now_df = pd.DataFrame(now_df, columns=['id', 'fund_code', 'forecast_date', 'pre_data', 'model_version',
       'create_time', 'update_time'])
    return now_df

def insert_pred_to_sql(df):
    try:
        # 读取数据库连接字符串
        with open('./datasets/sql_token.pkl', 'rb') as f:
            DB_URI = pickle.load(f)

        # 创建数据库引擎
        engine = create_engine(DB_URI)

        # 插入数据
        df.to_sql(
            name='temp_sql',   # 表名
            con=engine,                   # 数据库连接
            if_exists='append',           # 追加到已有表中
            index=False                   # 不插入索引列
        )
        print("✅ 数据成功写入数据库。")

    except FileNotFoundError:
        print("❌ 无法找到 sql_token.pkl 文件。请检查路径是否正确。")

    except SQLAlchemyError as e:
        print(f"❌ 数据库插入失败: {e}")

    except Exception as e:
        print(f"❌ 发生未知错误: {e}")

# [128, 16, 33, 3])
def start_server(current_date):
    config = get_config('FinancialConfig')
    group_fund_code = get_group_idx(27)
    history_input = get_history_data(group_fund_code, current_date, config)
    cleaned_input = check_input(history_input)
    model = get_pretrained_model(config)
    pred_value = predict_torch_model(model, cleaned_input, config)
    pred_value = get_sql_format_data(pred_value, cleaned_input)
    insert_pred_to_sql(pred_value)
    return pred_value

current_date = '2025-4-15'
pred_value = start_server(current_date)

✅ 数据成功写入数据库。


In [103]:
pred_value

Unnamed: 0,id,fund_code,forecast_date,pre_data,model_version,create_time,update_time
0,3,008783,2025-4-15,"{""pre"": [0.935352, 1.250437, 0.994587, 1.06789...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
1,8,519945,2025-4-15,"{""pre"": [1.297675, 1.479280, 1.112717, 1.51125...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
2,0,007097,2025-4-15,"{""pre"": [1.087883, 1.213997, 0.978471, 1.23382...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
3,9,002994,2025-4-15,"{""pre"": [1.140917, 1.330691, 0.925250, 1.11806...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
4,3,011048,2025-4-15,"{""pre"": [0.943371, 1.152475, 0.989076, 1.09547...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
...,...,...,...,...,...,...,...
675,8,310508,2025-4-15,"{""pre"": [1.449564, 1.702596, 1.290836, 1.57611...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
676,6,013391,2025-4-15,"{""pre"": [1.422609, 1.572178, 1.335387, 1.53521...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
677,7,004168,2025-4-15,"{""pre"": [1.100644, 1.381833, 0.957947, 1.12789...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40
678,3,004059,2025-4-15,"{""pre"": [1.300640, 1.352936, 1.060295, 1.21922...",v2025,2025-06-05 17:13:40,2025-06-05 17:13:40


In [58]:
import pandas as pd
from sqlalchemy import create_engine, text

with open('./datasets/sql_token.pkl', 'rb') as f:
    DB_URI = pickle.load(f)

# 创建数据库引擎
engine = create_engine(DB_URI)

# SQL 查询语句
sql = text("""
    SELECT fund_code, forecast_date, pre_data, model_version, create_time, update_time
    FROM b_fund_forecast_new
    WHERE fund_code IN :codes
    ORDER BY forecast_date
""")

df = pd.read_sql_query(
    sql.bindparams(codes=tuple(['005626'])),  # 或 codes=('005626',)
    engine
)
df

Unnamed: 0,fund_code,forecast_date,pre_data,model_version,create_time,update_time
0,005626,2023-07-01,"{""pre"": [1.4995611906051636, 1.526604056358337...",V3.0,2023-08-03 17:59:16,2023-09-05 14:28:30
1,005626,2023-07-02,"{""pre"": [1.4568397998809814, 1.464197158813476...",V3.0,2023-08-04 17:07:35,2023-09-05 14:28:30
2,005626,2023-07-03,"{""pre"": [1.4621272087097168, 1.466849327087402...",V3.0,2023-08-04 19:23:31,2023-09-05 14:28:30
3,005626,2023-07-04,"{""pre"": [1.4634904861450195, 1.471401810646057...",V3.0,2023-08-07 11:30:57,2023-09-05 14:28:30
4,005626,2023-07-05,"{""pre"": [1.4883460998535156, 1.518860697746276...",V3.0,2023-08-07 17:37:31,2023-09-05 14:28:30
...,...,...,...,...,...,...
468,005626,2025-04-08,"{""pre"": [1.2187557220458984, 1.235185742378234...",V3.1,2025-04-09 00:23:01,2025-04-09 00:23:01
469,005626,2025-04-09,"{""pre"": [1.234163522720337, 1.2369379997253418...",V3.1,2025-04-10 00:23:04,2025-04-10 00:23:04
470,005626,2025-04-10,"{""pre"": [1.2268850803375244, 1.230019569396972...",V3.1,2025-04-11 00:24:03,2025-04-11 00:24:03
471,005626,2025-04-12,"{""pre"": [1.2307374477386477, 1.234757900238037...",V3.1,2025-04-13 00:23:09,2025-04-13 00:23:09


In [63]:
print(df.pre_data[0])

{"pre": [1.4995611906051636, 1.5266040563583374, 1.5181519985198977, 1.488774299621582, 1.4389419555664062, 1.497912883758545, 1.528842210769653, 1.521162509918213, 1.491327881813049, 1.419955492019653, 1.509856343269348, 1.5188844203948977, 1.5192195177078247, 1.50710129737854, 1.4677369594573977, 1.433366775512695, 1.3952490091323853, 1.3760108947753906, 1.3878873586654663, 1.4015618562698364, 1.4881415367126465, 1.4886269569396973, 1.493364691734314, 1.4802333116531372, 1.4442403316497805, 1.400775909423828, 1.3564958572387695, 1.3270682096481323, 1.3393518924713137, 1.3652137517929075, 1.5138434171676636, 1.4872684478759766, 1.461197853088379, 1.442975640296936, 1.4501302242279053, 1.4590097665786743, 1.4291090965270996, 1.3820887804031372, 1.3443725109100342, 1.369544506072998, 1.4468461275100708, 1.495410442352295, 1.555127501487732, 1.5629266500473022, 1.484553337097168, 1.3875621557235718, 1.322195053100586, 1.2787725925445557, 1.2986984252929688, 1.364705204963684, 1.464452505

In [74]:
df.columns

Index(['fund_code', 'forecast_date', 'pre_data', 'model_version',
       'create_time', 'update_time'],
      dtype='object')