In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.fft as fft

# 引入用户提供的组件 (假设在同一路径或已定义)
# from resample_emb import resample_patchemb 

# ==========================================
# 1. 核心工具函数 (基于用户提供的代码)
# ==========================================

def resample_patchemb(old: torch.Tensor, new_patch_len: int):
    """
    LightGTS Flex-Resize 核心逻辑: 利用伪逆动态调整权重矩阵形状
    Input: old [d_model, anchor_len] (注意: nn.Linear.weight 是转置存储的)
    Output: new [d_model, new_patch_len]
    """
    assert old.dim() == 2, "输入张量应为2D (d_model, patch_size)"
    if old.size(1) == new_patch_len:
        return old

    # 这里的逻辑主要是转置处理，适应 SVD/Pinv 的维度要求
    old_T = old.T # [anchor_len, d_model]
    old_shape = old_T.size(0)
    factor = new_patch_len / old_shape
    
    # 定义辅助函数：批量resize (用于构造变换矩阵)
    def resize_fn(x_tensor, new_shape):
        # [1, 1, L] -> interpolate -> [1, 1, New_L]
        return F.interpolate(x_tensor.unsqueeze(0).unsqueeze(0), size=new_shape, mode='linear', align_corners=False).squeeze(0).squeeze(0)

    # 构造缩放矩阵 A
    basis_vectors = torch.eye(old_shape, dtype=torch.float32, device=old.device) # [L, L]
    # Resize 每一行基向量
    # 这里为了效率，其实可以预计算，但在动态周期下需实时计算
    # [L, New_L]
    resize_mat_list = [resize_fn(basis_vectors[i], new_patch_len) for i in range(old_shape)]
    resize_mat = torch.stack(resize_mat_list) 
    
    # 计算伪逆: theta' = delta^-1 * (A)^+ * theta
    # resize_mat corresponds to A.T in the paper's notation xA
    # We need to map weight from P to P'. 
    # paper: theta' = delta^-1 * (A)^+ * theta
    
    resize_mat_pinv = torch.linalg.pinv(resize_mat.T) # pseudo inverse
    
    # 应用变换
    resampled_kernels = resize_mat_pinv @ old_T * math.sqrt(factor)

    return resampled_kernels.T # 转回 [d_model, new_patch_len]

def ACF_for_Period_Per_Channel(x, k=2, min_period=4):
    """
    用户提供的 ACF 周期提取函数
    x: [B, T, C] -> 返回 [C, k] (每个通道的 Top-k 周期)
    """
    B, T, C = x.shape
    x_mean = x.mean(dim=1, keepdim=True)
    x_centered = x - x_mean

    # FFT 计算自相关
    n_fft = 1 << (2 * T - 1).bit_length()
    xf = torch.fft.rfft(x_centered, n=n_fft, dim=1)
    power_spectrum = xf * torch.conj(xf)
    acf = torch.fft.irfft(power_spectrum, n=n_fft, dim=1)
    acf = acf[:, :T, :]
    
    # 聚合 Batch 维度
    avg_acf = acf.mean(dim=0) # [T, C]
    
    # 掩码处理
    avg_acf[:min_period, :] = -float('inf')
    
    prev_lag = torch.roll(avg_acf, 1, dims=0)
    next_lag = torch.roll(avg_acf, -1, dims=0)
    prev_lag[0, :] = -float('inf')
    next_lag[-1, :] = -float('inf')
    
    is_peak = (avg_acf > prev_lag) & (avg_acf > next_lag) & (avg_acf > 0)
    masked_acf = avg_acf.clone()
    masked_acf[~is_peak] = -float('inf')
    
    # 提取 Top-K
    top_vals, top_inds = torch.topk(masked_acf, k, dim=0)
    
    return top_inds.t() # [C, k]

# ==========================================
# 2. 核心模块: Attentive Aggregator
# ==========================================

class AttentiveAggregator(nn.Module):
    """
    替代简单的“同步平均”。
    利用 Attention 机制，根据 Query (Last Patch) 动态聚合所有历史周期的 Patch。
    """
    def __init__(self, d_model, n_heads=4, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True, dropout=dropout)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, patches):
        """
        patches: [Batch, Num_Patches, d_model]
        return: [Batch, d_model] (聚合后的单个 Token)
        """
        # Query: 取最后一个 Patch (最近的时间窗口)，代表当前的相位/模式
        # [Batch, 1, d_model]
        query = patches[:, -1:, :] 
        
        # Key/Value: 所有 Patch
        # [Batch, Num_Patches, d_model]
        key_value = patches
        
        # Attention
        attn_out, _ = self.attn(query, key_value, key_value)
        
        # Residual + Norm
        out = self.norm(query + self.dropout(attn_out))
        
        return out.squeeze(1)

