## 🧠 Text Classification with PyTorch (AG News)
This notebook walks through training a text classification model on the AG News dataset using `torchtext` and a simple neural network.

In [None]:
import torch
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.optim as optim

### 1. Load and Tokenize the Dataset

In [None]:
tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

train_iter = AG_NEWS(split='train')
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])

print("Vocab size:", len(vocab))

### 2. Create Encoding and DataLoader

In [None]:
def encode(text):
    return torch.tensor([vocab[token] for token in tokenizer(text)], dtype=torch.long)

def collate_batch(batch):
    label_list, text_list = [], []
    for label, text in batch:
        label_list.append(torch.tensor(label - 1))
        text_list.append(encode(text))
    text_padded = pad_sequence(text_list, batch_first=True, padding_value=vocab['<pad>'])
    return text_padded, torch.tensor(label_list)

train_iter = AG_NEWS(split='train')
train_loader = DataLoader(list(train_iter)[:5000], batch_size=32, shuffle=True, collate_fn=collate_batch)

### 3. Define the Classification Model

In [None]:
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab['<pad>'])
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        pooled = embedded.mean(dim=1)  # mean pooling
        return self.fc(pooled)

model = TextClassifier(len(vocab), embed_dim=64, num_classes=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

### 4. Train the Model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

for epoch in range(5):
    total_loss = 0
    model.train()
    for text, labels in train_loader:
        text, labels = text.to(device), labels.to(device)
        preds = model(text)
        loss = criterion(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

### 5. Test Inference

In [None]:
def predict(text):
    model.eval()
    with torch.no_grad():
        encoded = encode(text).unsqueeze(0).to(device)
        pred = model(encoded)
        label = torch.argmax(pred, dim=1).item()
    return label

test_text = "The stock market saw a sharp decline today due to economic concerns."
print("Predicted class:", predict(test_text))