In [None]:
from data_provider.data_loader import DataModule
from exp.exp_model import Model
from utils.exp_logger import Logger
from utils.exp_metrics_plotter import MetricsPlotter
from run_train import get_experiment_name
from utils.utils import set_settings
# Experiment Settings, logger, plotter
from utils.exp_config import get_config
config = get_config('FinancialConfig')
config.multi_dataset = True
set_settings(config)
log_filename, exper_detail = get_experiment_name(config)
plotter = MetricsPlotter(log_filename, config)
log = Logger(log_filename, exper_detail, plotter, config)
datamodule = DataModule(config)
model = Model(config)

In [None]:
import torch
import torch.nn as nn

# 假设输入
bs, seq_len, channels, dim = 16, 48, 33, 64
x_enc = torch.randn(bs, seq_len, channels, dim)

# 定义 attention 层（无 batch_first 参数）
attn_channel = nn.MultiheadAttention(embed_dim=dim, num_heads=8)  # expects (seq_len, batch, dim)
attn_time = nn.MultiheadAttention(embed_dim=dim, num_heads=8)

# ===== 1. 跨通道 attention =====
# 原始 x_enc: (bs, 48, 33, 64)
# 调整为 (33, bs*48, 64)
x_enc_reshaped = x_enc.permute(2, 0, 1, 3).reshape(channels, bs * seq_len, dim)

# 注意力：通道之间的 self-attention
x_channel_attn, _ = attn_channel(x_enc_reshaped, x_enc_reshaped, x_enc_reshaped)  # (33, bs*48, 64)

# 还原为 (bs, 48, 33, 64)
x_channel_attn = x_channel_attn.reshape(channels, bs, seq_len, dim).permute(1, 2, 0, 3)

# ===== 2. 跨时间 attention =====
# 调整为 (48, bs*33, 64)
x_time_input = x_channel_attn.permute(1, 0, 2, 3).reshape(seq_len, bs * channels, dim)

# 注意力：时间步之间的 self-attention
x_time_attn, _ = attn_time(x_time_input, x_time_input, x_time_input)  # (48, bs*33, 64)

# 还原为 (bs, 48, 33, 64)
x_time_attn = x_time_attn.reshape(seq_len, bs, channels, dim).permute(1, 0, 2, 3)

# 最终输出
print(x_time_attn.shape)  # torch.Size([16, 48, 33, 64])


In [None]:
from modules.backbone import Backbone
from run_train import *

from utils.exp_config import get_config
config = get_config()
# datamodule = DataModule(config)
# model = Model(datamodule, config)
model = Backbone(3, config)

In [None]:
bs, seq_len, channels, dim = 1, 48, 33, 3
random_inputs = torch.rand(bs, seq_len, channels, dim)
y = model(random_inputs, None, None)
# [1, 48, 32, 3]

In [None]:
bs, seq_len, channels, dim = 1, 48, 1, 3
random_inputs = torch.rand(bs, seq_len, channels, dim)
y = model(random_inputs, None, None)
# [1, 48, 32, 3]

In [None]:
bs, seq_len, channels, dim = 1, 48, 16, 3
random_inputs = torch.rand(bs, seq_len, channels, dim)
y = model(random_inputs, None, None)
# [1, 48, 32, 3]

In [None]:
import torch

x = torch.arange(2*3*4*10).reshape(2, 3, 4, 10)
patch_len = 4
stride = 2

x_unfolded = x.unfold(dimension=-1, size=patch_len, step=stride)
print(x_unfolded.shape)

In [None]:
import os

# 读取所有文件名
all_files = os.listdir('results/financial/20250701/log')

# 提取每个文件中 _Multi_ 与 .md 之间的数字
existing_ids = set()
for filename in all_files:
    try:
        num = int(filename.split('_Multi_')[1].split('.md')[0])
        existing_ids.add(num)
    except (IndexError, ValueError):
        continue

# 检查 1-130 中缺失的编号
missing_ids = [i for i in range(0, 150) if i not in existing_ids]

print("缺失的编号：", missing_ids)

In [None]:
import pickle 
from collections import Counter
data = pickle.load(open('./datasets/func_code_to_label_150.pkl', 'rb'))
# 提取组号列
group_ids = data[:, 1]

# 统计每个组号的基金数量
counts = Counter(group_ids)

# 打印结果
for group_id, count in sorted(counts.items()):
    print(f"组号 {group_id} 中有 {count} 个基金")

In [1]:
import os 
import pickle
all_code = os.listdir('./datasets/financial/S20200713_E20250628')
all_code_len = []
for code in all_code:
    if code.endswith('.pkl'):
        with open(os.path.join('./datasets/financial/S20200713_E20250628', code), 'rb') as f:
            data = pickle.load(f)
            all_code_len.append(len(data))

FileNotFoundError: [Errno 2] No such file or directory: './datasets/financial/S20200713_E20250628'

In [None]:
import numpy as np

# 假设你已经有 all_code_len
all_code_len = np.array(all_code_len)

