In [1]:
import lightning as L 
import torch.nn.functional as F
import torch
from torch import nn,optim
import torchmetrics
import random
import glob
import math
from einops import rearrange
import numpy as np

In [2]:
# IC LOSS
class ICLoss(nn.Module):
    def __init__(self, gamma=0):
        super(ICLoss, self).__init__()
        self.gamma = gamma

    def forward(self, y_pred, y_true):
        y_pred_ = torch.mean(y_pred, dim=1).unsqueeze(1)
        y_pred_demean = y_pred_ - y_pred_.mean(dim=0, keepdim=True)
        y_true_demean = y_true - y_true.mean(dim=0, keepdim=True)
        cos_sim = F.cosine_similarity(y_pred_demean, y_true_demean, dim=0)
        loss1 = 1 - cos_sim.mean()
        
        if self.gamma > 0:
            F_inv = torch.linalg.inv(torch.matmul(y_pred_demean.T, y_pred_demean))
            penalty = torch.trace(F_inv)
            loss1 = loss1 + self.gamma * penalty
        return loss1

In [3]:
# ================= 位置编码=================
class PositionalEncoding(nn.Module):
    '''
    PatchTST位置编码
    num_patches: 序列长度
    num_input_channels: 输入通道数
    d_model: 嵌入维度
    use_cls_token: bool 是否在开头添加一个cls token
    positional_encoding_type: 位置编码类型
    
    input : [B  M  N  d_model]
    output: [B  M  N  d_model] / [B  M  N+1 d_model]
    '''
    def __init__(self,
                num_patches,
                num_input_channels,
                d_model,
                positional_dropout = 0,
                use_cls_token = False,
                positional_encoding_type = 'sincos' # 可选random（随机可学习）/ sincos（固定）
                ):
        super().__init__()
        self.use_cls_token = use_cls_token
        self.num_input_channels = num_input_channels
        
        if use_cls_token:
            # cls_token: [1 x num_input_channels x 1 x d_model]
            self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, d_model))
            # num_patches += 1
            
        # postional encoding: [num_patches x d_model]
        
        self.position_enc = self._init_pe(positional_encoding_type,num_patches,d_model)
        # Positional dropout
        self.positional_dropout = (
            nn.Dropout(positional_dropout) if positional_dropout > 0 else nn.Identity()
        )

    @staticmethod
    def _init_pe(positional_encoding_type, num_patches,d_model):
        # Positional encoding
        if positional_encoding_type == "random":
            position_enc = nn.Parameter(torch.randn(num_patches, d_model), requires_grad=True)
        elif positional_encoding_type == "sincos":
            position_enc = torch.zeros(num_patches,d_model)
            position = torch.arange(0, num_patches).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
            position_enc[:, 0::2] = torch.sin(position * div_term)
            position_enc[:, 1::2] = torch.cos(position * div_term)
            position_enc = position_enc - position_enc.mean()
            position_enc = position_enc / (position_enc.std() * 10)
            position_enc = nn.Parameter(position_enc, requires_grad=False)
        else:
            raise ValueError(
                f"{positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
            )
        return position_enc

    def forward(self, patch_input):
        if self.use_cls_token:
            # patch_input: [bs x num_channels x num_patches x d_model]
            patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :])
            # append cls token where cls_token: [1 x num_channels x 1 x d_model]
            cls_token = self.cls_token + self.position_enc[:1, :]
            # get the same copy of cls_token for all the samples in batch: [bs x num_channels x 1 x d_model]
            cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1)
            # hidden_state: [bs x num_channels x (num_patches+1) x d_model]
            hidden_state = torch.cat((cls_tokens, patch_input), dim=2)
        else:
            # hidden_state: [bs x num_channels x num_patches x d_model]
            hidden_state = self.positional_dropout(patch_input + self.position_enc)
        return hidden_state

