In [None]:
from torch import nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torch
from tqdm import tqdm


In [None]:
# Model Configuration
GRU_CONFIG = {
    'embedding_dim': 300,
    'hidden_dim': 512,
    'output_dim': 1,  # Binary classification
    'n_layers': 2,
    'dropout': 0.3
}

# Training Configuration
TRAIN_CONFIG = {
    'batch_size': 128,
    'epochs': 10,
    'learning_rate': 0.001,
    'optimizer': torch.optim.Adam,
    'loss_fn': nn.BCEWithLogitsLoss(),
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'max_vocab_size': 50000,
    'max_seq_length': 150,
    'min_freq': 2
}

In [None]:
tokenizer = get_tokenizer('basic_english')

class SentimentDataset(Dataset):
    """Dataset class for Sentiment140"""
    
    def __init__(self, split='train'):
        self.dataset = load_dataset('sentiment140', split=split)
        self.texts = [example['text'] for example in self.dataset]
        self.labels = [example['sentiment'] for example in self.dataset]

def build_vocab(dataset, config):
    def yield_tokens():
        for text in dataset.texts:
            yield tokenizer(text)[:config['max_seq_length']]
            
    vocab = build_vocab_from_iterator(
        yield_tokens(),
        max_tokens=config['max_vocab_size'],
        min_freq=config['min_freq'],
        specials=['<unk>', '<pad>']
    )
    vocab.set_default_index(vocab['<unk>'])
    return vocab

def collate_batch(batch, vocab, config):
    text_list, label_list, lengths = [], [], []
    for example in batch:
        text = example['text']
        tokens = tokenizer(text)[:config['max_seq_length']]
        processed_text = torch.tensor([vocab[token] for token in tokens], dtype=torch.int64)
        text_list.append(processed_text)
        label_list.append(example['sentiment'] / 4.0)  # Convert to 0-1 range
        lengths.append(len(processed_text))
    
    padded_text = nn.utils.rnn.pad_sequence(text_list, padding_value=vocab['<pad>'])
    labels = torch.tensor(label_list, dtype=torch.float32)
    lengths = torch.tensor(lengths, dtype=torch.int64)
    
    return padded_text.T, labels, lengths

In [None]:
class GRUClassifier(nn.Module):
    """Modular GRU for sequence classification with packed sequences"""
    
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int,
                 output_dim: int, n_layers: int, dropout: float, pad_idx: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=n_layers,
                         dropout=dropout, bidirectional=False, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> torch.Tensor:
        embedded = self.dropout(self.embedding(text))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, text_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, hidden = self.gru(packed_embedded)
        hidden = self.dropout(hidden[-1, :, :])
        return self.fc(hidden)

In [None]:
class GRUTrainer:
    """Training and evaluation module for GRU"""
    
    def __init__(self, model, train_loader, valid_loader, config):
        self.model = model.to(config['device'])
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.config = config
        self.optimizer = config['optimizer'](model.parameters(), lr=config['learning_rate'])
        self.loss_fn = config['loss_fn']
        
    def train_epoch(self):
        self.model.train()
        total_loss, total_acc = 0, 0
        
        for text, labels, lengths in tqdm(self.train_loader, desc="Training"):
            text, lengths = text.to(self.config['device']), lengths.to(self.config['device'])
            labels = labels.to(self.config['device'])
            
            self.optimizer.zero_grad()
            predictions = self.model(text, lengths).squeeze(1)
            loss = self.loss_fn(predictions, labels)
            acc = (torch.sigmoid(predictions) > 0.5).float().eq(labels > 0.5).float().mean()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
            self.optimizer.step()
            
            total_loss += loss.item()
            total_acc += acc.item()
            
        return total_loss / len(self.train_loader), total_acc / len(self.train_loader)
    
    def evaluate(self):
        self.model.eval()
        total_loss, total_acc = 0, 0
        
        with torch.no_grad():
            for text, labels, lengths in self.valid_loader:
                text, lengths = text.to(self.config['device']), lengths.to(self.config['device'])
                labels = labels.to(self.config['device'])
                
                predictions = self.model(text, lengths).squeeze(1)
                loss = self.loss_fn(predictions, labels)
                acc = (torch.sigmoid(predictions) > 0.5).float().eq(labels > 0.5).float().mean()
                
                total_loss += loss.item()
                total_acc += acc.item()
                
        return total_loss / len(self.valid_loader), total_acc / len(self.valid_loader)
    
    def save_model(self, path='gru_classifier.pth'):
        torch.save(self.model.state_dict(), path)

In [None]:
def main():
    # Load and prepare data
    train_data = SentimentDataset(split='train')
    valid_data = SentimentDataset(split='test')
    vocab = build_vocab(train_data, TRAIN_CONFIG)
    
    # Create data loaders
    train_loader = DataLoader(
        train_data.dataset,
        batch_size=TRAIN_CONFIG['batch_size'],
        shuffle=True,
        collate_fn=lambda x: collate_batch(x, vocab, TRAIN_CONFIG)
    )
    valid_loader = DataLoader(
        valid_data.dataset,
        batch_size=TRAIN_CONFIG['batch_size'],
        collate_fn=lambda x: collate_batch(x, vocab, TRAIN_CONFIG)
    )
    
    # Initialize model
    model = GRUClassifier(
        vocab_size=len(vocab),
        pad_idx=vocab['<pad>'],
        **GRU_CONFIG
    )
    
    # Initialize trainer
    trainer = GRUTrainer(model, train_loader, valid_loader, TRAIN_CONFIG)
    
    # Training loop
    for epoch in range(TRAIN_CONFIG['epochs']):
        train_loss, train_acc = trainer.train_epoch()
        valid_loss, valid_acc = trainer.evaluate()
        
        print(f"\nEpoch {epoch+1}/{TRAIN_CONFIG['epochs']}")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
        print(f"Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc*100:.2f}%")
    
    trainer.save_model()


main()