In [8]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [9]:
single_rnn = nn.RNN(4, 3,1, batch_first=True)
input = torch.randn(1,2,4)
output, h_n = single_rnn(input)
output.shape, h_n.shape

(torch.Size([1, 2, 3]), torch.Size([1, 1, 3]))

In [10]:
torch_rnn =nn.RNN(4, 3,1, batch_first=True, bidirectional=True)
input = torch.randn(1,2,4)
output, h_n = torch_rnn(input)
output.shape, h_n.shape

(torch.Size([1, 2, 6]), torch.Size([2, 1, 3]))

In [11]:
batch_size, seq_len = 2, 3
# 特征大小
input_size, hidden_size=  2, 3

input = torch.randn(batch_size, seq_len, input_size)
h_prev = torch.zeros((1, batch_size, hidden_size))
torch.manual_seed(20250630)
torch.cuda.manual_seed(20250630)

In [12]:
rnn_real = nn.RNN(input_size, hidden_size, batch_first=True)
out_real, h_real = rnn_real(input, h_prev)
print(out_real.shape, h_real.shape)
print("---Parameters---")
for k, v in rnn_real.named_parameters():
    print(k, v.shape)
print("---Output---")
print(out_real)
print(h_real)

torch.Size([2, 3, 3]) torch.Size([1, 2, 3])
---Parameters---
weight_ih_l0 torch.Size([3, 2])
weight_hh_l0 torch.Size([3, 3])
bias_ih_l0 torch.Size([3])
bias_hh_l0 torch.Size([3])
---Output---
tensor([[[ 0.7200,  0.2839,  0.1090],
         [ 0.8598,  0.5006, -0.0296],
         [ 0.9018,  0.3565,  0.0851]],

        [[ 0.7355, -0.1017, -0.2827],
         [ 0.8609,  0.5447, -0.2269],
         [ 0.9237,  0.3352,  0.6465]]], grad_fn=<TransposeBackward1>)
tensor([[[0.9018, 0.3565, 0.0851],
         [0.9237, 0.3352, 0.6465]]], grad_fn=<StackBackward0>)


In [13]:
from torch import Tensor


class RnnCell(nn.Module):
    def __init__(self, input_size, hidden_size, activ=nn.Tanh()):
        super(RnnCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activ = activ
        self.w_ih = nn.Parameter(torch.randn(hidden_size, input_size))
        self.w_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_ih = nn.Parameter(torch.randn(hidden_size))        
        self.b_hh = nn.Parameter(torch.randn(hidden_size))        
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_normal_(self.w_ih)
        nn.init.xavier_normal_(self.w_hh)
        nn.init.constant_(self.b_ih, 0.1)

    def forward(self, x_t: Tensor, h_prev: Tensor) -> Tensor:
        h = self.activ(x_t @ self.w_ih.T + self.b_ih + h_prev @ self.w_hh.T + self.b_hh)
        return h


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = RnnCell(input_size, hidden_size, activ=nn.Tanh())        
        # self.w_y = nn.Parameter(torch.randn(hidden_size, output_size))
        # self.b_y = nn.Parameter(torch.randn(output_size))
        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x: Tensor, h_prev=None):
        batch_size_, seq_len_, _ = x.shape
        h_out = torch.zeros((batch_size_, seq_len_, self.hidden_size))
        if h_prev is None:        
            h_prev = torch.zeros(batch_size, self.hidden_size).to(x.device)
        for i in range(seq_len):
            x_t = x[:, i, :]
            h_prev = self.rnn(x_t, h_prev)
            # y = self.leaky_relu((h_prev @ self.w_y) + self.b_y)
            h_out[:, i, :] = h_prev
        return h_out, h_prev

In [14]:
my_rnn = RNN(input_size, hidden_size, hidden_size)
# for k, v in my_rnn.named_parameters():
#     print(k, v.shape)
out, h = my_rnn(input, h_prev)
print(h.shape, out.shape)

torch.Size([1, 2, 3]) torch.Size([2, 3, 3])


In [15]:
# weight_ih_l0 torch.Size([3, 2])
# weight_hh_l0 torch.Size([3, 3])
# bias_ih_l0 torch.Size([3])
# bias_hh_l0 torch.Size([3])

my_rnn = RNN(input_size, hidden_size, hidden_size)
my_rnn.rnn.w_ih = rnn_real.weight_ih_l0
my_rnn.rnn.w_hh = rnn_real.weight_hh_l0
my_rnn.rnn.b_ih = rnn_real.bias_ih_l0
my_rnn.rnn.b_hh = rnn_real.bias_hh_l0
# for k, v in my_rnn.named_parameters():
#     print(k, v.shape)
out, h = my_rnn(input, h_prev)
print(out)
print(h)

tensor([[[ 0.7200,  0.2839,  0.1090],
         [ 0.8598,  0.5006, -0.0296],
         [ 0.9018,  0.3565,  0.0851]],

        [[ 0.7355, -0.1017, -0.2827],
         [ 0.8609,  0.5447, -0.2269],
         [ 0.9237,  0.3352,  0.6465]]], grad_fn=<CopySlices>)
tensor([[[0.9018, 0.3565, 0.0851],
         [0.9237, 0.3352, 0.6465]]], grad_fn=<TanhBackward0>)


In [16]:
assert torch.allclose(out_real, out)
assert torch.allclose(h_real, h)