print(f"📊 总文件数: {len(all_code_len)}")
print(f"📈 最大长度: {np.max(all_code_len)}")
print(f"📉 最小长度: {np.min(all_code_len)}")
print(f"📏 平均长度: {np.mean(all_code_len):.2f}")
print(f"📐 中位数: {np.median(all_code_len)}")
print(f"🔹 5%分位数: {np.percentile(all_code_len, 5)}")
print(f"🔹 6%分位数: {np.percentile(all_code_len, 6)}")
print(f"🔹 10%分位数: {np.percentile(all_code_len, 10)}")
print(f"🔹 25%分位数: {np.percentile(all_code_len, 25)}")
print(f"🔸 75%分位数: {np.percentile(all_code_len, 75)}")

In [None]:
import pandas as pd
from sqlalchemy import create_engine, text
import pickle
# 数据库配置
with open('./datasets/sql_token.pkl', 'rb') as f:
    DB_URI = pickle.load(f)
engine = create_engine(DB_URI)

def query_fund_data(fund, start_date, end_date):
    """查询数据库中某支基金的净值数据
        SELECT fund_code, date, nav, accnav, adj_nav
    """
    sql = text("""
        SELECT fund_code, date, accnav, adj_nav, nav
        FROM b_fund_nav_details_new
        WHERE fund_code IN :codes
          AND date BETWEEN :start AND :end
        ORDER BY date
    """)
    try:
        df = pd.read_sql_query(
            sql.bindparams(codes=tuple(fund), start=start_date, end=end_date),
            engine
        )
        fund_dict = {code: df_group.reset_index(drop=True)
                     for code, df_group in df.groupby("fund_code")}
        return fund_dict
    except Exception as e:
        print(f"[{fund}] 数据库查询失败: {str(e)}")
        return pd.DataFrame()
df = query_fund_data(['000001', '000003'], '2020-01-01', '2025-01-01')
    

In [54]:
import numpy as np 
with open('./datasets/func_code_to_label_150.pkl', 'rb') as f:
    data = pickle.load(f)
data = data[:, 0]
df = query_fund_data(data, '2020-01-01', '2025-01-01')

In [57]:
min_value, max_value = 1e9, -1e9
for fund_code, value in df.items():
    min_value = min(min_value, value['nav'].min())
    max_value = max(max_value, value['nav'].max())
print(f"最小值: {min_value}, 最大值: {max_value}")

最小值: 0.0871, 最大值: 141.426


In [53]:
df

'970135'

In [None]:
import numpy as np

def constrain_nav_prediction(predictions, bar=0.05, scale=0.9):
    """
    检测单位净值预测中是否存在超过bar的相邻涨跌幅，
    如果是，则整条基金的净值序列按相对首日值重新缩放（温和调整）

    参数：
    - predictions: np.ndarray [7, 64]，表示64支基金7天的预测单位净值
    - bar: float，单位净值日涨跌幅上限（如0.05表示5%）
    - scale: float，检测异常后，使用的趋势缩放系数（如0.9）

    返回：
    - adjusted: np.ndarray [7, 64]，处理后的单位净值预测
    - mask: np.ndarray [64]，表示哪些基金被缩放（True为缩放）
    """
    adjusted = predictions.copy()
    mask = np.zeros(predictions.shape[1], dtype=bool)
    for fund_idx in range(predictions.shape[1]):
        nav_series = predictions[:, fund_idx]
        # 计算相邻涨跌幅
        returns = nav_series[1:] / nav_series[:-1] - 1
        if np.any(np.abs(returns) > bar):
            # 以首日为锚点，重构温和曲线
            # 	•	以首日值为锚点，计算整个序列相对于首日的累计变化幅度；
	        #   •	然后将这些累计变化幅度乘以 scale（比如0.9），形成温和版本；
	        #   •	最后用 base * (1 + 相对变化 * 缩放因子) 得到缩放后的单位净值曲线；
	        #   •	更新 adjusted 和 mask。
            base = nav_series[0]
            relative_change = (nav_series - base) / base
            softened = base * (1 + relative_change * scale)
            adjusted[:, fund_idx] = softened
            mask[fund_idx] = True
    return adjusted, mask

# 模拟单位净值预测（中间人为插入一个异常）
np.random.seed(0)
preds = np.cumprod(1 + np.random.normal(0, 0.01, (7, 1)), axis=0)
preds[:, 0] *= [1, 1, 1.2, 1.5, 1.7, 10.0, 2.5]  # 第6支基金异常暴涨

In [25]:
preds.reshape(-1)

array([ 1.01764052,  1.02171269,  1.23805509,  1.58224823,  1.82670398,
       10.64030593,  2.68534956])

In [None]:
adjusted, flagged = constrain_nav_prediction(preds, bar=1, scale=0.5)
print(f"被缩放的基金编号：{np.where(flagged)[0]}")

被缩放的基金编号：[0]


In [27]:
adjusted.reshape(-1)

array([1.01764052, 1.0196766 , 1.12784781, 1.29994438, 1.42217225,
       5.82897323, 1.85149504])

In [8]:
import os
start_date: str = '2020-07-13'
end_date: str = '2025-06-28'
dir_name = 'S' + (start_date + '_E' + end_date).replace('-', '')
all_address = os.listdir(f'./datasets/financial/{dir_name}')
all_code_list = [item.split('.')[0] for item in all_address]
len(all_code_list)

11674

In [7]:
import pickle 
with open('./datasets/all_code_list.pkl', 'wb') as f:
    pickle.dump(all_code_list, f)

In [9]:
140 / 4

35.0