In [1]:
import torch
import torch.nn as nn

In [2]:

class SelfAttentionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SelfAttentionModel, self).__init__()
        # 線性層用於對特徵嵌入進行線性轉換
        self.linear = nn.Linear(input_dim, hidden_dim)
        # 自注意力機制層
        self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=1)
        
    def forward(self, x):
        # 進行特徵嵌入的線性轉換
        x = self.linear(x)
        # 使用自注意力機制
        x, _ = self.attention(x, x, x)  # 對輸入x應用自注意力機制
        return x

# 創建模型
input_dim = 10  # 假設每個時間步有10個特徵
hidden_dim = 64  # 隱藏層維度
model = SelfAttentionModel(input_dim, hidden_dim)

# 輸入示例數據，這是一個批次大小為3，時間步數為5，特徵維度為10的數據
input_data = torch.rand(3, 5, 10)  # (batch_size, time_steps, input_dim)

# 前向傳播
output = model(input_data)

# 查看輸出形狀
print(output.shape)  # 將顯示 torch.Size([3, 5, 64])

torch.Size([3, 5, 64])


In [3]:

class WaveletKernelNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(WaveletKernelNetwork, self).__init()
        # Wavelet Kernel Network 部分，你可以根據需求擴展這部分
        self.wkn_layer = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2)
        )
        
    def forward(self, x):
        x = self.wkn_layer(x)
        return x

class SelfAttentionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SelfAttentionModel, self).__init()
        # 線性層用於對特徵嵌入進行線性轉換
        self.linear = nn.Linear(input_dim, hidden_dim)
        # 自注意力機制層
        self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=1)
        
    def forward(self, x):
        # 進行特徵嵌入的線性轉換
        x = self.linear(x)
        # 使用自注意力機制
        x, _ = self.attention(x, x, x)  # 對輸入x應用自注意力機制
        return x

class WKN_LSTM_SelfAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(WKN_LSTM_SelfAttention, self).__init()
        # Wavelet Kernel Network 部分
        self.wkn = WaveletKernelNetwork(input_dim, hidden_dim)
        # LSTM 層
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        # 自注意力層
        self.self_attention = SelfAttentionModel(hidden_dim, hidden_dim)
        # 全連接層
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        # Wavelet Kernel Network 部分
        x = self.wkn(x)
        # LSTM 層
        x, _ = self.lstm(x)
        # 自注意力層
        x = self.self_attention(x)
        # 全連接層
        x = self.fc(x)
        return x