In [4]:
# ====================PATCH EMBEDDING====================
class TSTPatchEmbed(torch.nn.Module):
    '''
    输入: [B L M]
    (padding)unfold patch + linear embedding + 位置编码(可选)
    输出: B M N d_model ->  [B*M N d_model]
    '''
    def __init__(self,
                seq_len,
                d_model,
                do_padding = 'end', # padding方法，如果为None则不进行padding
                P = None,
                S = None,
                M = 6, # 变量数
                emb_dropout = 0,
                do_pe = True,       # 是否加入位置编码
                pe_type = 'sincos',
                ):
        super().__init__()
        assert do_padding in ['end',None]
        
        self.seq_len = seq_len
        self.P = P
        self.S = S
        
        # 计算patch_number
        patch_num = int((seq_len - P) / S + 1)
        # padding
        self.do_padding = do_padding
        
        if do_padding == "end":
            self.padding_patch_layer = torch.nn.ReplicationPad1d((0, S))
            patch_num += 1
        else:
            # 不进行padding时，需要保证时序排在【最后的一个窗口】正好能形成一组
            new_sequence_length = self.P + self.S * (patch_num - 1)
            self.sequence_start = seq_len - new_sequence_length
        
        # linear embedding layer
        self.emb = torch.nn.Linear(P, d_model)
        
        # 位置编码
        self.do_pe = do_pe
        if self.do_pe:
            self.pe_enc = positional_encoding(pe = pe_type
                                            , q_len = patch_num
                                            , d_model = d_model
                                            )
        self.emb_dropout = nn.Dropout(emb_dropout)
    
    def forward(self,inputs):
        z = rearrange(inputs, 'B L M -> B M L')  # [B L M] -> [B M L]
        
        if self.do_padding=='end': 
            z = self.padding_patch_layer(z)
        else:
            z = z[:, :, self.sequence_start :]  
        
        z = z.unfold(dimension=-1, size=self.P, step=self.S) # z: [bs x M x N x P]
        x_emb = self.emb(z)
        x_emb = rearrange(x_emb,"B M N d_model -> (B M) N d_model")  # 合并B和M的维度 -> 把每个变量视为了一个样本
        
        if self.do_pe:
            x_emb = self.emb_dropout(x_emb + self.pe_enc)   # 加入位置编码,输出维度: (BM , N ,d_model)
        else:
            x_emb = self.emb_dropout(x_emb)
        return x_emb  # [BM, N, d_model]

In [5]:
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST
class PatchTSTAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout = 0.0,
        is_decoder= False,
        is_causal = False,
        bias = True
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.is_causal = is_causal

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states,
        key_value_states=None,
        past_key_value=None,
        attention_mask=None,
        layer_head_mask=None,
        output_attentions=False,
    ):
        """Input shape: Batch x Time x Channel"""
        
        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
        # is checking that the `sequence_length` of the `past_key_value` is the same as
        # the provided `key_value_states` to support prefix tuning
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.reshape(*proj_shape)
        value_states = value_states.reshape(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if layer_head_mask is not None:
            if layer_head_mask.size() != (self.num_heads,):
                raise ValueError(
                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
                    f" {layer_head_mask.size()}"
                )
            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to be reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned across GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped, past_key_value
    
    
class _ScaledDotProductAttention(torch.nn.Module):
    def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
        super().__init__()
        self.attn_dropout = torch.nn.Dropout(attn_dropout)
        self.res_attention = res_attention
        head_dim = d_model // n_heads
        self.scale = torch.nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
        self.lsa = lsa

    def forward(self, q, k, v, prev=None, key_padding_mask=None, attn_mask=None):
        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
        attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]

        # Add pre-softmax attention scores from the previous layer (optional)
        if prev is not None: attn_scores = attn_scores + prev

        # Attention mask (optional)
        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
            if attn_mask.dtype == torch.bool:
                attn_scores.masked_fill_(attn_mask, -np.inf)
            else:
                attn_scores += attn_mask

        # Key padding mask (optional)
        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)
            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)

        # normalize the attention weights
        attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]
        attn_weights = self.attn_dropout(attn_weights)

        # compute the new values given the attention weights
        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]

        if self.res_attention: return output, attn_weights, attn_scores
        else: return output, attn_weights

In [6]:
class PatchTSTBatchNorm(nn.Module):
    """
    Compute batch normalization over the sequence length (time) dimension.
    inputs: B L D
    """
    def __init__(self, d_model):
        super().__init__()
        self.batchnorm = nn.BatchNorm1d(d_model)

    def forward(self, inputs):
        """
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        """
        output = inputs.transpose(1, 2)  # output: (batch_size, d_model, sequence_length)
        output = self.batchnorm(output)
        return output.transpose(1, 2)

In [7]:
class PatchTSTEncoderLayer(nn.Module):
    '''
    PatchTST Encoder
    input : B, M, N, D
    output: B, M, N, D
    '''
    def __init__(self,
                 d_model,
                 num_attention_heads,
                 expand_factor = 2,
                 pre_norm = True,
                 attention_dropout=0,
                 path_dropout=0,
                 ff_dropout = 0,
                 channel_attention=False,
                 norm_type = 'batchnorm',
                 output_attentions = False
                 ):
        super().__init__()
        ffn_dim = expand_factor * d_model
        self.channel_attention = channel_attention
        self.self_attn = PatchTSTAttention(
            embed_dim=d_model,
            num_heads=num_attention_heads,
            dropout=attention_dropout,
        )
        # Add & Norm of the sublayer 1
        self.dropout_path1 = nn.Dropout(path_dropout) if path_dropout > 0 else nn.Identity()
        if norm_type == "batchnorm":
            self.norm_sublayer1 = PatchTSTBatchNorm(d_model)
        elif norm_type == "layernorm":
            self.norm_sublayer1 = nn.LayerNorm(d_model)
        else:
            raise ValueError(f"{norm_type} is not a supported norm layer type.")

        # Add & Norm of the sublayer 2
        if self.channel_attention:
            self.dropout_path2 = nn.Dropout(path_dropout) if path_dropout > 0 else nn.Identity()
            if norm_type == "batchnorm":
                self.norm_sublayer2 = PatchTSTBatchNorm(d_model)
            elif norm_type == "layernorm":
                self.norm_sublayer2 = nn.LayerNorm(d_model)
            else:
                raise ValueError(f"{norm_type} is not a supported norm layer type.")
        
        # Position-wise Feed-Forward
        self.ff = nn.Sequential(
            nn.Linear(d_model, ffn_dim),
            nn.GELU(),
            nn.Dropout(ff_dropout) if ff_dropout > 0 else nn.Identity(),
            nn.Linear(ffn_dim, d_model),
        )

        # Add & Norm of sublayer 3
        self.dropout_path3 = nn.Dropout(path_dropout) if path_dropout > 0 else nn.Identity()
        if norm_type == "batchnorm":
            self.norm_sublayer3 = PatchTSTBatchNorm(d_model)
        elif norm_type == "layernorm":
            self.norm_sublayer3 = nn.LayerNorm(d_model)
        else:
            raise ValueError(f"{norm_type} is not a supported norm layer type.")

        self.pre_norm = pre_norm

    def forward(self, hidden_state, output_attentions=False):
        """
        Parameters:
            hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*):
                Past values of the time series
            output_attentions (`bool`, *optional*):
                Whether or not to return the output attention of all layers
        Return:
            `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`

        """
        batch_size, num_input_channels, sequence_length, d_model = hidden_state.shape
        # --------------------------1. attention across time--------------------------
        hidden_state = rearrange(hidden_state, 'b m n d -> (b m) n d')
        
        if self.pre_norm:
            attn_output, attn_weights, _ = self.self_attn(
                hidden_states = self.norm_sublayer1(hidden_state),
                output_attentions=output_attentions
            )
            hidden_state = hidden_state + self.dropout_path1(attn_output)
        else:
            attn_output, attn_weights, _ = self.self_attn(
                hidden_states=hidden_state, output_attentions=output_attentions
            )
            hidden_state = self.norm_sublayer1(hidden_state + self.dropout_path1(attn_output))

        hidden_state = rearrange(hidden_state, '(b m) n d -> b m n d', m = num_input_channels)
        # --------------------------2. attention across variable at any given time--------------------------
        if self.channel_attention:
            hidden_state = rearrange(hidden_state, 'b m n d -> (b n) m d')
            
            if self.pre_norm:
                ## Norm and Multi-Head attention and Add residual connection
                attn_output, channel_attn_weights, _ = self.self_attn(
                    hidden_states=self.norm_sublayer2(hidden_state), output_attentions=output_attentions
                )
                # Add: residual connection with residual dropout
                hidden_state = hidden_state + self.dropout_path2(attn_output)
            else:
                ## Multi-Head attention and Add residual connection and Norm
                attn_output, channel_attn_weights, _ = self.self_attn(
                    hidden_states=hidden_state, output_attentions=output_attentions
                )
                # hidden_states: [(bs*sequence_length) x num_channels x d_model]
                hidden_state = self.norm_sublayer2(hidden_state + self.dropout_path2(attn_output))

            # Reshape hidden state
            hidden_state = rearrange(hidden_state, '(b n) m d-> b m n d', n = sequence_length)
            
        # --------------------------3. mixing across hidden(FFN)--------------------------
        hidden_state = rearrange(hidden_state, 'b m n d -> (b m) n d')
        if self.pre_norm:
            ## Norm and Position-wise Feed-Forward and Add residual connection
            # Add: residual connection with residual dropout
            hidden_state = hidden_state + self.dropout_path3(self.ff(self.norm_sublayer3(hidden_state)))
        else:
            ## Position-wise Feed-Forward and Add residual connection and Norm
            # Add: residual connection with residual dropout
            hidden_state = self.norm_sublayer3(hidden_state + self.dropout_path3(self.ff(hidden_state)))
        # [bs x num_channels x sequence_length x d_model]
        # hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model)
        hidden_state = rearrange(hidden_state, '(b m) n d -> b m n d', m = num_input_channels)
        
        outputs = (hidden_state,)
        if output_attentions:
            outputs += (attn_weights, channel_attn_weights) if self.channel_attention else (attn_weights,)
        return outputs

