In [1]:
import torch 
import math
import torch.nn as nn

In [18]:
class NaiveCustomLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        
        # i_t input gate 输入门
        self.U_i = nn.Parameter(torch.Tensor(input_sz, hidden_sz))  # 与x相乘
        self.V_i = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz)) # 与h相乘
        self.b_i = nn.Parameter(torch.Tensor(hidden_sz)) # 维度 = 输出的维度
        
        # f_t forget gate 遗忘门
        self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_f = nn.Parameter(torch.Tensor(hidden_sz)) # 维度 = 输出的维度
        
        # o_t output gate 输出门
        self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_o = nn.Parameter(torch.Tensor(hidden_sz)) # 维度 = 输出的维度
        
        # c_t 候选更新单元
        self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()
        
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
            
    def forward(self, x, init_states=None):
        """
        assumes x.shape represents (batch_size, sequence_size, input_size)
        """
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        
        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_size).to(x.device),
                torch.zeros(bs, self.hidden_size).to(x.device)
            )
        else:
            h_t, c_t = init_states
        
        # 前向传播，循环网络
        for t in range(seq_sz):
            x_t = x[:, t, :]
            i_t = torch.sigmoid(torch.dot(x, self.U_i) + torch.dot(h_t, self.V_i) + self.b_i)
            f_t = torch.sigmoid(torch.dot(x, self.U_f) + torch.dot(h_t, self.V_f) + self.b_f)
            o_t = torch.sigmoid(torch.dot(x, self.U_o) + torch.dot(h_t, self.V_o) + self.b_o)
            _c_t = torch.tanh(torch.dot(x, self.U_c) + torch.dot(h_t, self.V_c) + self.b_c)
            
            c_t = f_t * c_t + i_t * _c_t
            h_t = o_t * c_t
            
            hidden_seq.append(h_t.unsqueeze(0))
        
        return hidden_seq   

In [14]:
class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_sz = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        
        self.init_weights()
        
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
    
    def forward(self, x, init_states=None):
        bs, seq_sz = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_sz).to(x.device),
                torch.zeros(bs, self.hidden_sz).to(x.device)
            )
        else:
            h_t, c_t = init_states
        
        i_t, 