In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

embedding_dim = 128
hidden_dim = 256
vocab_size = 10000
sequence_length = 30
batch_size = 64
num_epochs = 5
clip_grad_norm = 1.0

### Hyperparameters & Gradient Clipping

Here we define the model configuration. A key parameter here is `clip_grad_norm = 1.0`.

**Why Gradient Clipping?**
RNNs can suffer from the **exploding gradient problem**, where gradients grow exponentially during backpropagation through time. This can cause the model weights to become unstable (NaNs or Infinity).

**How it works:**
Gradient clipping rescales the gradient vector so that its norm (magnitude) does not exceed a threshold (here, `1.0`). This ensures stability during training without changing the direction of the gradient.

### Understanding Output Shapes

In the `RNNModel` below, pay attention to the tensor shapes in the `forward` method:

*   **`output`**: Shape `(batch_size, sequence_length, hidden_dim)`
    *   This contains the hidden states for **every time step** in the sequence.
*   **`logits`**: Shape `(batch_size, vocab_size)`
    *   We extract the hidden state from the **last time step** (`output[:, -1, :]`), which has shape `(batch_size, hidden_dim)`.
    *   This vector is passed through the fully connected layer (`self.fc`) to project it to the vocabulary size, producing the prediction logits.

In [None]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        # x: (batch_size, sequence_length)
        embedded = self.embedding(x)        # (batch_size, sequence_length, embedding_dim)
        
        # output: (batch_size, sequence_length, hidden_dim)
        # Contains hidden states for all time steps
        output, _ = self.rnn(embedded)      
        
        # logits: (batch_size, vocab_size)
        # Take the last time step's output and project to vocab size
        logits = self.fc(output[:, -1, :])  
        return logits

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
model = RNNModel(vocab_size, embedding_dim, hidden_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
def generate_dummy_data(batch_size, sequence_length, vocab_size):
    input = torch.randint(0, vocab_size, (batch_size, sequence_length), dtype=torch.long)
    target = torch.randint(0, vocab_size, (batch_size,), dtype=torch.long)
    return input, target

In [None]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for _ in range(100):  # Assume 100 batches per epoch
        inputs, targets = generate_dummy_data(batch_size, sequence_length, vocab_size)
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        # apply gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / 100
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

In [None]:
model.eval() # Evaluation
sample_input, _ = generate_dummy_data(1, sequence_length, vocab_size)
sample_input = sample_input.to(device)
with torch.no_grad():
    sample_output = model(sample_input)
    predicted_token = torch.argmax(sample_output, dim=1).item()
    print("Sample input:", sample_input.cpu().numpy()) # Print input sequence
    print("Predicted token:", predicted_token)