In [3]:
import lightning as L 
from torch import nn,optim
import torch.nn.functional as F
import torch
import torchmetrics
import random
import glob
import math
from einops import rearrange
import numpy as np

In [4]:
class PositionalEncoding(nn.Module):
    '''
    PatchTST位置编码
    num_patches: 序列长度
    num_input_channels: 输入通道数
    d_model: hidden state维度
    use_cls_token: 是在开头添加一个cls token
    positional_encoding_type: 位置编码类型
    
    input : [B  M  N  d_model]
    output: [B  M  N  d_model] / [B  M  N+1 d_model]
    '''
    def __init__(self,
                seq_len,
                d_model,
                positional_dropout = 0,
                use_cls_token = False,
                num_input_channels = 1,
                positional_encoding_type = 'sincos' # 可选random（随机可学习）/ sincos（固定）
                ):
        super().__init__()
        self.use_cls_token = use_cls_token
        self.num_input_channels = num_input_channels
        
        if use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, d_model))  # cls_token: [1 , num_input_channels , 1 , d_model]
            
        # 位置编码矩阵: [seq_len, d_model]
        self.position_enc = self._init_pe(positional_encoding_type,seq_len,d_model)
        # Positional dropout
        self.positional_dropout = (
            nn.Dropout(positional_dropout) if positional_dropout > 0 else nn.Identity()
        )

    @staticmethod
    def _init_pe(positional_encoding_type, seq_len, d_model):
        if positional_encoding_type == "random":
            position_enc = nn.Parameter(torch.randn(seq_len, d_model), requires_grad=True)
        elif positional_encoding_type == "sincos":
            position_enc = torch.zeros(seq_len,d_model)
            position = torch.arange(0, seq_len).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
            position_enc[:, 0::2] = torch.sin(position * div_term)
            position_enc[:, 1::2] = torch.cos(position * div_term)
            position_enc = position_enc - position_enc.mean()
            position_enc = position_enc / (position_enc.std() * 10)
            position_enc = nn.Parameter(position_enc, requires_grad=False)
        else:
            raise ValueError(
                f"{positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
            )
        return position_enc

    def forward(self, inputs):
        if self.use_cls_token:
            # PE矩阵形状为(N+1,D)
            inputs = self.positional_dropout(inputs + self.position_enc[1:, :])
            cls_token = self.cls_token + self.position_enc[:1, :]
            cls_tokens = cls_token.expand(inputs.shape[0], self.num_input_channels, -1, -1)
            hidden_state = torch.cat((cls_tokens, inputs), dim=2)
        else:
            # PE矩阵形状为(N,D)
            hidden_state = self.positional_dropout(inputs + self.position_enc)
        return hidden_state
    

class TSTPatchEmbed(torch.nn.Module):
    '''
    输入: [B L M]
    (padding)unfold patch + linear embedding + 位置编码(可选)
    输出: B M N d_model ->  [B*M N d_model]
    '''
    def __init__(self,
                seq_len,
                d_model,
                do_padding = 'end', # padding方法，如果为None则不进行padding
                P = None,
                S = None,
                M = 6, # 变量数
                emb_dropout = 0,
                do_pe = True,       # 是否加入位置编码
                pe_type = 'sincos',
                ):
        super().__init__()
        assert do_padding in ['end',None]
        
        self.seq_len = seq_len
        self.P = P
        self.S = S
        
        # 计算patch_number
        patch_num = int((seq_len - P) / S + 1)
        # padding
        self.do_padding = do_padding
        
        if do_padding == "end":
            self.padding_patch_layer = torch.nn.ReplicationPad1d((0, S))
            patch_num += 1
        else:
            # 不进行padding时，需要保证时序排在【最后的一个窗口】正好能形成一组
            new_sequence_length = self.P + self.S * (patch_num - 1)
            self.sequence_start = seq_len - new_sequence_length
        
        # linear embedding layer
        self.emb = torch.nn.Linear(P, d_model)
        
        # 位置编码
        self.do_pe = do_pe
        if self.do_pe:
            self.pe_enc = positional_encoding(pe = pe_type
                                            , q_len = patch_num
                                            , d_model = d_model
                                            )
        self.emb_dropout = nn.Dropout(emb_dropout)
    
    def forward(self,inputs):
        z = rearrange(inputs, 'B L M -> B M L')  # [B L M] -> [B M L]
        
        if self.do_padding=='end': 
            z = self.padding_patch_layer(z)
        else:
            z = z[:, :, self.sequence_start :]  
        
        z = z.unfold(dimension=-1, size=self.P, step=self.S) # z: [bs x M x N x P]
        x_emb = self.emb(z)
        x_emb = rearrange(x_emb,"B M N d_model -> (B M) N d_model")  # 合并B和M的维度 -> 把每个变量视为了一个样本
        
        if self.do_pe:
            x_emb = self.emb_dropout(x_emb + self.pe_enc)   # 加入位置编码,输出维度: (BM , N ,d_model)
        else:
            x_emb = self.emb_dropout(x_emb)
        return x_emb  # [BM, N, d_model]

class MixerNormLayer(nn.Module):
    '''
    Batch / Layer Norm
    输入维度: (B, M, N, D)
    输出维度: (B, M, N, D)
    如果是batchnorm,会先合并B和M的维度,然后把D转到倒数第二个维度，再进行batchnorm
    '''
    def __init__(self, norm_type,d_model):
        super().__init__()

        self.norm_type = norm_type
        
        if "batch" in norm_type.lower():
            self.norm = nn.BatchNorm1d(d_model) # 默认格式为(N,C) 或 (N,C,L)
        else:
            self.norm = nn.LayerNorm(d_model)  # 默认对最后一个维度进行LayerNorm
            
    def forward(self, inputs):
        if "batch" in self.norm_type.lower():
            # 将数据转为N C L 的格式
            B = inputs.shape[0]
            inputs = rearrange(inputs, "B M N D -> (B M) N D")
            inputs = inputs.transpose(1, 2)  # [BM, D, N])
            inputs = self.norm(inputs)
            inputs = inputs.transpose(1, 2)  # [BM, N, D])
            output = rearrange(inputs, "(B M) N D -> B M N D", B=B)
        else:
            output = self.norm(inputs)
        return output

In [5]:
# IC LOSS
class ICLoss(nn.Module):
    def __init__(self, gamma=0):
        super(ICLoss, self).__init__()
        self.gamma = gamma

    def forward(self, y_pred, y_true):
        y_pred_ = torch.mean(y_pred, dim=1).unsqueeze(1)
        y_pred_demean = y_pred_ - y_pred_.mean(dim=0, keepdim=True)
        y_true_demean = y_true - y_true.mean(dim=0, keepdim=True)
        cos_sim = F.cosine_similarity(y_pred_demean, y_true_demean, dim=0)
        loss1 = 1 - cos_sim.mean()
        
        if self.gamma > 0:
            F_inv = torch.linalg.inv(torch.matmul(y_pred_demean.T, y_pred_demean))
            penalty = torch.trace(F_inv)
            loss1 = loss1 + self.gamma * penalty
        return loss1

In [6]:
class GatedAttention(nn.Module):
    """
    门控ATTN插件
    对输入数据的最后一个维度进行加权
    输入输出维度相同
    """
    def __init__(self, in_size: int, out_size: int):
        super().__init__()
        self.attn_layer = nn.Linear(in_size, out_size)
        self.attn_softmax = nn.Softmax(dim=-1)

    def forward(self, inputs):
        attn_weight = self.attn_softmax(self.attn_layer(inputs))
        inputs = inputs * attn_weight
        return inputs
 

class MixerMLP(nn.Module):
    '''
    在三个提取器中使用的MLP Mixer
    类似FFN，逆瓶颈结构，最终的输出维度和输出维度一样
    '''
    
    def __init__(self, in_features, out_features, expansion_factor=1,dropout=0):
        super().__init__()
        num_hidden = in_features * expansion_factor
        self.fc1 = nn.Linear(in_features, num_hidden)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(num_hidden, out_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, inputs: torch.Tensor):
        inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
        inputs = self.fc2(inputs)
        inputs = self.dropout2(inputs)
        return inputs


