In [None]:
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout= 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)

        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, ff_hidden_dim=ff_hidden_dim, dropout=dropout , batch_first=True)

        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(embed_dim)

        self.ffn=(nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim)
        ))
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, X):
        normed = self.norm1(X)
        attn_out = self.attn(normed, normed, normed)
        x = x + self.dropout1(attn_out)

        normed = self.norm2(x)
        ffn_out = self.ffn(normed)
        x = x + self.dropout2(ffn_out)

        return x
