Skip to content

burning-phoenix/GoEmotions

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 

Repository files navigation

GoEmotions Multi-Label Emotion Classification

A deep learning project that combines DistilBERT with Convolutional Neural Networks (CNN) for multi-label emotion classification using the GoEmotions dataset.

Overview

This project implements a hybrid neural network architecture that leverages the contextual understanding of DistilBERT with the feature extraction capabilities of CNNs to classify text into multiple emotion categories simultaneously. The model can detect up to 28 different emotions in a single piece of text.

Dataset

The project uses the GoEmotions dataset, which contains:

  • ~58,000 carefully curated Reddit comments
  • 28 emotion categories including neutral, joy, sadness, anger, fear, surprise, and more
  • Multi-label annotations (each text can have multiple emotions)

Model Architecture

DistilBERT-CNN Hybrid

  • Base Model: DistilBERT (distilbert-base-uncased)
  • CNN Layer: 1D convolution with 128 filters and kernel size of 5
  • Pooling: Global max pooling
  • Output: Fully connected layer with sigmoid activation for multi-label classification

Key Features

  • Partial Fine-tuning: Only the last 3 layers of DistilBERT are fine-tuned
  • Multi-label Support: Uses BCEWithLogitsLoss for handling multiple emotions per text
  • Dropout Regularization: 42% dropout rate to prevent overfitting

Requirements

pip install torch transformers pandas numpy scikit-learn matplotlib seaborn kagglehub

Usage

Training the Model

  1. Download and preprocess the dataset:
import kagglehub
path = kagglehub.dataset_download("debarshichanda/goemotions")
# Data preprocessing and label encoding happens automatically
  1. Train the model:
# Model trains for 5 epochs with cosine annealing learning rate scheduler
# Training metrics are displayed for each epoch

Using the Trained Model

# Load the saved model
model = DistilBertCNN(num_labels=28)
model.load_state_dict(torch.load("GoEmotionsDistilBertCNN.pth"))
model.eval()

# Make predictions
query = "I'm so excited about this new opportunity!"
encoded_input = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
    output = model(encoded_input["input_ids"], encoded_input["attention_mask"])
    probabilities = torch.sigmoid(output)

# Get top 5 emotions
top_values, top_indices = torch.topk(probabilities[0], k=5)
for idx, value in zip(top_indices, top_values):
    emotion = emotions[idx.item()]
    confidence = value.item() * 100
    print(f"{emotion}: {confidence:.2f}%")

Model Performance

The trained model achieves the following metrics on the development set:

  • F1 Score: 0.4576 (macro-average)
  • Precision: 0.6113 (macro-average)
  • Recall: 0.3933 (macro-average)
  • AUC: 0.6908 (macro-average)
  • MCC: 0.5676 (Matthews Correlation Coefficient)

Training Configuration

  • Batch Size: 32
  • Learning Rate: 7e-5 (AdamW optimizer)
  • Max Sequence Length: 128 tokens
  • Epochs: 5
  • Loss Function: BCEWithLogitsLoss
  • Scheduler: CosineAnnealingLR
  • Prediction Threshold: 0.51

File Structure

├── GoEmotionsDistilBertCNN.pth          # Trained model weights
├── distil_bert_tokenizer_dir/           # Saved tokenizer files
│   ├── tokenizer_config.json
│   ├── vocab.txt
│   └── ...
├── main.py                              # Main training and evaluation script
└── README.md                            # This file

Key Components

Data Processing

  • Converts comma-separated emotion labels to multi-hot encoded vectors
  • Tokenizes text using DistilBERT tokenizer with padding and truncation
  • Creates PyTorch DataLoaders for efficient batch processing

Model Architecture (DistilBertCNN)

  • Inherits from nn.Module
  • Combines DistilBERT embeddings with 1D CNN for feature extraction
  • Uses global max pooling and dropout for regularization

Evaluation Metrics

  • Comprehensive evaluation including F1, Precision, Recall, AUC, and MCC
  • Multi-label confusion matrices for each emotion category
  • Visualization of model performance across all emotion classes

Future Improvements

  • Experiment with different CNN architectures and filter sizes
  • Implement class weighting to handle label imbalance
  • Add attention mechanisms between BERT and CNN layers
  • Explore ensemble methods with multiple model architectures
  • Fine-tune the prediction threshold for better precision-recall balance

Citation

If you use this code or the GoEmotions dataset, please cite:

@inproceedings{demszky2020goemotions,
  title={GoEmotions: A Dataset of Fine-Grained Emotions},
  author={Demszky, Dorottya and Movshovitz-Attias, Dana and Ko, Jeongwoo and Cowen, Alan and Nemade, Gaurav and Ravi, Sujith},
  booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
  year={2020}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors