In [1]:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from torch.nn.functional import conv1d

![双曲正切](./figures/tanh.png)


`self.cnn1-3`--三个不同膨胀率(dilation)

![conv](.\figures\conv.png)


In [2]:
class model(nn.Module):
    # nonlinearity: 非线性激活函数 tanh双曲正切
    # in_dim: 输入数据的特征维度（每个时间步的特征度）
    # in_channels： 输入数据的通道数（每个特征的时间序列数）
    # out_dim： 输出维度（预测值的维度）
    # seq_len: 时间序列的长度（时间步数）
    def __init__(self, resolution_ratio=4, nonlinearity="tanh", in_dim=5, in_channels=3, out_dim=1, seq_len=10):
        super(model, self).__init__()
        self.resolution_ratio = resolution_ratio
        self.activation = nn.ReLU() if nonlinearity=="relu" else nn.Tanh()
        # self.cnn1 = nn.Sequential(nn.Conv1d(in_channels=in_channels,
        merged_channels = in_channels * in_dim
        self.cnn1 = nn.Sequential(nn.Conv1d(in_channels=merged_channels,
                                           out_channels=8,
                                           kernel_size=3,
                                           padding=1,
                                           dilation=1),
                                  nn.GroupNorm(num_groups=1,
                                               num_channels=8))

        self.cnn2 = nn.Sequential(nn.Conv1d(in_channels=merged_channels,
                                           out_channels=8,
                                           kernel_size=3,
                                           padding=2,
                                           dilation=2),
                                  nn.GroupNorm(num_groups=1,
                                               num_channels=8))

        self.cnn3 = nn.Sequential(nn.Conv1d(in_channels=merged_channels,
                                           out_channels=8,
                                           kernel_size=3,
                                           padding=3,
                                           dilation=3),
                                  nn.GroupNorm(num_groups=1,
                                               num_channels=8))
        # 拼接处理
        self.cnn = nn.Sequential(self.activation,
                                 nn.Conv1d(in_channels=24, # 24=3*8 (三个并行块输出拼接)
                                           out_channels=16,
                                           kernel_size=3,
                                           padding=1),
                                 nn.GroupNorm(num_groups=1,
                                              num_channels=16),
                                 self.activation,

                                 nn.Conv1d(in_channels=16,
                                           out_channels=16,
                                           kernel_size=3,
                                           padding=1),
                                 nn.GroupNorm(num_groups=1,
                                              num_channels=16),
                                 self.activation,

                                 nn.Conv1d(in_channels=16,
                                           out_channels=16,
                                           kernel_size=1), # 最终输出16通道的特征图，与GRU输出的16维
                                 nn.GroupNorm(num_groups=1, # hidden state对齐，便于后续特征相加
                                              num_channels=16),
                                 self.activation)

        self.linear = nn.Linear(in_features=in_dim, out_features=1)
        
        self.gru = nn.GRU(input_size=in_channels * in_dim,
                          hidden_size=8,
                          num_layers=3,
                          batch_first=True,
                          bidirectional=True)

        self.gru_out = nn.GRU(input_size=16,
                              hidden_size=8,
                              num_layers=1,
                              batch_first=True,
                              bidirectional=True)
        self.out = nn.Linear(in_features=16*seq_len, out_features=out_dim)


        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
                nn.init.xavier_uniform_(m.weight.data)
                m.bias.data.zero_()
            elif isinstance(m, nn.GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    # shape of the input:(N, C, T, D_in)
    # shape of the label:(N, 1)

    def forward(self, x):
        N, C, T, D = x.shape
    
        # 调整维度，使 Conv1d 能处理
        x_conv = x.permute(0, 1, 3, 2)  # (N, C, D, T)
        # 这里的 C * D是flatten操作 将两个维度扁平化为一个维度
        x_conv = x_conv.reshape(N, C * D, T)  # (N, C*D, T)
    
        # 计算 CNN 分支
        cnn_out1 = self.cnn1(x_conv)
        cnn_out2 = self.cnn2(x_conv)
        cnn_out3 = self.cnn3(x_conv)
        cnn_out = self.cnn(torch.cat((cnn_out1, cnn_out2, cnn_out3), dim=1))
    
        # 调整维度，准备与 GRU 输出相加
        cnn_out = cnn_out.permute(0, 2, 1)  # (N, T, 16)
    
        # 计算 GRU 分支
        x_rnn = x.permute(0, 2, 1, 3)  # (N, T, C, D)
        x_rnn = x_rnn.reshape(N, T, C * D)  # (N, T, C*D)
        rnn_out, _ = self.gru(x_rnn)  # (N, T, 16)
    
        # 合并 CNN 和 GRU 输出
        x = rnn_out + cnn_out
    
        # 最终 GRU 和输出层
        x, _ = self.gru_out(x)  # (N, T, 16)
        x = x.reshape(N, -1)  # (N, T*16)
        x = self.out(x)  # (N, 1)
    
        return x.squeeze(-1)  # (N,)

In [3]:
# test
if __name__ == '__main__':
    # (N, C, T, D_in)
    data_input = torch.normal(0, 1, size=(16, 3, 10, 5))
    model = model()
    data_output = model(data_input)
    print(data_output.shape)

torch.Size([16])