class TsMixerBlock(nn.Module):
    """
    inter patch 信息抽取（每个通道、学习每一个特征维度的时间N交互）
    输入维度: B M N D
    输出维度: B M N D
    """
    def __init__(self,
                 num_patches,
                 d_model,
                 dropout,
                 expansion_factor,
                 gated_attn=True,
                 norm_type='batch',
                 ):
        super().__init__()

        self.norm = MixerNormLayer(norm_type,d_model)
        self.gated_attn = gated_attn
        self.mlp = MixerMLP(
                in_features=num_patches,
                out_features=num_patches,
                expansion_factor = expansion_factor,
                dropout = dropout
        )
        if gated_attn:
            self.gating_block = GatedAttention(in_size=num_patches, out_size=num_patches)

    def forward(self, hidden_state):
        residual = hidden_state
        
        hidden_state = self.norm(hidden_state)
        hidden_state = hidden_state.transpose(2, 3)
        hidden_state = self.mlp(hidden_state)

        if self.gated_attn:
            hidden_state = self.gating_block(hidden_state)

        # Transpose back
        hidden_state = hidden_state.transpose(2, 3)
        out = hidden_state + residual
        return out
    
class FeatureMixerBlock(nn.Module):
    '''
    输入维度:batch_size, m, num_patches, d_model
    输出维度:batch_size, m, num_patches, d_model
    '''
    def __init__(self
                 , d_model
                 ,expansion_factor
                 ,dropout
                 ,gated_attn=True
                 ,norm_type = 'batch'
                 
                 ):
        super().__init__()

        self.norm = MixerNormLayer(norm_type,d_model)
        self.gated_attn = gated_attn
        self.mlp = MixerMLP(
                in_features=d_model,
                out_features=d_model,
                expansion_factor = expansion_factor,
                dropout = dropout
        )
        if gated_attn:
            self.gating_block = GatedAttention(in_size=d_model, out_size=d_model)

    def forward(self, hidden):
        residual = hidden
        hidden = self.norm(hidden)
        hidden = self.mlp(hidden)

        if self.gated_attn:
            hidden = self.gating_block(hidden)

        out = hidden + residual
        return out

class ChannelMixerBlock(nn.Module):
    """
    输入维度
    输出维度
    """
    def __init__(self,
                 d_model,
                 norm_type,
                 num_input_channels,
                 expansion_factor,
                 dropout,
                 gated_attn=True,
                 ):

        super().__init__()

        self.norm = MixerNormLayer(norm_type,d_model)
        self.gated_attn = gated_attn
        self.mlp = MixerMLP(
                in_features=num_input_channels,
                out_features=num_input_channels,
                expansion_factor = expansion_factor,
                dropout = dropout
        )
        if gated_attn:
            self.gating_block = GatedAttention(
                in_size=num_input_channels, out_size=num_input_channels
            )

    def forward(self, inputs: torch.Tensor):
        residual = inputs
        inputs = self.norm(inputs)
      
        inputs = inputs.permute(0, 3, 2, 1)
        if self.gated_attn:
            inputs = self.gating_block(inputs)

        inputs = self.mlp(inputs)
        inputs = inputs.permute(0, 3, 2, 1)
        out = inputs + residual
        return out

In [10]:


