## 🔁 Text Classification with RNNs (AG News)
This notebook extends the basic text classifier by using a GRU-based RNN for improved sequence modeling.

In [None]:
import torch
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import torch.optim as optim

### 1. Load and Prepare Data

In [None]:
tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

train_iter = AG_NEWS(split='train')
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])

def encode(text):
    return torch.tensor([vocab[token] for token in tokenizer(text)], dtype=torch.long)

def collate_batch(batch):
    label_list, text_list, lengths = [], [], []
    for label, text in batch:
        encoded = encode(text)
        label_list.append(torch.tensor(label - 1))
        text_list.append(encoded)
        lengths.append(len(encoded))
    padded = pad_sequence(text_list, batch_first=True, padding_value=vocab['<pad>'])
    return padded, torch.tensor(lengths), torch.tensor(label_list)

train_iter = AG_NEWS(split='train')
train_loader = DataLoader(list(train_iter)[:5000], batch_size=32, shuffle=True, collate_fn=collate_batch)

### 2. Define GRU-based RNN Classifier

In [None]:
class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab['<pad>'])
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, lengths):
        embedded = self.embedding(x)
        packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, hidden = self.gru(packed)
        return self.fc(hidden.squeeze(0))

model = RNNClassifier(len(vocab), embed_dim=64, hidden_dim=128, num_classes=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

### 3. Train the RNN Model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

for epoch in range(5):
    total_loss = 0
    model.train()
    for text, lengths, labels in train_loader:
        text, lengths, labels = text.to(device), lengths.to(device), labels.to(device)
        preds = model(text, lengths)
        loss = criterion(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

### 4. Test Inference with RNN

In [None]:
def predict(text):
    model.eval()
    with torch.no_grad():
        encoded = encode(text)
        length = torch.tensor([len(encoded)])
        padded = pad_sequence([encoded], batch_first=True, padding_value=vocab['<pad>']).to(device)
        pred = model(padded, length.to(device))
        label = torch.argmax(pred, dim=1).item()
    return label

test_text = "The government is planning new economic reforms."
print("Predicted class:", predict(test_text))