<a href="https://colab.research.google.com/github/mehdifa1372/imdb-text-classification-/blob/main/imdb%20text%20classification%20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
IMDB Sentiment Analysis with Transformer Models
Author: Mehdi Faraz (mehdifaraz1372@gmail.com)
Description: A comprehensive tool for training and using transformer-based models
            for sentiment classification on the IMDB dataset.
"""
!pip install datasets
import numpy as np
import argparse
import logging
from pathlib import Path
from datetime import datetime
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

class SentimentAnalyzer:
    """Handles training and prediction for sentiment analysis models."""

    def __init__(self, model_name="bert-base-uncased", model_path=None):
        """
        Initialize the sentiment analyzer.

        Args:
            model_name (str): Pre-trained model identifier to use if training a new model
            model_path (str): Path to a saved model for inference
        """
        self.model_name = model_name
        self.model_path = model_path
        self.model = None
        self.tokenizer = None

    def compute_metrics(self, pred):
        """
        Compute evaluation metrics for model performance.

        Args:
            pred: Prediction object from Hugging Face Trainer

        Returns:
            dict: Dictionary containing accuracy, F1, precision and recall metrics
        """
        # Extract true labels from prediction object
        labels = pred.label_ids
        # Get predicted labels by taking argmax of logits
        preds = pred.predictions.argmax(-1)
        # Compute precision, recall, f1-score with binary averaging
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
        # Compute overall accuracy
        acc = accuracy_score(labels, preds)

        return {
            'accuracy': acc,
            'f1': f1,
            'precision': precision,
            'recall': recall
        }

    def tokenize_function(self, examples, max_length=512):
        """
        Tokenize text examples for model input.

        Args:
            examples (dict): Dictionary containing text examples
            max_length (int): Maximum sequence length for tokenization

        Returns:
            dict: Tokenized examples
        """
        return self.tokenizer(
            examples["text"],
            padding="max_length",  # Pad all sequences to the same length
            truncation=True,       # Truncate sequences that exceed max_length
            max_length=max_length  # Maximum sequence length
        )

    def train(self, output_dir="./model", epochs=3, batch_size=8, learning_rate=2e-5):
        """
        Train a sentiment analysis model on the IMDB dataset.

        Args:
            output_dir (str): Directory to save the trained model
            epochs (int): Number of training epochs
            batch_size (int): Batch size for training and evaluation
            learning_rate (float): Learning rate for optimizer

        Returns:
            dict: Evaluation results
        """
        logger.info(f"Loading IMDB dataset")
        # Load the IMDB dataset (train and test splits)
        dataset = load_dataset("imdb")

        logger.info(f"Initializing tokenizer: {self.model_name}")
        # Initialize the tokenizer from the specified pre-trained model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        logger.info("Tokenizing dataset")
        # Tokenize the dataset and remove original text columns
        tokenized_datasets = dataset.map(
            self.tokenize_function,
            batched=True,
            remove_columns=dataset["train"].column_names
        )

        logger.info(f"Loading pre-trained model: {self.model_name}")
        # Load pre-trained model configured for sequence classification
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=2  # Binary classification (positive/negative)
        )

        # Create output directory with timestamp for versioning
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_output_dir = Path(output_dir) / f"run_{timestamp}"

        logger.info(f"Configuring training parameters")
        # Define training arguments with best practices
        training_args = TrainingArguments(
            output_dir=str(model_output_dir),
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=epochs,
            weight_decay=0.01,             # L2 regularization to prevent overfitting
            evaluation_strategy="epoch",   # Evaluate after each epoch
            save_strategy="epoch",         # Save checkpoint after each epoch
            load_best_model_at_end=True,   # Load best model after training
            metric_for_best_model="f1",    # Use F1 score to determine best model
            greater_is_better=True,        # Higher F1 score is better
            push_to_hub=False,             # Do not push to Hugging Face Hub
            report_to="none",              # Disable external logging
            logging_dir=str(model_output_dir / 'logs'),
            logging_strategy="steps",
            logging_steps=500,
            fp16=True,                     # Use mixed precision training if available
        )

        logger.info("Initializing trainer")
        # Initialize the Trainer with model, data, and training parameters
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_datasets["train"],
            eval_dataset=tokenized_datasets["test"],
            compute_metrics=self.compute_metrics,
        )

        logger.info("Starting model training")
        # Train the model
        trainer.train()

        logger.info("Evaluating model performance")
        # Evaluate the trained model
        eval_results = trainer.evaluate()
        logger.info("Evaluation Results:")
        logger.info(f"Accuracy: {eval_results['eval_accuracy']:.4f}")
        logger.info(f"F1 Score: {eval_results['eval_f1']:.4f}")
        logger.info(f"Precision: {eval_results['eval_precision']:.4f}")
        logger.info(f"Recall: {eval_results['eval_recall']:.4f}")

        # Save the final model and tokenizer
        final_model_path = Path(output_dir) / "final_model"
        logger.info(f"Saving model to {final_model_path}")
        trainer.save_model(str(final_model_path))
        self.tokenizer.save_pretrained(str(final_model_path))
        self.model_path = str(final_model_path)

        return eval_results

    def load_model(self, model_path=None):
        """
        Load a pre-trained sentiment analysis model.

        Args:
            model_path (str): Path to the saved model directory
        """
        # Use provided path or fall back to instance variable
        path = model_path or self.model_path
        if not path:
            raise ValueError("Model path must be provided")

        logger.info(f"Loading model from {path}")
        # Load the saved model and tokenizer
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        logger.info("Model loaded successfully")

    def predict_sentiment(self, text):
        """
        Predict sentiment for given text.

        Args:
            text (str): Text input for sentiment analysis

        Returns:
            str: "Positive" or "Negative" sentiment prediction
        """
        # Ensure model is loaded
        if self.model is None or self.tokenizer is None:
            self.load_model()

        # Tokenize the input text
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)

        # Run inference
        outputs = self.model(**inputs)
        prediction = outputs.logits.argmax(-1).item()

        return "Positive" if prediction == 1 else "Negative"


def main():
    """Main entry point for CLI operation."""
    # Set up argument parser
    parser = argparse.ArgumentParser(description='IMDB Sentiment Analysis')
    subparsers = parser.add_subparsers(dest='command', help='Command to run')

    # Training command
    train_parser = subparsers.add_parser('train', help='Train a new model')
    train_parser.add_argument('--model_name', type=str, default="bert-base-uncased",
                             help='Pre-trained model name')
    train_parser.add_argument('--output_dir', type=str, default="./model",
                             help='Directory to save the model')
    train_parser.add_argument('--epochs', type=int, default=3,
                             help='Number of training epochs')
    train_parser.add_argument('--batch_size', type=int, default=8,
                             help='Training batch size')
    train_parser.add_argument('--learning_rate', type=float, default=2e-5,
                             help='Learning rate')

    # Prediction command
    predict_parser = subparsers.add_parser('predict', help='Predict sentiment for text')
    predict_parser.add_argument('--model_path', type=str, required=True,
                               help='Path to trained model')
    predict_parser.add_argument('--text', type=str, required=True,
                               help='Text to analyze')

    # Parse arguments
    args = parser.parse_args()

    if args.command == 'train':
        # Train a new model
        analyzer = SentimentAnalyzer(model_name=args.model_name)
        analyzer.train(
            output_dir=args.output_dir,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate
        )

    elif args.command == 'predict':
        # Run prediction on text
        analyzer = SentimentAnalyzer(model_path=args.model_path)
        result = analyzer.predict_sentiment(args.text)
        print(f"Text: {args.text}")
        print(f"Sentiment: {result}")

    else:
        parser.print_help()


if __name__ == "__main__":
    main()

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading 

usage: colab_kernel_launcher.py [-h] {train,predict} ...
colab_kernel_launcher.py: error: argument command: invalid choice: '/root/.local/share/jupyter/runtime/kernel-a2274b92-d1b4-4349-bb49-a7b9159f63c9.json' (choose from 'train', 'predict')
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/lib/python3.11/argparse.py", line 1919, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/argparse.py", line 2143, in _parse_known_args
    stop_index = consume_positionals(start_index)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/argparse.py", line 2099, in consume_positionals
    take_action(action, args)
  File "/usr/lib/python3.11/argparse.py", line 1979, in take_action
    argument_values = self._get_values(action, argument_strings)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/argparse.py", line 2528, in _get_values
    self._check_value(action, value[0])
  File "/usr/lib/python3.11/argparse.py", line 2575, in _check_value
    raise ArgumentError(action, msg % args)
argparse.ArgumentError: argument command: invalid choice: '/root/.loc

TypeError: object of type 'NoneType' has no len()