In [3]:
import torch
import torch.nn as nn 

In [4]:
class RNNEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNNEncoder, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)

    def forward(self, x):
        _, hidden = self.rnn(x)
        return hidden


In [5]:
# Example usage
input_size = 10
seq_len = 5
hidden_size = 8
num_heads = 2

# Generate random input sequence
input_seq = torch.randn((1, seq_len, input_size))

In [7]:
rnn_encoder = RNNEncoder(input_size, hidden_size)
rnn_hidden = rnn_encoder(input_seq)

In [24]:
# Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads):
        super(TransformerEncoder, self).__init__()
        self.self_attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=num_heads)
        self.norm = nn.LayerNorm([seq_len, input_size])  # Fix: Use the correct normalized_shape

    def forward(self, x):
        attn_output, _ = self.self_attention(x, x, x)
        x = self.norm(attn_output.permute(1, 0, 2) + x)
        return x.mean(dim=1)

In [23]:
transformer_encoder = TransformerEncoder(input_size, hidden_size, num_heads)
transformer_hidden = transformer_encoder(input_seq)

In [25]:
print("RNN Hidden State:", rnn_hidden)
print("Transformer Hidden State:", transformer_hidden)

RNN Hidden State: tensor([[[ 0.3613,  0.6381,  0.3856, -0.0666,  0.7605, -0.6685, -0.7132,
           0.1103]]], grad_fn=<StackBackward0>)
Transformer Hidden State: tensor([[-0.1773, -0.1662, -0.9960,  0.4090,  0.3220,  0.2566,  0.0077,  0.7282,
         -0.4104,  0.0263],
        [-0.3855,  0.3177, -1.1411, -0.1710,  0.5668,  0.8834,  0.0655, -0.1863,
         -0.3487,  0.3992],
        [ 0.1850, -0.1229, -1.0941,  0.2407,  0.3054,  0.7534,  0.3784,  0.2076,
         -0.5514, -0.3022],
        [-0.1949, -0.2398, -0.9249,  0.4045,  0.6320,  0.1181, -0.0458,  0.5508,
         -0.1808, -0.1193],
        [ 0.3325, -0.9344, -0.3485,  0.2562,  0.9666, -0.1814, -0.3914,  0.8050,
         -0.1246, -0.3801]], grad_fn=<MeanBackward1>)


   X0 -> RNN -> H0 -> RNN -> H1 -> RNN -> H2 -> ... -> Ht
   Input         Hidden State at t=0      Hidden State at t=1
