In [None]:
from torch import nn
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from tqdm import tqdm


In [None]:
# Model Configuration
LSTM_CONFIG = {
    'embedding_dim': 100,
    'hidden_dim': 256,
    'output_dim': 1,  # Binary classification
    'n_layers': 2,
    'dropout': 0.5
}

# Training Configuration
TRAIN_CONFIG = {
    'batch_size': 64,
    'epochs': 5,
    'learning_rate': 0.001,
    'optimizer': torch.optim.Adam,
    'loss_fn': nn.BCEWithLogitsLoss(),
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'max_vocab_size': 25000,
    'max_seq_length': 500
}

In [None]:
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

def get_datasets():
    # Load and split dataset
    train_iter, test_iter = IMDB(split=('train', 'test'))
    
    # Build vocabulary
    vocab = build_vocab_from_iterator(yield_tokens(train_iter), 
                                    max_words=config.TRAIN_CONFIG['max_vocab_size'],
                                    specials=['<unk>', '<pad>'])
    vocab.set_default_index(vocab['<unk>'])
    
    # Text processing pipeline
    text_pipeline = lambda x: vocab(tokenizer(x))
    label_pipeline = lambda x: int(x) - 0.5
    
    def collate_batch(batch):
        text_list, label_list, lengths = [], [], []
        for (_label, _text) in batch:
            processed_text = torch.tensor(text_pipeline(_text)[:config.TRAIN_CONFIG['max_seq_length']], 
                              dtype=torch.int64)
            text_list.append(processed_text)
            label_list.append(label_pipeline(_label))
            lengths.append(len(processed_text))
            
        padded_text = nn.utils.rnn.pad_sequence(text_list, padding_value=vocab['<pad>'])
        labels = torch.tensor(label_list, dtype=torch.float32)
        lengths = torch.tensor(lengths, dtype=torch.int64)
        
        return padded_text.T, labels, lengths
    
    return train_iter, test_iter, vocab, collate_batch

In [None]:
class LSTMClassifier(nn.Module):
    """Modular LSTM for sequence classification"""
    
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int,
                 output_dim: int, n_layers: int, dropout: float, pad_idx: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,
                           dropout=dropout, bidirectional=False, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> torch.Tensor:
        embedded = self.dropout(self.embedding(text))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, text_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (hidden, cell) = self.lstm(packed_embedded)
        hidden = self.dropout(hidden[-1, :, :])
        return self.fc(hidden)

In [None]:
class LSTMTrainer:
    """Training and evaluation module for LSTM"""
    
    def __init__(self, model, train_loader, test_loader, config):
        self.model = model.to(config['device'])
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.config = config
        self.optimizer = config['optimizer'](model.parameters(), lr=config['learning_rate'])
        self.loss_fn = config['loss_fn']
        
    def train_epoch(self):
        self.model.train()
        total_loss, total_acc = 0, 0
        
        for text, labels, lengths in tqdm(self.train_loader, desc="Training"):
            text, lengths = text.to(self.config['device']), lengths.to(self.config['device'])
            labels = labels.to(self.config['device'])
            
            self.optimizer.zero_grad()
            predictions = self.model(text, lengths).squeeze(1)
            loss = self.loss_fn(predictions, labels)
            acc = ((torch.sigmoid(predictions) > 0.5).float() == (labels > 0)).float().mean()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
            self.optimizer.step()
            
            total_loss += loss.item()
            total_acc += acc.item()
            
        return total_loss / len(self.train_loader), total_acc / len(self.train_loader)
    
    def evaluate(self):
        self.model.eval()
        total_loss, total_acc = 0, 0
        
        with torch.no_grad():
            for text, labels, lengths in self.test_loader:
                text, lengths = text.to(self.config['device']), lengths.to(self.config['device'])
                labels = labels.to(self.config['device'])
                
                predictions = self.model(text, lengths).squeeze(1)
                loss = self.loss_fn(predictions, labels)
                acc = ((torch.sigmoid(predictions) > 0.5).float() == (labels > 0)).float().mean()
                
                total_loss += loss.item()
                total_acc += acc.item()
                
        return total_loss / len(self.test_loader), total_acc / len(self.test_loader)
    
    def save_model(self, path='lstm_classifier.pth'):
        torch.save(self.model.state_dict(), path)

In [None]:
def main():
    # Get datasets and vocab
    train_iter, test_iter, vocab, collate_fn = get_datasets()
    
    # Create data loaders
    train_loader = DataLoader(list(train_iter), batch_size=TRAIN_CONFIG['batch_size'],
                             shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(list(test_iter), batch_size=TRAIN_CONFIG['batch_size'],
                            shuffle=False, collate_fn=collate_fn)
    
    # Initialize model
    model = LSTMClassifier(
        vocab_size=len(vocab),
        pad_idx=vocab['<pad>'],
        **LSTM_CONFIG
    )
    
    # Initialize trainer
    trainer = LSTMTrainer(model, train_loader, test_loader, TRAIN_CONFIG)
    
    # Training loop
    for epoch in range(TRAIN_CONFIG['epochs']):
        train_loss, train_acc = trainer.train_epoch()
        test_loss, test_acc = trainer.evaluate()
        
        print(f"Epoch {epoch+1}/{TRAIN_CONFIG['epochs']}")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
        print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}%")
    
    trainer.save_model()

main()