class TSMixerFactorNetV2(L.LightningModule):
    def __init__(self,
                 input_dim = 6,
                 d_model = 50,
                 P = 4,
                 S = 2,
                 seq_len = 30,
                 norm_type = 'batch',
                 label_idx = 1,
                 gamma = 0,
                 channel_mixer = True,
                 gated_attn = True,
                 use_cls = False,
                 pe_type = 'sincos',
                 do_padding = 'end',
                 expansion_factor = 1,
                 dropout = 0
                 ):
        super().__init__()
        
        self.d_model = d_model
        self.P = P          # patch 窗口大小
        self.S = S          # patch 步长
        self.use_cls = use_cls
        
        self.N = int((seq_len - P) / S + 1)
        self.M = input_dim
        self.patch_encoder = TSTPatchEmbed(
                            seq_len = seq_len,
                            P = P,
                            S = S,
                            d_model = d_model,
                            do_padding = do_padding,
                            do_pe = False
                            )
        if do_padding == 'end':
            self.N += 1
        if use_cls:
            self.N += 1
            
        self.pe_enc = PositionalEncoding(seq_len = self.N,
                            num_input_channels = input_dim,
                            d_model = d_model,
                            use_cls_token = use_cls,
                            positional_encoding_type = pe_type
                            )

        # -------backbone-------
        ts_mixer = TsMixerBlock(num_patches = self.N,
                        d_model = d_model,
                        gated_attn = gated_attn,
                        norm_type = norm_type,
                        dropout = dropout,
                        expansion_factor = expansion_factor
                    )
        
        feature_mixer = FeatureMixerBlock(d_model=d_model,
                        gated_attn = gated_attn,
                        norm_type = norm_type,
                        dropout = dropout,
                        expansion_factor = expansion_factor
                    )
        
        mixer_list = [ts_mixer,feature_mixer]
        
        if channel_mixer:
            channel_mixer = ChannelMixerBlock(d_model=d_model,
                            norm_type=norm_type,
                            num_input_channels = self.M,
                            gated_attn = gated_attn,
                            dropout = dropout,
                            expansion_factor = expansion_factor
                    )
            mixer_list.append(channel_mixer)
        
        self.mixer_backbone = nn.ModuleList(mixer_list)
        
        # -------head-------
        self.ts_linear = nn.Linear(self.d_model * self.N, self.d_model)   
        
        self.label_idx = label_idx
        # predict layer
        
        self.predict_layer = nn.Sequential(
            nn.Linear(self.M * self.d_model, 50)
        )
        self.loss_fn = ICLoss(gamma)
       

    def forward(self,inputs):
        z = self.patch_encoder(inputs)  # BM N D
        
        z = rearrange(z,'(b m) n d -> b m n d',m = self.M)
        z = self.pe_enc(z) # 加入位置编码
        
        for mixer in self.mixer_backbone:
            z = mixer(z)
        
        if self.use_cls:
            x_out = z[:,:,0,:] # b m n d -> b m d
            x_out = rearrange(x_out, 'b m d -> b (m d)')
        else:
            z = rearrange(z, 'b m n d -> b m (n d)')
            x_out = F.gelu(self.ts_linear(z)) # [B, M, predict_len]
            x_out = rearrange(x_out, 'b m d -> b (m d)') # [B, M * predict_len]
        return x_out

    def training_step(self, batch, batch_idx):
        inputs = batch[0].squeeze(0)
        x_out = self.forward(inputs)
        y_pred = self.predict_layer(x_out)
        
        labels  = batch[1].squeeze(0)[:,self.label_idx].reshape(-1,1)
        loss = self.loss_fn(y_pred,labels)
        self.log("train_loss",loss, on_epoch=True,prog_bar=True)
        return loss
    
    def predict_step(self, batch, batch_idx):
        inputs = batch[0].squeeze(0)
        ids = batch[1].squeeze(0)[:,0]
        x_out = self.forward(inputs)
        y_pred = self.predict_layer(x_out)
        return torch.cat((y_pred,ids.unsqueeze(1)),1).cpu().numpy()
    
    def validation_step(self,batch,batch_idx):
        inputs = batch[0].squeeze(0)
        x_out = self.forward(inputs)
        y_pred = self.predict_layer(x_out)
        labels  = batch[1].squeeze(0)[:,self.label_idx].reshape(-1,1)
        loss = self.loss_fn(y_pred,labels)
        self.log("val_loss",loss,prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [11]:
model_params = {
    "input_dim": 6,
    "d_model": 16,
    "P": 4,
    "S": 2,
    "dropout": 0.5,
    "use_cls": False,
    "expansion_factor": 2,
}
   
model = TSMixerFactorNetV2(**model_params)

In [12]:
# input shape: B L M
x_in  = torch.rand(1024, 30, 6)

In [13]:
# 调用forward模块的结果
# 用CLS时 输出维度为 B M*D
x_out = model(x_in)
x_out.shape 

torch.Size([1024, 96])

In [14]:
M = 6
D = 16
ts_linear = nn.Linear(M*D,50)

In [15]:
x_final = ts_linear(x_out)
x_final.shape

torch.Size([1024, 50])