In [8]:
class PatchTSTFactorNetV4(L.LightningModule):
    def __init__(self
                ,input_dim = 6
                ,d_model = 16
                ,P = 4
                ,S = 2
                ,seq_len = 30
                ,n_heads = 2
                ,use_cls = False
                ,channel_attention = False
                ,padding_patch = 'end'
                ,pe_type = 'sincos'
                ,expand_factor = 1
                ,norm = 'batch'
                ,pre_norm = True
                ,label_idx = 1
                ,gamma = 0
                 ):
        super().__init__()
        self.input_dim = input_dim
        self.d_ff = d_model * expand_factor
        
        self.N = int((seq_len - P)/S + 1)
        
        if padding_patch == "end":
            self.N += 1
        if use_cls:
            self.N += 1
            
        self.patch_emb = TSTPatchEmbed(d_model = d_model,
                                 seq_len = seq_len,
                                 do_padding = padding_patch,
                                 P = P,
                                 S = S,
                                 do_pe = False,
                                 emb_dropout = 0
                                 )

        self.use_cls = use_cls
        self.pe_enc = PositionalEncoding(num_patches = self.N,
                            num_input_channels = input_dim,
                            d_model = d_model,
                            use_cls_token = use_cls,
                            positional_encoding_type = pe_type
                            )

        self.layers  = PatchTSTEncoderLayer(d_model = d_model,
                            expand_factor = expand_factor,
                            num_attention_heads = n_heads,
                            attention_dropout=0,
                            channel_attention=channel_attention,
                            norm_type = 'batchnorm',
                            output_attentions = False,
                            pre_norm=pre_norm
                            )
        self.label_idx = label_idx

        self.predict_layer = nn.Linear(self.input_dim * d_model,50)
        self.loss_fn = ICLoss(gamma)

        self.ts_linear = nn.Linear(self.N * d_model,d_model)
        
    def forward(self, inputs): 
        x_emb = self.patch_emb(inputs)  # BM N D
        x_emb = rearrange(x_emb, '(B M) N D -> B M N D',M = self.input_dim)
        x_emb = self.pe_enc(x_emb)
        x_out = self.layers(x_emb)[0]  # B M N D
        
        if self.use_cls:
            x_out = x_out[:,:,0,:] # B M D
        else:
            x_out = rearrange(x_out,"B M N d_model -> B M (N d_model)")
            x_out = self.ts_linear(x_out) 
        
        output = rearrange(x_out,"B M D -> B (M D)")
        return output
    
    def training_step(self, batch, batch_idx):
        inputs = batch[0].squeeze(0)
        x_out = self.forward(inputs)
        y_pred = self.predict_layer(x_out)
        
        labels  = batch[1].squeeze(0)[:,self.label_idx].reshape(-1,1)
        loss = self.loss_fn(y_pred,labels)

        self.log("train_loss",loss, on_epoch=True,prog_bar=True)
        return loss
    
    def predict_step(self, batch, batch_idx):
        inputs = batch[0].squeeze(0)
        ids = batch[1].squeeze(0)[:,0]
        x_out = self.forward(inputs)
        y_pred = self.predict_layer(x_out)
        
        return torch.cat((y_pred,ids.unsqueeze(1)),1).cpu().numpy()
    
    def validation_step(self,batch,batch_idx):
        inputs = batch[0].squeeze(0)
        x_out = self.forward(inputs)
        y_pred = self.predict_layer(x_out)
        
        labels  = batch[1].squeeze(0)[:,self.label_idx].reshape(-1,1)
        loss = self.loss_fn(y_pred,labels)
        
        self.log("val_loss",loss,prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [9]:
label_idx: 1
d_model: 16
P: 4
S: 2
n_heads: 2
expand_factor: 1
channel_attention: True
use_cls: True

In [10]:
model_params = {
    "input_dim": 6,
    "d_model": 16,
    "P": 4,
    "S": 2,
    "seq_len": 30,
    "n_heads": 2,
    "d_model": 16,
    "use_cls": True,
    "channel_attention": True,
    "expand_factor": 1,
}

In [11]:
model = PatchTSTFactorNetV4(**model_params)

In [15]:
# input shape: B L M
x_in  = torch.rand(1024, 30, 6)

In [16]:
# 调用forward模块的结果
# 用CLS时 输出维度为 B M*D
x_out = model(x_in)
x_out.shape 

torch.Size([1024, 96])

In [20]:
M = 6
D = 16
ts_linear = nn.Linear(M*D,50)

In [21]:
x_final = ts_linear(x_out)
x_final.shape

torch.Size([1024, 50])