# ==========================================
# 3. 主架构: PAEmbedding
# ==========================================

class PAEmbedding(nn.Module):
    def __init__(self, c_in, seq_len, d_model, anchor_period=48, dropout=0.1):
        """
        Period-Adaptive Embedding Layer
        
        Args:
            c_in: 变量数量 (Channels)
            seq_len: 输入序列长度 (Lookback window)
            d_model: 隐层维度
            anchor_period: LightGTS Flex-Resize 的参考锚点长度 (Default: 48)
        """
        super().__init__()
        self.c_in = c_in
        self.seq_len = seq_len
        self.d_model = d_model
        self.anchor_period = anchor_period
        
        # --- 1. Anchor Weights (参数库) ---
        # 这是一个形状为 [d_model, anchor_period] 的基准权重
        # 所有的 Period Projection 都会从这个权重动态变形而来，实现参数共享
        self.anchor_weight = nn.Parameter(torch.randn(d_model, anchor_period))
        self.anchor_bias = nn.Parameter(torch.zeros(d_model))
        nn.init.xavier_uniform_(self.anchor_weight)

        # --- 2. Aggregators (去噪聚合) ---
        # Stream 1: 主周期聚合器
        self.agg_p1 = AttentiveAggregator(d_model, dropout=dropout)
        # Stream 2: 次周期聚合器
        self.agg_p2 = AttentiveAggregator(d_model, dropout=dropout)
        
        # --- 3. Stream 3 (Global Trend) ---
        # 全局趋势不需要聚合，直接是一个长切片 Flex-Resize 到 d_model
        # 但我们仍然使用 Anchor 机制来生成这个巨大的投影层
        
        # --- 4. Stream Fusion (融合层) ---
        # 融合 [Stream1, Stream2, Stream3] -> Final Token
        # 输入维度: 3 * d_model -> 输出: d_model
        self.fusion_layer = nn.Sequential(
            nn.Linear(3 * d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model)
        )
        
        # 缓存 (可选，用于加速)
        self._weight_cache = {} 

    def _get_flex_weight(self, target_len):
        """获取适配 target_len 的权重，带简单的运行时缓存"""
        if target_len in self._weight_cache:
            return self._weight_cache[target_len]
        
        # 调用 LightGTS 的核心 Flex-Resize
        # 注意: anchor_weight 是 [d_model, anchor_len]
        new_w = resample_patchemb(self.anchor_weight, target_len)
        self._weight_cache[target_len] = new_w
        return new_w

    def forward(self, x):
        """
        x: [Batch, Seq_Len, Channels]
        Output: [Batch, Channels, d_model] -> 符合 iTransformer 的输入格式
        """
        B, T, C = x.shape
        
        # Step 1: 发现周期 (Per-Channel ACF)
        # periods: [C, 2] -> 每个变量的两个显著周期
        with torch.no_grad():
            periods = ACF_for_Period_Per_Channel(x.permute(0, 2, 1), k=2) # permute to [B, C, T] for ACF func
            # 限制周期范围，防止过小或过大导致数值不稳定
            periods = torch.clamp(periods, min=4, max=T//2)
        
        # Step 2: 变量分组处理 (为了利用 Batch 计算，虽然每个变量周期不同)
        # 在工程实现上，为了避免 Loop N 次 (太慢)，我们将变量按计算出的 Period 进行分组
        # 但为了演示清晰，这里先展示逻辑循环，实际部署建议使用 Scatter/Gather 或 CUDA Kernel
        
        # 初始化三个流的输出容器
        stream_p1_out = torch.zeros(B, C, self.d_model, device=x.device)
        stream_p2_out = torch.zeros(B, C, self.d_model, device=x.device)
        stream_trend_out = torch.zeros(B, C, self.d_model, device=x.device)
        
        # 清空权重缓存 (Dynamic graph issues)
        self._weight_cache = {}

        # -------------------------------------------------------
        # Stream 3: Global Trend (全局趋势流)
        # -------------------------------------------------------
        # 处理整个序列 T。Flex-Resize: Anchor(48) -> T
        # 这是一个巨大的低通滤波器
        w_trend = self._get_flex_weight(T) # [d_model, T]
        # x: [B, T, C] -> permute -> [B, C, T]
        # Linear: x @ w.T -> [B, C, d_model]
        stream_trend_out = F.linear(x.permute(0, 2, 1), w_trend, self.anchor_bias)

        # -------------------------------------------------------
        # Stream 1 & 2: Period Streams (周期流)
        # -------------------------------------------------------
        # 这里为了演示逻辑，我们对 Unique Period 进行循环 (比对 C 循环快得多)
        # 因为现实中很多变量可能共享相同的周期 (如 24, 96)
        
        unique_periods = torch.unique(periods)
        
        for p in unique_periods:
            p = p.item()
            # 找到哪些 Channel 拥有这个周期 (作为 P1 或 P2)
            # mask_p1: [C]
            mask_p1 = (periods[:, 0] == p)
            mask_p2 = (periods[:, 1] == p)
            
            if not (mask_p1.any() or mask_p2.any()):
                continue
                
            # 获取适配该周期的权重
            w_p = self._get_flex_weight(p) # [d_model, p]
            
            # 准备数据切片
            # Unfold: [B, T, C] -> [B, Num_Patches, C, P]
            # 仅选取需要的 Channel 以节省计算
            # 这里简单起见，先对所有 Channel unfold，再 mask
            # Stride = P (非重叠切片，LightGTS 逻辑)
            patches = x.unfold(dimension=1, size=p, step=p) 
            # patches: [B, Num_Patches, C, P]
            
            # 投影 (Tokenization)
            # [B, Num_Patches, C, P] @ [P, d_model] -> [B, Num_Patches, C, d_model]
            # 这一步将物理数据映射到了语义空间
            tokens = patches @ w_p.T 
            
            # --- Stream 1 处理 ---
            if mask_p1.any():
                # 选出对应 Channel: [B, Num_Patches, N_subset, d_model]
                tokens_subset = tokens[:, :, mask_p1, :]
                # 变换维度以适应 Aggregator: [B * N_subset, Num_Patches, d_model]
                b_sz, num_p, n_sub, d_m = tokens_subset.shape
                tokens_flat = tokens_subset.permute(0, 2, 1, 3).reshape(-1, num_p, d_m)
                
                # 聚合 (Attention Aggregation)
                agg_tokens = self.agg_p1(tokens_flat) # [B*N_subset, d_model]
                agg_tokens = agg_tokens.reshape(b_sz, n_sub, d_m)
                
                # 填回结果
                stream_p1_out[:, mask_p1, :] = agg_tokens

            # --- Stream 2 处理 ---
            if mask_p2.any():
                tokens_subset = tokens[:, :, mask_p2, :]
                b_sz, num_p, n_sub, d_m = tokens_subset.shape
                tokens_flat = tokens_subset.permute(0, 2, 1, 3).reshape(-1, num_p, d_m)
                
                agg_tokens = self.agg_p2(tokens_flat)
                agg_tokens = agg_tokens.reshape(b_sz, n_sub, d_m)
                
                stream_p2_out[:, mask_p2, :] = agg_tokens

        # -------------------------------------------------------
        # Fusion (融合)
        # -------------------------------------------------------
        # 此时我们有三个 [B, C, d_model] 的张量
        
        # 拼接
        combined = torch.cat([stream_p1_out, stream_p2_out, stream_trend_out], dim=-1) # [B, C, 3*d_model]
        
        # 融合投影
        variate_token = self.fusion_layer(combined) # [B, C, d_model]
        
        # 输出形状符合 iTransformer 的 Encoder 输入要求
        # [Batch, Variates, d_model]
        return variate_token

# ==========================================
# 4. 示例调用 (Example Usage)
# ==========================================
if __name__ == "__main__":
    # 模拟数据: Batch=32, Lookback=96, Channels=7
    x = torch.randn(32, 96, 7)
    
    # 实例化 PAEmbedding
    # 使用 48 作为 Anchor，意味着模型学习的是一个长度为 48 的“标准周期模式”
    pa_emb = PAEmbedding(c_in=7, seq_len=96, d_model=128, anchor_period=48)
    
    # 前向传播
    out = pa_emb(x)
    
    print("Input shape:", x.shape)
    print("Output shape:", out.shape) # Should be [32, 7, 128]
    print("PA-Embedding successfully processed multi-scale inputs into unified tokens.")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (48x96 and 48x128)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

# ==========================================
# 1. 基础工具组件
# ==========================================

def resample_patchemb(old_weight: torch.Tensor, new_patch_len: int):
    """
    Flex-Resize: 调整权重矩阵形状 [d_model, old_len] -> [d_model, new_len]
    """
    assert old_weight.dim() == 2
    if old_weight.size(1) == new_patch_len:
        return old_weight

    old = old_weight.T # [old_len, d_model]
    old_len = old.size(0)
    factor = new_patch_len / old_len
    
    # 构造变换矩阵 (Interpolation Matrix)
    basis_vectors = torch.eye(old_len, dtype=torch.float32, device=old.device)
    # [1, 1, old, old] -> [1, 1, new, old] -> [new, old]
    resize_mat = F.interpolate(
        basis_vectors.unsqueeze(0).unsqueeze(0), 
        size=(new_patch_len, old_len), 
        mode='bilinear', 
        align_corners=False
    ).squeeze(0).squeeze(0)
    
    # 计算伪逆: theta' = (A)^+ * theta * scale
    resize_mat_pinv = torch.linalg.pinv(resize_mat.T) # [new, old]
    resampled = resize_mat_pinv @ old * math.sqrt(factor)

    return resampled.T # [d_model, new_len]

def ACF_for_Period_Per_Channel(x, k=2, min_period=4):
    """
    计算每个通道的 Top-K 周期
    x: [B, T, C] -> Returns: [C, k]
    """
    B, T, C = x.shape
    x_mean = x.mean(dim=1, keepdim=True)
    x_centered = x - x_mean

    n_fft = 1 << (2 * T - 1).bit_length()
    xf = torch.fft.rfft(x_centered, n=n_fft, dim=1)
    power_spectrum = xf * torch.conj(xf)
    acf = torch.fft.irfft(power_spectrum, n=n_fft, dim=1)
    acf = acf[:, :T, :]
    avg_acf = acf.mean(dim=0) # [T, C]
    
    avg_acf[:min_period, :] = -float('inf')
    
    prev = torch.roll(avg_acf, 1, dims=0); prev[0, :] = -float('inf')
    next_ = torch.roll(avg_acf, -1, dims=0); next_[-1, :] = -float('inf')
    is_peak = (avg_acf > prev) & (avg_acf > next_) & (avg_acf > 0)
    
    masked_acf = avg_acf.clone()
    masked_acf[~is_peak] = -float('inf')
    
    _, top_inds = torch.topk(masked_acf, k, dim=0)
    return top_inds.t()

# ==========================================
# 2. 门控融合模块 (Gated Fusion)
# ==========================================

class GatedFusion(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # 输入: P1, P2, Trend (拼接后 3*d_model)
        # 输出: 3个权重系数
        self.gate_net = nn.Sequential(
            nn.Linear(3 * d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, 3),
            nn.Softmax(dim=-1)
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, p1, p2, trend):
        """
        p1, p2, trend: [B, C, d_model]
        """
        # 拼接特征用于计算门控权重
        concat_feat = torch.cat([p1, p2, trend], dim=-1) # [B, C, 3*d]
        weights = self.gate_net(concat_feat) # [B, C, 3]
        
        # 拆分权重: w1, w2, w3 均为 [B, C, 1]
        w1, w2, w3 = weights.unbind(dim=-1)
        w1, w2, w3 = w1.unsqueeze(-1), w2.unsqueeze(-1), w3.unsqueeze(-1)
        
        # 加权融合
        out = w1 * p1 + w2 * p2 + w3 * trend
        return self.norm(out)

# ==========================================
# 3. 主架构: PAEmbedding (Robust Ver.)
# ==========================================

class PAEmbedding(nn.Module):
    def __init__(self, c_in, seq_len, d_model, anchor_periods=[12, 24, 48, 96], dropout=0.1):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        self.anchor_periods = sorted(anchor_periods)
        
        # --- Stream 3: Trend (保底机制) ---
        # 【修改点1】完全独立的 Linear 层，不使用 Flex-Resize
        # 这是 iTransformer 的原版 Embedding 方式，保留全序列信息
        self.raw_projection = nn.Linear(seq_len, d_model)
        
        # --- Stream 1 & 2: Period (周期特征) ---
        # 【修改点3】Bias 解耦，每个 Anchor 有独立的 Bias
        self.anchor_weights = nn.ParameterDict()
        self.anchor_biases = nn.ParameterDict()
        
        for p in self.anchor_periods:
            w = nn.Parameter(torch.empty(d_model, p))
            b = nn.Parameter(torch.zeros(d_model))
            nn.init.xavier_uniform_(w)
            self.anchor_weights[str(p)] = w
            self.anchor_biases[str(p)] = b
            
        # --- Fusion ---
        # 【修改点4】使用 Gated Fusion
        self.fusion = GatedFusion(d_model)
        
        # Cache
        self._weight_cache = {}

    def _get_flex_params(self, target_len):
        """获取适配 target_len 的权重和偏置"""
        if target_len in self._weight_cache:
            return self._weight_cache[target_len]
        
        # 1. 找最近锚点
        # 优化：增加对齐容忍度，避免微小差异导致的 resize
        # 例如 target=23, anchor=24，resize 影响很小
        source_p = min(self.anchor_periods, key=lambda x: abs(x - target_len))
        
        source_w = self.anchor_weights[str(source_p)]
        source_b = self.anchor_biases[str(source_p)] # Bias 直接复用，不 resize
        
        # 2. Resize 权重
        new_w = resample_patchemb(source_w, target_len)
        
        self._weight_cache[target_len] = (new_w, source_b)
        return new_w, source_b

    def forward(self, x):
        """
        x: [Batch, Seq_Len, Channels] (B, T, C)
        Returns: [Batch, Channels, d_model]
        """
        B, T, C = x.shape
        self._weight_cache = {} # Clear cache
        
        # ---------------------------------------------------
        # Stream 3: Trend Extraction (基准流)
        # ---------------------------------------------------
        # x.permute(0, 2, 1) -> [B, C, T]
        # raw_projection -> [B, C, d_model]
        # 这一步保证了即便 periodic stream 失效，模型表现也不会低于 iTransformer
        stream_trend = self.raw_projection(x.permute(0, 2, 1))
        
        # ---------------------------------------------------
        # Stream 1 & 2: Periodic Extraction (增强流)
        # ---------------------------------------------------
        # 初始化为 0，若没检测到周期则不贡献
        stream_p1 = torch.zeros(B, C, self.d_model, device=x.device)
        stream_p2 = torch.zeros(B, C, self.d_model, device=x.device)
        
        # 1. 计算周期
        with torch.no_grad():
            periods = ACF_for_Period_Per_Channel(x, k=2) 
            periods = torch.clamp(periods, min=4, max=T//2) # 限制范围
            
        unique_periods = torch.unique(periods)
        
        for p in unique_periods:
            p_val = p.item()
            mask_p1 = (periods[:, 0] == p_val)
            mask_p2 = (periods[:, 1] == p_val)
            
            if not (mask_p1.any() or mask_p2.any()): continue
            
            # 获取动态参数
            w_p, b_p = self._get_flex_params(p_val)
            
            # 切片与投影
            combined_mask = mask_p1 | mask_p2
            x_subset = x[:, :, combined_mask].permute(0, 2, 1) # [B, N_sub, T]
            
            # Unfold: [B, N_sub, Num_Patches, P]
            patches = x_subset.unfold(dimension=2, size=p_val, step=p_val)
            
            # Tokenization: [B, N_sub, Num_Patches, d_model]
            tokens = F.linear(patches, w_p, b_p)
            
            # 【修改点2】 Aggregation: 改用 Mean Pooling
            # tokens.mean(dim=2) -> [B, N_sub, d_model]
            # 相比 Attention，Mean Pooling 无需参数，训练更稳定
            # 虽然牺牲了动态性，但配合 Gated Fusion 可以由 Trend 流弥补
            agg_tokens = tokens.mean(dim=2)
            
            # 分发结果
            if mask_p1.any():
                # 寻找 mask_p1 在 combined_mask 中的索引
                subset_indices = torch.where(combined_mask)[0]
                p1_global_indices = torch.where(mask_p1)[0]
                # isin 得到的是 bool mask，长度等于 subset 大小
                rel_mask = torch.isin(subset_indices, p1_global_indices)
                
                stream_p1[:, mask_p1, :] = agg_tokens[:, rel_mask, :]
                
            if mask_p2.any():
                subset_indices = torch.where(combined_mask)[0]
                p2_global_indices = torch.where(mask_p2)[0]
                rel_mask = torch.isin(subset_indices, p2_global_indices)
                
                stream_p2[:, mask_p2, :] = agg_tokens[:, rel_mask, :]

        # ---------------------------------------------------
        # Final Fusion
        # ---------------------------------------------------
        # 动态加权融合
        variate_token = self.fusion(stream_p1, stream_p2, stream_trend)
        
        return variate_token

# Test
if __name__ == "__main__":
    model = PAEmbedding(c_in=7, seq_len=96, d_model=128)
    x = torch.randn(32, 96, 7)
    out = model(x)
    print("Output:", out.shape) # [32, 7, 128]

Output: torch.Size([32, 7, 128])
