In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def resample_patchemb(old_weight: torch.Tensor, new_patch_len: int):
    """
    Flex-Resize 核心算法: 将 Embedding 权重从 Reference Length 重采样到 New Length。
    
    Args:
        old_weight: 原始参考权重 [d_model, reference_len]
        new_patch_len: 目标周期长度 (P_i)
    Returns:
        resampled_weight: 调整后的权重 [d_model, new_patch_len]
    """
    # 维度检查与短路处理
    assert old_weight.dim() == 2, "输入张量应为2D (d_model, patch_size)"
    if old_weight.size(1) == new_patch_len:
        return old_weight

    # 1. 转置为 [reference_len, d_model] 以便矩阵计算
    old = old_weight.T
    old_len = old.size(0)
    
    # 2. 计算缩放因子 delta (对应论文中的 scaling factor)
    # factor = P_new / P_ref
    factor = new_patch_len / old_len
    
    # 定义辅助插值函数: 模拟矩阵 A 的线性变换
    def resize_fn(x_tensor, new_shape):
        # [L, L] -> [1, 1, L, L] -> interpolate -> [L, new_shape]
        return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear', align_corners=False).squeeze(0)

    # 3. 构造变换矩阵 A (通过对单位矩阵进行插值得到)
    # basis_vectors: [old_len, old_len]
    basis_vectors = torch.eye(old_len, dtype=torch.float32, device=old.device)
    
    # resize_mat 即为线性插值矩阵 A 的转置形式? 
    # resize 返回 [old_len, new_len]，这里做了 .T 变为 [new_len, old_len]
    # 这里的实现逻辑是构造从 old 映射到 new 的算子
    resize_mat = resize_fn(basis_vectors, new_patch_len).T 
    
    # 4. 计算伪逆 (Moore-Penrose Pseudoinverse)
    # 对应公式中的 (A)^+
    # pinv([old_len, new_len]) -> [new_len, old_len]
    resize_mat_pinv = torch.linalg.pinv(resize_mat.T)
    
    # 5. 应用变换与缩放
    # new_weight = (A)^+ @ old_weight * sqrt(factor)
    # [new_len, old_len] @ [old_len, d_model] -> [new_len, d_model]
    resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor)

    # 6. 转置回 [d_model, new_len] 以适配 nn.Linear
    return resampled_kernels.T

In [2]:
class FlexResizeProjector(nn.Module):
    def __init__(self, d_model, reference_len=48):
        super().__init__()
        self.d_model = d_model
        self.reference_len = reference_len
        
        # 初始化共享的参考权重 (Reference Weights)
        # 对应源码中的 self.embedding = nn.Linear(self.target_patch_len, d_model)
        # shape: [d_model, reference_len] (注意 nn.Linear 存储权重是转置的)
        self.reference_embedding = nn.Linear(reference_len, d_model)

    def forward(self, x_patches, period):
        """
        Args:
            x_patches: 输入切片 [Batch, Num_Patches, ..., Period]
            period: 当前输入的物理周期长度 (int)
        Returns:
            projected: 投影后的 Token [Batch, Num_Patches, ..., D_model]
        """
        # 1. 获取适应当前周期的权重
        # 如果 period == reference_len，resample_patchemb 会直接返回原权重
        # 注意：这里我们操作的是 weight.data，但在 forward 中为了保持梯度流，
        # 我们应该直接对 weight 变量进行计算（resample_patchemb内部全是可导的torch操作）
        current_weight = resample_patchemb(self.reference_embedding.weight, period)
        
        # 2. 执行线性投影
        # x @ W.T + b
        # 输入 x: [..., Period]
        # 权重 W: [D_model, Period] -> W.T: [Period, D_model]
        # 输出: [..., D_model]
        
        # 使用 F.linear 进行手动投影，因为我们动态生成了 weight
        # bias 可以共享，也可以缩放，LightGTS 源码中仅对 weight 做了 resample
        out = F.linear(x_patches, current_weight, self.reference_embedding.bias)
        
        return out


In [3]:
d_model = 256
ref_len = 48
projector = FlexResizeProjector(d_model, ref_len)

# 场景 1: 输入周期为 24 (短周期)
x_short = torch.randn(32, 10, 24) # [Batch, Patch_Num, Period]
out_short = projector(x_short, period=24)
print(f"Short Period Input (24) -> Output: {out_short.shape}") # Expect: [32, 10, 256]

# 场景 2: 输入周期为 96 (长周期)
x_long = torch.randn(32, 2, 96)
out_long = projector(x_long, period=96)
print(f"Long Period Input (96) -> Output: {out_long.shape}")   # Expect: [32, 2, 256]

Short Period Input (24) -> Output: torch.Size([32, 10, 256])
Long Period Input (96) -> Output: torch.Size([32, 2, 256])
