In [None]:
import torch
import torch.nn as nn

# Dummy input: batch of sequences
# For illustration, assume vocab_size = 10
batch_size = 2
seq_len = 5
vocab_size = 10
embedding_dim = 8
hidden_dim = 16
num_classes = 3

# Random input: integers 0-9
x = torch.randint(0, vocab_size, (batch_size, seq_len))

# Model
class SimpleRNNClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
        super(SimpleRNNClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim) # Converts each token (0–9) into an 8D vector.
        self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True) # Iterates step by step over the sequence, upadting hidden state
        self.fc = nn.Linear(hidden_dim, num_classes) # fully connected

    def forward(self, x):
        embedded = self.embedding(x)                # Shape: (batch, seq_len, embedding_dim)
        output, hidden = self.rnn(embedded)         # output: all timesteps, hidden: last hidden state
        out = self.fc(hidden.squeeze(0))            # hidden: (1, batch, hidden_dim)
        return out

model = SimpleRNNClassifier(vocab_size, embedding_dim, hidden_dim, num_classes)

# Forward pass
logits = model(x)
print("Logits:", logits.shape)  # (batch_size, num_classes)
